spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject [2/2] git commit: [SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up
Date Sun, 17 Aug 2014 06:53:22 GMT
[SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up

DecisionTree needs to match each example to a node at each iteration.  It currently does this with a set of filters very inefficiently: For each example, it examines each node at the current level and traces up to the root to see if that example should be handled by that node.

Fix: Filter top-down using the partly built tree itself.

Major changes:
* Eliminated Filter class, findBinsForLevel() method.
* Set up node parent links in main loop over levels in train().
* Added predictNodeIndex() for filtering top-down.
* Added DTMetadata class

Other changes:
* Pre-compute set of unorderedFeatures.

Notes for following expected PR based on [https://issues.apache.org/jira/browse/SPARK-3043]:
* The unorderedFeatures set will next be stored in a metadata structure to simplify function calls (to store other items such as the data in strategy).

I've done initial tests indicating that this speeds things up, but am only now running large-scale ones.

CC: mengxr manishamde chouqin  Any comments are welcome---thanks!

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1975 from jkbradley/dt-opt2 and squashes the following commits:

a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata.  Small doc updates.
3726d20 [Joseph K. Bradley] Small code improvements based on code review.
ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow.
db0d773 [Joseph K. Bradley] scala style fix
6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code
931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2
797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level.  Needed to update treePointToNodeIndex with groupShift.
f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint
6b5651e [Joseph K. Bradley] Updates based on code review.  1 major change: persisting to memory + disk, not just memory.
2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used.  Removed debugging println calls in DecisionTree.scala.
356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
430d782 [Joseph K. Bradley] Added more debug info on binning error.  Added some docs.
d036089 [Joseph K. Bradley] Print timing info to logDebug.
e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private
8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up.  Removed debugging println calls from DecisionTree.  Made TreePoint extend Serialiable
a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification
b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes
b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt
0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree
3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging)
f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
a95bc22 [Joseph K. Bradley] timing for DecisionTree internals


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/73ab7f14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/73ab7f14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/73ab7f14

Branch: refs/heads/master
Commit: 73ab7f141c205df277c6ac19252e590d6806c41f
Parents: fbad722
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Authored: Sat Aug 16 23:53:14 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sat Aug 16 23:53:14 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 878 ++++++++-----------
 .../mllib/tree/impl/DecisionTreeMetadata.scala  | 101 +++
 .../spark/mllib/tree/impl/TreePoint.scala       |  30 +-
 .../org/apache/spark/mllib/tree/model/Bin.scala |  18 +-
 .../mllib/tree/model/DecisionTreeModel.scala    |   2 +-
 .../apache/spark/mllib/tree/model/Filter.scala  |  28 -
 .../apache/spark/mllib/tree/model/Node.scala    |  16 +-
 .../apache/spark/mllib/tree/model/Split.scala   |   5 +-
 .../spark/mllib/tree/DecisionTreeSuite.scala    | 167 ++--
 9 files changed, 615 insertions(+), 630 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 2a3107a..6b9a8f7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
@@ -62,43 +62,38 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     timer.start("init")
 
     val retaggedInput = input.retag(classOf[LabeledPoint])
+    val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
     logDebug("algo = " + strategy.algo)
 
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
     timer.start("findSplitsBins")
-    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
     val numBins = bins(0).length
     timer.stop("findSplitsBins")
     logDebug("numBins = " + numBins)
 
+    // Bin feature values (TreePoint representation).
     // Cache input RDD for speedup during multiple passes.
-    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
+    val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
       .persist(StorageLevel.MEMORY_AND_DISK)
 
+    val numFeatures = metadata.numFeatures
     // depth of the decision tree
     val maxDepth = strategy.maxDepth
     // the max number of nodes possible given the depth of the tree
-    val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
-    // Initialize an array to hold filters applied to points for each node.
-    val filters = new Array[List[Filter]](maxNumNodes)
-    // The filter at the top node is an empty list.
-    filters(0) = List()
+    val maxNumNodes = (2 << maxDepth) - 1
     // Initialize an array to hold parent impurity calculations for each node.
     val parentImpurities = new Array[Double](maxNumNodes)
     // dummy value for top node (updated during first split calculation)
     val nodes = new Array[Node](maxNumNodes)
-    // num features
-    val numFeatures = treeInput.take(1)(0).binnedFeatures.size
 
     // Calculate level for single group construction
 
     // Max memory usage for aggregates
     val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
-      strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
-      strategy.algo)
+    val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins)
 
     logDebug("numElementsPerNode = " + numElementsPerNode)
     val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -114,9 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     /*
      * The main idea here is to perform level-wise training of the decision tree nodes thus
      * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
-     * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
-     * the sample is only used for the split calculation at the node if the sampled would have
-     * still survived the filters of the parent nodes.
+     * Each data sample is handled by a particular node at that level (or it reaches a leaf
+     * beforehand and is not used in later levels.
      */
 
     var level = 0
@@ -130,22 +124,37 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       // Find best split for all nodes at a level.
       timer.start("findBestSplits")
       val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
-        strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
+        metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
       timer.stop("findBestSplits")
 
+      val levelNodeIndexOffset = (1 << level) - 1
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+        val nodeIndex = levelNodeIndexOffset + index
+        val isLeftChild = level != 0 && nodeIndex % 2 == 1
+        val parentNodeIndex = if (isLeftChild) { // -1 for root node
+            (nodeIndex - 1) / 2
+          } else {
+            (nodeIndex - 2) / 2
+          }
+        // Extract info for this node (index) at the current level.
         timer.start("extractNodeInfo")
-        // Extract info for nodes at the current level.
         extractNodeInfo(nodeSplitStats, level, index, nodes)
         timer.stop("extractNodeInfo")
-        timer.start("extractInfoForLowerLevels")
+        if (level != 0) {
+          // Set parent.
+          if (isLeftChild) {
+            nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
+          } else {
+            nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
+          }
+        }
         // Extract info for nodes at the next lower level.
-        extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
-          filters)
+        timer.start("extractInfoForLowerLevels")
+        extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities)
         timer.stop("extractInfoForLowerLevels")
         logDebug("final best split = " + nodeSplitStats._1)
       }
-      require(math.pow(2, level) == splitsStatsForLevel.length)
+      require((1 << level) == splitsStatsForLevel.length)
       // Check whether all the nodes at the current level at leaves.
       val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
       logDebug("all leaf = " + allLeaf)
