spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-13761][ML] Remove remaining uses of validateParams
Date Thu, 17 Mar 2016 20:23:12 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4c08e2c08 -> b39e80d39


[SPARK-13761][ML] Remove remaining uses of validateParams

## What changes were proposed in this pull request?

Cleanups from [https://github.com/apache/spark/pull/11620]: remove remaining uses of validateParams,
and put functionality into transformSchema

## How was this patch tested?

Existing unit tests, modified to check using transformSchema instead of validateParams

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #11790 from jkbradley/SPARK-13761-cleanup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b39e80d3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b39e80d3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b39e80d3

Branch: refs/heads/master
Commit: b39e80d39dae8e6779f9d78c1631a27585239032
Parents: 4c08e2c
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Thu Mar 17 13:23:07 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Mar 17 13:23:07 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidator.scala | 20 +-------------------
 .../spark/ml/tuning/TrainValidationSplit.scala  | 20 +-------------------
 .../spark/ml/tuning/ValidatorParams.scala       | 14 ++++++++++++++
 .../org/apache/spark/ml/param/ParamsSuite.scala |  5 -----
 .../org/apache/spark/ml/param/TestParams.scala  |  5 -----
 .../spark/ml/tuning/CrossValidatorSuite.scala   | 13 ++++++-------
 .../ml/tuning/TrainValidationSplitSuite.scala   | 11 +++++------
 7 files changed, 27 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 3d7a91d..963f81c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid:
String)
   }
 
   @Since("1.4.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): CrossValidator = {
@@ -332,11 +320,6 @@ class CrossValidatorModel private[ml] (
   extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
 
   @Since("1.4.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
-  @Since("1.4.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
@@ -344,7 +327,6 @@ class CrossValidatorModel private[ml] (
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4587e25..70fa5f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override
val uid: St
   }
 
   @Since("1.5.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): TrainValidationSplit = {
@@ -161,11 +149,6 @@ class TrainValidationSplitModel private[ml] (
   extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
 
   @Since("1.5.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
-  @Since("1.5.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (
 
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 553f254..c004644 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.Estimator
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.sql.types.StructType
 
 /**
  * :: DeveloperApi ::
@@ -31,6 +32,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the estimator to be validated
+   *
    * @group param
    */
   val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
@@ -40,6 +42,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for estimator param maps
+   *
    * @group param
    */
   val estimatorParamMaps: Param[Array[ParamMap]] =
@@ -50,6 +53,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the evaluator used to select hyper-parameters that maximize the validated
metric
+   *
    * @group param
    */
   val evaluator: Param[Evaluator] = new Param(this, "evaluator",
@@ -57,4 +61,14 @@ private[ml] trait ValidatorParams extends Params {
 
   /** @group getParam */
   def getEvaluator: Evaluator = $(evaluator)
+
+  protected def transformSchemaImpl(schema: StructType): StructType = {
+    require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
+    val firstEstimatorParamMap = $(estimatorParamMaps).head
+    val est = $(estimator)
+    for (paramMap <- $(estimatorParamMaps).tail) {
+      est.copy(paramMap).transformSchema(schema)
+    }
+    est.copy(firstEstimatorParamMap).transformSchema(schema)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 7488685..a3366c0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
       solver.getParam("abc")
     }
 
-    intercept[IllegalArgumentException] {
-      solver.validateParams()
-    }
-    solver.copy(ParamMap(inputCol -> "input")).validateParams()
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
     assert(solver.isDefined(inputCol))
     assert(solver.getInputCol === "input")
-    solver.validateParams()
     intercept[IllegalArgumentException] {
       ParamMap(maxIter -> -10)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 9d23547..7d990ce 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid
 
   def clearMaxIter(): this.type = clear(maxIter)
 
-  override def validateParams(): Unit = {
-    super.validateParams()
-    require(isDefined(inputCol))
-  }
-
   override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 56545de..7af3c6d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
 
 class CrossValidatorSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -96,7 +96,7 @@ class CrossValidatorSuite
     assert(cvModel2.avgMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import CrossValidatorSuite.{MyEstimator, MyEvaluator}
 
     val est = new MyEstimator("est")
@@ -110,12 +110,12 @@ class CrossValidatorSuite
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
 
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 
@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol
{
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/b39e80d3/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 5fb8009..cf8dcef 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(cvModel2.validationMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import TrainValidationSplitSuite._
 
     val est = new MyEstimator("est")
@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 }
@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol
{
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message