spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-9230] [ML] Support StringType features in RFormula
Date Tue, 28 Jul 2015 00:17:53 GMT
Repository: spark
Updated Branches:
  refs/heads/master dafe8d857 -> 8ddfa52c2


[SPARK-9230] [ML] Support StringType features in RFormula

This adds StringType feature support via OneHotEncoder. As part of this task it was necessary
to change RFormula to an Estimator, so that factor levels could be determined from the training
dataset.

Not sure if I am using uids correctly here, would be good to get reviewer help on that.
cc mengxr

Umbrella design doc: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit#

Author: Eric Liang <ekl@databricks.com>

Closes #7574 from ericl/string-features and squashes the following commits:

f99131a [Eric Liang] comments
0bf3c26 [Eric Liang] update docs
c302a2c [Eric Liang] fix tests
9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features
e713da3 [Eric Liang] comments
4d79193 [Eric Liang] revert to seq + distinct
169a085 [Eric Liang] tweak functional test
a230a47 [Eric Liang] Merge branch 'master' into string-features
72bd6f3 [Eric Liang] fix merge
d841cec [Eric Liang] Merge branch 'master' into string-features
5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015
b01c7c5 [Eric Liang] add test
8a637db [Eric Liang] encoder wip
a1d03f4 [Eric Liang] refactor into estimator


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8ddfa52c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8ddfa52c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8ddfa52c

Branch: refs/heads/master
Commit: 8ddfa52c208bf329c2b2c8909c6be04301e36083
Parents: dafe8d8
Author: Eric Liang <ekl@databricks.com>
Authored: Mon Jul 27 17:17:49 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Jul 27 17:17:49 2015 -0700

