spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array
Date Sat, 24 Sep 2016 07:07:04 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 9e91a1009 -> ed545763a


[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null
string array

## What changes were proposed in this pull request?

To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type
in NGram

## How was this patch tested?

Jenkins tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #15179 from srowen/SPARK-10835.

(cherry picked from commit f3fe55439e4c865c26502487a1bccf255da33f4a)
Signed-off-by: Sean Owen <sowen@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ed545763
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ed545763
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ed545763

Branch: refs/heads/branch-2.0
Commit: ed545763adc3f50569581c9b017b396e8997ac31
Parents: 9e91a10
Author: Sean Owen <sowen@cloudera.com>
Authored: Sat Sep 24 08:06:41 2016 +0100
Committer: Sean Owen <sowen@cloudera.com>
Committed: Sat Sep 24 08:06:56 2016 +0100

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  |  3 ++-
 .../apache/spark/ml/feature/Word2VecSuite.scala | 21 ++++++++++++++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ed545763/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 14c0512..d53f3df 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params
    * Validate and transform the input schema.
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+    val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType,
false))
+    SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ed545763/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 16c74f6..c8f1311 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext
with Defaul
     val newInstance = testDefaultReadWrite(instance)
     assert(newInstance.getVectors.collect() === instance.getVectors.collect())
   }
+
+  test("Word2Vec works with input that is non-nullable (NGram)") {
+    val spark = this.spark
+    import spark.implicits._
+
+    val sentence = "a q s t q s t b b b s t m s t m q "
+    val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text")
+
+    val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams")
+    val ngramDF = ngram.transform(docDF)
+
+    val model = new Word2Vec()
+      .setVectorSize(2)
+      .setInputCol("ngrams")
+      .setOutputCol("result")
+      .fit(ngramDF)
+
+    // Just test that this transformation succeeds
+    model.transform(ngramDF).collect()
+  }
+
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message