spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation
Date Tue, 24 Feb 2015 23:13:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master da505e592 -> 2a0fe3489


[SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation

One can early stop if the decrease in error rate is lesser than a certain tol or if the error
increases if the training data is overfit.

This introduces a new method runWithValidation which takes in a pair of RDD's , one for the
training data and the other for the validation.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #4677 from MechCoder/spark-5436 and squashes the following commits:

1bb21d4 [MechCoder] Combine regression and classification tests into a single one
e4d799b [MechCoder] Addresses indentation and doc comments
b48a70f [MechCoder] COSMIT
b928a19 [MechCoder] Move validation while training section under usage tips
fad9b6e [MechCoder] Made the following changes 1. Add section to documentation 2. Return corresponding
to bestValidationError 3. Allow negative tolerance.
55e5c3b [MechCoder] One liner for prevValidateError
3e74372 [MechCoder] TST: Add test for classification
77549a9 [MechCoder] [SPARK-5436] Validate GradientBoostedTrees using runWithValidation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a0fe348
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a0fe348
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a0fe348

Branch: refs/heads/master
Commit: 2a0fe34891882e0fde1b5722d8227aa99acc0f1f
Parents: da505e5
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Authored: Tue Feb 24 15:13:22 2015 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Tue Feb 24 15:13:22 2015 -0800

----------------------------------------------------------------------
 docs/mllib-ensembles.md                         | 11 +++
 .../spark/mllib/tree/GradientBoostedTrees.scala | 75 ++++++++++++++++++--
 .../tree/configuration/BoostingStrategy.scala   |  6 +-
 .../mllib/tree/GradientBoostedTreesSuite.scala  | 36 ++++++++++
 4 files changed, 122 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a0fe348/docs/mllib-ensembles.md
----------------------------------------------------------------------
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index fb90b70..00040e6 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -427,6 +427,17 @@ We omit some decision tree parameters since those are covered in the
[decision t
 
 * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree
[Strategy] parameter.
 
+#### Validation while training
+
+Gradient boosting can overfit when trained with more trees. In order to prevent overfitting,
it is useful to validate while
+training. The method runWithValidation has been provided to make use of this option. It takes
a pair of RDD's as arguments, the
+first one being the training dataset and the second being the validation dataset.
+
+The training is stopped when the improvement in the validation error is not more than a certain
tolerance
+(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation
error
+decreases initially and later increases. There might be cases in which the validation error
does not change monotonically,
+and the user is advised to set a large enough negative tolerance and examine the validation
curve to to tune the number of
+iterations.
 
 ### Examples
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2a0fe348/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 61f6b13..b4466ff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
   def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
-      case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
+      case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
       case Classification =>
         // Map labels to -1, +1 so binary classification can be treated as regression.
         val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-        GradientBoostedTrees.boost(remappedInput, boostingStrategy)
+        GradientBoostedTrees.boost(remappedInput,
+          remappedInput, boostingStrategy, validate=false)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
     }
@@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
   def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
     run(input.rdd)
   }
-}
 