@@ -183,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       nodes: Array[Node]): Unit = {
     val split = nodeSplitStats._1
     val stats = nodeSplitStats._2
-    val nodeIndex = math.pow(2, level).toInt - 1 + index
+    val nodeIndex = (1 << level) - 1 + index
     val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
     val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
     logDebug("Node = " + node)
@@ -198,31 +207,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       index: Int,
       maxDepth: Int,
       nodeSplitStats: (Split, InformationGainStats),
-      parentImpurities: Array[Double],
-      filters: Array[List[Filter]]): Unit = {
-    // 0 corresponds to the left child node and 1 corresponds to the right child node.
-    var i = 0
-    while (i <= 1) {
-     // Calculate the index of the node from the node level and the index at the current level.
-      val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
-      if (level < maxDepth) {
-        val impurity = if (i == 0) {
-          nodeSplitStats._2.leftImpurity
-        } else {
-          nodeSplitStats._2.rightImpurity
-        }
-        logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
-        // noting the parent impurities
-        parentImpurities(nodeIndex) = impurity
-        // noting the parents filters for the child nodes
-        val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
-        filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
-        for (filter <- filters(nodeIndex)) {
-          logDebug("Filter = " + filter)
-        }
-      }
-      i += 1
+      parentImpurities: Array[Double]): Unit = {
+
+    if (level >= maxDepth) {
+      return
     }
+
+    val leftNodeIndex = (2 << level) - 1 + 2 * index
+    val leftImpurity = nodeSplitStats._2.leftImpurity
+    logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity)
+    parentImpurities(leftNodeIndex) = leftImpurity
+
+    val rightNodeIndex = leftNodeIndex + 1
+    val rightImpurity = nodeSplitStats._2.rightImpurity
+    logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity)
+    parentImpurities(rightNodeIndex) = rightImpurity
   }
 }
 
@@ -434,10 +433,8 @@ object DecisionTree extends Serializable with Logging {
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
-   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
-   *                 parameters for constructing the DecisionTree
+   * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param filters Filters for all nodes at a given level
    * @param splits possible splits for all features
    * @param bins possible bins for all features
    * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
@@ -446,9 +443,9 @@ object DecisionTree extends Serializable with Logging {
   protected[tree] def findBestSplits(
       input: RDD[TreePoint],
       parentImpurities: Array[Double],
-      strategy: Strategy,
+      metadata: DecisionTreeMetadata,
       level: Int,
-      filters: Array[List[Filter]],
+      nodes: Array[Node],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       maxLevelForSingleGroup: Int,
@@ -459,34 +456,32 @@ object DecisionTree extends Serializable with Logging {
       // the nodes are divided into multiple groups at each level with the number of groups
       // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
-      val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
+      val numGroups = 1 << level - maxLevelForSingleGroup
       logDebug("numGroups = " + numGroups)
       var bestSplits = new Array[(Split, InformationGainStats)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
       while (groupIndex < numGroups) {
-        val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
-          filters, splits, bins, timer, numGroups, groupIndex)
+        val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
+          nodes, splits, bins, timer, numGroups, groupIndex)
         bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
         groupIndex += 1
       }
       bestSplits
     } else {
-      findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer)
+      findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
     }
   }
 
-    /**
+  /**
    * Returns an array of optimal splits for a group of nodes at a given level
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
    * @param parentImpurities Impurities for all parent nodes for the current level
-   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
-   *                 parameters for constructing the DecisionTree
+   * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param filters Filters for all nodes at a given level
    * @param splits possible splits for all features
-   * @param bins possible bins for all features
+   * @param bins possible bins for all features, indexed as (numFeatures)(numBins)
    * @param numGroups total number of node groups at the current level. Default value is set to 1.
    * @param groupIndex index of the node group being processed. Default value is set to 0.
    * @return array of splits with best splits for all nodes at a given level.
@@ -494,9 +489,9 @@ object DecisionTree extends Serializable with Logging {
   private def findBestSplitsPerGroup(
       input: RDD[TreePoint],
       parentImpurities: Array[Double],
-      strategy: Strategy,
+      metadata: DecisionTreeMetadata,
       level: Int,
-      filters: Array[List[Filter]],
+      nodes: Array[Node],
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       timer: TimeTracker,
@@ -515,7 +510,7 @@ object DecisionTree extends Serializable with Logging {
      * We use a bin-wise best split computation strategy instead of a straightforward best split
      * computation strategy. Instead of analyzing each sample for contribution to the left/right
      * child node impurity of every split, we first categorize each feature of a sample into a
-     * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
+     * bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
      * is ordered (read ordering for categorical variables in the findSplitsBins method),
      * we exploit this structure to calculate aggregates for bins and then use these aggregates
      * to calculate information gain for each split.
@@ -531,160 +526,124 @@ object DecisionTree extends Serializable with Logging {
 
     // numNodes:  Number of nodes in this (level of tree, group),
     //            where nodes at deeper (larger) levels may be divided into groups.
-    val numNodes = math.pow(2, level).toInt / numGroups
+    val numNodes = (1 << level) / numGroups
     logDebug("numNodes = " + numNodes)
 
     // Find the number of features by looking at the first sample.
-    val numFeatures = input.first().binnedFeatures.size
+    val numFeatures = metadata.numFeatures
     logDebug("numFeatures = " + numFeatures)
 
     // numBins:  Number of bins = 1 + number of possible splits
     val numBins = bins(0).length
     logDebug("numBins = " + numBins)
 
-    val numClasses = strategy.numClassesForClassification
+    val numClasses = metadata.numClasses
     logDebug("numClasses = " + numClasses)
 
-    val isMulticlassClassification = strategy.isMulticlassClassification
-    logDebug("isMulticlassClassification = " + isMulticlassClassification)
+    val isMulticlass = metadata.isMulticlass
+    logDebug("isMulticlass = " + isMulticlass)
 
-    val isMulticlassClassificationWithCategoricalFeatures
-      = strategy.isMulticlassWithCategoricalFeatures
-    logDebug("isMultiClassWithCategoricalFeatures = " +
-      isMulticlassClassificationWithCategoricalFeatures)
+    val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures
+    logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures)
 
     // shift when more than one group is used at deep tree level
     val groupShift = numNodes * groupIndex
 
-    /** Find the filters used before reaching the current code. */
-    def findParentFilters(nodeIndex: Int): List[Filter] = {
-      if (level == 0) {
-        List[Filter]()
-      } else {
-        val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift
-        filters(nodeFilterIndex)
-      }
-    }
-
     /**
-     * Find whether the sample is valid input for the current node, i.e., whether it passes through
-     * all the filters for the current node.
+     * Get the node index corresponding to this data point.
+     * This function mimics prediction, passing an example from the root node down to a node
+     * at the current level being trained; that node's index is returned.
+     *
+     * @return  Leaf index if the data point reaches a leaf.
+     *          Otherwise, last node reachable in tree matching this example.
      */
