spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-13038][PYSPARK] Add load/save to pipeline
Date Wed, 16 Mar 2016 20:49:45 GMT
Repository: spark
Updated Branches:
  refs/heads/master c4bd57602 -> ae6c677c8


[SPARK-13038][PYSPARK] Add load/save to pipeline

## What changes were proposed in this pull request?

JIRA issue: https://issues.apache.org/jira/browse/SPARK-13038

1. Add load/save to PySpark Pipeline and PipelineModel

2. Add `_transfer_stage_to_java()` and `_transfer_stage_from_java()` for `JavaWrapper`.

## How was this patch tested?

Test with doctest.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #11683 from yinxusen/SPARK-13038-only.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ae6c677c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ae6c677c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ae6c677c

Branch: refs/heads/master
Commit: ae6c677c8a03174787be99af6238a5e1fbe4e389
Parents: c4bd576
Author: Xusen Yin <yinxusen@gmail.com>
Authored: Wed Mar 16 13:49:40 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Mar 16 13:49:40 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/__init__.py |   3 +-
 python/pyspark/ml/base.py     | 118 +++++++++++++++++++++
 python/pyspark/ml/pipeline.py | 208 ++++++++++++++++++++++---------------
 python/pyspark/ml/tests.py    |  45 +++++++-
 python/pyspark/ml/util.py     |   3 +
 python/pyspark/ml/wrapper.py  |  29 +++++-
 6 files changed, 317 insertions(+), 89 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ae6c677c/python/pyspark/ml/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index 327a11b..25cfac0 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
-from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel
+from pyspark.ml.base import Estimator, Model, Transformer
+from pyspark.ml.pipeline import Pipeline, PipelineModel
 
 __all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]

http://git-wip-us.apache.org/repos/asf/spark/blob/ae6c677c/python/pyspark/ml/base.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
new file mode 100644
index 0000000..a7a58e1
--- /dev/null
+++ b/python/pyspark/ml/base.py
@@ -0,0 +1,118 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark import since
+from pyspark.ml.param import Params
+from pyspark.mllib.common import inherit_doc
+
+
+@inherit_doc
+class Estimator(Params):
+    """
+    Abstract class for estimators that fit models to data.
+
+    .. versionadded:: 1.3.0
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def _fit(self, dataset):
+        """
+        Fits a model to the input dataset. This is called by the default implementation of
fit.
+
+        :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
+        :returns: fitted model
+        """
+        raise NotImplementedError()
+
+    @since("1.3.0")
+    def fit(self, dataset, params=None):
+        """
+        Fits a model to the input dataset with optional parameters.
+
+        :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
+        :param params: an optional param map that overrides embedded params. If a list/tuple
of
+                       param maps is given, this calls fit on each param map and returns
a list of
+                       models.
+        :returns: fitted model(s)
+        """
+        if params is None:
+            params = dict()
+        if isinstance(params, (list, tuple)):
+            return [self.fit(dataset, paramMap) for paramMap in params]
+        elif isinstance(params, dict):
+            if params:
+                return self.copy(params)._fit(dataset)
+            else:
+                return self._fit(dataset)
+        else:
+            raise ValueError("Params must be either a param map or a list/tuple of param
maps, "
+                             "but got %s." % type(params))
+
+
+@inherit_doc
+class Transformer(Params):
+    """
+    Abstract class for transformers that transform one dataset into another.
+
+    .. versionadded:: 1.3.0
+    """
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def _transform(self, dataset):
+        """
+        Transforms the input dataset.
+
+        :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
+        :returns: transformed dataset
+        """
+        raise NotImplementedError()
+
+    @since("1.3.0")
+    def transform(self, dataset, params=None):
+        """
+        Transforms the input dataset with optional parameters.
+
+        :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
+        :param params: an optional param map that overrides embedded params.
+        :returns: transformed dataset
+        """
+        if params is None:
+            params = dict()
+        if isinstance(params, dict):
+            if params:
+                return self.copy(params)._transform(dataset)
+            else:
+                return self._transform(dataset)
+        else:
+            raise ValueError("Params must be a param map but got %s." % type(params))
+
+
+@inherit_doc
+class Model(Transformer):
+    """
+    Abstract class for models that are fitted by estimators.
+
+    .. versionadded:: 1.4.0
+    """
+
+    __metaclass__ = ABCMeta

http://git-wip-us.apache.org/repos/asf/spark/blob/ae6c677c/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 661074c..a1658b0 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -15,116 +15,77 @@
 # limitations under the License.
 #
 
-from abc import ABCMeta, abstractmethod
+import sys
 
+if sys.version > '3':
+    basestring = str
+
+from pyspark import SparkContext
 from pyspark import since
+from pyspark.ml import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader
+from pyspark.ml.wrapper import JavaWrapper
 from pyspark.mllib.common import inherit_doc
 
 
-@inherit_doc
-class Estimator(Params):
+def _stages_java2py(java_stages):
     """
