Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 6F434200BC7 for ; Fri, 11 Nov 2016 02:13:36 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 6D114160B10; Fri, 11 Nov 2016 01:13:36 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 8AF65160B01 for ; Fri, 11 Nov 2016 02:13:35 +0100 (CET) Received: (qmail 6929 invoked by uid 500); 11 Nov 2016 01:13:34 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 6920 invoked by uid 99); 11 Nov 2016 01:13:34 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 11 Nov 2016 01:13:34 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id A0932E00E5; Fri, 11 Nov 2016 01:13:34 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: yliang@apache.org To: commits@spark.apache.org Message-Id: <903543d68f5e4262a0d43b3d36896e90@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-18401][SPARKR][ML] SparkR random forest should support output original label. Date: Fri, 11 Nov 2016 01:13:34 +0000 (UTC) archived-at: Fri, 11 Nov 2016 01:13:36 -0000 Repository: spark Updated Branches: refs/heads/branch-2.1 064d4315f -> 51dca6143 [SPARK-18401][SPARKR][ML] SparkR random forest should support output original label. ## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang Closes #15842 from yanboliang/spark-18401. (cherry picked from commit 5ddf69470b93c0b8a28bb4ac905e7670d9c50a95) Signed-off-by: Yanbo Liang Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/51dca614 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/51dca614 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/51dca614 Branch: refs/heads/branch-2.1 Commit: 51dca6143670ec1c1cb090047c3941becaf41fa9 Parents: 064d431 Author: Yanbo Liang Authored: Thu Nov 10 17:13:10 2016 -0800 Committer: Yanbo Liang Committed: Thu Nov 10 17:13:26 2016 -0800 ---------------------------------------------------------------------- R/pkg/inst/tests/testthat/test_mllib.R | 24 +++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 28 +++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/51dca614/R/pkg/inst/tests/testthat/test_mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1e456ef..33e85b7 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) }) test_that("spark.gbt", { http://git-wip-us.apache.org/repos/asf/spark/blob/51dca614/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 6947ba7..31f846d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -35,6 +35,8 @@ private[r] class RandomForestClassifierWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { + import RandomForestClassifierWrapper._ + private val rfcModel: RandomForestClassificationModel = pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] @@ -46,7 +48,9 @@ private[r] class RandomForestClassifierWrapper private ( def summary: String = rfcModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(rfcModel.getFeaturesCol) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(rfcModel.getFeaturesCol) } override def write: MLWriter = new @@ -54,6 +58,10 @@ private[r] class RandomForestClassifierWrapper private ( } private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( // scalastyle:ignore data: DataFrame, formula: String, @@ -73,6 +81,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) @@ -82,6 +91,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .attributes.get val features = featureAttrs.map(_.name.get) + // get label names from output schema + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + // assemble and fit the pipeline val rfc = new RandomForestClassifier() .setMaxDepth(maxDepth) @@ -97,10 +111,16 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, rfc)) + .setStages(Array(rFormulaModel, rfc, idxToStr)) .fit(data) new RandomForestClassifierWrapper(pipeline, formula, features) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org