-    def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = {
-      // leaf
-      if ((level > 0) && (parentFilters.length == 0)) {
-        return false
-      }
-
-      // Apply each filter and check sample validity. Return false when invalid condition found.
-      parentFilters.foreach { filter =>
-        val featureIndex = filter.split.feature
-        val comparison = filter.comparison
-        val isFeatureContinuous = filter.split.featureType == Continuous
-        if (isFeatureContinuous) {
-          val binId = treePoint.binnedFeatures(featureIndex)
-          val bin = bins(featureIndex)(binId)
-          val featureValue = bin.highSplit.threshold
-          val threshold = filter.split.threshold
-          comparison match {
-            case -1 => if (featureValue > threshold) return false
-            case 1 => if (featureValue <= threshold) return false
+    def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = {
+      if (node.isLeaf) {
+        node.id
+      } else {
+        val featureIndex = node.split.get.feature
+        val splitLeft = node.split.get.featureType match {
+          case Continuous => {
+            val binIndex = binnedFeatures(featureIndex)
+            val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+            // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
+            // We do not need to check lowSplit since bins are separated by splits.
+            featureValueUpperBound <= node.split.get.threshold
           }
-        } else {
-          val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-          val isSpaceSufficientForAllCategoricalSplits =
-            numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
-          val isUnorderedFeature =
-            isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-          val featureValue = if (isUnorderedFeature) {
-            treePoint.binnedFeatures(featureIndex)
+          case Categorical => {
+            val featureValue = if (metadata.isUnordered(featureIndex)) {
+                binnedFeatures(featureIndex)
+              } else {
+                val binIndex = binnedFeatures(featureIndex)
+                bins(featureIndex)(binIndex).category
+              }
+            node.split.get.categories.contains(featureValue)
+          }
+          case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
+        }
+        if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
+          // Return index from next layer of nodes to train
+          if (splitLeft) {
+            node.id * 2 + 1 // left
           } else {
-            val binId = treePoint.binnedFeatures(featureIndex)
-            bins(featureIndex)(binId).category
+            node.id * 2 + 2 // right
           }
-          val containsFeature = filter.split.categories.contains(featureValue)
-          comparison match {
-            case -1 => if (!containsFeature) return false
-            case 1 => if (containsFeature) return false
+        } else {
+          if (splitLeft) {
+            predictNodeIndex(node.leftNode.get, binnedFeatures)
+          } else {
+            predictNodeIndex(node.rightNode.get, binnedFeatures)
           }
         }
       }
+    }
 
-      // Return true when the sample is valid for all filters.
-      true
+    def nodeIndexToLevel(idx: Int): Int = {
+      if (idx == 0) {
+        0
+      } else {
+        math.floor(math.log(idx) / math.log(2)).toInt
+      }
     }
 
+    // Used for treePointToNodeIndex
+    val levelOffset = (1 << level) - 1
+
     /**
-     * Finds bins for all nodes (and all features) at a given level.
-     * For l nodes, k features the storage is as follows:
-     * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
-     * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
-     * classification and the categorical feature value in  multiclass classification.
-     * Invalid sample is denoted by noting bin for feature 1 as -1.
-     *
-     * For unordered features, the "bin index" returned is actually the feature value (category).
-     *
-     * @return  Array of size 1 + numFeatures * numNodes, where
-     *          arr(0) = label for labeledPoint, and
-     *          arr(1 + numFeatures * nodeIndex + featureIndex) =
-     *            bin index for this labeledPoint
-     *            (or InvalidBinIndex if labeledPoint is not handled by this node)
+     * Find the node index for the given example.
+     * Nodes are indexed from 0 at the start of this (level, group).
+     * If the example does not reach this level, returns a value < 0.
      */
-    def findBinsForLevel(treePoint: TreePoint): Array[Double] = {
-      // Calculate bin index and label per feature per node.
-      val arr = new Array[Double](1 + (numFeatures * numNodes))
-      // First element of the array is the label of the instance.
-      arr(0) = treePoint.label
-      // Iterate over nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        val parentFilters = findParentFilters(nodeIndex)
-        // Find out whether the sample qualifies for the particular node.
-        val sampleValid = isSampleValid(parentFilters, treePoint)
-        val shift = 1 + numFeatures * nodeIndex
-        if (!sampleValid) {
-          // Mark one bin as -1 is sufficient.
-          arr(shift) = InvalidBinIndex
-        } else {
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
-            featureIndex += 1
-          }
-        }
-        nodeIndex += 1
+    def treePointToNodeIndex(treePoint: TreePoint): Int = {
+      if (level == 0) {
+        0
+      } else {
+        val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures)
+        // Get index for this (level, group).
+        globalNodeIndex - levelOffset - groupShift
       }
-      arr
     }
 
-    // Find feature bins for all nodes at a level.
-    timer.start("aggregation")
-    val binMappedRDD = input.map(x => findBinsForLevel(x))
-
     /**
      * Increment aggregate in location for (node, feature, bin, label).
      *
-     * @param arr  Bin mapping from findBinsForLevel.  arr(0) stores the class label.
-     *             Array of size 1 + (numFeatures * numNodes).
+     * @param treePoint  Data point being aggregated.
      * @param agg  Array storing aggregate calculation, of size:
      *             numClasses * numBins * numFeatures * numNodes.
      *             Indexed by (node, feature, bin, label) where label is the least significant bit.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      */
     def updateBinForOrderedFeature(
-        arr: Array[Double],
+        treePoint: TreePoint,
         agg: Array[Double],
         nodeIndex: Int,
-        label: Double,
         featureIndex: Int): Unit = {
-      // Find the bin index for this feature.
-      val arrShift = 1 + numFeatures * nodeIndex
-      val arrIndex = arrShift + featureIndex
       // Update the left or right count for one bin.
       val aggIndex =
         numClasses * numBins * numFeatures * nodeIndex +
         numClasses * numBins * featureIndex +
-        numClasses * arr(arrIndex).toInt +
-        label.toInt
+        numClasses * treePoint.binnedFeatures(featureIndex) +
+        treePoint.label.toInt
       agg(aggIndex) += 1
     }
 
@@ -693,8 +652,8 @@ object DecisionTree extends Serializable with Logging {
      * where [bins] ranges over all bins.
      * Updates left or right side of aggregate depending on split.
      *
-     * @param arr  arr(0) = label.
-     *             arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category)
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+     * @param treePoint  Data point being aggregated.
      * @param agg  Indexed by (left/right, node, feature, bin, label)
      *             where label is the least significant bit.
      *             The left/right specifier is a 0/1 index indicating left/right child info.
@@ -703,21 +662,18 @@ object DecisionTree extends Serializable with Logging {
     def updateBinForUnorderedFeature(
         nodeIndex: Int,
         featureIndex: Int,
-        arr: Array[Double],
-        label: Double,
+        treePoint: TreePoint,
         agg: Array[Double],
         rightChildShift: Int): Unit = {
-      // Find the bin index for this feature.
-      val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
-      val featureValue = arr(arrIndex).toInt
+      val featureValue = treePoint.binnedFeatures(featureIndex)
       // Update the left or right count for one bin.
       val aggShift =
         numClasses * numBins * numFeatures * nodeIndex +
         numClasses * numBins * featureIndex +
-        label.toInt
+        treePoint.label.toInt
       // Find all matching bins and increment their values
-      val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-      val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+      val featureCategories = metadata.featureArity(featureIndex)
+      val numCategoricalBins = (1 << featureCategories - 1) - 1
       var binIndex = 0
       while (binIndex < numCategoricalBins) {
         val aggIndex = aggShift + binIndex * numClasses
@@ -733,30 +689,21 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Helper for binSeqOp.
      *
-     * @param arr  Bin mapping from findBinsForLevel. arr(0) stores the class label.
-     *             Array of size 1 + (numFeatures * numNodes).
      * @param agg  Array storing aggregate calculation, of size:
      *             numClasses * numBins * numFeatures * numNodes.
      *             Indexed by (node, feature, bin, label) where label is the least significant bit.
+     * @param treePoint  Data point being aggregated.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      */
-    def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
-      // Iterate over all nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        // Check whether the instance was valid for this nodeIndex.
-        val validSignalIndex = 1 + numFeatures * nodeIndex
-        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
-        if (isSampleValidForNode) {
-          // actual class label
-          val label = arr(0)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
-            featureIndex += 1
-          }
-        }
-        nodeIndex += 1
+    def binaryOrNotCategoricalBinSeqOp(
+        agg: Array[Double],
+        treePoint: TreePoint,
+        nodeIndex: Int): Unit = {
+      // Iterate over all features.
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
+        featureIndex += 1
       }
     }
 
@@ -765,49 +712,28 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Helper for binSeqOp.
      *
-     * @param arr  Bin mapping from findBinsForLevel. arr(0) stores the class label.
-     *             Array of size 1 + (numFeatures * numNodes).
-     *             For ordered features,
-     *               arr(1 + featureIndex + nodeIndex * numFeatures) = bin index.
-     *             For unordered features,
-     *               arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category).
      * @param agg  Array storing aggregate calculation.
      *             For ordered features, this is of size:
      *               numClasses * numBins * numFeatures * numNodes.
      *             For unordered features, this is of size:
      *               2 * numClasses * numBins * numFeatures * numNodes.
+     * @param treePoint   Data point being aggregated.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      */
-    def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
-      // Iterate over all nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        // Check whether the instance was valid for this nodeIndex.
-        val validSignalIndex = 1 + numFeatures * nodeIndex
-        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
-        if (isSampleValidForNode) {
-          // actual class label
-          val label = arr(0)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-            if (isFeatureContinuous) {
-              updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
-            } else {
-              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-              val isSpaceSufficientForAllCategoricalSplits
-                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-              if (isSpaceSufficientForAllCategoricalSplits) {
-                updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg,
-                  rightChildShift)
-              } else {
-                updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
-              }
-            }
-            featureIndex += 1
-          }
+    def multiclassWithCategoricalBinSeqOp(
+        agg: Array[Double],
+        treePoint: TreePoint,
+        nodeIndex: Int): Unit = {
+      val label = treePoint.label
+      // Iterate over all features.
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        if (metadata.isUnordered(featureIndex)) {
+          updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift)
+        } else {
+          updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex)
         }
-        nodeIndex += 1
+        featureIndex += 1
       }
     }
 
