spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hol...@apache.org
Subject spark git commit: [SPARK-22521][ML] VectorIndexerModel support handle unseen categories via handleInvalid: Python API
Date Tue, 21 Nov 2017 18:54:09 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5855b5c03 -> 2d868d939


[SPARK-22521][ML] VectorIndexerModel support handle unseen categories via handleInvalid: Python
API

## What changes were proposed in this pull request?

Add python api for VectorIndexerModel support handle unseen categories via handleInvalid.

## How was this patch tested?

doctest added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19753 from WeichenXu123/vector_indexer_invalid_py.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d868d93
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d868d93
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d868d93

Branch: refs/heads/master
Commit: 2d868d93987ea1757cc66cdfb534bc49794eb0d0
Parents: 5855b5c
Author: WeichenXu <weichen.xu@databricks.com>
Authored: Tue Nov 21 10:53:53 2017 -0800
Committer: Holden Karau <holdenkarau@google.com>
Committed: Tue Nov 21 10:53:53 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/feature/VectorIndexer.scala |  7 +++--
 python/pyspark/ml/feature.py                    | 30 +++++++++++++++-----
 2 files changed, 27 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2d868d93/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 3403ec4..e6ec4e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with
HasOu
    * Options are:
    * 'skip': filter out rows with invalid data.
    * 'error': throw an error.
-   * 'keep': put invalid data in a special additional bucket, at index numCategories.
+   * 'keep': put invalid data in a special additional bucket, at index of the number of
+   * categories of the feature.
    * Default value: "error"
    * @group param
    */
@@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with
HasOu
   override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
     "How to handle invalid data (unseen labels or NULL values). " +
     "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), "
+
-    "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
+    "or 'keep' (put invalid data in a special additional bucket, at index of the " +
+    "number of categories of the feature).",
     ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))
 
   setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)
@@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol
with HasOu
  *  - Preserve metadata in transform; if a feature's metadata is already present, do not
recompute.
  *  - Specify certain features to not index, either via a parameter or via existing metadata.
  *  - Add warning if a categorical feature has only 1 category.
- *  - Add option for allowing unknown categories.
  */
 @Since("1.4.0")
 class VectorIndexer @Since("1.4.0") (

http://git-wip-us.apache.org/repos/asf/spark/blob/2d868d93/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 232ae3e..608f2a5 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2490,7 +2490,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
 
 
 @inherit_doc
-class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
+class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
+                    JavaMLWritable):
     """
     Class for indexing categorical feature columns in a dataset of `Vector`.
 
@@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
Ja
         do not recompute.
       - Specify certain features to not index, either via a parameter or via existing metadata.
       - Add warning if a categorical feature has only 1 category.
-      - Add option for allowing unknown categories.
 
     >>> from pyspark.ml.linalg import Vectors
     >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
@@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
Ja
     True
     >>> loadedModel.categoryMaps == model.categoryMaps
     True
+    >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])
+    >>> indexer.getHandleInvalid()
+    'error'
+    >>> model3 = indexer.setHandleInvalid("skip").fit(df)
+    >>> model3.transform(dfWithInvalid).count()
+    0
+    >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)
+    >>> model4.transform(dfWithInvalid).head().indexed
+    DenseVector([2.0, 1.0])
 
     .. versionadded:: 1.4.0
     """
@@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
Ja
                           "(>= 2). If a feature is found to have > maxCategories values,
then " +
                           "it is declared continuous.", typeConverter=TypeConverters.toInt)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data "
+
+                          "(unseen labels or NULL values). Options are 'skip' (filter out
" +
+                          "rows with invalid data), 'error' (throw an error), or 'keep' (put
" +
+                          "invalid data in a special additional bucket, at index of the number
" +
+                          "of categories of the feature).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, maxCategories=20, inputCol=None, outputCol=None):
+    def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
         """
-        __init__(self, maxCategories=20, inputCol=None, outputCol=None)
+        __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
         """
         super(VectorIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer",
self.uid)
-        self._setDefault(maxCategories=20)
+        self._setDefault(maxCategories=20, handleInvalid="error")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, maxCategories=20, inputCol=None, outputCol=None):
+    def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
         """
-        setParams(self, maxCategories=20, inputCol=None, outputCol=None)
+        setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
         Sets params for this VectorIndexer.
         """
         kwargs = self._input_kwargs


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message