spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From cutl...@apache.org
Subject spark git commit: [SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels
Date Fri, 06 Apr 2018 18:51:43 GMT
Repository: spark
Updated Branches:
  refs/heads/master b6935ffb4 -> e99825058


[SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels

## What changes were proposed in this pull request?

The Scala StringIndexerModel has an alternate constructor that will create the model from
an array of label strings.  Add the corresponding Python API:

model = StringIndexerModel.from_labels(["a", "b", "c"])

## How was this patch tested?

Add doctest and unit test.

Author: Huaxin Gao <huaxing@us.ibm.com>

Closes #20968 from huaxingao/spark-23828.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9982505
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9982505
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9982505

Branch: refs/heads/master
Commit: e998250588de0df250e2800278da4d3e3705c259
Parents: b6935ff
Author: Huaxin Gao <huaxing@us.ibm.com>
Authored: Fri Apr 6 11:51:36 2018 -0700
Committer: Bryan Cutler <cutlerb@gmail.com>
Committed: Fri Apr 6 11:51:36 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/feature.py | 88 ++++++++++++++++++++++++++++-----------
 python/pyspark/ml/tests.py   | 41 +++++++++++++++++-
 2 files changed, 104 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9982505/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index fcb0dfc..5a3e0dd 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2342,9 +2342,38 @@ class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
         return self._call_java("mean")
 
 
+class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol):
+    """
+    Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
+    """
+
+    stringOrderType = Param(Params._dummy(), "stringOrderType",
+                            "How to order labels of string column. The first label after
" +
+                            "ordering is assigned an index of 0. Supported options: " +
+                            "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
+                            typeConverter=TypeConverters.toString)
+
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen
" +
+                          "or NULL values) in features and label column of string type. "
+
+                          "Options are 'skip' (filter out rows with invalid data), " +
+                          "error (throw an error), or 'keep' (put invalid data " +
+                          "in a special additional bucket, at index numLabels).",
+                          typeConverter=TypeConverters.toString)
+
+    def __init__(self, *args):
+        super(_StringIndexerParams, self).__init__(*args)
+        self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
+
+    @since("2.3.0")
+    def getStringOrderType(self):
+        """
+        Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
+        """
+        return self.getOrDefault(self.stringOrderType)
+
+
 @inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
-                    JavaMLWritable):
+class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
     """
     A label indexer that maps a string column of labels to an ML column of label indices.
     If the input column is numeric, we cast it to string and index the string values.
@@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
     >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
     ...     key=lambda x: x[0])
     [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
+    >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
+    ...     inputCol="label", outputCol="indexed", handleInvalid="error")
+    >>> result = fromlabelsModel.transform(stringIndDf)
+    >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
+    ...     key=lambda x: x[0])
+    [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
 
     .. versionadded:: 1.4.0
     """
 
-    stringOrderType = Param(Params._dummy(), "stringOrderType",
-                            "How to order labels of string column. The first label after
" +
-                            "ordering is assigned an index of 0. Supported options: " +
-                            "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
-                            typeConverter=TypeConverters.toString)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen
" +
-                          "or NULL values) in features and label column of string type. "
+
-                          "Options are 'skip' (filter out rows with invalid data), " +
-                          "error (throw an error), or 'keep' (put invalid data " +
-                          "in a special additional bucket, at index numLabels).",
-                          typeConverter=TypeConverters.toString)
-
     @keyword_only
     def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
                  stringOrderType="frequencyDesc"):
@@ -2414,7 +2436,6 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
         """
         super(StringIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer",
self.uid)
-        self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -2440,21 +2461,33 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
         """
         return self._set(stringOrderType=value)
 
-    @since("2.3.0")
-    def getStringOrderType(self):
-        """
-        Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
-        """
-        return self.getOrDefault(self.stringOrderType)
-
 
-class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
+class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
     """
     Model fitted by :py:class:`StringIndexer`.
 
     .. versionadded:: 1.4.0
     """
 
+    @classmethod
+    @since("2.4.0")
+    def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None):
+        """
+        Construct the model directly from an array of label strings,
+        requires an active SparkContext.
+        """
+        sc = SparkContext._active_spark_context
+        java_class = sc._gateway.jvm.java.lang.String
+        jlabels = StringIndexerModel._new_java_array(labels, java_class)
+        model = StringIndexerModel._create_from_java_class(
+            "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+        model.setInputCol(inputCol)
+        if outputCol is not None:
+            model.setOutputCol(outputCol)
+        if handleInvalid is not None:
+            model.setHandleInvalid(handleInvalid)
+        return model
+
     @property
     @since("1.5.0")
     def labels(self):
@@ -2463,6 +2496,13 @@ class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
         """
         return self._call_java("labels")
 
+    @since("2.4.0")
+    def setHandleInvalid(self, value):
+        """
+        Sets the value of :py:attr:`handleInvalid`.
+        """
+        return self._set(handleInvalid=value)
+
 
 @inherit_doc
 class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):

http://git-wip-us.apache.org/repos/asf/spark/blob/e9982505/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c2c4861..4ce5454 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -800,6 +800,43 @@ class FeatureTests(SparkSessionTestCase):
         expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
         self.assertEqual(actual2, expected2)
 
+    def test_string_indexer_from_labels(self):
+        model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label",
+                                               outputCol="indexed", handleInvalid="keep")
+        self.assertEqual(model.labels, ["a", "b", "c"])
+
+        df1 = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "c"),
+            (2, None),
+            (3, "b"),
+            (4, "b")], ["id", "label"])
+
+        result1 = model.transform(df1)
+        actual1 = result1.select("id", "indexed").collect()
+        expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0),
+                     Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)]
+        self.assertEqual(actual1, expected1)
+
+        model_empty_labels = StringIndexerModel.from_labels(
+            [], inputCol="label", outputCol="indexed", handleInvalid="keep")
+        actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect()
+        expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0),
+                     Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)]
+        self.assertEqual(actual2, expected2)
+
+        # Test model with default settings can transform
+        model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label")
+        df2 = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "c"),
+            (2, "b"),
+            (3, "b"),
+            (4, "b")], ["id", "label"])
+        transformed_list = model_default.transform(df2)\
+            .select(model_default.getOrDefault(model_default.outputCol)).collect()
+        self.assertEqual(len(transformed_list), 5)
+
 
 class HasInducedError(Params):
 
@@ -2097,9 +2134,11 @@ class DefaultValuesTests(PySparkTestCase):
                     ParamTests.check_params(self, cls(), check_params_exist=False)
 
         # Additional classes that need explicit construction
-        from pyspark.ml.feature import CountVectorizerModel
+        from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel
         ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'),
                                 check_params_exist=False)
+        ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'),
+                                check_params_exist=False)
 
 
 def _squared_distance(a, b):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message