----------------------------------------------------------------------
 R/pkg/inst/tests/test_mllib.R                   |   6 +-
 .../org/apache/spark/ml/feature/RFormula.scala  | 133 ++++++++++++++-----
 .../spark/ml/feature/RFormulaParserSuite.scala  |   1 +
 .../apache/spark/ml/feature/RFormulaSuite.scala |  64 +++++----
 4 files changed, 142 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8ddfa52c/R/pkg/inst/tests/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index a492763..29152a1 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -35,8 +35,8 @@ test_that("glm and predict", {
 
 test_that("predictions match with native glm", {
   training <- createDataFrame(sqlContext, iris)
-  model <- glm(Sepal_Width ~ Sepal_Length, data = training)
+  model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
   vals <- collect(select(predict(model, training), "prediction"))
-  rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
-  expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
+  rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddfa52c/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f7b46ef..0a95b1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,26 +17,42 @@
 
 package org.apache.spark.ml.feature
 
+import scala.collection.mutable.ArrayBuffer
 import scala.util.parsing.combinator.RegexParsers
 
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
 import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.VectorUDT
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
 /**
+ * Base trait for [[RFormula]] and [[RFormulaModel]].
+ */
+private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
+  /** @group getParam */
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group getParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  protected def hasLabelCol(schema: StructType): Boolean = {
+    schema.map(_.name).contains($(labelCol))
+  }
+}
+
+/**
  * :: Experimental ::
  * Implements the transforms required for fitting a dataset against an R model formula. Currently
  * we support a limited subset of the R operators, including '~' and '+'. Also see the R
formula
  * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
  */
 @Experimental
-class RFormula(override val uid: String)
-  extends Transformer with HasFeaturesCol with HasLabelCol {
+class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase
{
 
   def this() = this(Identifiable.randomUID("rFormula"))
 
@@ -62,19 +78,74 @@ class RFormula(override val uid: String)
   /** @group getParam */
   def getFormula: String = $(formula)
 
-  /** @group getParam */
-  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+  override def fit(dataset: DataFrame): RFormulaModel = {
+    require(parsedFormula.isDefined, "Must call setFormula() first.")
+    // StringType terms and terms representing interactions need to be encoded before assembly.
+    // TODO(ekl) add support for feature interactions
+    var encoderStages = ArrayBuffer[PipelineStage]()
+    var tempColumns = ArrayBuffer[String]()
+    val encodedTerms = parsedFormula.get.terms.map { term =>
+      dataset.schema(term) match {
+        case column if column.dataType == StringType =>
+          val indexCol = term + "_idx_" + uid
+          val encodedCol = term + "_onehot_" + uid
+          encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
+          encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
+          tempColumns += indexCol
+          tempColumns += encodedCol
+          encodedCol
+        case _ =>
+          term
+      }
+    }
+    encoderStages += new VectorAssembler(uid)
+      .setInputCols(encodedTerms.toArray)
+      .setOutputCol($(featuresCol))
+    encoderStages += new ColumnPruner(tempColumns.toSet)
+    val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
+    copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this))
+  }
 
-  /** @group getParam */
-  def setLabelCol(value: String): this.type = set(labelCol, value)
+  // optimistic schema; does not contain any ML attributes
+  override def transformSchema(schema: StructType): StructType = {
+    if (hasLabelCol(schema)) {
+      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
+    } else {
+      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
+        StructField($(labelCol), DoubleType, true))
+    }
+  }
+
+  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+
+  override def toString: String = s"RFormula(${get(formula)})"
+}
+
+/**
+ * :: Experimental ::
+ * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
+ * @param parsedFormula a pre-parsed R formula.
+ * @param pipelineModel the fitted feature model, including factor to index mappings.
+ */
+@Experimental
+class RFormulaModel private[feature](
+    override val uid: String,
+    parsedFormula: ParsedRFormula,
+    pipelineModel: PipelineModel)
+  extends Model[RFormulaModel] with RFormulaBase {
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    checkCanTransform(dataset.schema)
+    transformLabel(pipelineModel.transform(dataset))
+  }
 
   override def transformSchema(schema: StructType): StructType = {
     checkCanTransform(schema)
-    val withFeatures = transformFeatures.transformSchema(schema)
+    val withFeatures = pipelineModel.transformSchema(schema)
     if (hasLabelCol(schema)) {
       withFeatures
-    } else if (schema.exists(_.name == parsedFormula.get.label)) {
-      val nullable = schema(parsedFormula.get.label).dataType match {
+    } else if (schema.exists(_.name == parsedFormula.label)) {
+      val nullable = schema(parsedFormula.label).dataType match {
         case _: NumericType | BooleanType => false
         case _ => true
       }
@@ -86,24 +157,19 @@ class RFormula(override val uid: String)
     }
   }
 
-  override def transform(dataset: DataFrame): DataFrame = {
-    checkCanTransform(dataset.schema)
-    transformLabel(transformFeatures.transform(dataset))
-  }
-
-  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)
+  override def copy(extra: ParamMap): RFormulaModel = copyValues(
+    new RFormulaModel(uid, parsedFormula, pipelineModel))
 
-  override def toString: String = s"RFormula(${get(formula)})"
+  override def toString: String = s"RFormulaModel(${parsedFormula})"
 
   private def transformLabel(dataset: DataFrame): DataFrame = {
-    val labelName = parsedFormula.get.label
+    val labelName = parsedFormula.label
     if (hasLabelCol(dataset.schema)) {
       dataset
     } else if (dataset.schema.exists(_.name == labelName)) {
       dataset.schema(labelName).dataType match {
         case _: NumericType | BooleanType =>
           dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
-        // TODO(ekl) add support for string-type labels
         case other =>
           throw new IllegalArgumentException("Unsupported type for label: " + other)
       }
@@ -114,25 +180,32 @@ class RFormula(override val uid: String)
     }
   }
 
-  private def transformFeatures: Transformer = {
-    // TODO(ekl) add support for non-numeric features and feature interactions
-    new VectorAssembler(uid)
-      .setInputCols(parsedFormula.get.terms.toArray)
-      .setOutputCol($(featuresCol))
-  }
-
   private def checkCanTransform(schema: StructType) {
-    require(parsedFormula.isDefined, "Must call setFormula() first.")
     val columnNames = schema.map(_.name)
     require(!columnNames.contains($(featuresCol)), "Features column already exists.")
     require(
       !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
       "Label column already exists and is not of type DoubleType.")
   }
+}
 
-  private def hasLabelCol(schema: StructType): Boolean = {
-    schema.map(_.name).contains($(labelCol))
+/**
+ * Utility transformer for removing temporary columns from a DataFrame.
+ * TODO(ekl) make this a public transformer
+ */
+private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
+  override val uid = Identifiable.randomUID("columnPruner")
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
+    dataset.select(columnsToKeep.map(dataset.col) : _*)
   }