+  /**
+   * Method to validate a gradient boosting model
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @param validationInput Validation dataset:
+                          RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+                          Should be different from and follow the same distribution as input.
+                          e.g., these two datasets could be created from an original dataset
+                          by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+   * @return a gradient boosted trees model that can be used for prediction
+   */
+  def runWithValidation(
+      input: RDD[LabeledPoint],
+      validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+    val algo = boostingStrategy.treeStrategy.algo
+    algo match {
+      case Regression => GradientBoostedTrees.boost(
+        input, validationInput, boostingStrategy, validate=true)
+      case Classification =>
+        // Map labels to -1, +1 so binary classification can be treated as regression.
+        val remappedInput = input.map(
+          x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val remappedValidationInput = validationInput.map(
+          x => new LabeledPoint((x.label * 2) - 1, x.features))
+        GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
+          validate=true)
+      case _ =>
+        throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+    }
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]].
+   */
+  def runWithValidation(
+      input: JavaRDD[LabeledPoint],
+      validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+    runWithValidation(input.rdd, validationInput.rdd)
+  }
+}
 
 object GradientBoostedTrees extends Logging {
 
@@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging {
   /**
    * Internal method for performing regression using trees as base learners.
    * @param input training dataset
+   * @param validationInput validation dataset, ignored if validate is set to false.
    * @param boostingStrategy boosting parameters
+   * @param validate whether or not to use the validation dataset.
    * @return a gradient boosted trees model that can be used for prediction
    */
   private def boost(
       input: RDD[LabeledPoint],
-      boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+      validationInput: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy,
+      validate: Boolean): GradientBoostedTreesModel = {
 
     val timer = new TimeTracker()
     timer.start("total")
@@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging {
     val learningRate = boostingStrategy.learningRate
     // Prepare strategy for individual trees, which use regression with variance impurity.
     val treeStrategy = boostingStrategy.treeStrategy.copy
+    val validationTol = boostingStrategy.validationTol
     treeStrategy.algo = Regression
     treeStrategy.impurity = Variance
     treeStrategy.assertValid()
@@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging {
     baseLearnerWeights(0) = 1.0
     val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel),
Array(1.0))
     logDebug("error of gbt = " + loss.computeError(startingModel, input))
+
     // Note: A model of type regression is used since we require raw prediction
     timer.stop("building tree 0")
 
+    var bestValidateError = if (validate) loss.computeError(startingModel, validationInput)
else 0.0
+    var bestM = 1
+
     // psuedo-residual for second iteration
     data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
       point.features))
-
     var m = 1
     while (m < numIterations) {
       timer.start(s"building tree $m")
@@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging {
       val partialModel = new GradientBoostedTreesModel(
         Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
       logDebug("error of gbt = " + loss.computeError(partialModel, input))
+
+      if (validate) {
+        // Stop training early if
+        // 1. Reduction in error is less than the validationTol or
+        // 2. If the error increases, that is if the model is overfit.
+        // We want the model returned corresponding to the best validation error.
+        val currentValidateError = loss.computeError(partialModel, validationInput)
+        if (bestValidateError - currentValidateError < validationTol) {
+          return new GradientBoostedTreesModel(
+            boostingStrategy.treeStrategy.algo,
+            baseLearners.slice(0, bestM),
+            baseLearnerWeights.slice(0, bestM))
+        } else if (currentValidateError < bestValidateError) {
+            bestValidateError = currentValidateError
+            bestM = m + 1
+        }
+      }
       // Update data with pseudo-residuals
       data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
         point.features))
@@ -191,4 +255,5 @@ object GradientBoostedTrees extends Logging {
     new GradientBoostedTreesModel(
       boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2a0fe348/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index ed8e6a7..664c8df 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
  *                      weak hypotheses used in the final model.
  * @param learningRate Learning rate for shrinking the contribution of each estimator. The
  *                     learning rate should be between in the interval (0, 1]
+ * @param validationTol Useful when runWithValidation is used. If the error rate on the
+ *                      validation input between two iterations is less than the validationTol
+ *                      then stop. Ignored when [[run]] is used.
  */
 @Experimental
 case class BoostingStrategy(
@@ -42,7 +45,8 @@ case class BoostingStrategy(
     @BeanProperty var loss: Loss,
     // Optional boosting parameters
     @BeanProperty var numIterations: Int = 100,
-    @BeanProperty var learningRate: Double = 0.1) extends Serializable {
+    @BeanProperty var learningRate: Double = 0.1,
+    @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
 
   /**
    * Check validity of parameters.

http://git-wip-us.apache.org/repos/asf/spark/blob/2a0fe348/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index bde4760..b437aea 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,6 +158,40 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext
{
       }
     }
   }
+
+  test("runWithValidation stops early and performs better on a validation dataset") {
+    // Set numIterations large enough so that it stops early.
+    val numIterations = 20
+    val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
+    val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
+
+    val algos = Array(Regression, Regression, Classification)
+    val losses = Array(SquaredError, AbsoluteError, LogLoss)
+    (algos zip losses) map {
+      case (algo, loss) => {
+        val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+          categoricalFeaturesInfo = Map.empty)
+        val boostingStrategy =
+          new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+        val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+          .runWithValidation(trainRdd, validateRdd)
+        assert(gbtValidate.numTrees !== numIterations)
+
+        // Test that it performs better on the validation dataset.
+        val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+        val (errorWithoutValidation, errorWithValidation) = {
+          if (algo == Classification) {
+            val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+            (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+          } else {
+            (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
+          }
+        }
+        assert(errorWithValidation <= errorWithoutValidation)
+      }
+    }
+  }
+
 }
 
 private object GradientBoostedTreesSuite {
@@ -166,4 +200,6 @@ private object GradientBoostedTreesSuite {
   val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1,
0.75))
 
   val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+  val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120)
+  val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message