spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-12375][ML] VectorIndexerModel support handle unseen categories via handleInvalid
Date Wed, 15 Nov 2017 00:58:23 GMT
Repository: spark
Updated Branches:
  refs/heads/master 774398045 -> 1e6f76059


[SPARK-12375][ML] VectorIndexerModel support handle unseen categories via handleInvalid

## What changes were proposed in this pull request?

Support skip/error/keep strategy, similar to `StringIndexer`.
Implemented via `try...catch`, so that it can avoid possible performance impact.

## How was this patch tested?

Unit test added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19588 from WeichenXu123/handle_invalid_for_vector_indexer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1e6f7605
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1e6f7605
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1e6f7605

Branch: refs/heads/master
Commit: 1e6f760593d81def059c514d34173bf2777d71ec
Parents: 7743980
Author: WeichenXu <weichen.xu@databricks.com>
Authored: Tue Nov 14 16:58:18 2017 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Tue Nov 14 16:58:18 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/feature/VectorIndexer.scala | 92 +++++++++++++++++---
 .../spark/ml/feature/VectorIndexerSuite.scala   | 39 +++++++++
 2 files changed, 121 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1e6f7605/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index d371da7..3403ec4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -18,12 +18,13 @@
 package org.apache.spark.ml.feature
 
 import java.lang.{Double => JDouble, Integer => JInt}
-import java.util.{Map => JMap}
+import java.util.{Map => JMap, NoSuchElementException}
 
 import scala.collection.JavaConverters._
 
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.SparkException
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.attribute._
@@ -37,7 +38,27 @@ import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.util.collection.OpenHashSet
 
 /** Private trait for params for VectorIndexer and VectorIndexerModel */
-private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol {
+private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol
+  with HasHandleInvalid {
+
+  /**
+   * Param for how to handle invalid data (unseen labels or NULL values).
+   * Note: this param only applies to categorical features, not continuous ones.
+   * Options are:
+   * 'skip': filter out rows with invalid data.
+   * 'error': throw an error.
+   * 'keep': put invalid data in a special additional bucket, at index numCategories.
+   * Default value: "error"
+   * @group param
+   */
+  @Since("2.3.0")
+  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
+    "How to handle invalid data (unseen labels or NULL values). " +
+    "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), "
+
+    "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
+    ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))
+
+  setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)
 
   /**
    * Threshold for the number of values a categorical feature can take.
@@ -113,6 +134,10 @@ class VectorIndexer @Since("1.4.0") (
   @Since("1.4.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): VectorIndexerModel = {
     transformSchema(dataset.schema, logging = true)
@@ -148,6 +173,11 @@ class VectorIndexer @Since("1.4.0") (
 
 @Since("1.6.0")
 object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {
+  private[feature] val SKIP_INVALID: String = "skip"
+  private[feature] val ERROR_INVALID: String = "error"
+  private[feature] val KEEP_INVALID: String = "keep"
+  private[feature] val supportedHandleInvalids: Array[String] =
+    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
 
   @Since("1.6.0")
   override def load(path: String): VectorIndexer = super.load(path)
@@ -287,9 +317,15 @@ class VectorIndexerModel private[ml] (
     while (featureIndex < numFeatures) {
       if (categoryMaps.contains(featureIndex)) {
         // categorical feature
-        val featureValues: Array[String] =
+        val rawFeatureValues: Array[String] =
           categoryMaps(featureIndex).toArray.sortBy(_._1).map(_._1).map(_.toString)
-        if (featureValues.length == 2) {
+
+        val featureValues = if (getHandleInvalid == VectorIndexer.KEEP_INVALID) {
+          (rawFeatureValues.toList :+ "__unknown").toArray
+        } else {
+          rawFeatureValues
+        }
+        if (featureValues.length == 2 && getHandleInvalid != VectorIndexer.KEEP_INVALID)
{
           attrs(featureIndex) = new BinaryAttribute(index = Some(featureIndex),
             values = Some(featureValues))
         } else {
@@ -311,22 +347,39 @@ class VectorIndexerModel private[ml] (
   // TODO: Check more carefully about whether this whole class will be included in a closure.
 
   /** Per-vector transform function */
-  private val transformFunc: Vector => Vector = {
+  private lazy val transformFunc: Vector => Vector = {
     val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted
     val localVectorMap = categoryMaps
     val localNumFeatures = numFeatures
+    val localHandleInvalid = getHandleInvalid
     val f: Vector => Vector = { (v: Vector) =>
       assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" +
         s" $numFeatures but found length ${v.size}")
       v match {
         case dv: DenseVector =>
+          var hasInvalid = false
           val tmpv = dv.copy
           localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int])
=>
-            tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
+            try {
+              tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex))
+            } catch {
+              case _: NoSuchElementException =>
+                localHandleInvalid match {
+                  case VectorIndexer.ERROR_INVALID =>
+                    throw new SparkException(s"VectorIndexer encountered invalid value "
+
+                      s"${tmpv(featureIndex)} on feature index ${featureIndex}. To handle
" +
+                      s"or skip invalid value, try setting VectorIndexer.handleInvalid.")
+                  case VectorIndexer.KEEP_INVALID =>
+                    tmpv.values(featureIndex) = categoryMap.size
+                  case VectorIndexer.SKIP_INVALID =>
+                    hasInvalid = true
+                }
+            }
           }
