spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-3160] [SPARK-3494] [mllib] DecisionTree: eliminate pre-allocated nodes, parentImpurities arrays. Memory calc bug fix.
Date Fri, 12 Sep 2014 08:38:08 GMT
Repository: spark
Updated Branches:
  refs/heads/master 42904b8d0 -> b8634df1f


[SPARK-3160] [SPARK-3494] [mllib]  DecisionTree: eliminate pre-allocated nodes, parentImpurities
arrays. Memory calc bug fix.

This PR includes some code simplifications and re-organization which will be helpful for implementing
random forests.  The main changes are that the nodes and parentImpurities arrays are no longer
pre-allocated in the main train() method.

Also added 2 bug fixes:
* maxMemoryUsage calculation
* over-allocation of space for bins in DTStatsAggregator for unordered features.

Relation to RFs:
* Since RFs will be deeper and will therefore be more likely sparse (not full trees), it could
be a cost savings to avoid pre-allocating a full tree.
* The associated re-organization also reduces bookkeeping, which will make RFs easier to implement.
* The return code doneTraining may be generalized to include cases such as nodes ready for
local training.

Details:

No longer pre-allocate parentImpurities array in main train() method.
* parentImpurities values are now stored in individual nodes (in Node.stats.impurity).
* These were not really needed.  They were used in calculateGainForSplit(), but they can be
calculated anyways using parentNodeAgg.

No longer using Node.build since tree structure is constructed on-the-fly.
* Did not eliminate since it is public (Developer) API.  Marked as deprecated.

Eliminated pre-allocated nodes array in main train() method.
* Nodes are constructed and added to the tree structure as needed during training.
* Moved tree construction from main train() method into findBestSplitsPerGroup() since there
is no need to keep the (split, gain) array for an entire level of nodes.  Only one element
of that array is needed at a time, so we do not the array.

findBestSplits() now returns 2 items:
* rootNode (newly created root node on first iteration, same root node on later iterations)
* doneTraining (indicating if all nodes at that level were leafs)

Updated DecisionTreeSuite.  Notes:
* Improved test "Second level node building with vs. without groups"
** generateOrderedLabeledPoints() modified so that it really does require 2 levels of internal
nodes.
* Related update: Added Node.deepCopy (private[tree]), used for test suite

CC: mengxr

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #2341 from jkbradley/dt-spark-3160 and squashes the following commits:

07dd1ee [Joseph K. Bradley] Fixed overflow bug with computing maxMemoryUsage in DecisionTree.
 Also fixed bug with over-allocating space in DTStatsAggregator for unordered features.
debe072 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >=
1
0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc
eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix
d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated
d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160
1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. *
Nodes are constructed and added to the tree structure as needed during training.
2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code:


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b8634df1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b8634df1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b8634df1

Branch: refs/heads/master
Commit: b8634df1f1eb6ce909bec779522c9c9912c7d06a
Parents: 42904b8
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Authored: Fri Sep 12 01:37:59 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Sep 12 01:37:59 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 191 ++++++-------
 .../mllib/tree/configuration/Strategy.scala     |   3 +
 .../mllib/tree/impl/DTStatsAggregator.scala     |  11 +-
 .../mllib/tree/impl/DecisionTreeMetadata.scala  |   3 +-
 .../mllib/tree/model/DecisionTreeModel.scala    |   2 +-
 .../apache/spark/mllib/tree/model/Node.scala    |  37 +++
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 277 ++++++++++---------
 7 files changed, 268 insertions(+), 256 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 9859656..56bb881 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -87,17 +87,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable
