Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 3D0A7200C35 for ; Sun, 12 Mar 2017 20:15:26 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 3B91E160B77; Sun, 12 Mar 2017 19:15:26 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 36284160B69 for ; Sun, 12 Mar 2017 20:15:25 +0100 (CET) Received: (qmail 40343 invoked by uid 500); 12 Mar 2017 19:15:22 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 40334 invoked by uid 99); 12 Mar 2017 19:15:22 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 12 Mar 2017 19:15:22 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id AC1BADFE8F; Sun, 12 Mar 2017 19:15:22 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: felixcheung@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models Date: Sun, 12 Mar 2017 19:15:22 +0000 (UTC) archived-at: Sun, 12 Mar 2017 19:15:26 -0000 Repository: spark Updated Branches: refs/heads/master 2f5187bde -> 9f8ce4825 [SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models ## What changes were proposed in this pull request? RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models. Below 4 R wrappers are changed: * `RandomForestClassificationWrapper` * `RandomForestRegressionWrapper` * `GBTClassificationWrapper` * `GBTRegressionWrapper` ## How was this patch tested? Test manually on my local machine. Author: Xin Ren Closes #17207 from keypointt/SPARK-19282. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9f8ce482 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9f8ce482 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9f8ce482 Branch: refs/heads/master Commit: 9f8ce4825e378b6a856ce65cb9986a5a0f0b624e Parents: 2f5187b Author: Xin Ren Authored: Sun Mar 12 12:15:19 2017 -0700 Committer: Felix Cheung Committed: Sun Mar 12 12:15:19 2017 -0700 ---------------------------------------------------------------------- R/pkg/R/mllib_tree.R | 11 +++++++---- R/pkg/inst/tests/testthat/test_mllib_tree.R | 10 ++++++++++ .../org/apache/spark/ml/r/GBTClassificationWrapper.scala | 1 + .../org/apache/spark/ml/r/GBTRegressionWrapper.scala | 1 + .../spark/ml/r/RandomForestClassificationWrapper.scala | 1 + .../spark/ml/r/RandomForestRegressionWrapper.scala | 1 + 6 files changed, 21 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/R/pkg/R/mllib_tree.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 40a806c..82279be 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) { numFeatures <- callJMethod(jobj, "numFeatures") features <- callJMethod(jobj, "features") featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") numTrees <- callJMethod(jobj, "numTrees") treeWeights <- callJMethod(jobj, "treeWeights") list(formula = formula, numFeatures = numFeatures, features = features, featureImportances = featureImportances, + maxDepth = maxDepth, numTrees = numTrees, treeWeights = treeWeights, jobj = jobj) @@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) { cat("\nNumber of features: ", x$numFeatures) cat("\nFeatures: ", unlist(x$features)) cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) cat("\nNumber of trees: ", x$numTrees) cat("\nTree weights: ", unlist(x$treeWeights)) @@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method #' @export @@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), -#' and \code{treeWeights} (tree weights). +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/R/pkg/inst/tests/testthat/test_mllib_tree.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index e6fda25..e0802a9 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -39,6 +39,7 @@ test_that("spark.gbt", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_equal(stats$formula, "Employed ~ .") expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) @@ -53,6 +54,7 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, stats2$numFeatures) expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$numTrees, stats2$numTrees) expect_equal(stats$treeWeights, stats2$treeWeights) @@ -66,6 +68,7 @@ test_that("spark.gbt", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) predictions <- collect(predict(model, data))$prediction @@ -93,6 +96,7 @@ test_that("spark.gbt", { expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(s$numFeatures, 5) expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), @@ -116,6 +120,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) @@ -129,6 +134,7 @@ test_that("spark.randomForest", { tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") write.ml(model, modelPath) @@ -141,6 +147,7 @@ test_that("spark.randomForest", { expect_equal(stats$features, stats2$features) expect_equal(stats$featureImportances, stats2$featureImportances) expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) @@ -153,6 +160,7 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) # Test string prediction values @@ -187,6 +195,8 @@ test_that("spark.randomForest", { stats <- summary(model) expect_equal(stats$numFeatures, 2) expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + # Test numeric prediction values predictions <- collect(predict(model, data))$prediction expect_equal(length(grep("1.0", predictions)), 50) http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index aacb41e..c07eadb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private ( lazy val featureImportances: Vector = gbtcModel.featureImportances lazy val numTrees: Int = gbtcModel.getNumTrees lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + lazy val maxDepth: Int = gbtcModel.getMaxDepth def summary: String = gbtcModel.toDebugString http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala index 5850775..b568d78 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private ( lazy val featureImportances: Vector = gbtrModel.featureImportances lazy val numTrees: Int = gbtrModel.getNumTrees lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + lazy val maxDepth: Int = gbtrModel.getMaxDepth def summary: String = gbtrModel.toDebugString http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 366f375..8a83d4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private ( lazy val featureImportances: Vector = rfcModel.featureImportances lazy val numTrees: Int = rfcModel.getNumTrees lazy val treeWeights: Array[Double] = rfcModel.treeWeights + lazy val maxDepth: Int = rfcModel.getMaxDepth def summary: String = rfcModel.toDebugString http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index 4b9a3a7..038bd79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private ( lazy val featureImportances: Vector = rfrModel.featureImportances lazy val numTrees: Int = rfrModel.getNumTrees lazy val treeWeights: Array[Double] = rfrModel.treeWeights + lazy val maxDepth: Int = rfrModel.getMaxDepth def summary: String = rfrModel.toDebugString --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org