spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5886][ML] Add StringIndexer as a feature transformer
Date Mon, 13 Apr 2015 05:41:07 GMT
Repository: spark
Updated Branches:
  refs/heads/master d3792f549 -> 685ddcf52


[SPARK-5886][ML] Add StringIndexer as a feature transformer

This PR adds string indexer, which takes a column of string labels and outputs a double column
with labels indexed by their frequency.

TODOs:
- [x] store feature to index map in output metadata

Author: Xiangrui Meng <meng@databricks.com>

Closes #4735 from mengxr/SPARK-5886 and squashes the following commits:

d82575f [Xiangrui Meng] fix test
700e70f [Xiangrui Meng] rename LabelIndexer to StringIndexer
16a6f8c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886
457166e [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886
f8b30f4 [Xiangrui Meng] update label indexer to output metadata
e81ec28 [Xiangrui Meng] Merge branch 'openhashmap-contains' into SPARK-5886-2
d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap
748a69b [Xiangrui Meng] add contains to OpenHashMap
def3c5c [Xiangrui Meng] add LabelIndexer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/685ddcf5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/685ddcf5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/685ddcf5

Branch: refs/heads/master
Commit: 685ddcf5253c0ecb39853802431e22b0c7b61dee
Parents: d3792f5
Author: Xiangrui Meng <meng@databricks.com>
Authored: Sun Apr 12 22:41:05 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sun Apr 12 22:41:05 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala | 126 +++++++++++++++++++
 .../spark/ml/feature/StringIndexerSuite.scala   |  52 ++++++++
 2 files changed, 178 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/685ddcf5/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
new file mode 100644
index 0000000..61e6742
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.param._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Base trait for [[StringIndexer]] and [[StringIndexerModel]].
+ */
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
{
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+    val map = this.paramMap ++ paramMap
+    checkInputColumn(schema, map(inputCol), StringType)
+    val inputFields = schema.fields
+    val outputColName = map(outputCol)
+    require(inputFields.forall(_.name != outputColName),
+      s"Output column $outputColName already exists.")
+    val attr = NominalAttribute.defaultAttr.withName(map(outputCol))
+    val outputFields = inputFields :+ attr.toStructField()
+    StructType(outputFields)
+  }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A label indexer that maps a string column of labels to an ML column of label indices.
+ * The indices are in [0, numLabels), ordered by label frequencies.
+ * So the most frequent label gets index 0.
+ */
+@AlphaComponent
+class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  // TODO: handle unseen labels
+
+  override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
+    val map = this.paramMap ++ paramMap
+    val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
+    val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
+    val model = new StringIndexerModel(this, map, labels)
+    Params.inheritValues(map, this, model)
+    model
+  }
+
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+    validateAndTransformSchema(schema, paramMap)
+  }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[StringIndexer]].
+ */
+@AlphaComponent
+class StringIndexerModel private[ml] (
+    override val parent: StringIndexer,
+    override val fittingParamMap: ParamMap,
+    labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
+
+  private val labelToIndex: OpenHashMap[String, Double] = {
+    val n = labels.length
+    val map = new OpenHashMap[String, Double](n)
+    var i = 0
+    while (i < n) {
+      map.update(labels(i), i)
+      i += 1
+    }
+    map
+  }
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    val map = this.paramMap ++ paramMap
+    val indexer = udf { label: String =>
+      if (labelToIndex.contains(label)) {
+        labelToIndex(label)
+      } else {
+        // TODO: handle unseen labels
+        throw new SparkException(s"Unseen label: $label.")
+      }
+    }
+    val outputColName = map(outputCol)
+    val metadata = NominalAttribute.defaultAttr
+      .withName(outputColName).withValues(labels).toStructField().metadata
+    dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
+  }
+
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+    validateAndTransformSchema(schema, paramMap)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/685ddcf5/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
new file mode 100644
index 0000000..00b5d09
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.SQLContext
+
+class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
+  private var sqlContext: SQLContext = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+  }
+
+  test("StringIndexer") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")),
2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attr.values.get === Array("a", "c", "b"))
+    val output = transformed.select("id", "labelIndex").map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // a -> 0, b -> 2, c -> 1
+    val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
+    assert(output === expected)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message