with Lo
     val maxDepth = strategy.maxDepth
     require(maxDepth <= 30,
       s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth =
$maxDepth.")
-    // Number of nodes to allocate: max number of nodes possible given the depth of the tree,
plus 1
-    val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
-    // Initialize an array to hold parent impurity calculations for each node.
-    val parentImpurities = new Array[Double](maxNumNodesPlus1)
-    // dummy value for top node (updated during first split calculation)
-    val nodes = new Array[Node](maxNumNodesPlus1)
 
     // Calculate level for single group construction
 
     // Max memory usage for aggregates
-    val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
+    val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
     // TODO: Calculate memory usage more precisely.
     val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
@@ -120,81 +114,35 @@ class DecisionTree (private val strategy: Strategy) extends Serializable
with Lo
      * beforehand and is not used in later levels.
      */
 
+    var topNode: Node = null // set on first iteration
     var level = 0
     var break = false
     while (level <= maxDepth && !break) {
-
       logDebug("#####################################")
       logDebug("level = " + level)
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
       timer.start("findBestSplits")
-      val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
-        DecisionTree.findBestSplits(treeInput, parentImpurities,
-          metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
+      val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
+        metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
       timer.stop("findBestSplits")
 
-      val levelNodeIndexOffset = Node.startIndexInLevel(level)
-      for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
-        val nodeIndex = levelNodeIndexOffset + index
-
-        // Extract info for this node (index) at the current level.
-        timer.start("extractNodeInfo")
-        val split = nodeSplitStats._1
-        val stats = nodeSplitStats._2
-        val predict = nodeSplitStats._3.predict
-        val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
-        val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
-        logDebug("Node = " + node)
-        nodes(nodeIndex) = node
-        timer.stop("extractNodeInfo")
-
-        if (level != 0) {
-          // Set parent.
-          val parentNodeIndex = Node.parentIndex(nodeIndex)
-          if (Node.isLeftChild(nodeIndex)) {
-            nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
-          } else {
-            nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
-          }
-        }
-        // Extract info for nodes at the next lower level.
-        timer.start("extractInfoForLowerLevels")
-        if (level < maxDepth) {
-          val leftChildIndex = Node.leftChildIndex(nodeIndex)
-          val leftImpurity = stats.leftImpurity
-          logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
-          parentImpurities(leftChildIndex) = leftImpurity
-
-          val rightChildIndex = Node.rightChildIndex(nodeIndex)
-          val rightImpurity = stats.rightImpurity
-          logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
-          parentImpurities(rightChildIndex) = rightImpurity
-        }
-        timer.stop("extractInfoForLowerLevels")
-        logDebug("final best split = " + split)
+      if (level == 0) {
+        topNode = tmpTopNode
       }
-      require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
-      // Check whether all the nodes at the current level at leaves.
-      val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
-      logDebug("all leaf = " + allLeaf)
-      if (allLeaf) {
-        break = true // no more tree construction
-      } else {
-        level += 1
+      if (doneTraining) {
+        break = true
+        logDebug("done training")
       }
+
+      level += 1
     }
 
     logDebug("#####################################")
     logDebug("Extracting tree model")
     logDebug("#####################################")
 
-    // Initialize the top or root node of the tree.
-    val topNode = nodes(1)
-    // Build the full tree using the node info calculated in the level-wise best split calculations.
-    topNode.build(nodes)
-
     timer.stop("total")
 
     logInfo("Internal timing for DecisionTree:")
@@ -409,24 +357,26 @@ object DecisionTree extends Serializable with Logging {
    * multiple groups if the level-wise training task could lead to memory overflow.
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
-   * @param parentImpurities Impurities for all parent nodes for the current level
    * @param metadata Learning and dataset metadata
    * @param level Level of the tree
+   * @param topNode Root node of the tree (or invalid node when training first level).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
-   * @return array (over nodes) of splits with best split for each node at a given level.
+   * @return  (root, doneTraining) where:
+   *          root = Root node (which is newly created on the first iteration),
+   *          doneTraining = true if no more internal nodes were created.
    */
   private[tree] def findBestSplits(
       input: RDD[TreePoint],
-      parentImpurities: Array[Double],
       metadata: DecisionTreeMetadata,
       level: Int,
-      nodes: Array[Node],
+      topNode: Node,
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       maxLevelForSingleGroup: Int,
-      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)]
= {
+      timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
+
     // split into groups to avoid memory overflow during aggregation
     if (level > maxLevelForSingleGroup) {
       // When information for all nodes at a given level cannot be stored in memory,
@@ -435,18 +385,18 @@ object DecisionTree extends Serializable with Logging {
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
       val numGroups = 1 << level - maxLevelForSingleGroup
       logDebug("numGroups = " + numGroups)
-      var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
+      var doneTraining = true
       while (groupIndex < numGroups) {
-        val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata,
level,
-          nodes, splits, bins, timer, numGroups, groupIndex)
-        bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
+        val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+          topNode, splits, bins, timer, numGroups, groupIndex)
+        doneTraining = doneTraining && doneTrainingGroup
         groupIndex += 1
       }
-      bestSplits
+      (topNode, doneTraining) // Not first iteration, so topNode was already set.
     } else {
-      findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins,
timer)
+      findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
     }
   }
 
@@ -586,27 +536,27 @@ object DecisionTree extends Serializable with Logging {
    * Returns an array of optimal splits for a group of nodes at a given level
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
-   * @param parentImpurities Impurities for all parent nodes for the current level
    * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param nodes Array of all nodes in the tree.  Used for matching data points to nodes.
+   * @param topNode Root node of the tree (or invalid node when training first level).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param numGroups total number of node groups at the current level. Default value is
set to 1.
    * @param groupIndex index of the node group being processed. Default value is set to 0.
-   * @return array of splits with best splits for all nodes at a given level.
+   * @return  (root, doneTraining) where:
+   *          root = Root node (which is newly created on the first iteration),
+   *          doneTraining = true if no more internal nodes were created.
    */
   private def findBestSplitsPerGroup(
       input: RDD[TreePoint],
-      parentImpurities: Array[Double],
       metadata: DecisionTreeMetadata,
       level: Int,
-      nodes: Array[Node],
+      topNode: Node,
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       timer: TimeTracker,
       numGroups: Int = 1,
-      groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
+      groupIndex: Int = 0): (Node, Boolean) = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted here.
@@ -663,7 +613,7 @@ object DecisionTree extends Serializable with Logging {
         0
       } else {
         val globalNodeIndex =
-          predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
+          predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
         globalNodeIndex - globalNodeIndexOffset
       }
     }
@@ -706,33 +656,63 @@ object DecisionTree extends Serializable with Logging {
 
     // Calculate best splits for all nodes at a given level
     timer.start("chooseSplits")
-    val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
-    // Iterating over all nodes at this level
+    // On the first iteration, we need to get and return the newly created root node.
+    var newTopNode: Node = topNode
+
+    // Iterate over all nodes at this level
     var nodeIndex = 0
+    var internalNodeCount = 0
     while (nodeIndex < numNodes) {
-      val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
-      logDebug("node impurity = " + nodeImpurity)
-      bestSplits(nodeIndex) =
-        binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
-      logDebug("best split = " + bestSplits(nodeIndex)._1)
+      val (split: Split, stats: InformationGainStats, predict: Predict) =
+        binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
+      logDebug("best split = " + split)
+
+      val globalNodeIndex = globalNodeIndexOffset + nodeIndex
+
+      // Extract info for this node at the current level.
+      val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
+      val node =
+        new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats))
+      logDebug("Node = " + node)
+
+      if (!isLeaf) {
+        internalNodeCount += 1
+      }
+      if (level == 0) {
+        newTopNode = node
+      } else {
+        // Set parent.
+        val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
+        if (Node.isLeftChild(globalNodeIndex)) {
+          parentNode.leftNode = Some(node)
+        } else {
+          parentNode.rightNode = Some(node)
+        }
+      }
+      if (level < metadata.maxDepth) {
+        logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
+          ", impurity = " + stats.leftImpurity)
+        logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
+          ", impurity = " + stats.rightImpurity)
+      }
+
       nodeIndex += 1
     }
     timer.stop("chooseSplits")
 
