spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mln...@apache.org
Subject spark git commit: [SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set
Date Fri, 26 Jan 2018 10:23:20 GMT
Repository: spark
Updated Branches:
  refs/heads/master d1721816d -> cd3956df0


[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are
both set

## What changes were proposed in this pull request?

Currently there is a mixed situation when both single- and multi-column are supported. In
some cases exceptions are thrown, in others only a warning log is emitted. In this discussion
https://issues.apache.org/jira/browse/SPARK-8418?focusedCommentId=16275049&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-16275049,
the decision was to throw an exception.

The PR throws an exception in `Bucketizer`, instead of logging a warning.

## How was this patch tested?

modified UT

Author: Marco Gaido <marcogaido91@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #19993 from mgaido91/SPARK-22799.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cd3956df
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cd3956df
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cd3956df

Branch: refs/heads/master
Commit: cd3956df0f96dd416b6161bf7ce2962e06d0a62e
Parents: d172181
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Fri Jan 26 12:23:14 2018 +0200
Committer: Nick Pentreath <nickp@za.ibm.com>
Committed: Fri Jan 26 12:23:14 2018 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/feature/Bucketizer.scala    | 44 ++++++-------
 .../org/apache/spark/ml/param/params.scala      | 69 ++++++++++++++++++++
 .../spark/ml/feature/BucketizerSuite.scala      | 41 ++++++------
 .../org/apache/spark/ml/param/ParamsSuite.scala | 22 +++++++
 4 files changed, 131 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 8299a3e..c13bf47 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
 
 /**
- * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since
2.3.0,
+ * `Bucketizer` maps a column of continuous features to a column of feature buckets.
+ *
+ * Since 2.3.0,
  * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note
that
- * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed
and
- * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter
is
- * only used for single column usage, and `splitsArray` is for multiple columns.
+ * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown.
The
+ * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
+ * columns.
  */
 @Since("1.4.0")
 final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val
uid: String
   @Since("2.3.0")
   def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
 
-  /**
-   * Determines whether this `Bucketizer` is going to map multiple columns. If and only if
-   * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
-   * by `inputCol`. A warning will be printed if both are set.
-   */
-  private[feature] def isBucketizeMultipleColumns(): Boolean = {
-    if (isSet(inputCols) && isSet(inputCol)) {
-      logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this
" +
-        "`Bucketizer` only map one column specified by `inputCol`")
-      false
-    } else if (isSet(inputCols)) {
-      true
-    } else {
-      false
-    }
-  }
-
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema)
 
-    val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
+    val (inputColumns, outputColumns) = if (isSet(inputCols)) {
       ($(inputCols).toSeq, $(outputCols).toSeq)
     } else {
       (Seq($(inputCol)), Seq($(outputCol)))
@@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid:
String
       }
     }
 
-    val seqOfSplits = if (isBucketizeMultipleColumns()) {
+    val seqOfSplits = if (isSet(inputCols)) {
       $(splitsArray).toSeq
     } else {
       Seq($(splits))
@@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val
uid: String
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    if (isBucketizeMultipleColumns()) {
+    ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
+      Seq(outputCols, splitsArray))
+
+    if (isSet(inputCols)) {
+      require(getInputCols.length == getOutputCols.length &&
+        getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params
" +
+        s"for multi-column transform.  Params (inputCols, outputCols, splitsArray) should
have " +
+        s"equal lengths, but they have different lengths: " +
+        s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")
+
       var transformedSchema = schema
-      $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx)
=>
+      $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol),
idx) =>
         SchemaUtils.checkNumericType(transformedSchema, inputCol)
         transformedSchema = SchemaUtils.appendColumn(transformedSchema,
           prepOutputField($(splitsArray)(idx), outputCol))

http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 1b4b401..9a83a58 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -249,6 +249,75 @@ object ParamValidators {
   def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T])
=>
     value.length > lowerBound
   }
