spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec
Date Thu, 02 Jul 2015 22:55:22 GMT
Repository: spark
Updated Branches:
  refs/heads/master fc7aebd94 -> 488bad319


[SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #6821 from yu-iskw/SPARK-7104 and squashes the following commits:

975136b [Yu ISHIKAWA] Organize import
0ef58b6 [Yu ISHIKAWA] Use rmtree, instead of removedirs
cb21653 [Yu ISHIKAWA] Add an explicit type for `Word2VecModelWrapper.save`
1d468ef [Yu ISHIKAWA] [SPARK-7104][MLlib] Support model save/load in Python's Word2Vec


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/488bad31
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/488bad31
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/488bad31

Branch: refs/heads/master
Commit: 488bad319a70975733e83c83490240a70beb0c90
Parents: fc7aebd
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>
Authored: Thu Jul 2 15:55:16 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Jul 2 15:55:16 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala |  3 +++
 python/pyspark/mllib/feature.py                 | 21 +++++++++++++++++++-
 2 files changed, 23 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/488bad31/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 458fab4..e628059 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -28,6 +28,7 @@ import scala.reflect.ClassTag
 
 import net.razorvine.pickle._
 
+import org.apache.spark.SparkContext
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.mllib.classification._
@@ -641,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable {
     def getVectors: JMap[String, JList[Float]] = {
       model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
     }
+
+    def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/488bad31/python/pyspark/mllib/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index b513877..f921e3a 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -36,6 +36,7 @@ from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
 from pyspark.mllib.linalg import (
     Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
 from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.util import JavaLoader, JavaSaveable
 
 __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
            'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
@@ -416,7 +417,7 @@ class IDF(object):
         return IDFModel(jmodel)
 
 
-class Word2VecModel(JavaVectorTransformer):
+class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
     """
     class for Word2Vec model
     """
@@ -455,6 +456,12 @@ class Word2VecModel(JavaVectorTransformer):
         """
         return self.call("getVectors")
 
+    @classmethod
+    def load(cls, sc, path):
+        jmodel = sc._jvm.org.apache.spark.mllib.feature \
+            .Word2VecModel.load(sc._jsc.sc(), path)
+        return Word2VecModel(jmodel)
+
 
 @ignore_unicode_prefix
 class Word2Vec(object):
@@ -488,6 +495,18 @@ class Word2Vec(object):
     >>> syms = model.findSynonyms(vec, 2)
     >>> [s[0] for s in syms]
     [u'b', u'c']
+
+    >>> import os, tempfile
+    >>> path = tempfile.mkdtemp()
+    >>> model.save(sc, path)
+    >>> sameModel = Word2VecModel.load(sc, path)
+    >>> model.transform("a") == sameModel.transform("a")
+    True
+    >>> from shutil import rmtree
+    >>> try:
+    ...     rmtree(path)
+    ... except OSError:
+    ...     pass
     """
     def __init__(self):
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message