@@ -818,36 +744,25 @@ object DecisionTree extends Serializable with Logging {
      *
      * @param agg Array storing aggregate calculation, updated by this function.
      *            Size: 3 * numBins * numFeatures * numNodes
-     * @param arr Bin mapping from findBinsForLevel.
-     *             Array of size 1 + (numFeatures * numNodes).
+     * @param treePoint   Data point being aggregated.
+     * @param nodeIndex  Node corresponding to treePoint. Indexed from 0 at start of (level, group).
      * @return agg
      */
-    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
-      // Iterate over all nodes.
-      var nodeIndex = 0
-      while (nodeIndex < numNodes) {
-        // Check whether the instance was valid for this nodeIndex.
-        val validSignalIndex = 1 + numFeatures * nodeIndex
-        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
-        if (isSampleValidForNode) {
-          // actual class label
-          val label = arr(0)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            // Find the bin index for this feature.
-            val arrShift = 1 + numFeatures * nodeIndex
-            val arrIndex = arrShift + featureIndex
-            // Update count, sum, and sum^2 for one bin.
-            val aggShift = 3 * numBins * numFeatures * nodeIndex
-            val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
-            agg(aggIndex) = agg(aggIndex) + 1
-            agg(aggIndex + 1) = agg(aggIndex + 1) + label
-            agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
-            featureIndex += 1
-          }
-        }
-        nodeIndex += 1
+    def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
+      val label = treePoint.label
+      // Iterate over all features.
+      var featureIndex = 0
+      while (featureIndex < numFeatures) {
+        // Update count, sum, and sum^2 for one bin.
+        val binIndex = treePoint.binnedFeatures(featureIndex)
+        val aggIndex =
+          3 * numBins * numFeatures * nodeIndex +
+          3 * numBins * featureIndex +
+          3 * binIndex
+        agg(aggIndex) += 1
+        agg(aggIndex + 1) += label
+        agg(aggIndex + 2) += label * label
+        featureIndex += 1
       }
     }
 
