spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication
Date Tue, 02 Feb 2016 19:16:49 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 4c28b4c8f -> 9c0cf22f7


[SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication

Fixes problem and verifies fix by test suite.
Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn
and deduplicates SchemaUtils.appendColumn functions.

Author: Grzegorz Chilkiewicz <grzegorz.chilkiewicz@codilime.com>

Closes #10741 from grzegorz-chilkiewicz/master.

(cherry picked from commit b1835d727234fdff42aa8cadd17ddcf43b0bed15)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9c0cf22f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9c0cf22f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9c0cf22f

Branch: refs/heads/branch-1.6
Commit: 9c0cf22f7681ae05d894ae05f6a91a9467787519
Parents: 4c28b4c
Author: Grzegorz Chilkiewicz <grzegorz.chilkiewicz@codilime.com>
Authored: Tue Feb 2 11:16:24 2016 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Tue Feb 2 11:16:44 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StopWordsRemover.scala   |  4 +---
 .../scala/org/apache/spark/ml/util/SchemaUtils.scala |  8 +++-----
 .../spark/ml/feature/StopWordsRemoverSuite.scala     | 15 +++++++++++++++
 3 files changed, 19 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9c0cf22f/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 3188085..d9a9049 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -148,9 +148,7 @@ class StopWordsRemover(override val uid: String)
     val inputType = schema($(inputCol)).dataType
     require(inputType.sameType(ArrayType(StringType)),
       s"Input type must be ArrayType(StringType) but got $inputType.")
-    val outputFields = schema.fields :+
-      StructField($(outputCol), inputType, schema($(inputCol)).nullable)
-    StructType(outputFields)
+    SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
   }
 
   override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/9c0cf22f/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76f6514..7decbbd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -54,12 +54,10 @@ private[spark] object SchemaUtils {
   def appendColumn(
       schema: StructType,
       colName: String,
-      dataType: DataType): StructType = {
+      dataType: DataType,
+      nullable: Boolean = false): StructType = {
     if (colName.isEmpty) return schema
-    val fieldNames = schema.fieldNames
-    require(!fieldNames.contains(colName), s"Column $colName already exists.")
-    val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
-    StructType(outputFields)
+    appendColumn(schema, StructField(colName, dataType, nullable))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9c0cf22f/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index fb217e0..a5b24c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -89,4 +89,19 @@ class StopWordsRemoverSuite
       .setCaseSensitive(true)
     testDefaultReadWrite(t)
   }
+
+  test("StopWordsRemover output column already exists") {
+    val outputCol = "expected"
+    val remover = new StopWordsRemover()
+      .setInputCol("raw")
+      .setOutputCol(outputCol)
+    val dataSet = sqlContext.createDataFrame(Seq(
+      (Seq("The", "the", "swift"), Seq("swift"))
+    )).toDF("raw", outputCol)
+
+    val thrown = intercept[IllegalArgumentException] {
+      testStopWordsRemover(remover, dataSet)
+    }
+    assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message