-    Abstract class for estimators that fit models to data.
-
-    .. versionadded:: 1.3.0
+    Transforms the parameter Python stages from a list of Java stages.
+    :param java_stages: An array of Java stages.
+    :return: An array of Python stages.
     """
 
-    __metaclass__ = ABCMeta
+    return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages]
 
-    @abstractmethod
-    def _fit(self, dataset):
-        """
-        Fits a model to the input dataset. This is called by the
-        default implementation of fit.
 
-        :param dataset: input dataset, which is an instance of
-                        :py:class:`pyspark.sql.DataFrame`
-        :returns: fitted model
-        """
-        raise NotImplementedError()
+def _stages_py2java(py_stages, cls):
+    """
+    Transforms the parameter of Python stages to a Java array of Java stages.
+    :param py_stages: An array of Python stages.
+    :return: A Java array of Java Stages.
+    """
 
-    @since("1.3.0")
-    def fit(self, dataset, params=None):
-        """
-        Fits a model to the input dataset with optional parameters.
-
-        :param dataset: input dataset, which is an instance of
-                        :py:class:`pyspark.sql.DataFrame`
-        :param params: an optional param map that overrides embedded
-                       params. If a list/tuple of param maps is given,
-                       this calls fit on each param map and returns a
-                       list of models.
-        :returns: fitted model(s)
-        """
-        if params is None:
-            params = dict()
-        if isinstance(params, (list, tuple)):
-            return [self.fit(dataset, paramMap) for paramMap in params]
-        elif isinstance(params, dict):
-            if params:
-                return self.copy(params)._fit(dataset)
-            else:
-                return self._fit(dataset)
-        else:
-            raise ValueError("Params must be either a param map or a list/tuple of param
maps, "
-                             "but got %s." % type(params))
+    for stage in py_stages:
+        assert(isinstance(stage, JavaWrapper),
+               "Python side implementation is not supported in the meta-PipelineStage currently.")
+    gateway = SparkContext._gateway
+    java_stages = gateway.new_array(cls, len(py_stages))
+    for idx, stage in enumerate(py_stages):
+        java_stages[idx] = stage._transfer_stage_to_java()
+    return java_stages
 
 
 @inherit_doc
-class Transformer(Params):
+class PipelineMLWriter(JavaMLWriter, JavaWrapper):
     """
-    Abstract class for transformers that transform one dataset into
-    another.
-
-    .. versionadded:: 1.3.0
+    Private Pipeline utility class that can save ML instances through their Scala implementation.
     """
 
-    __metaclass__ = ABCMeta
-
-    @abstractmethod
-    def _transform(self, dataset):
-        """
-        Transforms the input dataset.
-
-        :param dataset: input dataset, which is an instance of
-                        :py:class:`pyspark.sql.DataFrame`
-        :returns: transformed dataset
-        """
-        raise NotImplementedError()
-
-    @since("1.3.0")
-    def transform(self, dataset, params=None):
-        """
-        Transforms the input dataset with optional parameters.
-
-        :param dataset: input dataset, which is an instance of
-                        :py:class:`pyspark.sql.DataFrame`
-        :param params: an optional param map that overrides embedded
-                       params.
-        :returns: transformed dataset
-        """
-        if params is None:
-            params = dict()
-        if isinstance(params, dict):
-            if params:
-                return self.copy(params,)._transform(dataset)
-            else:
-                return self._transform(dataset)
-        else:
-            raise ValueError("Params must be either a param map but got %s." % type(params))
+    def __init__(self, instance):
+        cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid)
+        self._java_obj.setStages(_stages_py2java(instance.getStages(), cls))
+        self._jwrite = self._java_obj.write()
 
 
 @inherit_doc
-class Model(Transformer):
+class PipelineMLReader(JavaMLReader):
     """
