spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib
Date Tue, 03 Mar 2015 06:27:16 GMT
Repository: spark
Updated Branches:
  refs/heads/master 54d19689f -> 7e53a79c3


[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib

Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree
models in Python.

jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #4854 from mengxr/SPARK-6097 and squashes the following commits:

4586a4d [Xiangrui Meng] fix more typos
8ebcac2 [Xiangrui Meng] fix python style
91172d8 [Xiangrui Meng] fix typos
201b3b9 [Xiangrui Meng] update user guide
b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e53a79c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e53a79c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e53a79c

Branch: refs/heads/master
Commit: 7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b
Parents: 54d1968
Author: Xiangrui Meng <meng@databricks.com>
Authored: Mon Mar 2 22:27:01 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Mar 2 22:27:01 2015 -0800

----------------------------------------------------------------------
 docs/mllib-decision-tree.md            | 16 ++++++++-----
 docs/mllib-ensembles.md                | 32 +++++++++++++++----------
 python/pyspark/mllib/recommendation.py |  9 +++----
 python/pyspark/mllib/tests.py          | 27 ++++++++++++++++++++-
 python/pyspark/mllib/tree.py           | 21 ++++++++++++----
 python/pyspark/mllib/util.py           | 37 +++++++++++++++++++++++++----
 6 files changed, 109 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e53a79c/docs/mllib-decision-tree.md
----------------------------------------------------------------------
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 8e478ab..c1d0f8a 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -293,11 +293,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
 
 <div data-lang="python">
 
-Note that the Python API does not yet support model save/load but will in the future.
-
 {% highlight python %}
 from pyspark.mllib.regression import LabeledPoint
-from pyspark.mllib.tree import DecisionTree
+from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
 from pyspark.mllib.util import MLUtils
 
 # Load and parse the data file into an RDD of LabeledPoint.
@@ -317,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()
/ float(tes
 print('Test Error = ' + str(testErr))
 print('Learned classification tree model:')
 print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = DecisionTreeModel.load(sc, "myModelPath")
 {% endhighlight %}
 </div>
 
@@ -440,11 +442,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
 
 <div data-lang="python">
 
-Note that the Python API does not yet support model save/load but will in the future.
-
 {% highlight python %}
 from pyspark.mllib.regression import LabeledPoint
-from pyspark.mllib.tree import DecisionTree
+from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
 from pyspark.mllib.util import MLUtils
 
 # Load and parse the data file into an RDD of LabeledPoint.
@@ -464,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()
/ flo
 print('Test Mean Squared Error = ' + str(testMSE))
 print('Learned regression tree model:')
 print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = DecisionTreeModel.load(sc, "myModelPath")
 {% endhighlight %}
 </div>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e53a79c/docs/mllib-ensembles.md
----------------------------------------------------------------------
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index ec1ef38..cbfb682 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -202,10 +202,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
 
 <div data-lang="python">
 
-Note that the Python API does not yet support model save/load but will in the future.
-
 {% highlight python %}
-from pyspark.mllib.tree import RandomForest
+from pyspark.mllib.tree import RandomForest, RandomForestModel
 from pyspark.mllib.util import MLUtils
 
 # Load and parse the data file into an RDD of LabeledPoint.
@@ -228,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()
/ float(tes
 print('Test Error = ' + str(testErr))
 print('Learned classification forest model:')
 print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = RandomForestModel.load(sc, "myModelPath")
 {% endhighlight %}
 </div>
 
@@ -354,10 +356,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
 
 <div data-lang="python">
 
-Note that the Python API does not yet support model save/load but will in the future.
-
 {% highlight python %}
-from pyspark.mllib.tree import RandomForest
+from pyspark.mllib.tree import RandomForest, RandomForestModel
 from pyspark.mllib.util import MLUtils
 
 # Load and parse the data file into an RDD of LabeledPoint.
@@ -380,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()
/ flo
 print('Test Mean Squared Error = ' + str(testMSE))
 print('Learned regression forest model:')
 print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = RandomForestModel.load(sc, "myModelPath")
 {% endhighlight %}
 </div>
 
@@ -581,10 +585,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(),
"m
 
 <div data-lang="python">
 
-Note that the Python API does not yet support model save/load but will in the future.
-
 {% highlight python %}
-from pyspark.mllib.tree import GradientBoostedTrees
+from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
 from pyspark.mllib.util import MLUtils
 
 # Load and parse the data file.
@@ -605,6 +607,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()
/ float(tes
 print('Test Error = ' + str(testErr))
 print('Learned classification GBT model:')
 print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
 {% endhighlight %}
 </div>
 
@@ -732,10 +738,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(),
"m
 
 <div data-lang="python">
 
-Note that the Python API does not yet support model save/load but will in the future.
-
 {% highlight python %}
-from pyspark.mllib.tree import GradientBoostedTrees
+from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
 from pyspark.mllib.util import MLUtils
 
 # Load and parse the data file.
@@ -756,6 +760,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()
/ flo
 print('Test Mean Squared Error = ' + str(testMSE))
 print('Learned regression GBT model:')
 print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
 {% endhighlight %}
 </div>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e53a79c/python/pyspark/mllib/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 03d7d01..1a4527b 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -20,7 +20,7 @@ from collections import namedtuple
 from pyspark import SparkContext
 from pyspark.rdd import RDD
 from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
-from pyspark.mllib.util import Saveable, JavaLoader
+from pyspark.mllib.util import JavaLoader, JavaSaveable
 
 __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
 
@@ -41,7 +41,7 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
 
 
 @inherit_doc
-class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
+class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
 
     """A matrix factorisation model trained by regularized alternating
     least-squares.
@@ -92,7 +92,7 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
     0.43...
     >>> try:
     ...     os.removedirs(path)
-    ... except:
+    ... except OSError:
     ...     pass
     """
     def predict(self, user, product):
@@ -111,9 +111,6 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
     def productFeatures(self):
         return self.call("getProductFeatures")
 
-    def save(self, sc, path):
-        self.call("save", sc._jsc.sc(), path)
-
 
 class ALS(object):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e53a79c/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 06207a0..5328d99 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -19,7 +19,9 @@
 Fuller unit tests for Python MLlib.
 """
 
+import os
 import sys
+import tempfile
 import array as pyarray
 
 from numpy import array, array_equal
@@ -195,7 +197,8 @@ class ListTests(PySparkTestCase):
 
     def test_classification(self):
         from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
-        from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
+        from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
+            RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel
         data = [
             LabeledPoint(0.0, [1, 0, 0]),
             LabeledPoint(1.0, [0, 1, 1]),
@@ -205,6 +208,8 @@ class ListTests(PySparkTestCase):
         rdd = self.sc.parallelize(data)
         features = [p.features.tolist() for p in data]
 
+        temp_dir = tempfile.mkdtemp()
+
         lr_model = LogisticRegressionWithSGD.train(rdd)
         self.assertTrue(lr_model.predict(features[0]) <= 0)
         self.assertTrue(lr_model.predict(features[1]) > 0)
@@ -231,6 +236,11 @@ class ListTests(PySparkTestCase):
         self.assertTrue(dt_model.predict(features[2]) <= 0)
         self.assertTrue(dt_model.predict(features[3]) > 0)
 
+        dt_model_dir = os.path.join(temp_dir, "dt")
+        dt_model.save(self.sc, dt_model_dir)
+        same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir)
+        self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString())
+
         rf_model = RandomForest.trainClassifier(
             rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
         self.assertTrue(rf_model.predict(features[0]) <= 0)
@@ -238,6 +248,11 @@ class ListTests(PySparkTestCase):
         self.assertTrue(rf_model.predict(features[2]) <= 0)
         self.assertTrue(rf_model.predict(features[3]) > 0)
 
+        rf_model_dir = os.path.join(temp_dir, "rf")
+        rf_model.save(self.sc, rf_model_dir)
+        same_rf_model = RandomForestModel.load(self.sc, rf_model_dir)
+        self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString())
+
         gbt_model = GradientBoostedTrees.trainClassifier(
             rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
         self.assertTrue(gbt_model.predict(features[0]) <= 0)
@@ -245,6 +260,16 @@ class ListTests(PySparkTestCase):
         self.assertTrue(gbt_model.predict(features[2]) <= 0)
         self.assertTrue(gbt_model.predict(features[3]) > 0)
 
+        gbt_model_dir = os.path.join(temp_dir, "gbt")
+        gbt_model.save(self.sc, gbt_model_dir)
+        same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir)
+        self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
+
+        try:
+            os.removedirs(temp_dir)
+        except OSError:
+            pass
+
     def test_regression(self):
         from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
             RidgeRegressionWithSGD

http://git-wip-us.apache.org/repos/asf/spark/blob/7e53a79c/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 73618f0..bf288d7 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -23,12 +23,13 @@ from pyspark import SparkContext, RDD
 from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
 from pyspark.mllib.linalg import _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.util import JavaLoader, JavaSaveable
 
 __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
            'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees']
 
 
-class TreeEnsembleModel(JavaModelWrapper):
+class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
     def predict(self, x):
         """
         Predict values for a single data point or an RDD of points using
@@ -66,7 +67,7 @@ class TreeEnsembleModel(JavaModelWrapper):
         return self._java_model.toDebugString()
 
 
-class DecisionTreeModel(JavaModelWrapper):
+class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     """
     .. note:: Experimental
 
@@ -103,6 +104,10 @@ class DecisionTreeModel(JavaModelWrapper):
         """ full model. """
         return self._java_model.toDebugString()
 
+    @classmethod
+    def _java_loader_class(cls):
+        return "org.apache.spark.mllib.tree.model.DecisionTreeModel"
+
 
 class DecisionTree(object):
     """
@@ -227,13 +232,17 @@ class DecisionTree(object):
 
 
 @inherit_doc
-class RandomForestModel(TreeEnsembleModel):
+class RandomForestModel(TreeEnsembleModel, JavaLoader):
     """
     .. note:: Experimental
 
     Represents a random forest model.
     """
 
+    @classmethod
+    def _java_loader_class(cls):
+        return "org.apache.spark.mllib.tree.model.RandomForestModel"
+
 
 class RandomForest(object):
     """
@@ -406,13 +415,17 @@ class RandomForest(object):
 
 
 @inherit_doc
-class GradientBoostedTreesModel(TreeEnsembleModel):
+class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
     """
     .. note:: Experimental
 
     Represents a gradient-boosted tree model.
     """
 
+    @classmethod
+    def _java_loader_class(cls):
+        return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+
 
 class GradientBoostedTrees(object):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/7e53a79c/python/pyspark/mllib/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 17d43ea..e877c72 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -18,7 +18,7 @@
 import numpy as np
 import warnings
 
-from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
 from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint
 
@@ -191,6 +191,17 @@ class Saveable(object):
         raise NotImplementedError
 
 
+@inherit_doc
+class JavaSaveable(Saveable):
+    """
+    Mixin for models that provide save() through their Scala
+    implementation.
+    """
+
+    def save(self, sc, path):
+        self._java_model.save(sc._jsc.sc(), path)
+
+
 class Loader(object):
     """
     Mixin for classes which can load saved models from files.
@@ -210,6 +221,7 @@ class Loader(object):
         raise NotImplemented
 
 
+@inherit_doc
 class JavaLoader(Loader):
     """
     Mixin for classes which can load saved models using its Scala
@@ -217,13 +229,30 @@ class JavaLoader(Loader):
     """
 
     @classmethod
-    def load(cls, sc, path):
+    def _java_loader_class(cls):
+        """
+        Returns the full class name of the Java loader. The default
+        implementation replaces "pyspark" by "org.apache.spark" in
+        the Python full class name.
+        """
         java_package = cls.__module__.replace("pyspark", "org.apache.spark")
-        java_class = ".".join([java_package, cls.__name__])
+        return ".".join([java_package, cls.__name__])
+
+    @classmethod
+    def _load_java(cls, sc, path):
+        """
+        Load a Java model from the given path.
+        """
+        java_class = cls._java_loader_class()
         java_obj = sc._jvm
         for name in java_class.split("."):
             java_obj = getattr(java_obj, name)
-        return cls(java_obj.load(sc._jsc.sc(), path))
+        return java_obj.load(sc._jsc.sc(), path)
+
+    @classmethod
+    def load(cls, sc, path):
+        java_model = cls._load_java(sc, path)
+        return cls(java_model)
 
 
 def _test():


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message