spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From felixche...@apache.org
Subject spark git commit: [SPARK-20619][ML] StringIndexer supports multiple ways to order label
Date Fri, 12 May 2017 07:12:51 GMT
Repository: spark
Updated Branches:
  refs/heads/master 888b84abe -> af40bb115


[SPARK-20619][ML] StringIndexer supports multiple ways to order label

## What changes were proposed in this pull request?

StringIndexer maps labels to numbers according to the descending order of label frequency.
Other types of ordering (e.g., alphabetical) may be needed in feature ETL.  For example, the
ordering will affect the result in one-hot encoding and RFormula.

This PR proposes to support other ordering methods and we add a parameter `stringOrderType`
that supports the following four options:
- 'frequencyDesc': descending order by label frequency (most frequent label assigned 0)
- 'frequencyAsc': ascending order by label frequency (least frequent label assigned 0)
- 'alphabetDesc': descending alphabetical order
- 'alphabetAsc': ascending alphabetical order

The default is still descending order of label frequency, so there should be no impact to
existing programs.

## How was this patch tested?
new test

Author: Wayne Zhang <actuaryzhang@uber.com>

Closes #17879 from actuaryzhang/stringIndexer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af40bb11
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af40bb11
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af40bb11

Branch: refs/heads/master
Commit: af40bb1159b1f443bf44594c716d2f2dd3c98640
Parents: 888b84a
Author: Wayne Zhang <actuaryzhang@uber.com>
Authored: Fri May 12 00:12:47 2017 -0700
Committer: Felix Cheung <felixcheung@apache.org>
Committed: Fri May 12 00:12:47 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala | 55 +++++++++++++++++---
 .../spark/ml/feature/StringIndexerSuite.scala   | 23 ++++++++
 2 files changed, 71 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af40bb11/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 99321bc..b2dc4fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -59,6 +59,29 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol
with Ha
   @Since("1.6.0")
   def getHandleInvalid: String = $(handleInvalid)
 
+  /**
+   * Param for how to order labels of string column. The first label after ordering is assigned
+   * an index of 0.
+   * Options are:
+   *   - 'frequencyDesc': descending order by label frequency (most frequent label assigned
0)
+   *   - 'frequencyAsc': ascending order by label frequency (least frequent label assigned
0)
+   *   - 'alphabetDesc': descending alphabetical order
+   *   - 'alphabetAsc': ascending alphabetical order
+   * Default is 'frequencyDesc'.
+   *
+   * @group param
+   */
+  @Since("2.3.0")
+  final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
+    "how to order labels of string column. " +
+    "The first label after ordering is assigned an index of 0. " +
+    s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
+    ParamValidators.inArray(StringIndexer.supportedStringOrderType))
+
+  /** @group getParam */
+  @Since("2.3.0")
+  def getStringOrderType: String = $(stringOrderType)
+
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
     val inputColName = $(inputCol)
@@ -79,8 +102,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol
with Ha
 /**
  * A label indexer that maps a string column of labels to an ML column of label indices.
  * If the input column is numeric, we cast it to string and index the string values.
- * The indices are in [0, numLabels), ordered by label frequencies.
- * So the most frequent label gets index 0.
+ * The indices are in [0, numLabels). By default, this is ordered by label frequencies
+ * so the most frequent label gets index 0. The ordering behavior is controlled by
+ * setting `stringOrderType`.
  *
  * @see `IndexToString` for the inverse transformation
  */
@@ -97,6 +121,11 @@ class StringIndexer @Since("1.4.0") (
   def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
 
   /** @group setParam */
+  @Since("2.3.0")
+  def setStringOrderType(value: String): this.type = set(stringOrderType, value)
+  setDefault(stringOrderType, StringIndexer.frequencyDesc)
+
+  /** @group setParam */
   @Since("1.4.0")
   def setInputCol(value: String): this.type = set(inputCol, value)
 
@@ -107,11 +136,17 @@ class StringIndexer @Since("1.4.0") (
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): StringIndexerModel = {
     transformSchema(dataset.schema, logging = true)
-    val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
-      .rdd
-      .map(_.getString(0))
-      .countByValue()
-    val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
+    val values = dataset.na.drop(Array($(inputCol)))
+      .select(col($(inputCol)).cast(StringType))
+      .rdd.map(_.getString(0))
+    val labels = $(stringOrderType) match {
+      case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2)
+        .map(_._1).toArray
+      case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2)
+        .map(_._1).toArray
+      case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _)
+      case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _)
+    }
     copyValues(new StringIndexerModel(uid, labels).setParent(this))
   }
 
@@ -131,6 +166,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
   private[feature] val KEEP_INVALID: String = "keep"
   private[feature] val supportedHandleInvalids: Array[String] =
     Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+  private[feature] val frequencyDesc: String = "frequencyDesc"
+  private[feature] val frequencyAsc: String = "frequencyAsc"
+  private[feature] val alphabetDesc: String = "alphabetDesc"
+  private[feature] val alphabetAsc: String = "alphabetAsc"
+  private[feature] val supportedStringOrderType: Array[String] =
+    Array(frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc)
 
   @Since("1.6.0")
   override def load(path: String): StringIndexer = super.load(path)

http://git-wip-us.apache.org/repos/asf/spark/blob/af40bb11/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 5634d42..806a927 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -291,4 +291,27 @@ class StringIndexerSuite
       NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName =
true)
     assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
   }
+
+  test("StringIndexer order types") {
+    val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b"))
+    val df = data.toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+
+    val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)),
+      Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
+      Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
+      Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))
+
+    var idx = 0
+    for (orderType <- StringIndexer.supportedStringOrderType) {
+      val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df)
+      val output = transformed.select("id", "labelIndex").rdd.map { r =>
+        (r.getInt(0), r.getDouble(1))
+      }.collect().toSet
+      assert(output === expected(idx))
+      idx += 1
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message