-    Abstract class for models that are fitted by estimators.
-
-    .. versionadded:: 1.4.0
+    Private utility class that can load Pipeline instances through their Scala implementation.
     """
 
-    __metaclass__ = ABCMeta
+    def load(self, path):
+        """Load the Pipeline instance from the input path."""
+        if not isinstance(path, basestring):
+            raise TypeError("path should be a basestring, got type %s" % type(path))
+
+        java_obj = self._jread.load(path)
+        instance = self._clazz()
+        instance._resetUid(java_obj.uid())
+        instance.setStages(_stages_java2py(java_obj.getStages()))
+
+        return instance
 
 
 @inherit_doc
@@ -232,6 +193,59 @@ class Pipeline(Estimator):
         stages = [stage.copy(extra) for stage in that.getStages()]
         return that.setStages(stages)
 
+    @since("2.0.0")
+    def write(self):
+        """Returns an JavaMLWriter instance for this ML instance."""
+        return PipelineMLWriter(self)
+
+    @since("2.0.0")
+    def save(self, path):
+        """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+        self.write().save(path)
+
+    @classmethod
+    @since("2.0.0")
+    def read(cls):
+        """Returns an JavaMLReader instance for this class."""
+        return PipelineMLReader(cls)
+
+    @classmethod
+    @since("2.0.0")
+    def load(cls, path):
+        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
+        return cls.read().load(path)
+
+
+@inherit_doc
+class PipelineModelMLWriter(JavaMLWriter, JavaWrapper):
+    """
+    Private PipelineModel utility class that can save ML instances through their Scala
+    implementation.
+    """
+
+    def __init__(self, instance):
+        cls = SparkContext._jvm.org.apache.spark.ml.Transformer
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel",
+                                            instance.uid,
+                                            _stages_py2java(instance.stages, cls))
+        self._jwrite = self._java_obj.write()
+
+
+@inherit_doc
+class PipelineModelMLReader(JavaMLReader):
+    """
+    Private utility class that can load PipelineModel instances through their Scala implementation.
+    """
+
+    def load(self, path):
+        """Load the PipelineModel instance from the input path."""
+        if not isinstance(path, basestring):
+            raise TypeError("path should be a basestring, got type %s" % type(path))
+        java_obj = self._jread.load(path)
+        instance = self._clazz(_stages_java2py(java_obj.stages()))
+        instance._resetUid(java_obj.uid())
+        return instance
+
 
 @inherit_doc
 class PipelineModel(Model):
@@ -262,3 +276,25 @@ class PipelineModel(Model):
             extra = dict()
         stages = [stage.copy(extra) for stage in self.stages]
         return PipelineModel(stages)
