spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mln...@apache.org
Subject spark git commit: [SPARK-18366][PYSPARK][ML] Add handleInvalid to Pyspark for QuantileDiscretizer and Bucketizer
Date Wed, 30 Nov 2016 09:33:37 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.1 5e4afbfb6 -> 7043c6b69


[SPARK-18366][PYSPARK][ML] Add handleInvalid to Pyspark for QuantileDiscretizer and Bucketizer

## What changes were proposed in this pull request?
added the new handleInvalid param for these transformers to Python to maintain API parity.

## How was this patch tested?
existing tests
testing is done with new doctests

Author: Sandeep Singh <sandeep@techaddict.me>

Closes #15817 from techaddict/SPARK-18366.

(cherry picked from commit fe854f2e4fb2fa1a1c501f11030e36f489ca546f)
Signed-off-by: Nick Pentreath <nickp@za.ibm.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7043c6b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7043c6b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7043c6b6

Branch: refs/heads/branch-2.1
Commit: 7043c6b695f77741c5e97a322d9590bd714289de
Parents: 5e4afbf
Author: Sandeep Singh <sandeep@techaddict.me>
Authored: Wed Nov 30 11:33:15 2016 +0200
Committer: Nick Pentreath <nickp@za.ibm.com>
Committed: Wed Nov 30 11:33:30 2016 +0200

----------------------------------------------------------------------
 python/pyspark/ml/feature.py | 85 ++++++++++++++++++++++++++++++++-------
 1 file changed, 71 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7043c6b6/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index aada38d..1d62b32 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -125,10 +125,13 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
Jav
     """
     Maps a column of continuous features to a column of feature buckets.
 
-    >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
+    >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
+    >>> df = spark.createDataFrame(values, ["values"])
     >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
     ...     inputCol="values", outputCol="buckets")
-    >>> bucketed = bucketizer.transform(df).collect()
+    >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
+    >>> len(bucketed)
+    6
     >>> bucketed[0].buckets
     0.0
     >>> bucketed[1].buckets
@@ -144,6 +147,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
Jav
     >>> loadedBucketizer = Bucketizer.load(bucketizerPath)
     >>> loadedBucketizer.getSplits() == bucketizer.getSplits()
     True
+    >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
+    >>> len(bucketed)
+    4
 
     .. versionadded:: 1.4.0
     """
@@ -158,21 +164,28 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
Jav
               "splits specified will be treated as errors.",
               typeConverter=TypeConverters.toListFloat)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries.
" +
+                          "Options are skip (filter out rows with invalid values), " +
+                          "error (throw an error), or keep (keep invalid values in a special
" +
+                          "additional bucket).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, splits=None, inputCol=None, outputCol=None):
+    def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
         """
-        __init__(self, splits=None, inputCol=None, outputCol=None)
+        __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
         """
         super(Bucketizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
+        self._setDefault(handleInvalid="error")
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, splits=None, inputCol=None, outputCol=None):
+    def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
         """
-        setParams(self, splits=None, inputCol=None, outputCol=None)
+        setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
         Sets params for this Bucketizer.
         """
         kwargs = self.setParams._input_kwargs
@@ -192,6 +205,20 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
Jav
         """
         return self.getOrDefault(self.splits)
 
+    @since("2.1.0")
+    def setHandleInvalid(self, value):
+        """
+        Sets the value of :py:attr:`handleInvalid`.
+        """
+        return self._set(handleInvalid=value)
+
+    @since("2.1.0")
+    def getHandleInvalid(self):
+        """
+        Gets the value of :py:attr:`handleInvalid` or its default value.
+        """
+        return self.getOrDefault(self.handleInvalid)
+
 
 @inherit_doc
 class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -1157,12 +1184,17 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol,
JavaMLReadab
     :py:attr:`relativeError` parameter.
     The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real
values.
 
-    >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
+    >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
+    >>> df = spark.createDataFrame(values, ["values"])
     >>> qds = QuantileDiscretizer(numBuckets=2,
-    ...     inputCol="values", outputCol="buckets", relativeError=0.01)
+    ...     inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error")
     >>> qds.getRelativeError()
     0.01
     >>> bucketizer = qds.fit(df)
+    >>> qds.setHandleInvalid("keep").fit(df).transform(df).count()
+    6
+    >>> qds.setHandleInvalid("skip").fit(df).transform(df).count()
+    4
     >>> splits = bucketizer.getSplits()
     >>> splits[0]
     -inf
@@ -1190,23 +1222,33 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol,
JavaMLReadab
                           "Must be in the range [0, 1].",
                           typeConverter=TypeConverters.toFloat)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries.
" +
+                          "Options are skip (filter out rows with invalid values), " +
+                          "error (throw an error), or keep (keep invalid values in a special
" +
+                          "additional bucket).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
+    def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
+                 handleInvalid="error"):
         """
-        __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
+        __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
\
+                 handleInvalid="error")
         """
         super(QuantileDiscretizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
                                             self.uid)
-        self._setDefault(numBuckets=2, relativeError=0.001)
+        self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
+    def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
+                  handleInvalid="error"):
         """
-        setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
+        setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
\
+                  handleInvalid="error")
         Set the params for the QuantileDiscretizer
         """
         kwargs = self.setParams._input_kwargs
@@ -1240,13 +1282,28 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol,
JavaMLReadab
         """
         return self.getOrDefault(self.relativeError)
 
+    @since("2.1.0")
+    def setHandleInvalid(self, value):
+        """
+        Sets the value of :py:attr:`handleInvalid`.
+        """
+        return self._set(handleInvalid=value)
+
+    @since("2.1.0")
+    def getHandleInvalid(self):
+        """
+        Gets the value of :py:attr:`handleInvalid` or its default value.
+        """
+        return self.getOrDefault(self.handleInvalid)
+
     def _create_model(self, java_model):
         """
         Private method to convert the java_model to a Python model.
         """
         return Bucketizer(splits=list(java_model.getSplits()),
                           inputCol=self.getInputCol(),
-                          outputCol=self.getOutputCol())
+                          outputCol=self.getOutputCol(),
+                          handleInvalid=self.getHandleInvalid())
 
 
 @inherit_doc


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message