spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [ML][MINOR] Separate estimator and model params for read/write test.
Date Wed, 08 Mar 2017 10:05:12 GMT
Repository: spark
Updated Branches:
  refs/heads/master 314e48a35 -> 1fa58868b


[ML][MINOR] Separate estimator and model params for read/write test.

## What changes were proposed in this pull request?
Since we allow ```Estimator``` and ```Model``` not always share same params (see ```ALSParams```
and ```ALSModelParams```), we should pass in test params for estimator and model separately
in function ```testEstimatorAndModelReadWrite```.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #17151 from yanboliang/test-rw.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1fa58868
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1fa58868
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1fa58868

Branch: refs/heads/master
Commit: 1fa58868bc6635ff2119264665bd3d00b4b1253a
Parents: 314e48a
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Wed Mar 8 02:05:01 2017 -0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Wed Mar 8 02:05:01 2017 -0800

----------------------------------------------------------------------
 .../DecisionTreeClassifierSuite.scala           |  8 +++--
 .../ml/classification/GBTClassifierSuite.scala  |  3 +-
 .../ml/classification/LinearSVCSuite.scala      |  2 +-
 .../LogisticRegressionSuite.scala               |  2 +-
 .../ml/classification/NaiveBayesSuite.scala     |  3 +-
 .../RandomForestClassifierSuite.scala           |  3 +-
 .../ml/clustering/BisectingKMeansSuite.scala    |  4 +--
 .../ml/clustering/GaussianMixtureSuite.scala    |  2 +-
 .../spark/ml/clustering/KMeansSuite.scala       |  3 +-
 .../apache/spark/ml/clustering/LDASuite.scala   |  4 ++-
 .../BucketedRandomProjectionLSHSuite.scala      |  2 +-
 .../spark/ml/feature/ChiSqSelectorSuite.scala   |  3 +-
 .../spark/ml/feature/MinHashLSHSuite.scala      |  2 +-
 .../org/apache/spark/ml/fpm/FPGrowthSuite.scala |  4 +--
 .../spark/ml/recommendation/ALSSuite.scala      | 35 +++++++-------------
 .../regression/AFTSurvivalRegressionSuite.scala |  3 +-
 .../regression/DecisionTreeRegressorSuite.scala |  5 +--
 .../spark/ml/regression/GBTRegressorSuite.scala |  3 +-
 .../GeneralizedLinearRegressionSuite.scala      |  1 +
 .../ml/regression/IsotonicRegressionSuite.scala |  2 +-
 .../ml/regression/LinearRegressionSuite.scala   |  2 +-
 .../regression/RandomForestRegressorSuite.scala |  3 +-
 .../spark/ml/util/DefaultReadWriteTest.scala    | 14 ++++----
 23 files changed, 59 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7f..10de503 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
     // Categorical splits with tree depth 2
     val categoricalData: DataFrame =
       TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
-    testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
+      allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 2
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
-    testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 0
     testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth"
-> 0),
-      checkModelData)
+      allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943..0cddb37 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
 
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
-    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index fe47176..4c63a2a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext
with Defau
     }
     val svm = new LinearSVC()
     testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
-      checkModelData)
+      LinearSVCSuite.allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index d89a958..affaa57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
     }
     val lr = new LogisticRegression()
     testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
-      checkModelData)
+      LogisticRegressionSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types")
{

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991..4d5d299 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       assert(model.theta === model2.theta)
     }
     val nb = new NaiveBayes()
-    testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
+      NaiveBayesSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types")
{

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585..c3003ce 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -218,7 +218,8 @@ class RandomForestClassifierSuite
 
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
-    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 30513c1..200a892 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -138,8 +138,8 @@ class BisectingKMeansSuite
       assert(model.clusterCenters === model2.clusterCenters)
     }
     val bisectingKMeans = new BisectingKMeans()
-    testEstimatorAndModelReadWrite(
-      bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
+      BisectingKMeansSuite.allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index c500c5b..61da897 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
       assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
     }
     val gm = new GaussianMixture()
-    testEstimatorAndModelReadWrite(gm, dataset,
+    testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
       GaussianMixtureSuite.allParamSettings, checkModelData)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e10127f..ca05b9c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with
DefaultR
       assert(model.clusterCenters === model2.clusterCenters)
     }
     val kmeans = new KMeans()
-    testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
+      KMeansSuite.allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 9aa11fb..75aa0be 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
         Vectors.dense(model2.getDocConcentration) absTol 1e-6)
     }
     val lda = new LDA()
-    testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
+      LDASuite.allParamSettings, checkModelData)
   }
 
   test("read/write DistributedLDAModel") {
@@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
     }
     val lda = new LDA()
     testEstimatorAndModelReadWrite(lda, dataset,
+      LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
       LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index ab93768..91eac9e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite
     }
     val mh = new BucketedRandomProjectionLSH()
     val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength"
-> 1.0)
-    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+    testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
   }
 
   test("hashFunction") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 482e5d5..d6925da 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
       assert(model.selectedFeatures === model2.selectedFeatures)
     }
     val nb = new ChiSqSelector
