spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-8736] [ML] GBTRegressor should not threshold prediction
Date Tue, 30 Jun 2015 21:02:51 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4bb8375fc -> 3ba23ffd3


[SPARK-8736] [ML] GBTRegressor should not threshold prediction

Changed GBTRegressor so it does NOT threshold the prediction.  Added test which fails with
bug but works after fix.

CC: feynmanliang  mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7134 from jkbradley/gbrt-fix and squashes the following commits:

613b90e [Joseph K. Bradley] Changed GBTRegressor so it does NOT threshold the prediction


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ba23ffd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ba23ffd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ba23ffd

Branch: refs/heads/master
Commit: 3ba23ffd377d12383d923d1550ac8e2b916090fc
Parents: 4bb8375
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Tue Jun 30 14:02:50 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Jun 30 14:02:50 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/GBTRegressor.scala      |  3 +--
 .../spark/ml/regression/GBTRegressorSuite.scala | 23 +++++++++++++++++++-
 2 files changed, 23 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3ba23ffd/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 036e3ac..47c110d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -172,8 +172,7 @@ final class GBTRegressionModel(
     // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
     // Classifies by thresholding sum of weighted tree predictions
     val treePredictions = _trees.map(_.rootNode.predict(features))
-    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
-    if (prediction > 0.0) 1.0 else 0.0
+    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
   }
 
   override def copy(extra: ParamMap): GBTRegressionModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/3ba23ffd/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 98fb3d3..9682edc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,12 +19,13 @@ package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 
 /**
@@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
{
     }
   }
 
+  test("GBTRegressor behaves reasonably on toy data") {
+    val df = sqlContext.createDataFrame(Seq(
+      LabeledPoint(10, Vectors.dense(1, 2, 3, 4)),
+      LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)),
+      LabeledPoint(11, Vectors.dense(2, 2, 3, 4)),
+      LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)),
+      LabeledPoint(9, Vectors.dense(1, 2, 6, 4)),
+      LabeledPoint(-4, Vectors.dense(6, 3, 2, 2))
+    ))
+    val gbt = new GBTRegressor()
+      .setMaxDepth(2)
+      .setMaxIter(2)
+    val model = gbt.fit(df)
+    val preds = model.transform(df)
+    val predictions = preds.select("prediction").map(_.getDouble(0))
+    // Checks based on SPARK-8736 (to ensure it is not doing classification)
+    assert(predictions.max() > 2)
+    assert(predictions.min() < -1)
+  }
+
   // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
   /*
   test("runWithValidation stops early and performs better on a validation dataset") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message