@@ -866,26 +781,30 @@ object DecisionTree extends Serializable with Logging {
      *              2 * numClasses * numBins * numFeatures * numNodes for unordered features.
      *            Size for regression:
      *              3 * numBins * numFeatures * numNodes.
-     * @param arr  Bin mapping from findBinsForLevel.
-     *             Array of size 1 + (numFeatures * numNodes).
+     * @param treePoint   Data point being aggregated.
      * @return  agg
      */
-    def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
-      strategy.algo match {
-        case Classification =>
-          if(isMulticlassClassificationWithCategoricalFeatures) {
-            multiclassWithCategoricalBinSeqOp(arr, agg)
+    def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = {
+      val nodeIndex = treePointToNodeIndex(treePoint)
+      // If the example does not reach this level, then nodeIndex < 0.
+      // If the example reaches this level but is handled in a different group,
+      //  then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
+      if (nodeIndex >= 0 && nodeIndex < numNodes) {
+        if (metadata.isClassification) {
+          if (isMulticlassWithCategoricalFeatures) {
+            multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex)
           } else {
-            binaryOrNotCategoricalBinSeqOp(arr, agg)
+            binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex)
           }
-        case Regression => regressionBinSeqOp(arr, agg)
+        } else {
+          regressionBinSeqOp(agg, treePoint, nodeIndex)
+        }
       }
       agg
     }
 
     // Calculate bin aggregate length for classification or regression.
-    val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
-        isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
+    val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins)
     logDebug("binAggregateLength = " + binAggregateLength)
 
     /**
@@ -905,144 +824,134 @@ object DecisionTree extends Serializable with Logging {
     }
 
     // Calculate bin aggregates.
+    timer.start("aggregation")
     val binAggregates = {
-      binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
+      input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
     }
     timer.stop("aggregation")
     logDebug("binAggregates.length = " + binAggregates.length)
 
     /**
-     * Calculates the information gain for all splits based upon left/right split aggregates.
-     * @param leftNodeAgg left node aggregates
-     * @param featureIndex feature index
-     * @param splitIndex split index
-     * @param rightNodeAgg right node aggregate
+     * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+     * @param leftNodeAgg left node aggregates for this (feature, split)
+     * @param rightNodeAgg right node aggregate for this (feature, split)
      * @param topImpurity impurity of the parent node
      * @return information gain and statistics for all splits
      */
     def calculateGainForSplit(
-        leftNodeAgg: Array[Array[Array[Double]]],
-        featureIndex: Int,
-        splitIndex: Int,
-        rightNodeAgg: Array[Array[Array[Double]]],
+        leftNodeAgg: Array[Double],
+        rightNodeAgg: Array[Double],
         topImpurity: Double): InformationGainStats = {
-      strategy.algo match {
-        case Classification =>
-          val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex)
-          val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex)
-          val leftTotalCount = leftCounts.sum
-          val rightTotalCount = rightCounts.sum
-
-          val impurity = {
-            if (level > 0) {
-              topImpurity
-            } else {
-              // Calculate impurity for root node.
-              val rootNodeCounts = new Array[Double](numClasses)
-              var classIndex = 0
-              while (classIndex < numClasses) {
-                rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
-                classIndex += 1
-              }
-              strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
-            }
-          }
+      if (metadata.isClassification) {
+        val leftTotalCount = leftNodeAgg.sum
+        val rightTotalCount = rightNodeAgg.sum
 
-          val totalCount = leftTotalCount + rightTotalCount
-          if (totalCount == 0) {
-            // Return arbitrary prediction.
-            return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+        val impurity = {
+          if (level > 0) {
+            topImpurity
+          } else {
+            // Calculate impurity for root node.
+            val rootNodeCounts = new Array[Double](numClasses)
+            var classIndex = 0
+            while (classIndex < numClasses) {
+              rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex)
+              classIndex += 1
+            }
+            metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
           }
+        }
 
-          // Sum of count for each label
-          val leftRightCounts: Array[Double] =
-            leftCounts.zip(rightCounts).map { case (leftCount, rightCount) =>
-              leftCount + rightCount
-            }
+        val totalCount = leftTotalCount + rightTotalCount
+        if (totalCount == 0) {
+          // Return arbitrary prediction.
+          return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+        }
 
-          def indexOfLargestArrayElement(array: Array[Double]): Int = {
-            val result = array.foldLeft(-1, Double.MinValue, 0) {
-              case ((maxIndex, maxValue, currentIndex), currentValue) =>
-                if (currentValue > maxValue) {
-                  (currentIndex, currentValue, currentIndex + 1)
-                } else {
-                  (maxIndex, maxValue, currentIndex + 1)
-                }
-            }
-            if (result._1 < 0) {
-              throw new RuntimeException("DecisionTree internal error:" +
-                " calculateGainForSplit failed in indexOfLargestArrayElement")
-            }
-            result._1
+        // Sum of count for each label
+        val leftrightNodeAgg: Array[Double] =
+          leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) =>
+            leftCount + rightCount
           }
 
-          val predict = indexOfLargestArrayElement(leftRightCounts)
-          val prob = leftRightCounts(predict) / totalCount
-
-          val leftImpurity = if (leftTotalCount == 0) {
-            topImpurity
-          } else {
-            strategy.impurity.calculate(leftCounts, leftTotalCount)
+        def indexOfLargestArrayElement(array: Array[Double]): Int = {
+          val result = array.foldLeft(-1, Double.MinValue, 0) {
+            case ((maxIndex, maxValue, currentIndex), currentValue) =>
+              if (currentValue > maxValue) {
+                (currentIndex, currentValue, currentIndex + 1)
+              } else {
+                (maxIndex, maxValue, currentIndex + 1)
+              }
           }
-          val rightImpurity = if (rightTotalCount == 0) {
-            topImpurity
-          } else {
-            strategy.impurity.calculate(rightCounts, rightTotalCount)
+          if (result._1 < 0) {
+            throw new RuntimeException("DecisionTree internal error:" +
+              " calculateGainForSplit failed in indexOfLargestArrayElement")
           }
+          result._1
+        }
 
-          val leftWeight = leftTotalCount / totalCount
-          val rightWeight = rightTotalCount / totalCount
+        val predict = indexOfLargestArrayElement(leftrightNodeAgg)
+        val prob = leftrightNodeAgg(predict) / totalCount
 
-          val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+        val leftImpurity = if (leftTotalCount == 0) {
+          topImpurity
+        } else {
+          metadata.impurity.calculate(leftNodeAgg, leftTotalCount)
+        }
+        val rightImpurity = if (rightTotalCount == 0) {
+          topImpurity
+        } else {
+          metadata.impurity.calculate(rightNodeAgg, rightTotalCount)
+        }
 
-          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+        val leftWeight = leftTotalCount / totalCount
+        val rightWeight = rightTotalCount / totalCount
 
-        case Regression =>
-          val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
-          val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
-          val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
+        val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
-          val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
-          val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
-          val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
+        new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
 
-          val impurity = {
-            if (level > 0) {
-              topImpurity
-            } else {
-              // Calculate impurity for root node.
-              val count = leftCount + rightCount
-              val sum = leftSum + rightSum
-              val sumSquares = leftSumSquares + rightSumSquares
-              strategy.impurity.calculate(count, sum, sumSquares)
-            }
-          }
+      } else {
+        // Regression
 
-          if (leftCount == 0) {
-            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
-              rightSum / rightCount)
-          }
-          if (rightCount == 0) {
-            return new InformationGainStats(0, topImpurity ,topImpurity,
-              Double.MinValue, leftSum / leftCount)
+        val leftCount = leftNodeAgg(0)
+        val leftSum = leftNodeAgg(1)
+        val leftSumSquares = leftNodeAgg(2)
+
+        val rightCount = rightNodeAgg(0)
+        val rightSum = rightNodeAgg(1)
+        val rightSumSquares = rightNodeAgg(2)
+
+        val impurity = {
+          if (level > 0) {
+            topImpurity
+          } else {
+            // Calculate impurity for root node.
+            val count = leftCount + rightCount
+            val sum = leftSum + rightSum
+            val sumSquares = leftSumSquares + rightSumSquares
+            metadata.impurity.calculate(count, sum, sumSquares)
           }
+        }
+
+        if (leftCount == 0) {
+          return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
+            rightSum / rightCount)
+        }
+        if (rightCount == 0) {
+          return new InformationGainStats(0, topImpurity, topImpurity,
+            Double.MinValue, leftSum / leftCount)
+        }
 
-          val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
-          val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
+        val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares)
+        val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares)
 
