spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5888] [MLLIB] Add OneHotEncoder as a Transformer
Date Tue, 05 May 2015 19:34:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master ee374e89c -> 47728db7c


[SPARK-5888] [MLLIB] Add OneHotEncoder as a Transformer

This patch adds a one hot encoder for categorical features.  Planning to add documentation
and another test after getting feedback on the approach.

A couple choices made here:
* There's an `includeFirst` option which, if false, creates numCategories - 1 columns and,
if true, creates numCategories columns.  The default is true, which is the behavior in scikit-learn.
* The user is expected to pass a `Seq` of category names when instantiating a `OneHotEncoder`.
 These can be easily gotten from a `StringIndexer`.  The names are used for the output column
names, which take the form colName_categoryName.

Author: Sandy Ryza <sandy@cloudera.com>

Closes #5500 from sryza/sandy-spark-5888 and squashes the following commits:

f383250 [Sandy Ryza] Infer label names automatically
6e257b9 [Sandy Ryza] Review comments
7c539cf [Sandy Ryza] Vector transformers
1c182dd [Sandy Ryza] SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/47728db7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/47728db7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/47728db7

Branch: refs/heads/master
Commit: 47728db7cfac995d9417cdf0e16d07391aabd581
Parents: ee374e8
Author: Sandy Ryza <sandy@cloudera.com>
Authored: Tue May 5 12:34:02 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue May 5 12:34:02 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/OneHotEncoder.scala | 107 +++++++++++++++++++
 .../spark/ml/feature/OneHotEncoderSuite.scala   |  80 ++++++++++++++
 2 files changed, 187 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/47728db7/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
new file mode 100644
index 0000000..46514ae
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+/**
+ * A one-hot encoder that maps a column of label indices to a column of binary vectors, with
+ * at most a single one-value. By default, the binary vector has an element for each category,
so
+ * with 5 categories, an input value of 2.0 would map to an output vector of
+ * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted,
so the
+ * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
+ * of 0.0 would map to a vector of all zeros. Including the first category makes the vector
columns
+ * linearly dependent because they sum up to one.
+ */
+@AlphaComponent
+class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
+  with HasInputCol with HasOutputCol {
+
+  /**
+   * Whether to include a component in the encoded vectors for the first category, defaults
to true.
+   * @group param
+   */
+  final val includeFirst: BooleanParam =
+    new BooleanParam(this, "includeFirst", "include first category")
+  setDefault(includeFirst -> true)
+
+  private var categories: Array[String] = _
+
+  /** @group setParam */
+  def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)
+
+  /** @group setParam */
+  override def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  override def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+    val inputFields = schema.fields
+    val outputColName = $(outputCol)
+    require(inputFields.forall(_.name != $(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+
+    val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
+    categories = inputColAttr match {
+      case nominal: NominalAttribute =>
+        nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
+      case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
+      case _ =>
+        throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
+    }
+
+    val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
+    val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
+    val outputFields = inputFields :+ attr.toStructField()
+    StructType(outputFields)
+  }
+
+  protected override def createTransformFunc(): (Double) => Vector = {
+    val first = $(includeFirst)
+    val vecLen = if (first) categories.length else categories.length - 1
+    val oneValue = Array(1.0)
+    val emptyValues = Array[Double]()
+    val emptyIndices = Array[Int]()
+    label: Double => {
+      val values = if (first || label != 0.0) oneValue else emptyValues
+      val indices = if (first) {
+        Array(label.toInt)
+      } else if (label != 0.0) {
+        Array(label.toInt - 1)
+      } else {
+        emptyIndices
+      }
+      Vectors.sparse(vecLen, indices, values)
+    }
+  }
+
+  /**
+   * Returns the data type of the output column.
+   */
+  protected def outputDataType: DataType = new VectorUDT
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/47728db7/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
new file mode 100644
index 0000000..92ec407
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+
+class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
+  private var sqlContext: SQLContext = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+  }
+
+  def stringIndexed(): DataFrame = {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")),
2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    indexer.transform(df)
+  }
+
+  test("OneHotEncoder includeFirst = true") {
+    val transformed = stringIndexed()
+    val encoder = new OneHotEncoder()
+      .setInputCol("labelIndex")
+      .setOutputCol("labelVec")
+    val encoded = encoder.transform(transformed)
+
+    val output = encoded.select("id", "labelVec").map { r =>
+      val vec = r.get(1).asInstanceOf[Vector]
+      (r.getInt(0), vec(0), vec(1), vec(2))
+    }.collect().toSet
+    // a -> 0, b -> 2, c -> 1
+    val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
+      (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
+    assert(output === expected)
+  }
+
+  test("OneHotEncoder includeFirst = false") {
+    val transformed = stringIndexed()
+    val encoder = new OneHotEncoder()
+      .setIncludeFirst(false)
+      .setInputCol("labelIndex")
+      .setOutputCol("labelVec")
+    val encoded = encoder.transform(transformed)
+
+    val output = encoded.select("id", "labelVec").map { r =>
+      val vec = r.get(1).asInstanceOf[Vector]
+      (r.getInt(0), vec(0), vec(1))
+    }.collect().toSet
+    // a -> 0, b -> 2, c -> 1
+    val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
+      (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
+    assert(output === expected)
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message