-    testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings,
+      ChiSqSelectorSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and not support other types") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 3461cdf..a2f0093 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with
Defa
     }
     val mh = new MinHashLSH()
     val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
-    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+    testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
   }
 
   test("hashFunction") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 74c7461..076d55c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
         model2.freqItemsets.sort("items").collect())
     }
     val fPGrowth = new FPGrowth()
-    testEstimatorAndModelReadWrite(
-      fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
+      FPGrowthSuite.allParamSettings, checkModelData)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index e494ea8..a177ed1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -518,37 +518,26 @@ class ALSSuite
   }
 
   test("read/write") {
-    import ALSSuite._
-    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
-    val als = new ALS()
-    allEstimatorParamSettings.foreach { case (p, v) =>
-      als.set(als.getParam(p), v)
-    }
     val spark = this.spark
     import spark.implicits._
-    val model = als.fit(ratings.toDF())
-
-    // Test Estimator save/load
-    val als2 = testDefaultReadWrite(als)
-    allEstimatorParamSettings.foreach { case (p, v) =>
-      val param = als.getParam(p)
-      assert(als.get(param).get === als2.get(param).get)
-    }
+    import ALSSuite._
+    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
 
-    // Test Model save/load
-    val model2 = testDefaultReadWrite(model)
-    allModelParamSettings.foreach { case (p, v) =>
-      val param = model.getParam(p)
-      assert(model.get(param).get === model2.get(param).get)
-    }
-    assert(model.rank === model2.rank)
     def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
       df.select("id", "features").collect().map { case r =>
         (r.getInt(0), r.getAs[Array[Float]](1))
       }.toSet
     }
-    assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
-    assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+
+    def checkModelData(model: ALSModel, model2: ALSModel): Unit = {
+      assert(model.rank === model2.rank)
+      assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
+      assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+    }
+
+    val als = new ALS()
+    testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings,
+      allModelParamSettings, checkModelData)
   }
 
   test("input type validation") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 3cd4b0a..708185a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite
     }
     val aft = new AFTSurvivalRegression()
     testEstimatorAndModelReadWrite(aft, datasetMultivariate,
-      AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
+      AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings,
+      checkModelData)
   }
 
   test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e..0e91284 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite
     val categoricalData: DataFrame =
       TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
     testEstimatorAndModelReadWrite(dt, categoricalData,
-      TreeTests.allParamSettings, checkModelData)
+      TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 2
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
     testEstimatorAndModelReadWrite(dt, continuousData,
-      TreeTests.allParamSettings, checkModelData)
+      TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 0
     testEstimatorAndModelReadWrite(dt, continuousData,
+      TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
       TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a..03c2f97 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
     val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
-    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index add28a7..4019117 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite
 
     val glr = new GeneralizedLinearRegression()
     testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+      GeneralizedLinearRegressionSuite.allParamSettings,
       GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 8cbb2ac..f41a360 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -178,7 +178,7 @@ class IsotonicRegressionSuite
 
     val ir = new IsotonicRegression()
     testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
-      checkModelData)
+      IsotonicRegressionSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types")
{

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b2..6a51e75 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -985,7 +985,7 @@ class LinearRegressionSuite
     }
     val lr = new LinearRegression()
     testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
-      checkModelData)
+      LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types")
{

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f..3bf0445 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
 
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
-    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1fa58868/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 553b872..bfe8f12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
    *  - Check Params on Estimator and Model
    *  - Compare model data
    *
-   * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+   * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s.
    *
    * @param estimator  Estimator to test
    * @param dataset  Dataset to pass to [[Estimator.fit()]]
-   * @param testParams  Set of [[Param]] values to set in estimator
+   * @param testEstimatorParams  Set of [[Param]] values to set in estimator
+   * @param testModelParams Set of [[Param]] values to set in model
    * @param checkModelData  Method which takes the original and loaded [[Model]] and compares
their
    *                        data.  This method does not need to check [[Param]] values.
    * @tparam E  Type of [[Estimator]]
@@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
     E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
       estimator: E,
       dataset: Dataset[_],
-      testParams: Map[String, Any],
+      testEstimatorParams: Map[String, Any],
+      testModelParams: Map[String, Any],
       checkModelData: (M, M) => Unit): Unit = {
     // Set some Params to make sure set Params are serialized.
-    testParams.foreach { case (p, v) =>
+    testEstimatorParams.foreach { case (p, v) =>
       estimator.set(estimator.getParam(p), v)
     }
     val model = estimator.fit(dataset)
 
     // Test Estimator save/load
     val estimator2 = testDefaultReadWrite(estimator)
-    testParams.foreach { case (p, v) =>
+    testEstimatorParams.foreach { case (p, v) =>
       val param = estimator.getParam(p)
       assert(estimator.get(param).get === estimator2.get(param).get)
     }
 
     // Test Model save/load
     val model2 = testDefaultReadWrite(model)
-    testParams.foreach { case (p, v) =>
+    testModelParams.foreach { case (p, v) =>
       val param = model.getParam(p)
       assert(model.get(param).get === model2.get(param).get)
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message