-    bestSplits
+    val doneTraining = internalNodeCount == 0
+    (newTopNode, doneTraining)
   }
 
   /**
    * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
    * @param leftImpurityCalculator left node aggregates for this (feature, split)
    * @param rightImpurityCalculator right node aggregate for this (feature, split)
-   * @param topImpurity impurity of the parent node
    * @return information gain and statistics for all splits
    */
   private def calculateGainForSplit(
       leftImpurityCalculator: ImpurityCalculator,
       rightImpurityCalculator: ImpurityCalculator,
-      topImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata): InformationGainStats = {
     val leftCount = leftImpurityCalculator.count
@@ -747,14 +727,10 @@ object DecisionTree extends Serializable with Logging {
 
     val totalCount = leftCount + rightCount
 
-    // impurity of parent node
-    val impurity = if (level > 0) {
-      topImpurity
-    } else {
-      val parentNodeAgg = leftImpurityCalculator.copy
-      parentNodeAgg.add(rightImpurityCalculator)
-      parentNodeAgg.calculate()
-    }
+    val parentNodeAgg = leftImpurityCalculator.copy
+    parentNodeAgg.add(rightImpurityCalculator)
+
+    val impurity = parentNodeAgg.calculate()
 
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count
= 0
     val rightImpurity = rightImpurityCalculator.calculate()
@@ -795,19 +771,15 @@ object DecisionTree extends Serializable with Logging {
    * Find the best split for a node.
    * @param binAggregates Bin statistics.
    * @param nodeIndex Index for node to split in this (level, group).
-   * @param nodeImpurity Impurity of the node (nodeIndex).
    * @return tuple for best split: (Split, information gain)
    */
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       nodeIndex: Int,
-      nodeImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata,
       splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
 
-    logDebug("node impurity = " + nodeImpurity)
-
     // calculate predict only once
     var predict: Option[Predict] = None
 
@@ -831,8 +803,7 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset,
numSplits)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats =
-              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level,
metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level,
metadata)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -845,8 +816,7 @@ object DecisionTree extends Serializable with Logging {
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats =
-              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level,
metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level,
metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -917,8 +887,7 @@ object DecisionTree extends Serializable with Logging {
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats =
-              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level,
metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level,
metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =
@@ -937,8 +906,8 @@ object DecisionTree extends Serializable with Logging {
   /**
    * Get the number of values to be stored per node in the bin aggregates.
    */
-  private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
-    val totalBins = metadata.numBins.sum
+  private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
+    val totalBins = metadata.numBins.map(_.toLong).sum
     if (metadata.isClassification) {
       metadata.numClasses * totalBins
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 987fe63..31d1e8a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -75,6 +75,9 @@ class Strategy (
   if (algo == Classification) {
     require(numClassesForClassification >= 2)
   }
+  require(minInstancesPerNode >= 1,
+    s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+
   val isMulticlassClassification =
     algo == Classification && numClassesForClassification > 2
   val isMulticlassWithCategoricalFeatures

http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 866d85a..61a9424 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
    * Offset for each feature for calculating indices into the [[allStats]] array.
    */
   private val featureOffsets: Array[Int] = {
-    def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
-      if (isUnordered(featureIndex)) {
-        total + 2 * numBins(featureIndex)
-      } else {
-        total + numBins(featureIndex)
-      }
-    }
-    Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
+    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
   }
 
   /**
@@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
       s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only,"
+
       s" but was called for ordered feature $featureIndex.")
     val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
-    (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 5ceaa81..b6d49e5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
     val numBins: Array[Int],
     val impurity: Impurity,
     val quantileStrategy: QuantileStrategy,
+    val maxDepth: Int,
     val minInstancesPerNode: Int,
     val minInfoGain: Double) extends Serializable {
 
@@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {
 
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
-      strategy.impurity, strategy.quantileCalculationStrategy,
+      strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
       strategy.minInstancesPerNode, strategy.minInfoGain)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 0594fd0..271b2c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
    * Predict values for the given data set using the model trained.
    *
    * @param features RDD representing data points to be predicted
-   * @return RDD[Int] where each entry contains the corresponding prediction
+   * @return RDD of predictions for each of the given data points
    */
   def predict(features: RDD[Vector]): RDD[Double] = {
     features.map(x => predict(x))

http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 5b8a4cb..5f0095d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -55,6 +55,8 @@ class Node (
    * build the left node and right nodes if not leaf
    * @param nodes array of nodes
    */
+  @deprecated("build should no longer be used since trees are constructed on-the-fly in training",
+    "1.2.0")
   def build(nodes: Array[Node]): Unit = {
     logDebug("building node " + id + " at level " + Node.indexToLevel(id))
     logDebug("id = " + id + ", split = " + split)
@@ -94,6 +96,23 @@ class Node (
   }
 
   /**
+   * Returns a deep copy of the subtree rooted at this node.
+   */
+  private[tree] def deepCopy(): Node = {
+    val leftNodeCopy = if (leftNode.isEmpty) {
+      None
+    } else {
+      Some(leftNode.get.deepCopy())
+    }
+    val rightNodeCopy = if (rightNode.isEmpty) {
+      None
+    } else {
+      Some(rightNode.get.deepCopy())
+    }
+    new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+  }
+
+  /**
    * Get the number of nodes in tree below this node, including leaf nodes.
    * E.g., if this is a leaf, returns 0.  If both children are leaves, returns 2.
    */
@@ -190,4 +209,22 @@ private[tree] object Node {
    */
   def startIndexInLevel(level: Int): Int = 1 << level
 
+  /**
+   * Traces down from a root node to get the node with the given node index.
+   * This assumes the node exists.
+   */
+  def getNode(nodeIndex: Int, rootNode: Node): Node = {
+    var tmpNode: Node = rootNode
+    var levelsToGo = indexToLevel(nodeIndex)
+    while (levelsToGo > 0) {
+      if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
+        tmpNode = tmpNode.leftNode.get
+      } else {
+        tmpNode = tmpNode.rightNode.get
+      }
+      levelsToGo -= 1
+    }
+    tmpNode
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b8634df1/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index fd8547c..1bd7ea0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -270,19 +270,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 0)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode: Node, doneTraining: Boolean) =
+      DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
 
-    val split = bestSplits(0)._1
+    val split = rootNode.split.get
     assert(split.categories === List(1.0))
     assert(split.featureType === Categorical)
     assert(split.threshold === Double.MinValue)
 
-    val stats = bestSplits(0)._2
-    val predict = bestSplits(0)._3
+    val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(predict.predict === 1)
-    assert(predict.prob === 0.6)
+    assert(rootNode.predict === 1)
     assert(stats.impurity > 0.2)
   }
 
@@ -303,19 +301,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    val split = bestSplits(0)._1
+    val split = rootNode.split.get
     assert(split.categories.length === 1)
     assert(split.categories.contains(1.0))
     assert(split.featureType === Categorical)
     assert(split.threshold === Double.MinValue)
 
-    val stats = bestSplits(0)._2
-    val predict = bestSplits(0)._3.predict
+    val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(predict === 0.6)
+    assert(rootNode.predict === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -356,13 +353,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Gini") {
@@ -382,14 +382,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._3.predict === 1)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
+    assert(rootNode.predict === 1)
   }
 
   test("Binary classification stump with fixed label 0 for Entropy") {
@@ -409,14 +412,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._3.predict === 0)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
+    assert(rootNode.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Entropy") {
@@ -436,14 +442,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._3.predict === 1)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
+    assert(rootNode.predict === 1)
   }
 
   test("Second level node building with vs. without groups") {
@@ -459,40 +468,46 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     // Train a 1-node model
-    val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+    val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+      numClassesForClassification = 2, maxBins = 100)
     val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val nodes: Array[Node] = new Array[Node](8)
-    nodes(1) = modelOneNode.topNode
-    nodes(1).leftNode = None
-    nodes(1).rightNode = None
-
-    val parentImpurities = Array(0, 0.5, 0.5, 0.5)
+    val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
+    val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
 
     // Single group second level tree construction.
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
nodes,
-      splits, bins, 10)
-    assert(bestSplits.length === 2)
-    assert(bestSplits(0)._2.gain > 0)
-    assert(bestSplits(1)._2.gain > 0)
+    val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
+      rootNodeCopy1, splits, bins, 10)
+    assert(rootNode.leftNode.nonEmpty)
+    assert(rootNode.rightNode.nonEmpty)
+    val children1 = new Array[Node](2)
+    children1(0) = rootNode.leftNode.get
+    children1(1) = rootNode.rightNode.get
 
     // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
     // level tree construction.
-    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata,
1,
-      nodes, splits, bins, 0)
-    assert(bestSplitsWithGroups.length === 2)
-    assert(bestSplitsWithGroups(0)._2.gain > 0)
-    assert(bestSplitsWithGroups(1)._2.gain > 0)
+    val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
+      rootNodeCopy2, splits, bins, 0)
+    assert(rootNode2.leftNode.nonEmpty)
+    assert(rootNode2.rightNode.nonEmpty)
+    val children2 = new Array[Node](2)
+    children2(0) = rootNode2.leftNode.get
+    children2(1) = rootNode2.rightNode.get
 
     // Verify whether the splits obtained using single group and multiple group level
     // construction strategies are the same.
-    for (i <- 0 until bestSplits.length) {
-      assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
-      assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
-      assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
-      assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
-      assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
-      assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
+    for (i <- 0 until 2) {
+      assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+      assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+      assert(children1(i).split === children2(i).split)
+      assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+      val stats1 = children1(i).stats.get
+      val stats2 = children2(i).stats.get
+      assert(stats1.gain === stats2.gain)
+      assert(stats1.impurity === stats2.impurity)
+      assert(stats1.leftImpurity === stats2.leftImpurity)
+      assert(stats1.rightImpurity === stats2.rightImpurity)
+      assert(children1(i).predict === children2(i).predict)
     }
   }
 
@@ -508,15 +523,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-    assert(bestSplit.feature === 0)
-    assert(bestSplit.categories.length === 1)
-    assert(bestSplit.categories.contains(1))
-    assert(bestSplit.featureType === Categorical)
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+    assert(split.categories.length === 1)
+    assert(split.categories.contains(1))
+    assert(split.featureType === Categorical)
   }
 
   test("Binary classification stump with 1 continuous feature, to check off-by-1 error")
{
@@ -573,16 +587,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-    assert(bestSplit.feature === 0)
-    assert(bestSplit.categories.length === 1)
-    assert(bestSplit.categories.contains(1))
-    assert(bestSplit.featureType === Categorical)
-    val gain = bestSplits(0)._2
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+    assert(split.categories.length === 1)
+    assert(split.categories.contains(1))
+    assert(split.featureType === Categorical)
+
+    val gain = rootNode.stats.get
     assert(gain.leftImpurity === 0)
     assert(gain.rightImpurity === 0)
   }
