spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5895] [ML] Add VectorSlicer - updated
Date Thu, 06 Aug 2015 00:08:12 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 618dc63e7 -> 3b617e87c


[SPARK-5895] [ML] Add VectorSlicer - updated

Add VectorSlicer transformer to spark.ml, with features specified as either indices or names.
 Transfers feature attributes for selected features.

Updated version of [https://github.com/apache/spark/pull/5731]

CC: yinxusen This updates your PR.  You'll still be the primary author of this PR.

CC: mengxr

Author: Xusen Yin <yinxusen@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7972 from jkbradley/yinxusen-SPARK-5895 and squashes the following commits:

b16e86e [Joseph K. Bradley] fixed scala style
71c65d2 [Joseph K. Bradley] fix import order
86e9739 [Joseph K. Bradley] cleanups per code review
9d8d6f1 [Joseph K. Bradley] style fix
83bc2e9 [Joseph K. Bradley] Updated VectorSlicer
98c6939 [Xusen Yin] fix style error
ecbf2d3 [Xusen Yin] change interfaces and params
f6be302 [Xusen Yin] Merge branch 'master' into SPARK-5895
e4781f2 [Xusen Yin] fix commit error
fd154d7 [Xusen Yin] add test suite of vector slicer
17171f8 [Xusen Yin] fix slicer
9ab9747 [Xusen Yin] add vector slicer
aa5a0bf [Xusen Yin] add vector slicer

(cherry picked from commit a018b85716fd510ae95a3c66d676bbdb90f8d4e7)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3b617e87
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3b617e87
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3b617e87

Branch: refs/heads/branch-1.5
Commit: 3b617e87cc8524a86a9d5c4a9971520b91119736
Parents: 618dc63
Author: Xusen Yin <yinxusen@gmail.com>
Authored: Wed Aug 5 17:07:55 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Aug 5 17:08:04 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/VectorSlicer.scala  | 170 +++++++++++++++++++
 .../apache/spark/ml/util/MetadataUtils.scala    |  17 ++
 .../org/apache/spark/mllib/linalg/Vectors.scala |  24 +++
 .../spark/ml/feature/VectorSlicerSuite.scala    | 109 ++++++++++++
 .../spark/mllib/linalg/VectorsSuite.scala       |   7 +
 5 files changed, 327 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3b617e87/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
new file mode 100644
index 0000000..772bebe
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ * This class takes a feature vector and outputs a new feature vector with a subarray of
the
+ * original features.
+ *
+ * The subset of features can be specified with either indices ([[setIndices()]])
+ * or names ([[setNames()]]).  At least one feature must be selected. Duplicate features
+ * are not allowed, so there can be no overlap between selected indices and names.
+ *
+ * The output vector will order features with the selected indices first (in the order given),
+ * followed by the selected names (in the order given).
+ */
+@Experimental
+final class VectorSlicer(override val uid: String)
+  extends Transformer with HasInputCol with HasOutputCol {
+
+  def this() = this(Identifiable.randomUID("vectorSlicer"))
+
+  /**
+   * An array of indices to select features from a vector column.
+   * There can be no overlap with [[names]].
+   * @group param
+   */
+  val indices = new IntArrayParam(this, "indices",
+    "An array of indices to select features from a vector column." +
+      " There can be no overlap with names.", VectorSlicer.validIndices)
+
+  setDefault(indices -> Array.empty[Int])
+
+  /** @group getParam */
+  def getIndices: Array[Int] = $(indices)
+
+  /** @group setParam */
+  def setIndices(value: Array[Int]): this.type = set(indices, value)
+
+  /**
+   * An array of feature names to select features from a vector column.
+   * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
+   * There can be no overlap with [[indices]].
+   * @group param
+   */
+  val names = new StringArrayParam(this, "names",
+    "An array of feature names to select features from a vector column." +
+      " There can be no overlap with indices.", VectorSlicer.validNames)
+
+  setDefault(names -> Array.empty[String])
+
+  /** @group getParam */
+  def getNames: Array[String] = $(names)
+
+  /** @group setParam */
+  def setNames(value: Array[String]): this.type = set(names, value)
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  override def validateParams(): Unit = {
+    require($(indices).length > 0 || $(names).length > 0,
+      s"VectorSlicer requires that at least one feature be selected.")
+  }
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    // Validity checks
+    transformSchema(dataset.schema)
+    val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
+    inputAttr.numAttributes.foreach { numFeatures =>
+      val maxIndex = $(indices).max
+      require(maxIndex < numFeatures,
+        s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
+    }
+
+    // Prepare output attributes
+    val inds = getSelectedFeatureIndices(dataset.schema)
+    val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
+      inds.map(index => attrs(index))
+    }
+    val outputAttr = selectedAttrs match {
+      case Some(attrs) => new AttributeGroup($(outputCol), attrs)
+      case None => new AttributeGroup($(outputCol), inds.length)
+    }
+
+    // Select features
+    val slicer = udf { vec: Vector =>
+      vec match {
+        case features: DenseVector => Vectors.dense(inds.map(features.apply))
+        case features: SparseVector => features.slice(inds)
+      }
+    }
+    dataset.withColumn($(outputCol),
+      slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
+  }
+
+  /** Get the feature indices in order: indices, names */
+  private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
+    val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
+    val indFeatures = $(indices)
+    val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
+    lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
+      s" sets of features, but they overlap." +
+      s" indices: ${indFeatures.mkString("[", ",", "]")}." +
+      s" names: " +
+      nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
+    require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
+    indFeatures ++ nameFeatures
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+
+    if (schema.fieldNames.contains($(outputCol))) {
+      throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
+    }
+    val numFeaturesSelected = $(indices).length + $(names).length
+    val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
+    val outputFields = schema.fields :+ outputAttr.toStructField()
+    StructType(outputFields)
+  }
+
+  override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
+}
+
+private[feature] object VectorSlicer {
+
+  /** Return true if given feature indices are valid */
+  def validIndices(indices: Array[Int]): Boolean = {
+    if (indices.isEmpty) {
+      true
+    } else {
+      indices.length == indices.distinct.length && indices.forall(_ >= 0)
+    }
+  }
+
+  /** Return true if given feature names are valid */
+  def validNames(names: Array[String]): Boolean = {
+    names.forall(_.nonEmpty) && names.length == names.distinct.length
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3b617e87/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index 2a1db90..fcb517b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.util
 import scala.collection.immutable.HashMap
 
 import org.apache.spark.ml.attribute._
+import org.apache.spark.mllib.linalg.VectorUDT
 import org.apache.spark.sql.types.StructField
 
 
@@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
     }
   }
 
