spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-10023] [ML] [PySpark] Unified DecisionTreeParams checkpointInterval between Scala and Python API.
Date Fri, 11 Sep 2015 03:34:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0eabea8a0 -> 339a52714


[SPARK-10023] [ML] [PySpark] Unified DecisionTreeParams checkpointInterval between Scala and
Python API.

"checkpointInterval" is member of DecisionTreeParams in Scala API which is inconsistency with
Python API, we should unified them.
```
member of DecisionTreeParams <-> Scala API
shared param for all ML Transformer/Estimator <-> Python API
```
Proposal:
"checkpointInterval" is also used by ALS, so we make it shared params at Scala.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8528 from yanboliang/spark-10023.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/339a5271
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/339a5271
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/339a5271

Branch: refs/heads/master
Commit: 339a527141984bfb182862b0987d3c4690c9ede1
Parents: 0eabea8
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Thu Sep 10 20:34:00 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu Sep 10 20:34:00 2015 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |  1 +
 .../ml/param/shared/SharedParamsCodeGen.scala   |  3 +-
 .../spark/ml/param/shared/sharedParams.scala    |  4 +--
 .../org/apache/spark/ml/tree/treeParams.scala   | 32 +++++++-------------
 4 files changed, 16 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/339a5271/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 6f70b96..0a75d5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasCheckpointInterval
 import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
 import org.apache.spark.ml.tree.impl.RandomForest
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}

http://git-wip-us.apache.org/repos/asf/spark/blob/339a5271/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8c16c61..e9e99ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -56,7 +56,8 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
-      ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
+      ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means
that " +
+        "the cache will get checkpointed every 10 iterations.",
         isValid = "ParamValidators.gtEq(1)"),
       ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
       ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip
(which " +

http://git-wip-us.apache.org/repos/asf/spark/blob/339a5271/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index c267689..3009217 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params {
 private[ml] trait HasCheckpointInterval extends Params {
 
   /**
-   * Param for checkpoint interval (>= 1).
+   * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed
every 10 iterations..
    * @group param
    */
-  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint
interval (>= 1)", ParamValidators.gtEq(1))
+  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint
interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.",
ParamValidators.gtEq(1))
 
   /** @group getParam */
   final def getCheckpointInterval: Int = $(checkpointInterval)

http://git-wip-us.apache.org/repos/asf/spark/blob/339a5271/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index dbd8d31..d29f525 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree
 import org.apache.spark.ml.classification.ClassifierParams
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds}
+import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy =>
OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini,
Impurity => OldImpurity, Variance => OldVariance}
 import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
  *
  * Note: Marked as private and DeveloperApi since this may be made public in the future.
  */
-private[ml] trait DecisionTreeParams extends PredictorParams {
+private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval {
 
   /**
    * Maximum depth of the tree (>= 0).
@@ -96,21 +96,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
     " algorithm will cache node IDs for each instance. Caching can speed up training of deeper"
+
     " trees.")
 
-  /**
-   * Specifies how often to checkpoint the cached node IDs.
-   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
-   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
-   * [[org.apache.spark.SparkContext]].
-   * Must be >= 1.
-   * (default = 10)
-   * @group expertParam
-   */
-  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies"
+
-    " how often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get"
+
-    " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if
the" +
-    " checkpoint directory is set in the SparkContext. Must be >= 1.",
-    ParamValidators.gtEq(1))
-
   setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain
-> 0.0,
     maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
 
@@ -150,12 +135,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   /** @group expertGetParam */
   final def getCacheNodeIds: Boolean = $(cacheNodeIds)
 
-  /** @group expertSetParam */
+  /**
+   * Specifies how often to checkpoint the cached node IDs.
+   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+   * [[org.apache.spark.SparkContext]].
+   * Must be >= 1.
+   * (default = 10)
+   * @group expertSetParam
+   */
   def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
 
-  /** @group expertGetParam */
-  final def getCheckpointInterval: Int = $(checkpointInterval)
-
   /** (private[ml]) Create a Strategy instance to use with the old API. */
   private[ml] def getOldStrategy(
       categoricalFeatures: Map[Int, Int],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message