spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From felixche...@apache.org
Subject spark git commit: [SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models
Date Sun, 12 Mar 2017 19:15:22 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2f5187bde -> 9f8ce4825


[SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to
R models

## What changes were proposed in this pull request?

RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models.

Below 4 R wrappers are changed:
* `RandomForestClassificationWrapper`
* `RandomForestRegressionWrapper`
* `GBTClassificationWrapper`
* `GBTRegressionWrapper`

## How was this patch tested?

Test manually on my local machine.

Author: Xin Ren <iamshrek@126.com>

Closes #17207 from keypointt/SPARK-19282.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9f8ce482
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9f8ce482
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9f8ce482

Branch: refs/heads/master
Commit: 9f8ce4825e378b6a856ce65cb9986a5a0f0b624e
Parents: 2f5187b
Author: Xin Ren <iamshrek@126.com>
Authored: Sun Mar 12 12:15:19 2017 -0700
Committer: Felix Cheung <felixcheung@apache.org>
Committed: Sun Mar 12 12:15:19 2017 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib_tree.R                                     | 11 +++++++----
 R/pkg/inst/tests/testthat/test_mllib_tree.R              | 10 ++++++++++
 .../org/apache/spark/ml/r/GBTClassificationWrapper.scala |  1 +
 .../org/apache/spark/ml/r/GBTRegressionWrapper.scala     |  1 +
 .../spark/ml/r/RandomForestClassificationWrapper.scala   |  1 +
 .../spark/ml/r/RandomForestRegressionWrapper.scala       |  1 +
 6 files changed, 21 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/R/pkg/R/mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 40a806c..82279be 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) {
   numFeatures <- callJMethod(jobj, "numFeatures")
   features <-  callJMethod(jobj, "features")
   featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
+  maxDepth <- callJMethod(jobj, "maxDepth")
   numTrees <- callJMethod(jobj, "numTrees")
   treeWeights <- callJMethod(jobj, "treeWeights")
   list(formula = formula,
        numFeatures = numFeatures,
        features = features,
        featureImportances = featureImportances,
+       maxDepth = maxDepth,
        numTrees = numTrees,
        treeWeights = treeWeights,
        jobj = jobj)
@@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) {
   cat("\nNumber of features: ", x$numFeatures)
   cat("\nFeatures: ", unlist(x$features))
   cat("\nFeature importances: ", x$featureImportances)
+  cat("\nMax Depth: ", x$maxDepth)
   cat("\nNumber of trees: ", x$numTrees)
   cat("\nTree weights: ", unlist(x$treeWeights))
 
@@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
 #' @return \code{summary} returns summary information of the fitted model, which is a list.
 #'         The list of components includes \code{formula} (formula),
 #'         \code{numFeatures} (number of features), \code{features} (list of features),
-#'         \code{featureImportances} (feature importances), \code{numTrees} (number of trees),
-#'         and \code{treeWeights} (tree weights).
+#'         \code{featureImportances} (feature importances), \code{maxDepth} (max depth of
trees),
+#'         \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
 #' @rdname spark.gbt
 #' @aliases summary,GBTRegressionModel-method
 #' @export
@@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula
= "fo
 #' @return \code{summary} returns summary information of the fitted model, which is a list.
 #'         The list of components includes \code{formula} (formula),
 #'         \code{numFeatures} (number of features), \code{features} (list of features),
-#'         \code{featureImportances} (feature importances), \code{numTrees} (number of trees),
-#'         and \code{treeWeights} (tree weights).
+#'         \code{featureImportances} (feature importances), \code{maxDepth} (max depth of
trees),
+#'         \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
 #' @rdname spark.randomForest
 #' @aliases summary,RandomForestRegressionModel-method
 #' @export

