spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-17025][ML][PYTHON] Persistence for Pipelines with Python-only Stages
Date Sat, 12 Aug 2017 06:57:16 GMT
Repository: spark
Updated Branches:
  refs/heads/master b0bdfce9c -> 35db3b9fe


[SPARK-17025][ML][PYTHON] Persistence for Pipelines with Python-only Stages

## What changes were proposed in this pull request?

Implemented a Python-only persistence framework for pipelines containing stages that cannot
be saved using Java.

## How was this patch tested?

Created a custom Python-only UnaryTransformer, included it in a Pipeline, and saved/loaded
the pipeline. The loaded pipeline was compared against the original using _compare_pipelines()
in tests.py.

Author: Ajay Saini <ajays725@gmail.com>

Closes #18888 from ajaysaini725/PythonPipelines.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35db3b9f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35db3b9f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35db3b9f

Branch: refs/heads/master
Commit: 35db3b9fe38dadfb8afb0b0857c09f83196398be
Parents: b0bdfce
Author: Ajay Saini <ajays725@gmail.com>
Authored: Fri Aug 11 23:57:08 2017 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Fri Aug 11 23:57:08 2017 -0700

----------------------------------------------------------------------
 python/pyspark/ml/pipeline.py | 156 +++++++++++++++++++++++++++++++++++--
 python/pyspark/ml/tests.py    |  35 ++++++++-
 2 files changed, 183 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35db3b9f/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a8dc76b..0975302 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -16,6 +16,7 @@
 #
 
 import sys
+import os
 
 if sys.version > '3':
     basestring = str
@@ -23,7 +24,7 @@ if sys.version > '3':
 from pyspark import since, keyword_only, SparkContext
 from pyspark.ml.base import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
-from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
+from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.common import inherit_doc
 
@@ -130,13 +131,16 @@ class Pipeline(Estimator, MLReadable, MLWritable):
     @since("2.0.0")
     def write(self):
         """Returns an MLWriter instance for this ML instance."""
-        return JavaMLWriter(self)
+        allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
+        if allStagesAreJava:
+            return JavaMLWriter(self)
+        return PipelineWriter(self)
 
     @classmethod
     @since("2.0.0")
     def read(cls):
         """Returns an MLReader instance for this class."""
-        return JavaMLReader(cls)
+        return PipelineReader(cls)
 
     @classmethod
     def _from_java(cls, java_stage):
@@ -172,6 +176,76 @@ class Pipeline(Estimator, MLReadable, MLWritable):
 
 
 @inherit_doc
