spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5885][MLLIB] Add VectorAssembler as a feature transformer
Date Mon, 13 Apr 2015 05:42:04 GMT
Repository: spark
Updated Branches:
  refs/heads/master 685ddcf52 -> 929404498


[SPARK-5885][MLLIB] Add VectorAssembler as a feature transformer

VectorAssembler merges multiple columns into a vector column. This PR contains content from
#5195.

~~carry ML attributes~~ (moved to a follow-up PR)

Author: Xiangrui Meng <meng@databricks.com>

Closes #5196 from mengxr/SPARK-5885 and squashes the following commits:

a52b101 [Xiangrui Meng] recognize more types
35daac2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5885
bb5e64b [Xiangrui Meng] add TODO for null
976a3d6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5885
0859311 [Xiangrui Meng] Revert "add CreateStruct"
29fb6ac [Xiangrui Meng] use CreateStruct
adb71c4 [Xiangrui Meng] Merge branch 'SPARK-6542' into SPARK-5885
85f3106 [Xiangrui Meng] add CreateStruct
4ff16ce [Xiangrui Meng] add VectorAssembler


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92940449
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92940449
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92940449

Branch: refs/heads/master
Commit: 929404498506c34180e2eaaa1a4d4a3c4ed51daa
Parents: 685ddcf
Author: Xiangrui Meng <meng@databricks.com>
Authored: Sun Apr 12 22:42:01 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sun Apr 12 22:42:01 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/Identifiable.scala      |   2 +-
 .../spark/ml/feature/VectorAssembler.scala      | 111 +++++++++++++++++++
 .../apache/spark/ml/param/sharedParams.scala    |  10 ++
 .../spark/ml/feature/VectorAssemblerSuite.scala |  63 +++++++++++
 4 files changed, 185 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92940449/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
index cd84b05..a500906 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
@@ -29,5 +29,5 @@ private[ml] trait Identifiable extends Serializable {
    * random hex chars.
    */
   private[ml] val uid: String =
-    this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
+    this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/92940449/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
new file mode 100644
index 0000000..d1b8f7e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.{Column, DataFrame, Row}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * :: AlphaComponent ::
+ * A feature transformer than merge multiple columns into a vector column.
+ */
+@AlphaComponent
+class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
+
+  /** @group setParam */
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    val map = this.paramMap ++ paramMap
+    val assembleFunc = udf { r: Row =>
+      VectorAssembler.assemble(r.toSeq: _*)
+    }
+    val schema = dataset.schema
+    val inputColNames = map(inputCols)
+    val args = inputColNames.map { c =>
+      schema(c).dataType match {
+        case DoubleType => UnresolvedAttribute(c)
+        case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
+        case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
+      }
+    }
+    dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
+  }
+
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+    val map = this.paramMap ++ paramMap
+    val inputColNames = map(inputCols)
+    val outputColName = map(outputCol)
+    val inputDataTypes = inputColNames.map(name => schema(name).dataType)
+    inputDataTypes.foreach {
+      case _: NativeType =>
+      case t if t.isInstanceOf[VectorUDT] =>
+      case other =>
+        throw new IllegalArgumentException(s"Data type $other is not supported.")
+    }
+    if (schema.fieldNames.contains(outputColName)) {
+      throw new IllegalArgumentException(s"Output column $outputColName already exists.")
+    }
+    StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
+  }
+}
+
+@AlphaComponent
+object VectorAssembler {
+
+  private[feature] def assemble(vv: Any*): Vector = {
+    val indices = ArrayBuilder.make[Int]
+    val values = ArrayBuilder.make[Double]
+    var cur = 0
+    vv.foreach {
+      case v: Double =>
+        if (v != 0.0) {
+          indices += cur
+          values += v
+        }
+        cur += 1
+      case vec: Vector =>
+        vec.foreachActive { case (i, v) =>
+          if (v != 0.0) {
+            indices += cur + i
+            values += v
+          }
+        }
+        cur += vec.size
+      case null =>
+        // TODO: output Double.NaN?
+        throw new SparkException("Values to assemble cannot be null.")
+      case o =>
+        throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
+    }
+    Vectors.sparse(cur, indices.result(), values.result())
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/92940449/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 0739fdb..07e6eb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -140,6 +140,16 @@ private[ml] trait HasInputCol extends Params {
   def getInputCol: String = get(inputCol)
 }
 
+private[ml] trait HasInputCols extends Params {
+  /**
+   * Param for input column names.
+   */
+  val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names")
+
+  /** @group getParam */
+  def getInputCols: Array[String] = get(inputCols)
+}
+
 private[ml] trait HasOutputCol extends Params {
   /**
    * param for output column name

http://git-wip-us.apache.org/repos/asf/spark/blob/92940449/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
new file mode 100644
index 0000000..57d0278
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkException
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Row, SQLContext}
+
+class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
+
+  @transient var sqlContext: SQLContext = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+  }
+
+  test("assemble") {
+    import org.apache.spark.ml.feature.VectorAssembler.assemble
+    assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
+    assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
+    val dv = Vectors.dense(2.0, 0.0)
+    assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
+    val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
+    assert(assemble(0.0, dv, 1.0, sv) ===
+      Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
+    for (v <- Seq(1, "a", null)) {
+      intercept[SparkException](assemble(v))
+      intercept[SparkException](assemble(1.0, v))
+    }
+  }
+
+  test("VectorAssembler") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
+    )).toDF("id", "x", "y", "name", "z", "n")
+    val assembler = new VectorAssembler()
+      .setInputCols(Array("x", "y", "z", "n"))
+      .setOutputCol("features")
+    assembler.transform(df).select("features").collect().foreach {
+      case Row(v: Vector) =>
+        assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message