spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-10299][ML] word2vec should allow users to specify the window size
Date Wed, 09 Dec 2015 16:45:20 GMT
Repository: spark
Updated Branches:
  refs/heads/master 6e1c55eac -> 22b9a8740


[SPARK-10299][ML] word2vec should allow users to specify the window size

Currently word2vec has the window hard coded at 5, some users may want different sizes (for
example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size
.

Author: Holden Karau <holden@us.ibm.com>
Author: Holden Karau <holden@pigscanfly.ca>

Closes #8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/22b9a874
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/22b9a874
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/22b9a874

Branch: refs/heads/master
Commit: 22b9a8740d51289434553d19b6b1ac34aecdc09a
Parents: 6e1c55e
Author: Holden Karau <holden@us.ibm.com>
Authored: Wed Dec 9 16:45:13 2015 +0000
Committer: Sean Owen <sowen@cloudera.com>
Committed: Wed Dec 9 16:45:13 2015 +0000

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 15 +++++++
 .../apache/spark/mllib/feature/Word2Vec.scala   | 11 ++++-
 .../apache/spark/ml/feature/Word2VecSuite.scala | 43 ++++++++++++++++++--
 3 files changed, 65 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/22b9a874/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index a8d61b6..f105a98 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -50,6 +50,17 @@ private[feature] trait Word2VecBase extends Params
   def getVectorSize: Int = $(vectorSize)
 
   /**
+   * The window size (context words from [-window, window]) default 5.
+   * @group expertParam
+   */
+  final val windowSize = new IntParam(
+    this, "windowSize", "the window size (context words from [-window, window])")
+  setDefault(windowSize -> 5)
+
+  /** @group expertGetParam */
+  def getWindowSize: Int = $(windowSize)
+
+  /**
    * Number of partitions for sentences of words.
    * Default: 1
    * @group param
@@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
   /** @group setParam */
   def setVectorSize(value: Int): this.type = set(vectorSize, value)
 
+  /** @group expertSetParam */
+  def setWindowSize(value: Int): this.type = set(windowSize, value)
+
   /** @group setParam */
   def setStepSize(value: Double): this.type = set(stepSize, value)
 
@@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
       .setNumPartitions($(numPartitions))
       .setSeed($(seed))
       .setVectorSize($(vectorSize))
+      .setWindowSize($(windowSize))
       .fit(input)
     copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/22b9a874/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 23b1514..1f400e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -126,6 +126,15 @@ class Word2Vec extends Serializable with Logging {
   }
 
   /**
+   * Sets the window of words (default: 5)
+   */
+  @Since("1.6.0")
+  def setWindowSize(window: Int): this.type = {
+    this.window = window
+    this
+  }
+
+  /**
    * Sets minCount, the minimum number of times a token must appear to be included in the
word2vec
    * model's vocabulary (default: 5).
    */
@@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging {
   private val MAX_SENTENCE_LENGTH = 1000
 
   /** context words from [-window, window] */
-  private val window = 5
+  private var window = 5
 
   private var trainWordsCount = 0
   private var vocabSize = 0

http://git-wip-us.apache.org/repos/asf/spark/blob/22b9a874/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a773244..d561bbb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
   }
 
   test("Word2Vec") {
-    val sqlContext = new SQLContext(sc)
+
+    val sqlContext = this.sqlContext
     import sqlContext.implicits._
 
     val sentence = "a b " * 100 + "a c " * 10
@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
 
   test("getVectors") {
 
-    val sqlContext = new SQLContext(sc)
+    val sqlContext = this.sqlContext
     import sqlContext.implicits._
 
     val sentence = "a b " * 100 + "a c " * 10
@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
 
   test("findSynonyms") {
 
-    val sqlContext = new SQLContext(sc)
+    val sqlContext = this.sqlContext
     import sqlContext.implicits._
 
     val sentence = "a b " * 100 + "a c " * 10
@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext
with Defaul
     expectedSimilarity.zip(similarity).map {
       case (expected, actual) => assert(math.abs((expected - actual) / expected) <
1E-5)
     }
+  }
+
+  test("window size") {
+
+    val sqlContext = this.sqlContext
+    import sqlContext.implicits._
+
+    val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
+    val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+    val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+    val model = new Word2Vec()
+      .setVectorSize(3)
+      .setWindowSize(2)
+      .setInputCol("text")
+      .setOutputCol("result")
+      .setSeed(42L)
+      .fit(docDF)
 
+    val (synonyms, similarity) = model.findSynonyms("a", 6).map {
+      case Row(w: String, sim: Double) => (w, sim)
+    }.collect().unzip
+
+    // Increase the window size
+    val biggerModel = new Word2Vec()
+      .setVectorSize(3)
+      .setInputCol("text")
+      .setOutputCol("result")
+      .setSeed(42L)
+      .setWindowSize(10)
+      .fit(docDF)
+
+    val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map {
+      case Row(w: String, sim: Double) => (w, sim)
+    }.collect().unzip
+    // The similarity score should be very different with the larger window
+    assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
   }
 
   test("Word2Vec read/write") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message