spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark
Date Thu, 07 Jan 2016 03:44:38 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 d10b9d572 -> 5e86c0cce


[SPARK-12016] [MLLIB] [PYSPARK] Wrap Word2VecModel when loading it in pyspark

JIRA: https://issues.apache.org/jira/browse/SPARK-12016

We should not directly use Word2VecModel in pyspark. We need to wrap it in a Word2VecModelWrapper
when loading it in pyspark.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #10100 from viirya/fix-load-py-wordvecmodel.

(cherry picked from commit b51a4cdff3a7e640a8a66f7a9c17021f3056fd34)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e86c0cc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e86c0cc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e86c0cc

Branch: refs/heads/branch-1.5
Commit: 5e86c0cce8486b156af8b8c9f0d0719d7d870548
Parents: d10b9d5
Author: Liang-Chi Hsieh <viirya@appier.com>
Authored: Mon Dec 14 09:59:42 2015 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Jan 6 19:44:34 2016 -0800

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 33 -----------
 .../mllib/api/python/Word2VecModelWrapper.scala | 62 ++++++++++++++++++++
 python/pyspark/mllib/feature.py                 |  6 +-
 3 files changed, 67 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5e86c0cc/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 2f8b5e5..649d2d4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -638,39 +638,6 @@ private[python] class PythonMLLibAPI extends Serializable {
     }
   }
 
-  private[python] class Word2VecModelWrapper(model: Word2VecModel) {
-    def transform(word: String): Vector = {
-      model.transform(word)
-    }
-
-    /**
-     * Transforms an RDD of words to its vector representation
-     * @param rdd an RDD of words
-     * @return an RDD of vector representations of words
-     */
-    def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
-      rdd.rdd.map(model.transform)
-    }
-
-    def findSynonyms(word: String, num: Int): JList[Object] = {
-      val vec = transform(word)
-      findSynonyms(vec, num)
-    }
-
-    def findSynonyms(vector: Vector, num: Int): JList[Object] = {
-      val result = model.findSynonyms(vector, num)
-      val similarity = Vectors.dense(result.map(_._2))
-      val words = result.map(_._1)
-      List(words, similarity).map(_.asInstanceOf[Object]).asJava
-    }
-
-    def getVectors: JMap[String, JList[Float]] = {
-      model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
-    }
-
-    def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
-  }
-
   /**
    * Java stub for Python mllib DecisionTree.train().
    * This stub returns a handle to the Java object instead of the content of the Java object.

http://git-wip-us.apache.org/repos/asf/spark/blob/5e86c0cc/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
new file mode 100644
index 0000000..0f55980
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.feature.Word2VecModel
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+/**
+  * Wrapper around Word2VecModel to provide helper methods in Python
+  */
+private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+  def transform(word: String): Vector = {
+    model.transform(word)
+  }
+
+  /**
+   * Transforms an RDD of words to its vector representation
+   * @param rdd an RDD of words
+   * @return an RDD of vector representations of words
+   */
+  def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
+    rdd.rdd.map(model.transform)
+  }
+
+  def findSynonyms(word: String, num: Int): JList[Object] = {
+    val vec = transform(word)
+    findSynonyms(vec, num)
+  }
+
+  def findSynonyms(vector: Vector, num: Int): JList[Object] = {
+    val result = model.findSynonyms(vector, num)
+    val similarity = Vectors.dense(result.map(_._2))
+    val words = result.map(_._1)
+    List(words, similarity).map(_.asInstanceOf[Object]).asJava
+  }
+
+  def getVectors: JMap[String, JList[Float]] = {
+    model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
+  }
+
+  def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5e86c0cc/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index f921e3a..ba04240 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -460,7 +460,8 @@ class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
     def load(cls, sc, path):
         jmodel = sc._jvm.org.apache.spark.mllib.feature \
             .Word2VecModel.load(sc._jsc.sc(), path)
-        return Word2VecModel(jmodel)
+        model = sc._jvm.Word2VecModelWrapper(jmodel)
+        return Word2VecModel(model)
 
 
 @ignore_unicode_prefix
@@ -502,6 +503,9 @@ class Word2Vec(object):
     >>> sameModel = Word2VecModel.load(sc, path)
     >>> model.transform("a") == sameModel.transform("a")
     True
+    >>> syms = sameModel.findSynonyms("a", 2)
+    >>> [s[0] for s in syms]
+    [u'b', u'c']
     >>> from shutil import rmtree
     >>> try:
     ...     rmtree(path)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message