http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/R/pkg/inst/tests/testthat/test_mllib_tree.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R
index e6fda25..e0802a9 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_tree.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R
@@ -39,6 +39,7 @@ test_that("spark.gbt", {
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
   expect_equal(stats$formula, "Employed ~ .")
   expect_equal(stats$numFeatures, 6)
   expect_equal(length(stats$treeWeights), 20)
@@ -53,6 +54,7 @@ test_that("spark.gbt", {
   expect_equal(stats$numFeatures, stats2$numFeatures)
   expect_equal(stats$features, stats2$features)
   expect_equal(stats$featureImportances, stats2$featureImportances)
+  expect_equal(stats$maxDepth, stats2$maxDepth)
   expect_equal(stats$numTrees, stats2$numTrees)
   expect_equal(stats$treeWeights, stats2$treeWeights)
 
@@ -66,6 +68,7 @@ test_that("spark.gbt", {
   stats <- summary(model)
   expect_equal(stats$numFeatures, 2)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
   predictions <- collect(predict(model, data))$prediction
@@ -93,6 +96,7 @@ test_that("spark.gbt", {
   expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
   expect_equal(s$numFeatures, 5)
   expect_equal(s$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
 
   # spark.gbt classification can work on libsvm data
   data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
@@ -116,6 +120,7 @@ test_that("spark.randomForest", {
 
   stats <- summary(model)
   expect_equal(stats$numTrees, 1)
+  expect_equal(stats$maxDepth, 5)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
 
@@ -129,6 +134,7 @@ test_that("spark.randomForest", {
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
 
   modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
   write.ml(model, modelPath)
@@ -141,6 +147,7 @@ test_that("spark.randomForest", {
   expect_equal(stats$features, stats2$features)
   expect_equal(stats$featureImportances, stats2$featureImportances)
   expect_equal(stats$numTrees, stats2$numTrees)
+  expect_equal(stats$maxDepth, stats2$maxDepth)
   expect_equal(stats$treeWeights, stats2$treeWeights)
 
   unlink(modelPath)
@@ -153,6 +160,7 @@ test_that("spark.randomForest", {
   stats <- summary(model)
   expect_equal(stats$numFeatures, 2)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
   # Test string prediction values
@@ -187,6 +195,8 @@ test_that("spark.randomForest", {
   stats <- summary(model)
   expect_equal(stats$numFeatures, 2)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
+
   # Test numeric prediction values
   predictions <- collect(predict(model, data))$prediction
   expect_equal(length(grep("1.0", predictions)), 50)

http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index aacb41e..c07eadb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private (
   lazy val featureImportances: Vector = gbtcModel.featureImportances
   lazy val numTrees: Int = gbtcModel.getNumTrees
   lazy val treeWeights: Array[Double] = gbtcModel.treeWeights
+  lazy val maxDepth: Int = gbtcModel.getMaxDepth
 
   def summary: String = gbtcModel.toDebugString
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
index 5850775..b568d78 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
@@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private (
   lazy val featureImportances: Vector = gbtrModel.featureImportances
   lazy val numTrees: Int = gbtrModel.getNumTrees
   lazy val treeWeights: Array[Double] = gbtrModel.treeWeights
+  lazy val maxDepth: Int = gbtrModel.getMaxDepth
 
   def summary: String = gbtrModel.toDebugString
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 366f375..8a83d4e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private (
   lazy val featureImportances: Vector = rfcModel.featureImportances
   lazy val numTrees: Int = rfcModel.getNumTrees
   lazy val treeWeights: Array[Double] = rfcModel.treeWeights
+  lazy val maxDepth: Int = rfcModel.getMaxDepth
 
   def summary: String = rfcModel.toDebugString
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9f8ce482/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
index 4b9a3a7..038bd79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
@@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private (
   lazy val featureImportances: Vector = rfrModel.featureImportances
   lazy val numTrees: Int = rfrModel.getNumTrees
   lazy val treeWeights: Array[Double] = rfrModel.treeWeights
+  lazy val maxDepth: Int = rfrModel.getMaxDepth
 
   def summary: String = rfrModel.toDebugString
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message