spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-2207][SPARK-3272][MLLib]Add minimum information gain and minimum instances per node as training parameters for decision tree.
Date Wed, 10 Sep 2014 22:37:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master 558962a83 -> 79cdb9b64


[SPARK-2207][SPARK-3272][MLLib]Add minimum information gain and minimum instances per node
as training parameters for decision tree.

These two parameters can act as early stop rules to do pre-pruning. When a split cause cause
left or right child to have less than `minInstancesPerNode` or has less information gain than
`minInfoGain`, current node will not be split by this split.

When there is no possible splits that satisfy requirements, there is no useful information
gain stats, but we still need to calculate the predict value for current node. So I separated
calculation of predict from calculation of information gain, which can also save computation
when the number of possible splits is large. Please see [SPARK-3272](https://issues.apache.org/jira/browse/SPARK-3272)
for more details.

CC: mengxr manishamde jkbradley, please help me review this, thanks.

Author: qiping.lqp <qiping.lqp@alibaba-inc.com>
Author: chouqin <liqiping1991@gmail.com>

Closes #2332 from chouqin/dt-preprune and squashes the following commits:

f1d11d1 [chouqin] fix typo
c7ebaf1 [chouqin] fix typo
39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test
0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree
d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1
efcc736 [qiping.lqp] fix bug
10b8012 [qiping.lqp] fix style
6728fad [qiping.lqp] minor fix: remove empty lines
bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune
cadd569 [qiping.lqp] add api docs
46b891f [qiping.lqp] fix bug
e72c7e4 [qiping.lqp] add comments
845c6fa [qiping.lqp] fix style
f195e83 [qiping.lqp] fix style
987cbf4 [qiping.lqp] fix bug
ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain
ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/79cdb9b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/79cdb9b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/79cdb9b6

Branch: refs/heads/master
Commit: 79cdb9b64ad2fa3ab7f2c221766d36658b917c40
Parents: 558962a
Author: qiping.lqp <qiping.lqp@alibaba-inc.com>
Authored: Wed Sep 10 15:37:10 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Sep 10 15:37:10 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  |  72 +++++++++----
 .../mllib/tree/configuration/Strategy.scala     |   9 ++
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |   7 +-
 .../mllib/tree/model/InformationGainStats.scala |  20 ++--
 .../apache/spark/mllib/tree/model/Predict.scala |  36 +++++++
 .../apache/spark/mllib/tree/model/Split.scala   |   2 +
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 103 +++++++++++++++++--
 7 files changed, 213 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index d1309b2..9859656 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable
with Lo
 
       // Find best split for all nodes at a level.
       timer.start("findBestSplits")
-      val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+      val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
         DecisionTree.findBestSplits(treeInput, parentImpurities,
           metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
       timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable
with Lo
         timer.start("extractNodeInfo")
         val split = nodeSplitStats._1
         val stats = nodeSplitStats._2
+        val predict = nodeSplitStats._3.predict
         val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
-        val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+        val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
         logDebug("Node = " + node)
         nodes(nodeIndex) = node
         timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       maxLevelForSingleGroup: Int,
-      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
+      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)]
= {
     // split into groups to avoid memory overflow during aggregation
     if (level > maxLevelForSingleGroup) {
       // When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
       val numGroups = 1 << level - maxLevelForSingleGroup
       logDebug("numGroups = " + numGroups)
-      var bestSplits = new Array[(Split, InformationGainStats)](0)
+      var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
       while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
       bins: Array[Array[Bin]],
       timer: TimeTracker,
       numGroups: Int = 1,
-      groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
+      groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
 
     // Calculate best splits for all nodes at a given level
     timer.start("chooseSplits")
-    val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+    val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
     // Iterating over all nodes at this level
     var nodeIndex = 0
     while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
       topImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata): InformationGainStats = {
-
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
 
-    val totalCount = leftCount + rightCount
-    if (totalCount == 0) {
-      // Return arbitrary prediction.
-      return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+    // If left child or right child doesn't satisfy minimum instances per node,
+    // then this split is invalid, return invalid information gain stats.
+    if ((leftCount < metadata.minInstancesPerNode) ||
+        (rightCount < metadata.minInstancesPerNode)) {
+      return InformationGainStats.invalidInformationGainStats
     }
 
-    val parentNodeAgg = leftImpurityCalculator.copy
-    parentNodeAgg.add(rightImpurityCalculator)
+    val totalCount = leftCount + rightCount
+
     // impurity of parent node
     val impurity = if (level > 0) {
       topImpurity
     } else {
+      val parentNodeAgg = leftImpurityCalculator.copy
+      parentNodeAgg.add(rightImpurityCalculator)
       parentNodeAgg.calculate()
     }
 
-    val predict = parentNodeAgg.predict
-    val prob = parentNodeAgg.prob(predict)
-
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count
= 0
     val rightImpurity = rightImpurityCalculator.calculate()
 
@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
 
     val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
-    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+    // if information gain doesn't satisfy minimum information gain,
+    // then this split is invalid, return invalid information gain stats.
+    if (gain < metadata.minInfoGain) {
+      return InformationGainStats.invalidInformationGainStats
+    }
+
+    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+  }
+
+  /**
+   * Calculate predict value for current node, given stats of any split.
+   * Note that this function is called only once for each node.
+   * @param leftImpurityCalculator left node aggregates for a split
+   * @param rightImpurityCalculator right node aggregates for a node
+   * @return predict value for current node
+   */
+  private def calculatePredict(
+      leftImpurityCalculator: ImpurityCalculator,
+      rightImpurityCalculator: ImpurityCalculator): Predict =  {
+    val parentNodeAgg = leftImpurityCalculator.copy
+    parentNodeAgg.add(rightImpurityCalculator)
+    val predict = parentNodeAgg.predict
+    val prob = parentNodeAgg.prob(predict)
+
+    new Predict(predict, prob)
   }
 
   /**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
       nodeImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata,
-      splits: Array[Array[Split]]): (Split, InformationGainStats) = {
+      splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
 
     logDebug("node impurity = " + nodeImpurity)
 
+    // calculate predict only once
+    var predict: Option[Predict] = None
+
     // For each (feature, split), calculate the gain, and select the best (feature, split).
-    Range(0, metadata.numFeatures).map { featureIndex =>
+    val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
       val numSplits = metadata.numSplits(featureIndex)
       if (metadata.isContinuous(featureIndex)) {
         // Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
             val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset,
numSplits)
             rightChildStats.subtract(leftChildStats)
+            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
             val gainStats =
               calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level,
metadata)
             (splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
             val gainStats =
               calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level,
metadata)
             (splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats =
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
+            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
             val gainStats =
               calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level,
metadata)
             (splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
         (bestFeatureSplit, bestFeatureGainStats)
       }
     }.maxBy(_._2.gain)
+
+    require(predict.isDefined, "must calculate predict for each node")
+
+    (bestSplit, bestSplitStats, predict.get)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 23f74d5..987fe63 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  *                                k) implies the feature n is categorical with k categories
0,
  *                                1, 2, ... , k-1. It's important to note that features are
  *                                zero-indexed.
+ * @param minInstancesPerNode Minimum number of instances each child must have after split.
+ *                            Default value is 1. If a split cause left or right child
+ *                            to have less than minInstancesPerNode,
+ *                            this split will not be considered as a valid split.
+ * @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
+ *                    If a split has less information gain than minInfoGain,
+ *                    this split will not be considered as a valid split.
  * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default
value is
  *                      256 MB.
  */
@@ -61,6 +68,8 @@ class Strategy (
     val maxBins: Int = 32,
     val quantileCalculationStrategy: QuantileStrategy = Sort,
     val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+    val minInstancesPerNode: Int = 1,
+    val minInfoGain: Double = 0.0,
     val maxMemoryInMB: Int = 256) extends Serializable {
 
   if (algo == Classification) {

http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index e95add7..5ceaa81 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
     val unorderedFeatures: Set[Int],
     val numBins: Array[Int],
     val impurity: Impurity,
-    val quantileStrategy: QuantileStrategy) extends Serializable {
+    val quantileStrategy: QuantileStrategy,
+    val minInstancesPerNode: Int,
+    val minInfoGain: Double) extends Serializable {
 
   def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
 
@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
 
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
-      strategy.impurity, strategy.quantileCalculationStrategy)
+      strategy.impurity, strategy.quantileCalculationStrategy,
+      strategy.minInstancesPerNode, strategy.minInfoGain)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index fb12298..f3e2619 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi
  * @param impurity current node impurity
  * @param leftImpurity left node impurity
  * @param rightImpurity right node impurity
- * @param predict predicted value
- * @param prob probability of the label (classification only)
  */
 @DeveloperApi
 class InformationGainStats(
     val gain: Double,
     val impurity: Double,
     val leftImpurity: Double,
-    val rightImpurity: Double,
-    val predict: Double,
-    val prob: Double = 0.0) extends Serializable {
+    val rightImpurity: Double) extends Serializable {
 
   override def toString = {
-    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob
= %f"
-      .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
+      .format(gain, impurity, leftImpurity, rightImpurity)
   }
 }
+
+
+private[tree] object InformationGainStats {
+  /**
+   * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
+   * denote that current split doesn't satisfies minimum info gain or
+   * minimum number of instances per node.
+   */
+  val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0,
-1.0)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
new file mode 100644
index 0000000..6fac2be
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Predicted value for a node
+ * @param predict predicted value
+ * @param prob probability of the label (classification only)
+ */
+@DeveloperApi
+private[tree] class Predict(
+    val predict: Double,
+    val prob: Double = 0.0) extends Serializable{
+
+  override def toString = {
+    "predict = %f, prob = %f".format(predict, prob)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index 50fb48b..b7a85f5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
 
 /**
  * :: DeveloperApi ::

http://git-wip-us.apache.org/repos/asf/spark/blob/79cdb9b6/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 69482f2..fd8547c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
 import org.apache.spark.mllib.util.LocalSparkContext
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -279,9 +279,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(split.threshold === Double.MinValue)
 
     val stats = bestSplits(0)._2
+    val predict = bestSplits(0)._3
     assert(stats.gain > 0)
-    assert(stats.predict === 1)
-    assert(stats.prob === 0.6)
+    assert(predict.predict === 1)
+    assert(predict.prob === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -312,8 +313,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(split.threshold === Double.MinValue)
 
     val stats = bestSplits(0)._2
+    val predict = bestSplits(0)._3.predict
     assert(stats.gain > 0)
-    assert(stats.predict === 0.6)
+    assert(predict === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -387,7 +389,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.gain === 0)
     assert(bestSplits(0)._2.leftImpurity === 0)
     assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._2.predict === 1)
+    assert(bestSplits(0)._3.predict === 1)
   }
 
   test("Binary classification stump with fixed label 0 for Entropy") {
@@ -414,7 +416,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.gain === 0)
     assert(bestSplits(0)._2.leftImpurity === 0)
     assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._2.predict === 0)
+    assert(bestSplits(0)._3.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Entropy") {
@@ -441,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.gain === 0)
     assert(bestSplits(0)._2.leftImpurity === 0)
     assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._2.predict === 1)
+    assert(bestSplits(0)._3.predict === 1)
   }
 
   test("Second level node building with vs. without groups") {
@@ -490,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
       assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
       assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
-      assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
+      assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
     }
   }
 
@@ -674,6 +676,91 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     validateClassifier(model, arr, 0.6)
   }
 
+  test("split must satisfy min instances per node requirements") {
+    val arr = new Array[LabeledPoint](3)
+    arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+    arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+    arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini,
+      maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
+
+    val model = DecisionTree.train(input, strategy)
+    assert(model.topNode.isLeaf)
+    assert(model.topNode.predict == 0.0)
+    val predicts = input.map(p => model.predict(p.features)).collect()
+    predicts.foreach { predict =>
+      assert(predict == 0.0)
+    }
+
+    // test for findBestSplits when no valid split can be found
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
+
+    assert(bestSplits.length == 1)
+    val bestInfoStats = bestSplits(0)._2
+    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+  }
+
+  test("don't choose split that doesn't satisfy min instance per node requirements") {
+    // if a split doesn't satisfy min instances per node requirements,
+    // this split is invalid, even though the information gain of split is large.
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+    arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini,
+      maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
+      numClassesForClassification = 2, minInstancesPerNode = 2)
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
+
+    assert(bestSplits.length == 1)
+    val bestSplit = bestSplits(0)._1
+    val bestSplitStats = bestSplits(0)._1
+    assert(bestSplit.feature == 1)
+    assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+  }
+
+  test("split must satisfy min info gain requirements") {
+    val arr = new Array[LabeledPoint](3)
+    arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+    arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+    arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, minInfoGain = 1.0)
+
+    val model = DecisionTree.train(input, strategy)
+    assert(model.topNode.isLeaf)
+    assert(model.topNode.predict == 0.0)
+    val predicts = input.map(p => model.predict(p.features)).collect()
+    predicts.foreach { predict =>
+      assert(predict == 0.0)
+    }
+
+    // test for findBestSplits when no valid split can be found
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
+
+    assert(bestSplits.length == 1)
+    val bestInfoStats = bestSplits(0)._2
+    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+  }
 }
 
 object DecisionTreeSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message