spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mln...@apache.org
Subject spark git commit: [SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label
Date Fri, 13 May 2016 07:08:20 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5b849766a -> 31f1aebbe


[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for
label

## What changes were proposed in this pull request?

Made ChiSqSelector and RFormula accept all numeric types for label

## How was this patch tested?

Unit tests

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #12467 from BenFradet/SPARK-13961.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31f1aebb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31f1aebb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31f1aebb

Branch: refs/heads/master
Commit: 31f1aebbeb77b4eb1080f22c9bece7fafd8022f8
Parents: 5b84976
Author: BenFradet <benjamin.fradet@gmail.com>
Authored: Fri May 13 09:08:04 2016 +0200
Committer: Nick Pentreath <nick.pentreath@gmail.com>
Committed: Fri May 13 09:08:04 2016 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/feature/ChiSqSelector.scala |  4 +--
 .../org/apache/spark/ml/feature/RFormula.scala  |  4 +--
 .../DecisionTreeClassifierSuite.scala           |  2 +-
 .../ml/classification/GBTClassifierSuite.scala  |  2 +-
 .../LogisticRegressionSuite.scala               |  2 +-
 .../MultilayerPerceptronClassifierSuite.scala   |  2 +-
 .../ml/classification/NaiveBayesSuite.scala     |  2 +-
 .../ml/classification/OneVsRestSuite.scala      |  2 +-
 .../RandomForestClassifierSuite.scala           |  2 +-
 .../spark/ml/feature/ChiSqSelectorSuite.scala   | 10 ++++++-
 .../apache/spark/ml/feature/RFormulaSuite.scala | 30 ++++++++++++++++----
 .../regression/AFTSurvivalRegressionSuite.scala |  2 +-
 .../regression/DecisionTreeRegressorSuite.scala |  2 +-
 .../spark/ml/regression/GBTRegressorSuite.scala |  2 +-
 .../GeneralizedLinearRegressionSuite.scala      |  2 +-
 .../ml/regression/IsotonicRegressionSuite.scala |  2 +-
 .../ml/regression/LinearRegressionSuite.scala   |  2 +-
 .../regression/RandomForestRegressorSuite.scala |  2 +-
 .../apache/spark/ml/util/MLTestingUtils.scala   |  4 +--
 19 files changed, 53 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index cfecae7..29f55a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -80,7 +80,7 @@ final class ChiSqSelector(override val uid: String)
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
     transformSchema(dataset.schema, logging = true)
-    val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
+    val input = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map
{
       case Row(label: Double, features: Vector) =>
         LabeledPoint(label, features)
     }
@@ -90,7 +90,7 @@ final class ChiSqSelector(override val uid: String)
 
   override def transformSchema(schema: StructType): StructType = {
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
-    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    SchemaUtils.checkNumericType(schema, $(labelCol))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 5219680..a2f3d44 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -256,8 +256,8 @@ class RFormulaModel private[feature](
     val columnNames = schema.map(_.name)
     require(!columnNames.contains($(featuresCol)), "Features column already exists.")
     require(
-      !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
-      "Label column already exists and is not of type DoubleType.")
+      !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
+      "Label column already exists and is not of type NumericType.")
   }
 
   @Since("2.0.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index f94d336..91a947f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -337,7 +337,7 @@ class DecisionTreeClassifierSuite
   test("should support all NumericType labels and not support other types") {
     val dt = new DecisionTreeClassifier().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
-      dt, isClassification = true, spark) { (expected, actual) =>
+      dt, spark) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index c9453aa..5a5e5c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
   test("should support all NumericType labels and not support other types") {
     val gbt = new GBTClassifier().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
-      gbt, isClassification = true, spark) { (expected, actual) =>
+      gbt, spark) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index cb4d087..f127aa2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -938,7 +938,7 @@ class LogisticRegressionSuite
   test("should support all NumericType labels and not support other types") {
     val lr = new LogisticRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
-      lr, isClassification = true, spark) { (expected, actual) =>
+      lr, spark) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients.toArray === actual.coefficients.toArray)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 876e047..d5282e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
     val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
     MLTestingUtils.checkNumericTypes[
         MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
-      mpc, isClassification = true, spark) { (expected, actual) =>
+      mpc, spark) { (expected, actual) =>
         assert(expected.layers === actual.layers)
         assert(expected.weights === actual.weights)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 15d0358..2a05c44 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
   test("should support all NumericType labels and not support other types") {
     val nb = new NaiveBayes()
     MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
-      nb, isClassification = true, spark) { (expected, actual) =>
+      nb, spark) { (expected, actual) =>
         assert(expected.pi === actual.pi)
         assert(expected.theta === actual.theta)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 005d609..5044d40 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext
with Defau
   test("should support all NumericType labels and not support other types") {
     val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
     MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
-      ovr, isClassification = true, spark) { (expected, actual) =>
+      ovr, spark) { (expected, actual) =>
         val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
         val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
         assert(expectedModels.length === actualModels.length)

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 97f3fea..8002a2f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -189,7 +189,7 @@ class RandomForestClassifierSuite
   test("should support all NumericType labels and not support other types") {
     val rf = new RandomForestClassifier().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
-      rf, isClassification = true, spark) { (expected, actual) =>
+      rf, spark) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 4c6d9c5..4fcc974 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -81,4 +81,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
     val newInstance = testDefaultReadWrite(instance)
     assert(newInstance.selectedFeatures === instance.selectedFeatures)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val css = new ChiSqSelector()
+    MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector](
+      css, spark) { (expected, actual) =>
+        assert(expected.selectedFeatures === actual.selectedFeatures)
+      }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 46e7495..c623a62 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.DoubleType
 
 class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
   test("params") {
@@ -68,9 +68,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
     assert(resultSchema.toString == model.transform(original).schema.toString)
   }
 
-  test("label column already exists but is not double type") {
+  test("label column already exists but is not numeric type") {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
-    val original = spark.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+    val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y")
     val model = formula.fit(original)
     intercept[IllegalArgumentException] {
       model.transformSchema(original.schema)
@@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
     ).toDF("id", "a", "b")
     val model = formula.fit(original)
     val result = model.transform(original)
-    val resultSchema = model.transformSchema(original.schema)
     val expected = spark.createDataFrame(
       Seq(
         ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
@@ -188,7 +187,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
       "vec2",
       Array[Attribute](
         NumericAttribute.defaultAttr,
-        NumericAttribute.defaultAttr)).toMetadata
+        NumericAttribute.defaultAttr)).toMetadata()
     val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
     val model = formula.fit(original)
     val result = model.transform(original)
@@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext
with Defaul
     val newModel = testDefaultReadWrite(model)
     checkModelData(model, newModel)
   }
+
+  test("should support all NumericType labels") {
+    val formula = new RFormula().setFormula("label ~ features")
+      .setLabelCol("x")
+      .setFeaturesCol("y")
+    val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark)
+    val expected = formula.fit(dfs(DoubleType))
+    val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t)))
+    actuals.foreach { actual =>
+      assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length)
+      expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach {
+        case (exTransformer, acTransformer) =>
+          assert(exTransformer.params === acTransformer.params)
+      }
+      assert(expected.resolvedFormula.label === actual.resolvedFormula.label)
+      assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms)
+      assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index f8fc775..e4772df 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite
   test("should support all NumericType labels") {
     val aft = new AFTSurvivalRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
-      aft, isClassification = false, spark) { (expected, actual) =>
+      aft, spark, isClassification = false) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index d9f26ad..2d30cbf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite
   test("should support all NumericType labels and not support other types") {
     val dt = new DecisionTreeRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
-      dt, isClassification = false, spark) { (expected, actual) =>
+      dt, spark, isClassification = false) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index f6ea5bb..ac833b8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
   test("should support all NumericType labels and not support other types") {
     val gbt = new GBTRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
-      gbt, isClassification = false, spark) { (expected, actual) =>
+      gbt, spark, isClassification = false) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 161f8c8..3d9aeb8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1021,7 +1021,7 @@ class GeneralizedLinearRegressionSuite
     val glr = new GeneralizedLinearRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[
         GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
-      glr, isClassification = false, spark) { (expected, actual) =>
+      glr, spark, isClassification = false) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 9bf7542..bed4978 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -184,7 +184,7 @@ class IsotonicRegressionSuite
   test("should support all NumericType labels and not support other types") {
     val ir = new IsotonicRegression()
     MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
-      ir, isClassification = false, spark) { (expected, actual) =>
+      ir, spark, isClassification = false) { (expected, actual) =>
         assert(expected.boundaries === actual.boundaries)
         assert(expected.predictions === actual.predictions)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 10f547b..a98227d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -1010,7 +1010,7 @@ class LinearRegressionSuite
   test("should support all NumericType labels and not support other types") {
     val lr = new LinearRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
-      lr, isClassification = false, spark) { (expected, actual) =>
+      lr, spark, isClassification = false) { (expected, actual) =>
         assert(expected.intercept === actual.intercept)
         assert(expected.coefficients === actual.coefficients)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 72f3c65..7a3a369 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
   test("should support all NumericType labels and not support other types") {
     val rf = new RandomForestRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
-      rf, isClassification = false, spark) { (expected, actual) =>
+      rf, spark, isClassification = false) { (expected, actual) =>
         TreeTests.checkEqual(expected, actual)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/31f1aebb/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 4fe473b..ad7d2c9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -37,8 +37,8 @@ object MLTestingUtils extends SparkFunSuite {
 
   def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
       estimator: T,
-      isClassification: Boolean,
-      spark: SparkSession)(check: (M, M) => Unit): Unit = {
+      spark: SparkSession,
+      isClassification: Boolean = true)(check: (M, M) => Unit): Unit = {
     val dfs = if (isClassification) {
       genClassifDFWithNumericLabelCol(spark)
     } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message