+  /**
+   * Takes a Vector column and a list of feature names, and returns the corresponding list
of
+   * feature indices in the column, in order.
+   * @param col  Vector column which must have feature names specified via attributes
+   * @param names  List of feature names
+   */
+  def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
+    require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column
$col"
+      + s" to be Vector type, but it was type ${col.dataType} instead.")
+    val inputAttr = AttributeGroup.fromStructField(col)
+    names.map { name =>
+      require(inputAttr.hasAttr(name),
+        s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
+      inputAttr.getAttr(name).index.get
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3b617e87/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 96d1f48..86c461f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -766,6 +766,30 @@ class SparseVector(
       maxIdx
     }
   }
+
+  /**
+   * Create a slice of this vector based on the given indices.
+   * @param selectedIndices Unsorted list of indices into the vector.
+   *                        This does NOT do bound checking.
+   * @return  New SparseVector with values in the order specified by the given indices.
+   *
+   * NOTE: The API needs to be discussed before making this public.
+   *       Also, if we have a version assuming indices are sorted, we should optimize it.
+   */
+  private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
+    var currentIdx = 0
+    val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
+      val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
+      val i_v = if (iIdx >= 0) {
+        Iterator((currentIdx, this.values(iIdx)))
+      } else {
+        Iterator()
+      }
+      currentIdx += 1
+      i_v
+    }.unzip
+    new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
+  }
 }
 
 object SparseVector {

http://git-wip-us.apache.org/repos/asf/spark/blob/3b617e87/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
new file mode 100644
index 0000000..a6c2fba
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    val slicer = new VectorSlicer
+    ParamsSuite.checkParams(slicer)
+    assert(slicer.getIndices.length === 0)
+    assert(slicer.getNames.length === 0)
+    withClue("VectorSlicer should not have any features selected by default") {
+      intercept[IllegalArgumentException] {
+        slicer.validateParams()
+      }
+    }
+  }
+
+  test("feature validity checks") {
+    import VectorSlicer._
+    assert(validIndices(Array(0, 1, 8, 2)))
+    assert(validIndices(Array.empty[Int]))
+    assert(!validIndices(Array(-1)))
+    assert(!validIndices(Array(1, 2, 1)))
+
+    assert(validNames(Array("a", "b")))
+    assert(validNames(Array.empty[String]))
+    assert(!validNames(Array("", "b")))
+    assert(!validNames(Array("a", "b", "a")))
+  }
+
+  test("Test vector slicer") {
+    val sqlContext = new SQLContext(sc)
+
+    val data = Array(
+      Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
+      Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
+      Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
+      Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
+      Vectors.sparse(5, Seq())
+    )
+
+    // Expected after selecting indices 1, 4
+    val expected = Array(
+      Vectors.sparse(2, Seq((0, 2.3))),
+      Vectors.dense(2.3, 1.0),
+      Vectors.dense(0.0, 0.0),
+      Vectors.dense(-1.1, 3.3),
+      Vectors.sparse(2, Seq())
+    )
+
+    val defaultAttr = NumericAttribute.defaultAttr
+    val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
+    val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
+
+    val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
+    val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
+
+    val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
+    val df = sqlContext.createDataFrame(rdd,
+      StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
+
+    val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
+
+    def validateResults(df: DataFrame): Unit = {
+      df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector)
=>
+        assert(vec1 === vec2)
+      }
+      val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
+      val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
+      assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
+      resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a,
b) =>
+        assert(a === b)
+      }
+    }
+
+    vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
+    validateResults(vectorSlicer.transform(df))
+
+    vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
+    validateResults(vectorSlicer.transform(df))
+
+    vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
+    validateResults(vectorSlicer.transform(df))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3b617e87/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 1c37ea5..6508dde 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging {
     val sv1c = sv1.compressed.asInstanceOf[DenseVector]
     assert(sv1 === sv1c)
   }
+
+  test("SparseVector.slice") {
+    val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
+    assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
+    assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
+    assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message