spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [SPARK-18401][SPARKR][ML] SparkR random forest should support output original label.
Date Fri, 11 Nov 2016 01:13:18 GMT
Repository: spark
Updated Branches:
  refs/heads/master a3356343c -> 5ddf69470


[SPARK-18401][SPARKR][ML] SparkR random forest should support output original label.

## What changes were proposed in this pull request?
SparkR ```spark.randomForest``` classification prediction should output original label rather
than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291).

## How was this patch tested?
Add unit tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #15842 from yanboliang/spark-18401.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ddf6947
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ddf6947
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ddf6947

Branch: refs/heads/master
Commit: 5ddf69470b93c0b8a28bb4ac905e7670d9c50a95
Parents: a335634
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Thu Nov 10 17:13:10 2016 -0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Thu Nov 10 17:13:10 2016 -0800

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_mllib.R          | 24 +++++++++++++++++
 .../r/RandomForestClassificationWrapper.scala   | 28 +++++++++++++++++---
 2 files changed, 48 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5ddf6947/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 33e9d0d..b76f75d 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", {
   expect_equal(stats$numTrees, 20)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
+  # Test string prediction values
+  predictions <- collect(predict(model, data))$prediction
+  expect_equal(length(grep("setosa", predictions)), 50)
+  expect_equal(length(grep("versicolor", predictions)), 50)
 
   modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
   write.ml(model, modelPath)
@@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", {
   expect_equal(stats$numClasses, stats2$numClasses)
 
   unlink(modelPath)
+
+  # Test numeric response variable
+  labelToIndex <- function(species) {
+    switch(as.character(species),
+      setosa = 0.0,
+      versicolor = 1.0,
+      virginica = 2.0
+    )
+  }
+  iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
+  data <- suppressWarnings(createDataFrame(iris[-5]))
+  model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
+                              maxDepth = 5, maxBins = 16)
+  stats <- summary(model)
+  expect_equal(stats$numFeatures, 2)
+  expect_equal(stats$numTrees, 20)
+  # Test numeric prediction values
+  predictions <- collect(predict(model, data))$prediction
+  expect_equal(length(grep("1.0", predictions)), 50)
+  expect_equal(length(grep("2.0", predictions)), 50)
 })
 
 test_that("spark.gbt", {

http://git-wip-us.apache.org/repos/asf/spark/blob/5ddf6947/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 6947ba7..31f846d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
 import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private (
   val formula: String,
   val features: Array[String]) extends MLWritable {
 
+  import RandomForestClassifierWrapper._
+
   private val rfcModel: RandomForestClassificationModel =
     pipeline.stages(1).asInstanceOf[RandomForestClassificationModel]
 
@@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private (
   def summary: String = rfcModel.toDebugString
 
   def transform(dataset: Dataset[_]): DataFrame = {
-    pipeline.transform(dataset).drop(rfcModel.getFeaturesCol)
+    pipeline.transform(dataset)
+      .drop(PREDICTED_LABEL_INDEX_COL)
+      .drop(rfcModel.getFeaturesCol)
   }
 
   override def write: MLWriter = new
@@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private (
 }
 
 private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper]
{
+
+  val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+  val PREDICTED_LABEL_COL = "prediction"
+
   def fit(  // scalastyle:ignore
       data: DataFrame,
       formula: String,
@@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
 
     val rFormula = new RFormula()
       .setFormula(formula)
+      .setForceIndexLabel(true)
     RWrapperUtils.checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 
@@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
       .attributes.get
     val features = featureAttrs.map(_.name.get)
 
+    // get label names from output schema
+    val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
+      .asInstanceOf[NominalAttribute]
+    val labels = labelAttr.values.get
+
     // assemble and fit the pipeline
     val rfc = new RandomForestClassifier()
       .setMaxDepth(maxDepth)
@@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
       .setCacheNodeIds(cacheNodeIds)
       .setProbabilityCol(probabilityCol)
       .setFeaturesCol(rFormula.getFeaturesCol)
+      .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
     if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
 
+    val idxToStr = new IndexToString()
+      .setInputCol(PREDICTED_LABEL_INDEX_COL)
+      .setOutputCol(PREDICTED_LABEL_COL)
+      .setLabels(labels)
+
     val pipeline = new Pipeline()
-      .setStages(Array(rFormulaModel, rfc))
+      .setStages(Array(rFormulaModel, rfc, idxToStr))
       .fit(data)
 
     new RandomForestClassifierWrapper(pipeline, formula, features)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message