-          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
-          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+        val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+        val rightWeight = rightCount.toDouble / (leftCount + rightCount)
 
-          val gain = {
-            if (level > 0) {
-              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-            } else {
-              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
-            }
-          }
+        val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
-          val predict = (leftSum + rightSum) / (leftCount + rightCount)
-          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+        val predict = (leftSum + rightSum) / (leftCount + rightCount)
+        new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
       }
     }
 
@@ -1065,6 +974,19 @@ object DecisionTree extends Serializable with Logging {
         binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
 
 
+      /**
+       * The input binData is indexed as (feature, bin, class).
+       * This computes cumulative sums over splits.
+       * Each (feature, class) pair is handled separately.
+       * Note: numSplits = numBins - 1.
+       * @param leftNodeAgg  Each (feature, class) slice is an array over splits.
+       *                     Element i (i = 0, ..., numSplits - 2) is set to be
+       *                     the cumulative sum (from left) over binData for bins 0, ..., i.
+       * @param rightNodeAgg Each (feature, class) slice is an array over splits.
+       *                     Element i (i = 1, ..., numSplits - 1) is set to be
+       *                     the cumulative sum (from right) over binData for bins
+       *                     numBins - 1, ..., numBins - 1 - i.
+       */
       def findAggForOrderedFeatureClassification(
           leftNodeAgg: Array[Array[Array[Double]]],
           rightNodeAgg: Array[Array[Array[Double]]],
@@ -1169,45 +1091,32 @@ object DecisionTree extends Serializable with Logging {
         }
       }
 
-      strategy.algo match {
-        case Classification =>
-          // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            if (isMulticlassClassificationWithCategoricalFeatures) {
-              val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-              if (isFeatureContinuous) {
-                findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-              } else {
-                val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-                val isSpaceSufficientForAllCategoricalSplits
-                  = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-                if (isSpaceSufficientForAllCategoricalSplits) {
-                  findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-                } else {
-                  findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-                }
-              }
-            } else {
-              findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
-            }
-            featureIndex += 1
-          }
-
-          (leftNodeAgg, rightNodeAgg)
-        case Regression =>
-          // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
-          // Iterate over all features.
-          var featureIndex = 0
-          while (featureIndex < numFeatures) {
-            findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
-            featureIndex += 1
+      if (metadata.isClassification) {
+        // Initialize left and right split aggregates.
+        val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+        val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+        var featureIndex = 0
+        while (featureIndex < numFeatures) {
+          if (metadata.isUnordered(featureIndex)) {
+            findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+          } else {
+            findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
           }
-          (leftNodeAgg, rightNodeAgg)
+          featureIndex += 1
+        }
+        (leftNodeAgg, rightNodeAgg)
+      } else {
+        // Regression
+        // Initialize left and right split aggregates.
+        val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+        val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+        // Iterate over all features.
+        var featureIndex = 0
+        while (featureIndex < numFeatures) {
+          findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
+          featureIndex += 1
+        }
+        (leftNodeAgg, rightNodeAgg)
       }
     }
 
@@ -1225,8 +1134,9 @@ object DecisionTree extends Serializable with Logging {
         val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
         var splitIndex = 0
         while (splitIndex < numSplitsForFeature) {
-          gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
-            splitIndex, rightNodeAgg, nodeImpurity)
+          gains(featureIndex)(splitIndex) =
+            calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex),
+              rightNodeAgg(featureIndex)(splitIndex), nodeImpurity)
           splitIndex += 1
         }
         featureIndex += 1
