spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-9893] User guide with Java test suite for VectorSlicer
Date Fri, 21 Aug 2015 23:30:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master f01c4220d -> 630a994e6


[SPARK-9893] User guide with Java test suite for VectorSlicer

Add user guide for `VectorSlicer`, with Java test suite and Python version VectorSlicer.

Note that Python version does not support selecting by names now.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #8267 from yinxusen/SPARK-9893.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/630a994e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/630a994e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/630a994e

Branch: refs/heads/master
Commit: 630a994e6a9785d1704f8e7fb604f32f5dea24f8
Parents: f01c422
Author: Xusen Yin <yinxusen@gmail.com>
Authored: Fri Aug 21 16:30:12 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Aug 21 16:30:12 2015 -0700

----------------------------------------------------------------------
 docs/ml-features.md                             | 133 +++++++++++++++++++
 .../spark/ml/feature/JavaVectorSlicerSuite.java |  85 ++++++++++++
 2 files changed, 218 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/630a994e/docs/ml-features.md
----------------------------------------------------------------------
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 6309db9..642a4b4 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -1477,6 +1477,139 @@ print(output.select("features", "clicked").first())
 </div>
 </div>
 
+# Feature Selectors
+
+## VectorSlicer
+
+`VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector
with a
+sub-array of the original features. It is useful for extracting features from a vector column.
+
+`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector
column
+whose values are selected via those indices. There are two types of indices, 
+
+ 1. Integer indices that represents the indices into the vector, `setIndices()`;
+
+ 2. String indices that represents the names of features into the vector, `setNames()`. 
+ *This requires the vector column to have an `AttributeGroup` since the implementation matches
on
+ the name field of an `Attribute`.*
+
+Specification by integer and string are both acceptable. Moreover, you can use integer index
and 
+string name simultaneously. At least one feature must be selected. Duplicate features are
not
+allowed, so there can be no overlap between selected indices and names. Note that if names
of
+features are selected, an exception will be threw out when encountering with empty input
attributes.
+
+The output vector will order features with the selected indices first (in the order given),
+followed by the selected names (in the order given).
+
+**Examples**
+
+Suppose that we have a DataFrame with the column `userFeatures`:
+
+~~~
+ userFeatures     
+------------------
+ [0.0, 10.0, 0.5] 
+~~~
+
+`userFeatures` is a vector column that contains three user features. Assuming that the first
column
+of `userFeatures` are all zeros, so we want to remove it and only the last two columns are
selected.
+The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a
new vector
+column named `features`:
+
+~~~
+ userFeatures     | features
+------------------|-----------------------------
+ [0.0, 10.0, 0.5] | [10.0, 0.5]
+~~~
+
+Suppose also that we have a potential input attributes for the `userFeatures`, i.e. 
+`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them.
+
+~~~
+ userFeatures     | features
+------------------|-----------------------------
+ [0.0, 10.0, 0.5] | [10.0, 0.5]
+ ["f1", "f2", "f3"] | ["f2", "f3"]
+~~~
+
+<div class="codetabs">
+<div data-lang="scala" markdown="1">
+
+[`VectorSlicer`](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) takes an
input
+column name with specified indices or names and an output column name.
+
+{% highlight scala %}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
+import org.apache.spark.ml.feature.VectorSlicer
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+val data = Array(
+  Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
+  Vectors.dense(-2.0, 2.3, 0.0)
+)
+
+val defaultAttr = NumericAttribute.defaultAttr
+val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName)
+val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]])
+
+val dataRDD = sc.parallelize(data).map(Row.apply)
+val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField()))
+
+val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features")
+
+slicer.setIndices(1).setNames("f3")
+// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3"))
+
+val output = slicer.transform(dataset)
+println(output.select("userFeatures", "features").first())
+{% endhighlight %}
+</div>
+
+<div data-lang="java" markdown="1">
+
+[`VectorSlicer`](api/java/org/apache/spark/ml/feature/VectorSlicer.html) takes an input column
name
+with specified indices or names and an output column name.
+
+{% highlight java %}
+import java.util.Arrays;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.types.DataTypes.*;
+
+Attribute[] attrs = new Attribute[]{
+  NumericAttribute.defaultAttr().withName("f1"),
+  NumericAttribute.defaultAttr().withName("f2"),
+  NumericAttribute.defaultAttr().withName("f3")
+};
+AttributeGroup group = new AttributeGroup("userFeatures", attrs);
+
+JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
+  RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
+  RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
+));
+
+DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
+
+VectorSlicer vectorSlicer = new VectorSlicer()
+  .setInputCol("userFeatures").setOutputCol("features");
+
+vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
+// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"})
+
+DataFrame output = vectorSlicer.transform(dataset);
+
+System.out.println(output.select("userFeatures", "features").first());
+{% endhighlight %}
+</div>
+</div>
+
 ## RFormula
 
 `RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html).
It produces a vector column of features and a double column of labels. Like when formulas
are used in R for linear regression, string input columns will be one-hot encoded, and numeric
columns will be cast to doubles. If not already present in the DataFrame, the output label
column will be created from the specified response variable in the formula.

http://git-wip-us.apache.org/repos/asf/spark/blob/630a994e/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
new file mode 100644
index 0000000..56988b9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.attribute.Attribute;
+import org.apache.spark.ml.attribute.AttributeGroup;
+import org.apache.spark.ml.attribute.NumericAttribute;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.StructType;
+
+
+public class JavaVectorSlicerSuite {
+  private transient JavaSparkContext jsc;
+  private transient SQLContext jsql;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
+    jsql = new SQLContext(jsc);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void vectorSlice() {
+    Attribute[] attrs = new Attribute[]{
+      NumericAttribute.defaultAttr().withName("f1"),
+      NumericAttribute.defaultAttr().withName("f2"),
+      NumericAttribute.defaultAttr().withName("f3")
+    };
+    AttributeGroup group = new AttributeGroup("userFeatures", attrs);
+
+    JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
+      RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
+      RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
+    ));
+
+    DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
+
+    VectorSlicer vectorSlicer = new VectorSlicer()
+      .setInputCol("userFeatures").setOutputCol("features");
+
+    vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
+
+    DataFrame output = vectorSlicer.transform(dataset);
+
+    for (Row r : output.select("userFeatures", "features").take(2)) {
+      Vector features = r.getAs(1);
+      Assert.assertEquals(features.size(), 2);
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message