spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns
Date Fri, 06 Apr 2018 02:56:51 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4807d381b -> f2ac08795


[SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid
values in non-string columns

## What changes were proposed in this pull request?

`handleInvalid` Param was forwarded to the VectorAssembler used by RFormula.

## How was this patch tested?

added a test and ran all tests for RFormula and VectorAssembler

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>

Closes #20970 from yogeshg/spark_23562.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f2ac0879
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f2ac0879
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f2ac0879

Branch: refs/heads/master
Commit: f2ac0879561cde63ed4eb759f5efa0a5ce393a22
Parents: 4807d38
Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Authored: Thu Apr 5 19:55:42 2018 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Apr 5 19:55:42 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/RFormula.scala  |  1 +
 .../apache/spark/ml/feature/RFormulaSuite.scala | 23 ++++++++++++++++++++
 2 files changed, 24 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f2ac0879/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 22e7b8b..e214765 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
     encoderStages += new VectorAssembler(uid)
       .setInputCols(encodedTerms.toArray)
       .setOutputCol($(featuresCol))
+      .setHandleInvalid($(handleInvalid))
     encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
     encoderStages += new ColumnPruner(tempColumns.toSet)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f2ac0879/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 27d570f..a250331 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.feature
 
+import org.apache.spark.SparkException
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
@@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
         assert(features.toArray === a +: b.toArray)
     }
   }
+
+  test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.")
{
+    val d1 = Seq(
+      (1001L, "a"),
+      (1002L, "b")).toDF("id1", "c1")
+    val d2 = Seq[(java.lang.Long, String)](
+      (20001L, "x"),
+      (20002L, "y"),
+      (null, null)).toDF("id2", "c2")
+    val dataset = d1.crossJoin(d2)
+
+    def get_output(mode: String): DataFrame = {
+      val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode)
+      formula.fit(dataset).transform(dataset).select("features", "label")
+    }
+
+    assert(intercept[SparkException](get_output("error").collect())
+      .getMessage.contains("Encountered null while assembling a row"))
+    assert(get_output("skip").count() == 4)
+    assert(get_output("keep").count() == 6)
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message