-          tmpv
+          if (hasInvalid) null else tmpv
         case sv: SparseVector =>
           // We use the fact that categorical value 0 is always mapped to index 0.
+          var hasInvalid = false
           val tmpv = sv.copy
           var catFeatureIdx = 0 // index into sortedCatFeatureIndices
           var k = 0 // index into non-zero elements of sparse vector
@@ -337,12 +390,26 @@ class VectorIndexerModel private[ml] (
             } else if (featureIndex > tmpv.indices(k)) {
               k += 1
             } else {
-              tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+              try {
+                tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k))
+              } catch {
+                case _: NoSuchElementException =>
+                  localHandleInvalid match {
+                    case VectorIndexer.ERROR_INVALID =>
+                      throw new SparkException(s"VectorIndexer encountered invalid value
" +
+                        s"${tmpv.values(k)} on feature index ${featureIndex}. To handle "
+
+                        s"or skip invalid value, try setting VectorIndexer.handleInvalid.")
+                    case VectorIndexer.KEEP_INVALID =>
+                      tmpv.values(k) = localVectorMap(featureIndex).size
+                    case VectorIndexer.SKIP_INVALID =>
+                      hasInvalid = true
+                  }
+              }
               catFeatureIdx += 1
               k += 1
             }
           }
-          tmpv
+          if (hasInvalid) null else tmpv
       }
     }
     f
@@ -362,7 +429,12 @@ class VectorIndexerModel private[ml] (
     val newField = prepOutputField(dataset.schema)
     val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
     val newCol = transformUDF(dataset($(inputCol)))
-    dataset.withColumn($(outputCol), newCol, newField.metadata)
+    val ds = dataset.withColumn($(outputCol), newCol, newField.metadata)
+    if (getHandleInvalid == VectorIndexer.SKIP_INVALID) {
+      ds.na.drop(Array($(outputCol)))
+    } else {
+      ds
+    }
   }
 
   @Since("1.4.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/1e6f7605/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index f2cca8a..69a7b75 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -38,6 +38,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
   // identical, of length 3
   @transient var densePoints1: DataFrame = _
   @transient var sparsePoints1: DataFrame = _
+  @transient var densePoints1TestInvalid: DataFrame = _
+  @transient var sparsePoints1TestInvalid: DataFrame = _
   @transient var point1maxes: Array[Double] = _
 
   // identical, of length 2
@@ -55,11 +57,19 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
       Vectors.dense(0.0, 1.0, 2.0),
       Vectors.dense(0.0, 0.0, -1.0),
       Vectors.dense(1.0, 3.0, 2.0))
+    val densePoints1SeqTestInvalid = densePoints1Seq ++ Seq(
+      Vectors.dense(10.0, 2.0, 0.0),
+      Vectors.dense(0.0, 10.0, 2.0),
+      Vectors.dense(1.0, 3.0, 10.0))
     val sparsePoints1Seq = Seq(
       Vectors.sparse(3, Array(0, 1), Array(1.0, 2.0)),
       Vectors.sparse(3, Array(1, 2), Array(1.0, 2.0)),
       Vectors.sparse(3, Array(2), Array(-1.0)),
       Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 2.0)))
+    val sparsePoints1SeqTestInvalid = sparsePoints1Seq ++ Seq(
+      Vectors.sparse(3, Array(0, 1), Array(10.0, 2.0)),
+      Vectors.sparse(3, Array(1, 2), Array(10.0, 2.0)),
+      Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 3.0, 10.0)))
     point1maxes = Array(1.0, 3.0, 2.0)
 
     val densePoints2Seq = Seq(
@@ -88,6 +98,8 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
 
     densePoints1 = densePoints1Seq.map(FeatureData).toDF()
     sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF()
+    densePoints1TestInvalid = densePoints1SeqTestInvalid.map(FeatureData).toDF()
+    sparsePoints1TestInvalid = sparsePoints1SeqTestInvalid.map(FeatureData).toDF()
     densePoints2 = densePoints2Seq.map(FeatureData).toDF()
     sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF()
     badPoints = badPointsSeq.map(FeatureData).toDF()
@@ -219,6 +231,33 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
     checkCategoryMaps(densePoints2, maxCategories = 2, categoricalFeatures = Set(1, 3))
   }
 
+  test("handle invalid") {
+    for ((points, pointsTestInvalid) <- Seq((densePoints1, densePoints1TestInvalid),
+      (sparsePoints1, sparsePoints1TestInvalid))) {
+      val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error")
+      val model = vectorIndexer.fit(points)
+      intercept[SparkException] {
+        model.transform(pointsTestInvalid).collect()
+      }
+      val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip")
+      val model1 = vectorIndexer1.fit(points)
+      val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed")
+        .collect().map(_(0))
+      val transformed1 = model1.transform(points).select("indexed").collect().map(_(0))
+      assert(transformed1 === invalidTransformed1)
+
+      val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep")
+      val model2 = vectorIndexer2.fit(points)
+      val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed")
+        .collect().map(_(0))
+      assert(invalidTransformed2 === transformed1 ++ Array(
+        Vectors.dense(2.0, 2.0, 0.0),
+        Vectors.dense(0.0, 4.0, 2.0),
+        Vectors.dense(1.0, 3.0, 3.0))
+      )
+    }
+  }
+
   test("Maintain sparsity for sparse vectors") {
     def checkSparsity(data: DataFrame, maxCategories: Int): Unit = {
       val points = data.collect().map(_.getAs[Vector](0))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message