+
+  /**
+   * Utility for Param validity checks for Transformers which have both single- and multi-column
+   * support.  This utility assumes that `inputCol` indicates single-column usage and
+   * that `inputCols` indicates multi-column usage.
+   *
+   * This checks to ensure that exactly one set of Params has been set, and it
+   * raises an `IllegalArgumentException` if not.
+   *
+   * @param singleColumnParams Params which should be set (or have defaults) if `inputCol`
has been
+   *                           set.  This does not need to include `inputCol`.
+   * @param multiColumnParams Params which should be set (or have defaults) if `inputCols`
has been
+   *                           set.  This does not need to include `inputCols`.
+   */
+  def checkSingleVsMultiColumnParams(
+      model: Params,
+      singleColumnParams: Seq[Param[_]],
+      multiColumnParams: Seq[Param[_]]): Unit = {
+    val name = s"${model.getClass.getSimpleName} $model"
+
+    def checkExclusiveParams(
+        isSingleCol: Boolean,
+        requiredParams: Seq[Param[_]],
+        excludedParams: Seq[Param[_]]): Unit = {
+      val badParamsMsgBuilder = new mutable.StringBuilder()
+
+      val mustUnsetParams = excludedParams.filter(p => model.isSet(p))
+        .map(_.name).mkString(", ")
+      if (mustUnsetParams.nonEmpty) {
+        badParamsMsgBuilder ++=
+          s"The following Params are not applicable and should not be set: $mustUnsetParams."
+      }
+
+      val mustSetParams = requiredParams.filter(p => !model.isDefined(p))
+        .map(_.name).mkString(", ")
+      if (mustSetParams.nonEmpty) {
+        badParamsMsgBuilder ++=
+          s"The following Params must be defined but are not set: $mustSetParams."
+      }
+
+      val badParamsMsg = badParamsMsgBuilder.toString()
+
+      if (badParamsMsg.nonEmpty) {
+        val errPrefix = if (isSingleCol) {
+          s"$name has the inputCol Param set for single-column transform."
+        } else {
+          s"$name has the inputCols Param set for multi-column transform."
+        }
+        throw new IllegalArgumentException(s"$errPrefix $badParamsMsg")
+      }
+    }
+
+    val inputCol = model.getParam("inputCol")
+    val inputCols = model.getParam("inputCols")
+
+    if (model.isSet(inputCol)) {
+      require(!model.isSet(inputCols), s"$name requires " +
+        s"exactly one of inputCol, inputCols Params to be set, but both are set.")
+
+      checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams,
+        excludedParams = multiColumnParams)
+    } else if (model.isSet(inputCols)) {
+      checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams,
+        excludedParams = singleColumnParams)
+    } else {
+      throw new IllegalArgumentException(s"$name requires " +
+        s"exactly one of inputCol, inputCols Params to be set, but neither is set.")
+    }
+  }
 }
 
 // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int,
...

http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index d9c97ae..7403680 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setOutputCols(Array("result1", "result2"))
       .setSplitsArray(splits)
 
-    assert(bucketizer1.isBucketizeMultipleColumns())
-
     bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
     BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
       Seq("result1", "result2"),
@@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setOutputCols(Array("result"))
       .setSplitsArray(Array(splits(0)))
 
-    assert(bucketizer2.isBucketizeMultipleColumns())
-
     withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
       intercept[SparkException] {
         bucketizer2.transform(badDF1).collect()
@@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setOutputCols(Array("result1", "result2"))
       .setSplitsArray(splits)
 
-    assert(bucketizer.isBucketizeMultipleColumns())
-
     BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
       Seq("result1", "result2"),
       Seq("expected1", "expected2"))
@@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setOutputCols(Array("result1", "result2"))
       .setSplitsArray(splits)
 
-    assert(bucketizer.isBucketizeMultipleColumns())
-
     bucketizer.setHandleInvalid("keep")
     BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
       Seq("result1", "result2"),
@@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setInputCols(Array("myInputCol"))
       .setOutputCols(Array("myOutputCol"))
       .setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
-    assert(t.isBucketizeMultipleColumns())
     testDefaultReadWrite(t)
   }
 
@@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setOutputCols(Array("result1", "result2"))
       .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
 
-    assert(bucket.isBucketizeMultipleColumns())
-
     val pl = new Pipeline()
       .setStages(Array(bucket))
       .fit(df)
@@ -401,15 +390,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
     }
   }
 
-  test("Both inputCol and inputCols are set") {
-    val bucket = new Bucketizer()
-      .setInputCol("feature1")
-      .setOutputCol("result")
-      .setSplits(Array(-0.5, 0.0, 0.5))
-      .setInputCols(Array("feature1", "feature2"))
-
-    // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
-    assert(bucket.isBucketizeMultipleColumns() == false)
+  test("assert exception is thrown if both multi-column and single-column params are set")
{
+    val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
+    ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+      ("inputCols", Array("feature1", "feature2")))
+    ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+      ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)),
+      ("outputCols", Array("result1", "result2")))
+    ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+      ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)),
+      ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))))
+
+    // this should fail because at least one of inputCol and inputCols must be set
+    ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"),
+      ("splits", Array(-0.5, 0.0, 0.5)))
+
+    // the following should fail because not all the params are set
+    ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"),
+      ("outputCol", "result1"))
+    ParamsSuite.testExclusiveParams(new Bucketizer, df,
+      ("inputCols", Array("feature1", "feature2")),
+      ("outputCols", Array("result1", "result2")))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cd3956df/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 85198ad..36e0609 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -20,8 +20,10 @@ package org.apache.spark.ml.param
 import java.io.{ByteArrayOutputStream, ObjectOutputStream}
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Transformer}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.util.MyParams
+import org.apache.spark.sql.Dataset
 
 class ParamsSuite extends SparkFunSuite {
 
@@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite {
     require(copyReturnType === obj.getClass,
       s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.")
   }
+
+  /**
+   * Checks that the class throws an exception in case multiple exclusive params are set.
+   * The params to be checked are passed as arguments with their value.
+   */
+  def testExclusiveParams(
+      model: Params,
+      dataset: Dataset[_],
+      paramsAndValues: (String, Any)*): Unit = {
+    val m = model.copy(ParamMap.empty)
+    paramsAndValues.foreach { case (paramName, paramValue) =>
+      m.set(m.getParam(paramName), paramValue)
+    }
+    intercept[IllegalArgumentException] {
+      m match {
+        case t: Transformer => t.transform(dataset)
+        case e: Estimator[_] => e.fit(dataset)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message