spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-7047] [ML] ml.Model optional parent support
Date Tue, 19 May 2015 17:55:35 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.4 8567d29ef -> 24cb323e7


[SPARK-7047] [ML] ml.Model optional parent support

Made Model.parent transient.  Added Model.hasParent to test for null parent

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5914 from jkbradley/parent-optional and squashes the following commits:

d501774 [Joseph K. Bradley] Made Model.parent transient.  Added Model.hasParent to test for
null parent

(cherry picked from commit fb90273212dc7241c9a0c3446e25e0e0b9377750)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/24cb323e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/24cb323e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/24cb323e

Branch: refs/heads/branch-1.4
Commit: 24cb323e767a342496cf24e0d06398b5af38ac80
Parents: 8567d29
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Tue May 19 10:55:21 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue May 19 10:55:32 2015 -0700

----------------------------------------------------------------------
 mllib/src/main/scala/org/apache/spark/ml/Model.scala            | 5 ++++-
 .../spark/ml/classification/LogisticRegressionSuite.scala       | 1 +
 .../spark/ml/classification/RandomForestClassifierSuite.scala   | 2 ++
 3 files changed, 7 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/24cb323e/mllib/src/main/scala/org/apache/spark/ml/Model.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index 7fd5153..70e7495 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer {
    * The parent estimator that produced this model.
    * Note: For ensembles' component Models, this value can be null.
    */
-  var parent: Estimator[M] = _
+  @transient var parent: Estimator[M] = _
 
   /**
    * Sets the parent of this model (Java API).
@@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer {
     this.asInstanceOf[M]
   }
 
+  /** Indicates whether this [[Model]] has a corresponding parent. */
+  def hasParent: Boolean = parent != null
+
   override def copy(extra: ParamMap): M = {
     // The default implementation of Params.copy doesn't work for models.
     throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")

http://git-wip-us.apache.org/repos/asf/spark/blob/24cb323e/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 4376524..97f9749 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext
{
     assert(model.getRawPredictionCol === "rawPrediction")
     assert(model.getProbabilityCol === "probability")
     assert(model.intercept !== 0.0)
+    assert(model.hasParent)
   }
 
   test("logistic regression doesn't fit intercept when fitIntercept is off") {

http://git-wip-us.apache.org/repos/asf/spark/blob/24cb323e/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 08f86fa..cdbbaca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -162,5 +162,7 @@ private object RandomForestClassifierSuite {
     val oldModelAsNew = RandomForestClassificationModel.fromOld(
       oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
+    assert(newModel.hasParent)
+    assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message