@@ -600,16 +614,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplit.feature === 1)
-    assert(bestSplit.featureType === Continuous)
-    assert(bestSplit.threshold > 1980)
-    assert(bestSplit.threshold < 2020)
+    val split = rootNode.split.get
+    assert(split.feature === 1)
+    assert(split.featureType === Continuous)
+    assert(split.threshold > 1980)
+    assert(split.threshold < 2020)
 
   }
 
@@ -627,16 +639,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-
-    assert(bestSplit.feature === 1)
-    assert(bestSplit.featureType === Continuous)
-    assert(bestSplit.threshold > 1980)
-    assert(bestSplit.threshold < 2020)
+    val split = rootNode.split.get
+    assert(split.feature === 1)
+    assert(split.featureType === Continuous)
+    assert(split.threshold > 1980)
+    assert(split.threshold < 2020)
   }
 
   test("Multiclass classification stump with 10-ary (ordered) categorical features") {
@@ -652,15 +662,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-    assert(bestSplit.feature === 0)
-    assert(bestSplit.categories.length === 1)
-    assert(bestSplit.categories.contains(1.0))
-    assert(bestSplit.featureType === Categorical)
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+    assert(split.categories.length === 1)
+    assert(split.categories.contains(1.0))
+    assert(split.featureType === Categorical)
   }
 
   test("Multiclass classification tree with 10-ary (ordered) categorical features," +
@@ -698,12 +707,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length == 1)
-    val bestInfoStats = bestSplits(0)._2
-    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+    val gain = rootNode.stats.get
+    assert(gain == InformationGainStats.invalidInformationGainStats)
   }
 
   test("don't choose split that doesn't satisfy min instance per node requirements") {
@@ -722,14 +730,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length == 1)
-    val bestSplit = bestSplits(0)._1
-    val bestSplitStats = bestSplits(0)._1
-    assert(bestSplit.feature == 1)
-    assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+    val split = rootNode.split.get
+    val gain = rootNode.stats.get
+    assert(split.feature == 1)
+    assert(gain != InformationGainStats.invalidInformationGainStats)
   }
 
   test("split must satisfy min info gain requirements") {
@@ -754,12 +761,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length == 1)
-    val bestInfoStats = bestSplits(0)._2
-    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+    val gain = rootNode.stats.get
+    assert(gain == InformationGainStats.invalidInformationGainStats)
   }
 }
 
@@ -786,13 +792,16 @@ object DecisionTreeSuite {
   def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](1000)
     for (i <- 0 until 1000) {
-      if (i < 600) {
-        val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
-        arr(i) = lp
+      val label = if (i < 100) {
+        0.0
+      } else if (i < 500) {
+        1.0
+      } else if (i < 900) {
+        0.0
       } else {
-        val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
-        arr(i) = lp
+        1.0
       }
+      arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i))
     }
     arr
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message