@@ -1238,18 +1148,14 @@ object DecisionTree extends Serializable with Logging {
      * Get the number of splits for a feature.
      */
     def getNumSplitsForFeature(featureIndex: Int): Int = {
-      val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-      if (isFeatureContinuous) {
+      if (metadata.isContinuous(featureIndex)) {
         numBins - 1
       } else {
         // Categorical feature
-        val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-        val isSpaceSufficientForAllCategoricalSplits =
-          numBins > math.pow(2, featureCategories.toInt - 1) - 1
-        if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
-          math.pow(2.0, featureCategories - 1).toInt - 1
+        val featureCategories = metadata.featureArity(featureIndex)
+        if (metadata.isUnordered(featureIndex)) {
+          (1 << featureCategories - 1) - 1
         } else {
-          // Ordered features
           featureCategories
         }
       }
@@ -1308,29 +1214,29 @@ object DecisionTree extends Serializable with Logging {
      * Get bin data for one node.
      */
     def getBinDataForNode(node: Int): Array[Double] = {
-      strategy.algo match {
-        case Classification =>
-          if (isMulticlassClassificationWithCategoricalFeatures) {
-            val shift = numClasses * node * numBins * numFeatures
-            val rightChildShift = numClasses * numBins * numFeatures * numNodes
-            val binsForNode = {
-              val leftChildData
-                = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
-              val rightChildData
-              = binAggregates.slice(rightChildShift + shift,
-                rightChildShift + shift + numClasses * numBins * numFeatures)
-              leftChildData ++ rightChildData
-            }
-            binsForNode
-          } else {
-            val shift = numClasses * node * numBins * numFeatures
-            val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
-            binsForNode
+      if (metadata.isClassification) {
+        if (isMulticlassWithCategoricalFeatures) {
+          val shift = numClasses * node * numBins * numFeatures
+          val rightChildShift = numClasses * numBins * numFeatures * numNodes
+          val binsForNode = {
+            val leftChildData
+            = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+            val rightChildData
+            = binAggregates.slice(rightChildShift + shift,
+              rightChildShift + shift + numClasses * numBins * numFeatures)
+            leftChildData ++ rightChildData
           }
-        case Regression =>
-          val shift = 3 * node * numBins * numFeatures
-          val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
           binsForNode
+        } else {
+          val shift = numClasses * node * numBins * numFeatures
+          val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+          binsForNode
+        }
+      } else {
+        // Regression
+        val shift = 3 * node * numBins * numFeatures
+        val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
+        binsForNode
       }
     }
 
@@ -1340,7 +1246,7 @@ object DecisionTree extends Serializable with Logging {
     // Iterating over all nodes at this level
     var node = 0
     while (node < numNodes) {
-      val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift
+      val nodeImpurityIndex = (1 << level) - 1 + node + groupShift
       val binsForNode: Array[Double] = getBinDataForNode(node)
       logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
       val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
@@ -1358,20 +1264,15 @@ object DecisionTree extends Serializable with Logging {
    *
    * @param numBins  Number of bins = 1 + number of possible splits.
    */
-  private def getElementsPerNode(
-      numFeatures: Int,
-      numBins: Int,
-      numClasses: Int,
-      isMulticlassClassificationWithCategoricalFeatures: Boolean,
-      algo: Algo): Int = {
-    algo match {
-      case Classification =>
-        if (isMulticlassClassificationWithCategoricalFeatures) {
-          2 * numClasses * numBins * numFeatures
-        } else {
-          numClasses * numBins * numFeatures
-        }
-      case Regression => 3 * numBins * numFeatures
+  private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = {
+    if (metadata.isClassification) {
+      if (metadata.isMulticlassWithCategoricalFeatures) {
+        2 * metadata.numClasses * numBins * metadata.numFeatures
+      } else {
+        metadata.numClasses * numBins * metadata.numFeatures
+      }
+    } else {
+      3 * numBins * metadata.numFeatures
     }
   }
 
@@ -1390,16 +1291,15 @@ object DecisionTree extends Serializable with Logging {
    *       For multiclass classification with a low-arity feature
    *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
    *       the feature is split based on subsets of categories.
-   *       There are math.pow(2, maxFeatureValue - 1) - 1 splits.
+   *       There are (1 << maxFeatureValue - 1) - 1 splits.
    *   (b) "ordered features"
    *       For regression and binary classification,
    *       and for multiclass classification with a high-arity feature,
    *       there is one bin per category.
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
-   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
-   *                 parameters for construction the DecisionTree
-   * @return A tuple of (splits,bins).
+   * @param metadata Learning and dataset metadata
+   * @return A tuple of (splits, bins).
    *         Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
    *          of size (numFeatures, numBins - 1).
    *         Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
@@ -1407,19 +1307,18 @@ object DecisionTree extends Serializable with Logging {
    */
   protected[tree] def findSplitsBins(
       input: RDD[LabeledPoint],
-      strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+      metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
 
     val count = input.count()
 
     // Find the number of features by looking at the first sample
     val numFeatures = input.take(1)(0).features.size
 
-    val maxBins = strategy.maxBins
+    val maxBins = metadata.maxBins
     val numBins = if (maxBins <= count) maxBins else count.toInt
     logDebug("numBins = " + numBins)
-    val isMulticlassClassification = strategy.isMulticlassClassification
-    logDebug("isMulticlassClassification = " + isMulticlassClassification)
-
+    val isMulticlass = metadata.isMulticlass
+    logDebug("isMulticlass = " + isMulticlass)
 
     /*
      * Ensure numBins is always greater than the categories. For multiclass classification,
@@ -1431,13 +1330,12 @@ object DecisionTree extends Serializable with Logging {
      * by the number of training examples.
      * TODO: Allow this case, where we simply will know nothing about some categories.
      */
-    if (strategy.categoricalFeaturesInfo.size > 0) {
-      val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
+    if (metadata.featureArity.size > 0) {
+      val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2
       require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
         "in categorical features")
     }
 
-
     // Calculate the number of sample for approximate quantile calculation.
     val requiredSamples = numBins*numBins
     val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
@@ -1451,7 +1349,7 @@ object DecisionTree extends Serializable with Logging {
     val stride: Double = numSamples.toDouble / numBins
     logDebug("stride = " + stride)
 
-    strategy.quantileCalculationStrategy match {
+    metadata.quantileStrategy match {
       case Sort =>
         val splits = Array.ofDim[Split](numFeatures, numBins - 1)
         val bins = Array.ofDim[Bin](numFeatures, numBins)
@@ -1462,7 +1360,7 @@ object DecisionTree extends Serializable with Logging {
         var featureIndex = 0
         while (featureIndex < numFeatures) {
           // Check whether the feature is continuous.
-          val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+          val isFeatureContinuous = metadata.isContinuous(featureIndex)
           if (isFeatureContinuous) {
             val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
             val stride: Double = numSamples.toDouble / numBins
@@ -1475,18 +1373,14 @@ object DecisionTree extends Serializable with Logging {
               splits(featureIndex)(index) = split
             }
           } else { // Categorical feature
-            val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
-            val isSpaceSufficientForAllCategoricalSplits
-              = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+            val featureCategories = metadata.featureArity(featureIndex)
 
             // Use different bin/split calculation strategy for categorical features in multiclass
             // classification that satisfy the space constraint.
-            val isUnorderedFeature =
-              isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-            if (isUnorderedFeature) {
+            if (metadata.isUnordered(featureIndex)) {
               // 2^(maxFeatureValue- 1) - 1 combinations
               var index = 0
-              while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
+              while (index < (1 << featureCategories - 1) - 1) {
                 val categories: List[Double]
                   = extractMultiClassCategories(index + 1, featureCategories)
                 splits(featureIndex)(index)
@@ -1516,7 +1410,7 @@ object DecisionTree extends Serializable with Logging {
                * centroidForCategories is a mapping: category (for the given feature) --> centroid
                */
               val centroidForCategories = {
-                if (isMulticlassClassification) {
+                if (isMulticlass) {
                   // For categorical variables in multiclass classification,
                   // each bin is a category. The bins are sorted and they
                   // are ordered by calculating the impurity of their corresponding labels.
@@ -1524,7 +1418,7 @@ object DecisionTree extends Serializable with Logging {
                    .groupBy(_._1)
                    .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
                    .map(x => (x._1, x._2.values.toArray))
-                   .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum)))
+                   .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum)))
                 } else { // regression or binary classification
                   // For categorical variables in regression and binary classification,
                   // each bin is a category. The bins are sorted and they
@@ -1576,7 +1470,7 @@ object DecisionTree extends Serializable with Logging {
         // Find all bins.
         featureIndex = 0
         while (featureIndex < numFeatures) {
-          val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+          val isFeatureContinuous = metadata.isContinuous(featureIndex)
           if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
             bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
               splits(featureIndex)(0), Continuous, Double.MinValue)
@@ -1590,7 +1484,7 @@ object DecisionTree extends Serializable with Logging {
           }
           featureIndex += 1
         }
-        (splits,bins)
+        (splits, bins)
       case MinMax =>
         throw new UnsupportedOperationException("minmax not supported yet.")
       case ApproxHist =>

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
new file mode 100644
index 0000000..d9eda35
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Learning and dataset metadata for DecisionTree.
+ *
+ * @param numClasses    For classification: labels can take values {0, ..., numClasses - 1}.
+ *                      For regression: fixed at 0 (no meaning).
+ * @param featureArity  Map: categorical feature index --> arity.
+ *                      I.e., the feature takes values in {0, ..., arity - 1}.
+ */
+private[tree] class DecisionTreeMetadata(
+    val numFeatures: Int,
+    val numExamples: Long,
+    val numClasses: Int,
+    val maxBins: Int,
+    val featureArity: Map[Int, Int],
+    val unorderedFeatures: Set[Int],
+    val impurity: Impurity,
+    val quantileStrategy: QuantileStrategy) extends Serializable {
+
+  def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
+
+  def isClassification: Boolean = numClasses >= 2
+
+  def isMulticlass: Boolean = numClasses > 2
+
+  def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
+
+  def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
+
+  def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+
+}
+
+private[tree] object DecisionTreeMetadata {
+
+  def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
+
+    val numFeatures = input.take(1)(0).features.size
+    val numExamples = input.count()
+    val numClasses = strategy.algo match {
+      case Classification => strategy.numClassesForClassification
+      case Regression => 0
+    }
+
+    val maxBins = math.min(strategy.maxBins, numExamples).toInt
+    val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
+
+    val unorderedFeatures = new mutable.HashSet[Int]()
+    if (numClasses > 2) {
+      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
+        if (k - 1 < log2MaxBinsp1) {
+          // Note: The above check is equivalent to checking:
+          //       numUnorderedBins = (1 << k - 1) - 1 < maxBins
+          unorderedFeatures.add(f)
+        } else {
+          // TODO: Allow this case, where we simply will know nothing about some categories?
+          require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
+            s"in categorical features (>= $k)")
+        }
+      }
+    } else {
+      strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
+        require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
+          s"in categorical features (>= $k)")
+      }
+    }
+
+    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
+      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
+      strategy.impurity, strategy.quantileCalculationStrategy)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index ccac103..170e43e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.mllib.tree.impl
 
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.model.Bin
 import org.apache.spark.rdd.RDD
 
@@ -48,50 +47,35 @@ private[tree] object TreePoint {
    * Convert an input dataset into its TreePoint representation,
    * binning feature values in preparation for DecisionTree training.
    * @param input     Input dataset.
-   * @param strategy  DecisionTree training info, used for dataset metadata.
    * @param bins      Bins for features, of size (numFeatures, numBins).
+   * @param metadata Learning and dataset metadata
    * @return  TreePoint dataset representation
    */
   def convertToTreeRDD(
       input: RDD[LabeledPoint],
-      strategy: Strategy,
-      bins: Array[Array[Bin]]): RDD[TreePoint] = {
+      bins: Array[Array[Bin]],
+      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
     input.map { x =>
-      TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
-        strategy.categoricalFeaturesInfo)
+      TreePoint.labeledPointToTreePoint(x, bins, metadata)
     }
   }
 
   /**
    * Convert one LabeledPoint into its TreePoint representation.
    * @param bins      Bins for features, of size (numFeatures, numBins).
-   * @param categoricalFeaturesInfo  Map over categorical features: feature index --> feature arity
    */
   private def labeledPointToTreePoint(
       labeledPoint: LabeledPoint,
-      isMulticlassClassification: Boolean,
       bins: Array[Array[Bin]],
-      categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
+      metadata: DecisionTreeMetadata): TreePoint = {
 
     val numFeatures = labeledPoint.features.size
     val numBins = bins(0).size
     val arr = new Array[Int](numFeatures)
     var featureIndex = 0
     while (featureIndex < numFeatures) {
-      val featureInfo = categoricalFeaturesInfo.get(featureIndex)
-      val isFeatureContinuous = featureInfo.isEmpty
-      if (isFeatureContinuous) {
-        arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
-          bins, categoricalFeaturesInfo)
-      } else {
-        val featureCategories = featureInfo.get
-        val isSpaceSufficientForAllCategoricalSplits
-          = numBins > math.pow(2, featureCategories.toInt - 1) - 1
-        val isUnorderedFeature =
-          isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
-        arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
-          isUnorderedFeature, bins, categoricalFeaturesInfo)
-      }
+      arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
+        metadata.isUnordered(featureIndex), bins, metadata.featureArity)
       featureIndex += 1
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index c89c1e3..af35d88 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 
 /**
- * Used for "binning" the features bins for faster best split calculation. For a continuous
- * feature, a bin is determined by a low and a high "split". For a categorical feature,
- * the a bin is determined using a single label value (category).
+ * Used for "binning" the features bins for faster best split calculation.
+ *
+ * For a continuous feature, the bin is determined by a low and a high split,
+ *  where an example with featureValue falls into the bin s.t.
+ *  lowSplit.threshold < featureValue <= highSplit.threshold.
+ *
+ * For ordered categorical features, there is a 1-1-1 correspondence between
+ *  bins, splits, and feature values.  The bin is determined by category/feature value.
+ *  However, the bins are not necessarily ordered by feature value;
+ *  they are ordered using impurity.
+ * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
+ *  where bins and splits correspond to subsets of feature values (in highSplit.categories).
+ *
  * @param lowSplit signifying the lower threshold for the continuous feature to be
  *                 accepted in the bin
  * @param highSplit signifying the upper threshold for the continuous feature to be
  *                 accepted in the bin
  * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for binary classification
+ * @param category categorical label value accepted in the bin for ordered features
  */
 private[tree]
 case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 3d3406b..0594fd0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
    * @return Double prediction from the trained model
    */
   def predict(features: Vector): Double = {
-    topNode.predictIfLeaf(features)
+    topNode.predict(features)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
deleted file mode 100644
index 2deaf4a..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-/**
- * Filter specifying a split and type of comparison to be applied on features
- * @param split split specifying the feature index, type and threshold
- * @param comparison integer specifying <,=,>
- */
-private[tree] case class Filter(split: Split, comparison: Int) {
-  // Comparison -1,0,1 signifies <.=,>
-  override def toString = " split = " + split + "comparison = " + comparison
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 944f11c..0eee626 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -69,24 +69,24 @@ class Node (
 
   /**
    * predict value if node is not leaf
-   * @param feature feature value
+   * @param features feature value
    * @return predicted value
    */
-  def predictIfLeaf(feature: Vector) : Double = {
+  def predict(features: Vector) : Double = {
     if (isLeaf) {
       predict
     } else{
       if (split.get.featureType == Continuous) {
-        if (feature(split.get.feature) <= split.get.threshold) {
-          leftNode.get.predictIfLeaf(feature)
+        if (features(split.get.feature) <= split.get.threshold) {
+          leftNode.get.predict(features)
         } else {
-          rightNode.get.predictIfLeaf(feature)
+          rightNode.get.predict(features)
         }
       } else {
-        if (split.get.categories.contains(feature(split.get.feature))) {
-          leftNode.get.predictIfLeaf(feature)
+        if (split.get.categories.contains(features(split.get.feature))) {
+          leftNode.get.predict(features)
         } else {
-          rightNode.get.predictIfLeaf(feature)
+          rightNode.get.predict(features)
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/73ab7f14/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index d7ffd38..50fb48b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
  * :: DeveloperApi ::
  * Split applied to a feature
  * @param feature feature index
- * @param threshold threshold for continuous feature
+ * @param threshold Threshold for continuous feature.
+ *                  Split left if feature <= threshold, else right.
  * @param featureType type of feature -- categorical or continuous
- * @param categories accepted values for categorical variables
+ * @param categories Split left if categorical feature value is in this set, else right.
  */
 @DeveloperApi
 case class Split(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message