+class PipelineWriter(MLWriter):
+    """
+    (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
+    """
+
+    def __init__(self, instance):
+        super(PipelineWriter, self).__init__()
+        self.instance = instance
+
+    def saveImpl(self, path):
+        stages = self.instance.getStages()
+        PipelineSharedReadWrite.validateStages(stages)
+        PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
+
+
+@inherit_doc
+class PipelineReader(MLReader):
+    """
+    (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
+    """
+
+    def __init__(self, cls):
+        super(PipelineReader, self).__init__()
+        self.cls = cls
+
+    def load(self, path):
+        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] !=
'Python':
+            return JavaMLReader(self.cls).load(path)
+        else:
+            uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
+            return Pipeline(stages=stages)._resetUid(uid)
+
+
+@inherit_doc
+class PipelineModelWriter(MLWriter):
+    """
+    (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
+    """
+
+    def __init__(self, instance):
+        super(PipelineModelWriter, self).__init__()
+        self.instance = instance
+
+    def saveImpl(self, path):
+        stages = self.instance.stages
+        PipelineSharedReadWrite.validateStages(stages)
+        PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
+
+
+@inherit_doc
+class PipelineModelReader(MLReader):
+    """
+    (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
+    """
+
+    def __init__(self, cls):
+        super(PipelineModelReader, self).__init__()
+        self.cls = cls
+
+    def load(self, path):
+        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] !=
'Python':
+            return JavaMLReader(self.cls).load(path)
+        else:
+            uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
+            return PipelineModel(stages=stages)._resetUid(uid)
+
+
+@inherit_doc
 class PipelineModel(Model, MLReadable, MLWritable):
     """
     Represents a compiled pipeline with transformers and fitted models.
@@ -204,13 +278,16 @@ class PipelineModel(Model, MLReadable, MLWritable):
     @since("2.0.0")
     def write(self):
         """Returns an MLWriter instance for this ML instance."""
-        return JavaMLWriter(self)
+        allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
+        if allStagesAreJava:
+            return JavaMLWriter(self)
+        return PipelineModelWriter(self)
 
     @classmethod
     @since("2.0.0")
     def read(cls):
         """Returns an MLReader instance for this class."""
-        return JavaMLReader(cls)
+        return PipelineModelReader(cls)
 
     @classmethod
     def _from_java(cls, java_stage):
@@ -242,3 +319,72 @@ class PipelineModel(Model, MLReadable, MLWritable):
             JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
 
         return _java_obj
+
+
+@inherit_doc
+class PipelineSharedReadWrite():
+    """
+    .. note:: DeveloperApi
+
+    Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
+    :py:class:`Pipeline` and :py:class:`PipelineModel`
+
+    .. versionadded:: 2.3.0
+    """
+
+    @staticmethod
+    def checkStagesForJava(stages):
+        return all(isinstance(stage, JavaMLWritable) for stage in stages)
+
+    @staticmethod
+    def validateStages(stages):
+        """
+        Check that all stages are Writable
+        """
+        for stage in stages:
+            if not isinstance(stage, MLWritable):
+                raise ValueError("Pipeline write will fail on this pipeline " +
+                                 "because stage %s of type %s is not MLWritable",
+                                 stage.uid, type(stage))
+
+    @staticmethod
+    def saveImpl(instance, stages, sc, path):
+        """
+        Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
+        - save metadata to path/metadata
+        - save stages to stages/IDX_UID
+        """
+        stageUids = [stage.uid for stage in stages]
+        jsonParams = {'stageUids': stageUids, 'language': 'Python'}
+        DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
+        stagesDir = os.path.join(path, "stages")
+        for index, stage in enumerate(stages):
+            stage.write().save(PipelineSharedReadWrite
+                               .getStagePath(stage.uid, index, len(stages), stagesDir))
+
+    @staticmethod
+    def load(metadata, sc, path):
+        """
+        Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
+
+        :return: (UID, list of stages)
+        """
+        stagesDir = os.path.join(path, "stages")
+        stageUids = metadata['paramMap']['stageUids']
+        stages = []
+        for index, stageUid in enumerate(stageUids):
+            stagePath = \
+                PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
+            stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
+            stages.append(stage)
+        return (metadata['uid'], stages)
+
+    @staticmethod
+    def getStagePath(stageUid, stageIdx, numStages, stagesDir):
+        """
+        Get path for saving the given stage.
+        """
+        stageIdxDigits = len(str(numStages))
+        stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
+        stagePath = os.path.join(stagesDir, stageDir)
+        return stagePath

http://git-wip-us.apache.org/repos/asf/spark/blob/35db3b9f/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6aecc7f..0495973 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -123,7 +123,7 @@ class MockTransformer(Transformer, HasFake):
         return dataset
 
 
-class MockUnaryTransformer(UnaryTransformer):
+class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
 
     shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
                   "data in a DataFrame",
@@ -150,7 +150,7 @@ class MockUnaryTransformer(UnaryTransformer):
     def validateInputType(self, inputType):
         if inputType != DoubleType():
             raise TypeError("Bad input type: {}. ".format(inputType) +
-                            "Requires Integer.")
+                            "Requires Double.")
 
 
 class MockEstimator(Estimator, HasFake):
@@ -1063,7 +1063,7 @@ class PersistenceTest(SparkSessionTestCase):
         """
         self.assertEqual(m1.uid, m2.uid)
         self.assertEqual(type(m1), type(m2))
-        if isinstance(m1, JavaParams):
+        if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
             self.assertEqual(len(m1.params), len(m2.params))
             for p in m1.params:
                 self._compare_params(m1, m2, p)
@@ -1142,6 +1142,35 @@ class PersistenceTest(SparkSessionTestCase):
             except OSError:
                 pass
 
+    def test_python_transformer_pipeline_persistence(self):
+        """
+        Pipeline[MockUnaryTransformer, Binarizer]
+        """
+        temp_path = tempfile.mkdtemp()
+
+        try:
+            df = self.spark.range(0, 10).toDF('input')
+            tf = MockUnaryTransformer(shiftVal=2)\
+                .setInputCol("input").setOutputCol("shiftedInput")
+            tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
+            pl = Pipeline(stages=[tf, tf2])
+            model = pl.fit(df)
+
+            pipeline_path = temp_path + "/pipeline"
+            pl.save(pipeline_path)
+            loaded_pipeline = Pipeline.load(pipeline_path)
+            self._compare_pipelines(pl, loaded_pipeline)
+
+            model_path = temp_path + "/pipeline-model"
+            model.save(model_path)
+            loaded_model = PipelineModel.load(model_path)
+            self._compare_pipelines(model, loaded_model)
+        finally:
+            try:
+                rmtree(temp_path)
+            except OSError:
+                pass
+
     def test_onevsrest(self):
         temp_path = tempfile.mkdtemp()
         df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message