+
+  override def transformSchema(schema: StructType): StructType = {
+    StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
+  }
+
+  override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
 }
 
 /**
@@ -149,7 +222,7 @@ private[ml] object RFormulaParser extends RegexParsers {
   def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a
:: list }
 
   def formula: Parser[ParsedRFormula] =
-    (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
+    (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) }
 
   def parse(value: String): ParsedRFormula = parseAll(formula, value) match {
     case Success(result, _) => result

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddfa52c/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index c8d065f..c4b45ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -28,6 +28,7 @@ class RFormulaParserSuite extends SparkFunSuite {
 
   test("parse simple formulas") {
     checkParse("y ~ x", "y", Seq("x"))
+    checkParse("y ~ x + x", "y", Seq("x"))
     checkParse("y ~   ._foo  ", "y", Seq("._foo"))
     checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8ddfa52c/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 79c4ccf..8148c55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -31,72 +31,78 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
     val formula = new RFormula().setFormula("id ~ v1 + v2")
     val original = sqlContext.createDataFrame(
       Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
-    val result = formula.transform(original)
-    val resultSchema = formula.transformSchema(original.schema)
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val resultSchema = model.transformSchema(original.schema)
     val expected = sqlContext.createDataFrame(
       Seq(
-        (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
-        (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+        (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
+        (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
       ).toDF("id", "v1", "v2", "features", "label")
     // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
     assert(result.schema.toString == resultSchema.toString)
     assert(resultSchema == expected.schema)
-    assert(result.collect().toSeq == expected.collect().toSeq)
+    assert(result.collect() === expected.collect())
   }
 
   test("features column already exists") {
     val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
     val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
     intercept[IllegalArgumentException] {
-      formula.transformSchema(original.schema)
+      formula.fit(original)
     }
     intercept[IllegalArgumentException] {
-      formula.transform(original)
+      formula.fit(original)
     }
   }
 
   test("label column already exists") {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
     val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
-    val resultSchema = formula.transformSchema(original.schema)
+    val model = formula.fit(original)
+    val resultSchema = model.transformSchema(original.schema)
     assert(resultSchema.length == 3)
-    assert(resultSchema.toString == formula.transform(original).schema.toString)
+    assert(resultSchema.toString == model.transform(original).schema.toString)
   }
 
   test("label column already exists but is not double type") {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
     val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+    val model = formula.fit(original)
     intercept[IllegalArgumentException] {
-      formula.transformSchema(original.schema)
+      model.transformSchema(original.schema)
     }
     intercept[IllegalArgumentException] {
-      formula.transform(original)
+      model.transform(original)
     }
   }
 
   test("allow missing label column for test datasets") {
     val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
     val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
-    val resultSchema = formula.transformSchema(original.schema)
+    val model = formula.fit(original)
+    val resultSchema = model.transformSchema(original.schema)
     assert(resultSchema.length == 3)
     assert(!resultSchema.exists(_.name == "label"))
-    assert(resultSchema.toString == formula.transform(original).schema.toString)
+    assert(resultSchema.toString == model.transform(original).schema.toString)
   }
 
-// TODO(ekl) enable after we implement string label support
-//  test("transform string label") {
-//    val formula = new RFormula().setFormula("name ~ id")
-//    val original = sqlContext.createDataFrame(
-//      Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
-//    val result = formula.transform(original)
-//    val resultSchema = formula.transformSchema(original.schema)
-//    val expected = sqlContext.createDataFrame(
-//      Seq(
-//        (1, "foo", Vectors.dense(Array(1.0)), 1.0),
-//        (2, "bar", Vectors.dense(Array(2.0)), 0.0),
-//        (3, "bar", Vectors.dense(Array(3.0)), 0.0))
-//      ).toDF("id", "name", "features", "label")
-//    assert(result.schema.toString == resultSchema.toString)
-//    assert(result.collect().toSeq == expected.collect().toSeq)
-//  }
+  test("encodes string terms") {
+    val formula = new RFormula().setFormula("id ~ a + b")
+    val original = sqlContext.createDataFrame(
+      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+    ).toDF("id", "a", "b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val resultSchema = model.transformSchema(original.schema)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+        (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+        (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+        (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+      ).toDF("id", "a", "b", "features", "label")
+    assert(result.schema.toString == resultSchema.toString)
+    assert(result.collect() === expected.collect())
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message