+
+    @since("2.0.0")
+    def write(self):
+        """Returns an JavaMLWriter instance for this ML instance."""
+        return PipelineModelMLWriter(self)
+
+    @since("2.0.0")
+    def save(self, path):
+        """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+        self.write().save(path)
+
+    @classmethod
+    @since("2.0.0")
+    def read(cls):
+        """Returns an JavaMLReader instance for this class."""
+        return PipelineModelMLReader(cls)
+
+    @classmethod
+    @since("2.0.0")
+    def load(cls, path):
+        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
+        return cls.read().load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/ae6c677c/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4da9a37..c76f893 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -37,7 +37,7 @@ else:
 from shutil import rmtree
 import tempfile
 
-from pyspark.ml import Estimator, Model, Pipeline, Transformer
+from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
 from pyspark.ml.classification import LogisticRegression
 from pyspark.ml.clustering import KMeans
 from pyspark.ml.evaluation import RegressionEvaluator
@@ -499,6 +499,49 @@ class PersistenceTest(PySparkTestCase):
         except OSError:
             pass
 
+    def test_pipeline_persistence(self):
+        sqlContext = SQLContext(self.sc)
+        temp_path = tempfile.mkdtemp()
+
+        try:
+            df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+            tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+            pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+            pl = Pipeline(stages=[tf, pca])
+            model = pl.fit(df)
+            pipeline_path = temp_path + "/pipeline"
+            pl.save(pipeline_path)
+            loaded_pipeline = Pipeline.load(pipeline_path)
+            self.assertEqual(loaded_pipeline.uid, pl.uid)
+            self.assertEqual(len(loaded_pipeline.getStages()), 2)
+
+            [loaded_tf, loaded_pca] = loaded_pipeline.getStages()
+            self.assertIsInstance(loaded_tf, HashingTF)
+            self.assertEqual(loaded_tf.uid, tf.uid)
+            param = loaded_tf.getParam("numFeatures")
+            self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param))
+
+            self.assertIsInstance(loaded_pca, PCA)
+            self.assertEqual(loaded_pca.uid, pca.uid)
+            self.assertEqual(loaded_pca.getK(), pca.getK())
+
+            model_path = temp_path + "/pipeline-model"
+            model.save(model_path)
+            loaded_model = PipelineModel.load(model_path)
+            [model_tf, model_pca] = model.stages
+            [loaded_model_tf, loaded_model_pca] = loaded_model.stages
+            self.assertEqual(model_tf.uid, loaded_model_tf.uid)
+            self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param))
+
+            self.assertEqual(model_pca.uid, loaded_model_pca.uid)
+            self.assertEqual(model_pca.pc, loaded_model_pca.pc)
+            self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance)
+        finally:
+            try:
+                rmtree(temp_path)
+            except OSError:
+                pass
+
 
 class HasThrowableProperty(Params):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ae6c677c/python/pyspark/ml/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d7a813f..42801c9 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -161,6 +161,9 @@ class JavaMLReader(object):
         the Python full class name.
         """
         java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
+        if clazz.__name__ in ("Pipeline", "PipelineModel"):
+            # Remove the last package name "pipeline" for Pipeline and PipelineModel.
+            java_package = ".".join(java_package.split(".")[0:-1])
         return ".".join([java_package, clazz.__name__])
 
     @classmethod

http://git-wip-us.apache.org/repos/asf/spark/blob/ae6c677c/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index f8feaa1..0f7b5e9 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -19,8 +19,8 @@ from abc import ABCMeta, abstractmethod
 
 from pyspark import SparkContext
 from pyspark.sql import DataFrame
+from pyspark.ml import Estimator, Transformer, Model
 from pyspark.ml.param import Params
-from pyspark.ml.pipeline import Estimator, Transformer, Model
 from pyspark.ml.util import _jvm
 from pyspark.mllib.common import inherit_doc, _java2py, _py2java
 
@@ -90,6 +90,33 @@ class JavaWrapper(Params):
         """
         return _jvm().org.apache.spark.ml.param.ParamMap()
 
+    def _transfer_stage_to_java(self):
+        self._transfer_params_to_java()
+        return self._java_obj
+
+    @staticmethod
+    def _transfer_stage_from_java(java_stage):
+        def __get_class(clazz):
+            """
+            Loads Python class from its name.
+            """
+            parts = clazz.split('.')
+            module = ".".join(parts[:-1])
+            m = __import__(module)
+            for comp in parts[1:]:
+                m = getattr(m, comp)
+            return m
+        stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
+        # Generate a default new instance from the stage_name class.
+        py_stage = __get_class(stage_name)()
+        assert(isinstance(py_stage, JavaWrapper),
+               "Python side implementation is not supported in the meta-PipelineStage currently.")
+        # Load information from java_stage to the instance.
+        py_stage._java_obj = java_stage
+        py_stage._resetUid(java_stage.uid())
+        py_stage._transfer_params_from_java()
+        return py_stage
+
 
 @inherit_doc
 class JavaEstimator(Estimator, JavaWrapper):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message