spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject [2/2] git commit: MLI-1 Decision Trees
Date Wed, 02 Apr 2014 04:41:04 GMT
MLI-1 Decision Trees

Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010.

Key features:
+ Supports binary classification and regression
+ Supports gini, entropy and variance for information gain calculation
+ Supports both continuous and categorical features

The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include:

1. Level-wise training to reduce passes over the entire dataset.
2. Bin-wise split calculation to reduce computation overhead.
3. Aggregation over partitions before combining to reduce communication overhead.

Author: Manish Amde <manish9ue@gmail.com>
Author: manishamde <manish9ue@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #79 from manishamde/tree and squashes the following commits:

1e8c704 [Manish Amde] remove numBins field in the Strategy class
7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree
f536ae9 [Xiangrui Meng] another pass on code style
e1dd86f [Manish Amde] implementing code style suggestions
62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing
201702f [Manish Amde] making some more methods private
f963ef5 [Manish Amde] making methods private
c487e6a [manishamde] Merge pull request #1 from mengxr/dtree
24500c5 [Xiangrui Meng] minor style updates
4576b64 [Manish Amde] documentation and for to while loop conversion
ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins
632818f [Manish Amde] removing threshold for classification predict method
2116360 [Manish Amde] removing dummy bin calculation for categorical variables
6068356 [Manish Amde] ensuring num bins is always greater than max number of categories
62c2562 [Manish Amde] fixing comment indentation
ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions
d1ef4f6 [Manish Amde] more documentation
794ff4d [Manish Amde] minor improvements to docs and style
eb8fcbe [Manish Amde] minor code style updates
cd2c2b4 [Manish Amde] fixing code style based on feedback
63e786b [Manish Amde] added multiple train methods for java compatability
d3023b3 [Manish Amde] adding more docs for nested methods
84f85d6 [Manish Amde] code documentation
9372779 [Manish Amde] code style: max line lenght <= 100
dd0c0d7 [Manish Amde] minor: some docs
0dd7659 [manishamde] basic doc
5841c28 [Manish Amde] unit tests for categorical features
f067d68 [Manish Amde] minor cleanup
c0e522b [Manish Amde] updated predict and split threshold logic
b09dc98 [Manish Amde] minor refactoring
6b7de78 [Manish Amde] minor refactoring and tests
d504eb1 [Manish Amde] more tests for categorical features
dbb7ac1 [Manish Amde] categorical feature support
6df35b9 [Manish Amde] regression predict logic
53108ed [Manish Amde] fixing index for highest bin
e23c2e5 [Manish Amde] added regression support
c8f6d60 [Manish Amde] adding enum for feature type
b0e3e76 [Manish Amde] adding enum for feature type
154aa77 [Manish Amde] enums for configurations
733d6dd [Manish Amde] fixed tests
02c595c [Manish Amde] added command line parsing
98ec8d5 [Manish Amde] tree building and prediction logic
b0eb866 [Manish Amde] added logic to handle leaf nodes
80e8c66 [Manish Amde] working version of multi-level split calculation
4798aae [Manish Amde] added gain stats class
dad0afc [Manish Amde] decison stump functionality working
03f534c [Manish Amde] some more tests
0012a77 [Manish Amde] basic stump working
8bca1e2 [Manish Amde] additional code for creating intermediate RDD
92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested.
cd53eae [Manish Amde] skeletal framework


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8b3045ce
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8b3045ce
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8b3045ce

Branch: refs/heads/master
Commit: 8b3045ceab591a3f3ca18823c7e2c5faca38a06e
Parents: 45df912
Author: Manish Amde <manish9ue@gmail.com>
Authored: Tue Apr 1 21:40:49 2014 -0700
Committer: Matei Zaharia <matei@databricks.com>
Committed: Tue Apr 1 21:40:49 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 1150 ++++++++++++++++++
 .../scala/org/apache/spark/mllib/tree/README.md |   17 +
 .../spark/mllib/tree/configuration/Algo.scala   |   26 +
 .../mllib/tree/configuration/FeatureType.scala  |   26 +
 .../tree/configuration/QuantileStrategy.scala   |   26 +
 .../mllib/tree/configuration/Strategy.scala     |   43 +
 .../spark/mllib/tree/impurity/Entropy.scala     |   47 +
 .../apache/spark/mllib/tree/impurity/Gini.scala |   46 +
 .../spark/mllib/tree/impurity/Impurity.scala    |   42 +
 .../spark/mllib/tree/impurity/Variance.scala    |   37 +
 .../org/apache/spark/mllib/tree/model/Bin.scala |   33 +
 .../mllib/tree/model/DecisionTreeModel.scala    |   49 +
 .../apache/spark/mllib/tree/model/Filter.scala  |   28 +
 .../mllib/tree/model/InformationGainStats.scala |   39 +
 .../apache/spark/mllib/tree/model/Node.scala    |   90 ++
 .../apache/spark/mllib/tree/model/Split.scala   |   64 +
 .../spark/mllib/tree/DecisionTreeSuite.scala    |  425 +++++++
 17 files changed, 2188 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
new file mode 100644
index 0000000..33205b9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -0,0 +1,1150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.util.control.Breaks._
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * A class that implements a decision tree algorithm for classification and regression. It
+ * supports both continuous and categorical features.
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ *                 of algorithm (classification, regression, etc.), feature type (continuous,
+ *                 categorical), depth of the tree, quantile calculation strategy, etc.
+ */
+class DecisionTree private(val strategy: Strategy) extends Serializable with Logging {
+
+  /**
+   * Method to train a decision tree model over an RDD
+   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+   * @return a DecisionTreeModel that can be used for prediction
+   */
+  def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+
+    // Cache input RDD for speedup during multiple passes.
+    input.cache()
+    logDebug("algo = " + strategy.algo)
+
+    // Find the splits and the corresponding bins (interval between the splits) using a sample
+    // of the input data.
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    logDebug("numSplits = " + bins(0).length)
+
+    // depth of the decision tree
+    val maxDepth = strategy.maxDepth
+    // the max number of nodes possible given the depth of the tree
+    val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
+    // Initialize an array to hold filters applied to points for each node.
+    val filters = new Array[List[Filter]](maxNumNodes)
+    // The filter at the top node is an empty list.
+    filters(0) = List()
+    // Initialize an array to hold parent impurity calculations for each node.
+    val parentImpurities = new Array[Double](maxNumNodes)
+    // dummy value for top node (updated during first split calculation)
+    val nodes = new Array[Node](maxNumNodes)
+
+
+    /*
+     * The main idea here is to perform level-wise training of the decision tree nodes thus
+     * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
+     * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
+     * the sample is only used for the split calculation at the node if the sampled would have
+     * still survived the filters of the parent nodes.
+     */
+
+    // TODO: Convert for loop to while loop
+    breakable {
+      for (level <- 0 until maxDepth) {
+
+        logDebug("#####################################")
+        logDebug("level = " + level)
+        logDebug("#####################################")
+
+        // Find best split for all nodes at a level.
+        val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
+          level, filters, splits, bins)
+
+        for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+          // Extract info for nodes at the current level.
+          extractNodeInfo(nodeSplitStats, level, index, nodes)
+          // Extract info for nodes at the next lower level.
+          extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
+            filters)
+          logDebug("final best split = " + nodeSplitStats._1)
+        }
+        require(scala.math.pow(2, level) == splitsStatsForLevel.length)
+        // Check whether all the nodes at the current level at leaves.
+        val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
+        logDebug("all leaf = " + allLeaf)
+        if (allLeaf) break // no more tree construction
+      }
+    }
+
+    // Initialize the top or root node of the tree.
+    val topNode = nodes(0)
+    // Build the full tree using the node info calculated in the level-wise best split calculations.
+    topNode.build(nodes)
+
+    new DecisionTreeModel(topNode, strategy.algo)
+  }
+
+  /**
+   * Extract the decision tree node information for the given tree level and node index
+   */
+  private def extractNodeInfo(
+      nodeSplitStats: (Split, InformationGainStats),
+      level: Int,
+      index: Int,
+      nodes: Array[Node]): Unit = {
+    val split = nodeSplitStats._1
+    val stats = nodeSplitStats._2
+    val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
+    val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
+    val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+    logDebug("Node = " + node)
+    nodes(nodeIndex) = node
+  }
+
+  /**
+   *  Extract the decision tree node information for the children of the node
+   */
+  private def extractInfoForLowerLevels(
+      level: Int,
+      index: Int,
+      maxDepth: Int,
+      nodeSplitStats: (Split, InformationGainStats),
+      parentImpurities: Array[Double],
+      filters: Array[List[Filter]]): Unit = {
+    // 0 corresponds to the left child node and 1 corresponds to the right child node.
+    // TODO: Convert to while loop
+    for (i <- 0 to 1) {
+     // Calculate the index of the node from the node level and the index at the current level.
+      val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
+      if (level < maxDepth - 1) {
+        val impurity = if (i == 0) {
+          nodeSplitStats._2.leftImpurity
+        } else {
+          nodeSplitStats._2.rightImpurity
+        }
+        logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
+        // noting the parent impurities
+        parentImpurities(nodeIndex) = impurity
+        // noting the parents filters for the child nodes
+        val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
+        filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
+        for (filter <- filters(nodeIndex)) {
+          logDebug("Filter = " + filter)
+        }
+      }
+    }
+  }
+}
+
+object DecisionTree extends Serializable with Logging {
+
+  /**
+   * Method to train a decision tree model where the instances are represented as an RDD of
+   * (label, features) pairs. The method supports binary classification and regression. For the
+   * binary classification, the label for each instance should either be 0 or 1 to denote the two
+   * classes. The parameters for the algorithm are specified using the strategy parameter.
+   *
+   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+   *              for DecisionTree
+   * @param strategy The configuration parameters for the tree algorithm which specify the type
+   *                 of algorithm (classification, regression, etc.), feature type (continuous,
+   *                 categorical), depth of the tree, quantile calculation strategy, etc.
+   * @return a DecisionTreeModel that can be used for prediction
+  */
+  def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
+    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+  }
+
+  /**
+   * Method to train a decision tree model where the instances are represented as an RDD of
+   * (label, features) pairs. The method supports binary classification and regression. For the
+   * binary classification, the label for each instance should either be 0 or 1 to denote the two
+   * classes.
+   *
+   * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+   *              training data
+   * @param algo algorithm, classification or regression
+   * @param impurity impurity criterion used for information gain calculation
+   * @param maxDepth maxDepth maximum depth of the tree
+   * @return a DecisionTreeModel that can be used for prediction
+   */
+  def train(
+      input: RDD[LabeledPoint],
+      algo: Algo,
+      impurity: Impurity,
+      maxDepth: Int): DecisionTreeModel = {
+    val strategy = new Strategy(algo,impurity,maxDepth)
+    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+  }
+
+
+  /**
+   * Method to train a decision tree model where the instances are represented as an RDD of
+   * (label, features) pairs. The decision tree method supports binary classification and
+   * regression. For the binary classification, the label for each instance should either be 0 or
+   * 1 to denote the two classes. The method also supports categorical features inputs where the
+   * number of categories can specified using the categoricalFeaturesInfo option.
+   *
+   * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+   *              training data for DecisionTree
+   * @param algo classification or regression
+   * @param impurity criterion used for information gain calculation
+   * @param maxDepth  maximum depth of the tree
+   * @param maxBins maximum number of bins used for splitting features
+   * @param quantileCalculationStrategy  algorithm for calculating quantiles
+   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+   *                                the number of discrete values they take. For example,
+   *                                an entry (n -> k) implies the feature n is categorical with k
+   *                                categories 0, 1, 2, ... , k-1. It's important to note that
+   *                                features are zero-indexed.
+   * @return a DecisionTreeModel that can be used for prediction
+   */
+  def train(
+      input: RDD[LabeledPoint],
+      algo: Algo,
+      impurity: Impurity,
+      maxDepth: Int,
+      maxBins: Int,
+      quantileCalculationStrategy: QuantileStrategy,
+      categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
+    val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
+      categoricalFeaturesInfo)
+    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+  }
+
+  private val InvalidBinIndex = -1
+
+  /**
+   * Returns an array of optimal splits for all nodes at a given level
+   *
+   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+   *              for DecisionTree
+   * @param parentImpurities Impurities for all parent nodes for the current level
+   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+   *                parameters for construction the DecisionTree
+   * @param level Level of the tree
+   * @param filters Filters for all nodes at a given level
+   * @param splits possible splits for all features
+   * @param bins possible bins for all features
+   * @return array of splits with best splits for all nodes at a given level.
+   */
+  protected[tree] def findBestSplits(
+      input: RDD[LabeledPoint],
+      parentImpurities: Array[Double],
+      strategy: Strategy,
+      level: Int,
+      filters: Array[List[Filter]],
+      splits: Array[Array[Split]],
+      bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
+
+    /*
+     * The high-level description for the best split optimizations are noted here.
+     *
+     * *Level-wise training*
+     * We perform bin calculations for all nodes at the given level to avoid making multiple
+     * passes over the data. Thus, for a slightly increased computation and storage cost we save
+     * several iterations over the data especially at higher levels of the decision tree.
+     *
+     * *Bin-wise computation*
+     * We use a bin-wise best split computation strategy instead of a straightforward best split
+     * computation strategy. Instead of analyzing each sample for contribution to the left/right
+     * child node impurity of every split, we first categorize each feature of a sample into a
+     * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
+     * is ordered (read ordering for categorical variables in the findSplitsBins method),
+     * we exploit this structure to calculate aggregates for bins and then use these aggregates
+     * to calculate information gain for each split.
+     *
+     * *Aggregation over partitions*
+     * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+     * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+     * indices) in a single array for all bins and rely upon the RDD aggregate method to
+     * drastically reduce the communication overhead.
+     */
+
+    // common calculations for multiple nested methods
+    val numNodes = scala.math.pow(2, level).toInt
+    logDebug("numNodes = " + numNodes)
+    // Find the number of features by looking at the first sample.
+    val numFeatures = input.first().features.length
+    logDebug("numFeatures = " + numFeatures)
+    val numBins = bins(0).length
+    logDebug("numBins = " + numBins)
+
+    /** Find the filters used before reaching the current code. */
+    def findParentFilters(nodeIndex: Int): List[Filter] = {
+      if (level == 0) {
+        List[Filter]()
+      } else {
+        val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
+        filters(nodeFilterIndex)
+      }
+    }
+
+    /**
+     * Find whether the sample is valid input for the current node, i.e., whether it passes through
+     * all the filters for the current node.
+     */
+    def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
+      // leaf
+      if ((level > 0) & (parentFilters.length == 0)) {
+        return false
+      }
+
+      // Apply each filter and check sample validity. Return false when invalid condition found.
+      for (filter <- parentFilters) {
+        val features = labeledPoint.features
+        val featureIndex = filter.split.feature
+        val threshold = filter.split.threshold
+        val comparison = filter.comparison
+        val categories = filter.split.categories
+        val isFeatureContinuous = filter.split.featureType == Continuous
+        val feature =  features(featureIndex)
+        if (isFeatureContinuous) {
+          comparison match {
+            case -1 => if (feature > threshold) return false
+            case 1 => if (feature <= threshold) return false
+          }
+        } else {
+          val containsFeature = categories.contains(feature)
+          comparison match {
+            case -1 => if (!containsFeature) return false
+            case 1 => if (containsFeature) return false
+          }
+
+        }
+      }
+
+      // Return true when the sample is valid for all filters.
+      true
+    }
+
+    /**
+     * Find bin for one feature.
+     */
+    def findBin(
+        featureIndex: Int,
+        labeledPoint: LabeledPoint,
+        isFeatureContinuous: Boolean): Int = {
+      val binForFeatures = bins(featureIndex)
+      val feature = labeledPoint.features(featureIndex)
+
+      /**
+       * Binary search helper method for continuous feature.
+       */
+      def binarySearchForBins(): Int = {
+        var left = 0
+        var right = binForFeatures.length - 1
+        while (left <= right) {
+          val mid = left + (right - left) / 2
+          val bin = binForFeatures(mid)
+          val lowThreshold = bin.lowSplit.threshold
+          val highThreshold = bin.highSplit.threshold
+          if ((lowThreshold < feature) & (highThreshold >= feature)){
+            return mid
+          }
+          else if (lowThreshold >= feature) {
+            right = mid - 1
+          }
+          else {
+            left = mid + 1
+          }
+        }
+        -1
+      }
+
+      /**
+       * Sequential search helper method to find bin for categorical feature.
+       */
+      def sequentialBinSearchForCategoricalFeature(): Int = {
+        val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
+        var binIndex = 0
+        while (binIndex < numCategoricalBins) {
+          val bin = bins(featureIndex)(binIndex)
+          val category = bin.category
+          val features = labeledPoint.features
+          if (category == features(featureIndex)) {
+            return binIndex
+          }
+          binIndex += 1
+        }
+        -1
+      }
+
+      if (isFeatureContinuous) {
+        // Perform binary search for finding bin for continuous features.
+        val binIndex = binarySearchForBins()
+        if (binIndex == -1){
+          throw new UnknownError("no bin was found for continuous variable.")
+        }
+        binIndex
+      } else {
+        // Perform sequential search to find bin for categorical features.
+        val binIndex = sequentialBinSearchForCategoricalFeature()
+        if (binIndex == -1){
+          throw new UnknownError("no bin was found for categorical variable.")
+        }
+        binIndex
+      }
+    }
+
+    /**
+     * Finds bins for all nodes (and all features) at a given level.
+     * For l nodes, k features the storage is as follows:
+     * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
+     * where b_ij is an integer between 0 and numBins - 1.
+     * Invalid sample is denoted by noting bin for feature 1 as -1.
+     */
+    def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
+      // Calculate bin index and label per feature per node.
+      val arr = new Array[Double](1 + (numFeatures * numNodes))
+      arr(0) = labeledPoint.label
+      var nodeIndex = 0
+      while (nodeIndex < numNodes) {
+        val parentFilters = findParentFilters(nodeIndex)
+        // Find out whether the sample qualifies for the particular node.
+        val sampleValid = isSampleValid(parentFilters, labeledPoint)
+        val shift = 1 + numFeatures * nodeIndex
+        if (!sampleValid) {
+          // Mark one bin as -1 is sufficient.
+          arr(shift) = InvalidBinIndex
+        } else {
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+            arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
+            featureIndex += 1
+          }
+        }
+        nodeIndex += 1
+      }
+      arr
+    }
+
+    /**
+     * Performs a sequential aggregation over a partition for classification. For l nodes,
+     * k features, either the left count or the right count of one of the p bins is
+     * incremented based upon whether the feature is classified as 0 or 1.
+     *
+     * @param agg Array[Double] storing aggregate calculation of size
+     *            2 * numSplits * numFeatures*numNodes for classification
+     * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+     * @return Array[Double] storing aggregate calculation of size
+     *         2 * numSplits * numFeatures * numNodes for classification
+     */
+    def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+      // Iterate over all nodes.
+      var nodeIndex = 0
+      while (nodeIndex < numNodes) {
+        // Check whether the instance was valid for this nodeIndex.
+        val validSignalIndex = 1 + numFeatures * nodeIndex
+        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+        if (isSampleValidForNode) {
+          // actual class label
+          val label = arr(0)
+          // Iterate over all features.
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            // Find the bin index for this feature.
+            val arrShift = 1 + numFeatures * nodeIndex
+            val arrIndex = arrShift + featureIndex
+            // Update the left or right count for one bin.
+            val aggShift = 2 * numBins * numFeatures * nodeIndex
+            val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
+            label match {
+              case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
+              case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
+            }
+            featureIndex += 1
+          }
+        }
+        nodeIndex += 1
+      }
+    }
+
+    /**
+     * Performs a sequential aggregation over a partition for regression. For l nodes, k features,
+     * the count, sum, sum of squares of one of the p bins is incremented.
+     *
+     * @param agg Array[Double] storing aggregate calculation of size
+     *            3 * numSplits * numFeatures * numNodes for classification
+     * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+     * @return Array[Double] storing aggregate calculation of size
+     *         3 * numSplits * numFeatures * numNodes for regression
+     */
+    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+      // Iterate over all nodes.
+      var nodeIndex = 0
+      while (nodeIndex < numNodes) {
+        // Check whether the instance was valid for this nodeIndex.
+        val validSignalIndex = 1 + numFeatures * nodeIndex
+        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+        if (isSampleValidForNode) {
+          // actual class label
+          val label = arr(0)
+          // Iterate over all features.
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            // Find the bin index for this feature.
+            val arrShift = 1 + numFeatures * nodeIndex
+            val arrIndex = arrShift + featureIndex
+            // Update count, sum, and sum^2 for one bin.
+            val aggShift = 3 * numBins * numFeatures * nodeIndex
+            val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
+            agg(aggIndex) = agg(aggIndex) + 1
+            agg(aggIndex + 1) = agg(aggIndex + 1) + label
+            agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
+            featureIndex += 1
+          }
+        }
+        nodeIndex += 1
+      }
+    }
+
+    /**
+     * Performs a sequential aggregation over a partition.
+     */
+    def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
+      strategy.algo match {
+        case Classification => classificationBinSeqOp(arr, agg)
+        case Regression => regressionBinSeqOp(arr, agg)
+      }
+      agg
+    }
+
+    // Calculate bin aggregate length for classification or regression.
+    val binAggregateLength = strategy.algo match {
+      case Classification => 2 * numBins * numFeatures * numNodes
+      case Regression =>  3 * numBins * numFeatures * numNodes
+    }
+    logDebug("binAggregateLength = " + binAggregateLength)
+
+    /**
+     * Combines the aggregates from partitions.
+     * @param agg1 Array containing aggregates from one or more partitions
+     * @param agg2 Array containing aggregates from one or more partitions
+     * @return Combined aggregate from agg1 and agg2
+     */
+    def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = {
+      var index = 0
+      val combinedAggregate = new Array[Double](binAggregateLength)
+      while (index < binAggregateLength) {
+        combinedAggregate(index) = agg1(index) + agg2(index)
+        index += 1
+      }
+      combinedAggregate
+    }
+
+    // Find feature bins for all nodes at a level.
+    val binMappedRDD = input.map(x => findBinsForLevel(x))
+
+    // Calculate bin aggregates.
+    val binAggregates = {
+      binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
+    }
+    logDebug("binAggregates.length = " + binAggregates.length)
+
+    /**
+     * Calculates the information gain for all splits based upon left/right split aggregates.
+     * @param leftNodeAgg left node aggregates
+     * @param featureIndex feature index
+     * @param splitIndex split index
+     * @param rightNodeAgg right node aggregate
+     * @param topImpurity impurity of the parent node
+     * @return information gain and statistics for all splits
+     */
+    def calculateGainForSplit(
+        leftNodeAgg: Array[Array[Double]],
+        featureIndex: Int,
+        splitIndex: Int,
+        rightNodeAgg: Array[Array[Double]],
+        topImpurity: Double): InformationGainStats = {
+      strategy.algo match {
+        case Classification =>
+          val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
+          val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
+          val leftCount = left0Count + left1Count
+
+          val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
+          val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
+          val rightCount = right0Count + right1Count
+
+          val impurity = {
+            if (level > 0) {
+              topImpurity
+            } else {
+              // Calculate impurity for root node.
+              strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
+            }
+          }
+
+          if (leftCount == 0) {
+            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
+          }
+          if (rightCount == 0) {
+            return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
+          }
+
+          val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
+          val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
+
+          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+
+          val gain = {
+            if (level > 0) {
+              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+            } else {
+              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+            }
+          }
+
+          val predict = (left1Count + right1Count) / (leftCount + rightCount)
+
+          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+        case Regression =>
+          val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
+          val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
+          val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
+
+          val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
+          val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
+          val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
+
+          val impurity = {
+            if (level > 0) {
+              topImpurity
+            } else {
+              // Calculate impurity for root node.
+              val count = leftCount + rightCount
+              val sum = leftSum + rightSum
+              val sumSquares = leftSumSquares + rightSumSquares
+              strategy.impurity.calculate(count, sum, sumSquares)
+            }
+          }
+
+          if (leftCount == 0) {
+            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
+              rightSum / rightCount)
+          }
+          if (rightCount == 0) {
+            return new InformationGainStats(0, topImpurity ,topImpurity,
+              Double.MinValue, leftSum / leftCount)
+          }
+
+          val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
+          val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
+
+          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+
+          val gain = {
+            if (level > 0) {
+              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+            } else {
+              impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+            }
+          }
+
+          val predict = (leftSum + rightSum) / (leftCount + rightCount)
+          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+      }
+    }
+
+    /**
+     * Extracts left and right split aggregates.
+     * @param binData Array[Double] of size 2*numFeatures*numSplits
+     * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
+     *         Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
+     */
+    def extractLeftRightNodeAggregates(
+        binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
+      strategy.algo match {
+        case Classification =>
+          // Initialize left and right split aggregates.
+          val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
+          val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
+          // Iterate over all features.
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            // shift for this featureIndex
+            val shift = 2 * featureIndex * numBins
+
+            // left node aggregate for the lowest split
+            leftNodeAgg(featureIndex)(0) = binData(shift + 0)
+            leftNodeAgg(featureIndex)(1) = binData(shift + 1)
+
+            // right node aggregate for the highest split
+            rightNodeAgg(featureIndex)(2 * (numBins - 2))
+              = binData(shift + (2 * (numBins - 1)))
+            rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
+              = binData(shift + (2 * (numBins - 1)) + 1)
+
+            // Iterate over all splits.
+            var splitIndex = 1
+            while (splitIndex < numBins - 1) {
+              // calculating left node aggregate for a split as a sum of left node aggregate of a
+              // lower split and the left bin aggregate of a bin where the split is a high split
+              leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
+                leftNodeAgg(featureIndex)(2 * splitIndex - 2)
+              leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
+                leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
+
+              // calculating right node aggregate for a split as a sum of right node aggregate of a
+              // higher split and the right bin aggregate of a bin where the split is a low split
+              rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
+                binData(shift + (2 *(numBins - 2 - splitIndex))) +
+                rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
+              rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
+                binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+                  rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
+
+              splitIndex += 1
+            }
+            featureIndex += 1
+          }
+          (leftNodeAgg, rightNodeAgg)
+        case Regression =>
+          // Initialize left and right split aggregates.
+          val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+          val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+          // Iterate over all features.
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            // shift for this featureIndex
+            val shift = 3 * featureIndex * numBins
+            // left node aggregate for the lowest split
+            leftNodeAgg(featureIndex)(0) = binData(shift + 0)
+            leftNodeAgg(featureIndex)(1) = binData(shift + 1)
+            leftNodeAgg(featureIndex)(2) = binData(shift + 2)
+
+            // right node aggregate for the highest split
+            rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
+              binData(shift + (3 * (numBins - 1)))
+            rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
+              binData(shift + (3 * (numBins - 1)) + 1)
+            rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
+              binData(shift + (3 * (numBins - 1)) + 2)
+
+            // Iterate over all splits.
+            var splitIndex = 1
+            while (splitIndex < numBins - 1) {
+              // calculating left node aggregate for a split as a sum of left node aggregate of a
+              // lower split and the left bin aggregate of a bin where the split is a high split
+              leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
+                leftNodeAgg(featureIndex)(3 * splitIndex - 3)
+              leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
+                leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
+              leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
+                leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
+
+              // calculating right node aggregate for a split as a sum of right node aggregate of a
+              // higher split and the right bin aggregate of a bin where the split is a low split
+              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
+                binData(shift + (3 * (numBins - 2 - splitIndex))) +
+                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
+              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
+                binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
+              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
+                binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
+
+              splitIndex += 1
+            }
+            featureIndex += 1
+          }
+          (leftNodeAgg, rightNodeAgg)
+      }
+    }
+
+    /**
+     * Calculates information gain for all nodes splits.
+     */
+    def calculateGainsForAllNodeSplits(
+        leftNodeAgg: Array[Array[Double]],
+        rightNodeAgg: Array[Array[Double]],
+        nodeImpurity: Double): Array[Array[InformationGainStats]] = {
+      val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
+
+      for (featureIndex <- 0 until numFeatures) {
+        for (splitIndex <- 0 until numBins - 1) {
+          gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
+            splitIndex, rightNodeAgg, nodeImpurity)
+        }
+      }
+      gains
+    }
+
+    /**
+     * Find the best split for a node.
+     * @param binData Array[Double] of size 2 * numSplits * numFeatures
+     * @param nodeImpurity impurity of the top node
+     * @return tuple of split and information gain
+     */
+    def binsToBestSplit(
+        binData: Array[Double],
+        nodeImpurity: Double): (Split, InformationGainStats) = {
+
+      logDebug("node impurity = " + nodeImpurity)
+
+      // Extract left right node aggregates.
+      val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
+
+      // Calculate gains for all splits.
+      val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
+
+      val (bestFeatureIndex,bestSplitIndex, gainStats) = {
+        // Initialize with infeasible values.
+        var bestFeatureIndex = Int.MinValue
+        var bestSplitIndex = Int.MinValue
+        var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
+        // Iterate over features.
+        var featureIndex = 0
+        while (featureIndex < numFeatures) {
+          // Iterate over all splits.
+          var splitIndex = 0
+          while (splitIndex < numBins - 1) {
+            val gainStats = gains(featureIndex)(splitIndex)
+            if (gainStats.gain > bestGainStats.gain) {
+              bestGainStats = gainStats
+              bestFeatureIndex = featureIndex
+              bestSplitIndex = splitIndex
+            }
+            splitIndex += 1
+          }
+          featureIndex += 1
+        }
+        (bestFeatureIndex, bestSplitIndex, bestGainStats)
+      }
+
+      logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
+      logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
+
+      (splits(bestFeatureIndex)(bestSplitIndex), gainStats)
+    }
+
+    /**
+     * Get bin data for one node.
+     */
+    def getBinDataForNode(node: Int): Array[Double] = {
+      strategy.algo match {
+        case Classification =>
+          val shift = 2 * node * numBins * numFeatures
+          val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
+          binsForNode
+        case Regression =>
+          val shift = 3 * node * numBins * numFeatures
+          val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
+          binsForNode
+      }
+    }
+
+    // Calculate best splits for all nodes at a given level
+    val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+    // Iterating over all nodes at this level
+    var node = 0
+    while (node < numNodes) {
+      val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
+      val binsForNode: Array[Double] = getBinDataForNode(node)
+      logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
+      val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
+      logDebug("node impurity = " + parentNodeImpurity)
+      bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
+      node += 1
+    }
+
+    bestSplits
+  }
+
+  /**
+   * Returns split and bins for decision tree calculation.
+   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+   *              for DecisionTree
+   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+   *                parameters for construction the DecisionTree
+   * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree
+   *         .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
+   *         .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
+   */
+  protected[tree] def findSplitsBins(
+      input: RDD[LabeledPoint],
+      strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+    val count = input.count()
+
+    // Find the number of features by looking at the first sample
+    val numFeatures = input.take(1)(0).features.length
+
+    val maxBins = strategy.maxBins
+    val numBins = if (maxBins <= count) maxBins else count.toInt
+    logDebug("numBins = " + numBins)
+
+    /*
+     * TODO: Add a require statement ensuring #bins is always greater than the categories.
+     * It's a limitation of the current implementation but a reasonable trade-off since features
+     * with large number of categories get favored over continuous features.
+     */
+    if (strategy.categoricalFeaturesInfo.size > 0) {
+      val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
+      require(numBins >= maxCategoriesForFeatures)
+    }
+
+    // Calculate the number of sample for approximate quantile calculation.
+    val requiredSamples = numBins*numBins
+    val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
+    logDebug("fraction of data used for calculating quantiles = " + fraction)
+
+    // sampled input for RDD calculation
+    val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect()
+    val numSamples = sampledInput.length
+
+    val stride: Double = numSamples.toDouble / numBins
+    logDebug("stride = " + stride)
+
+    strategy.quantileCalculationStrategy match {
+      case Sort =>
+        val splits = Array.ofDim[Split](numFeatures, numBins - 1)
+        val bins = Array.ofDim[Bin](numFeatures, numBins)
+
+        // Find all splits.
+
+        // Iterate over all features.
+        var featureIndex = 0
+        while (featureIndex < numFeatures){
+          // Check whether the feature is continuous.
+          val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+          if (isFeatureContinuous) {
+            val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
+            val stride: Double = numSamples.toDouble / numBins
+            logDebug("stride = " + stride)
+            for (index <- 0 until numBins - 1) {
+              val sampleIndex = (index + 1) * stride.toInt
+              val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
+              splits(featureIndex)(index) = split
+            }
+          } else {
+            val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
+            require(maxFeatureValue < numBins, "number of categories should be less than number " +
+              "of bins")
+
+            // For categorical variables, each bin is a category. The bins are sorted and they
+            // are ordered by calculating the centroid of their corresponding labels.
+            val centroidForCategories =
+              sampledInput.map(lp => (lp.features(featureIndex),lp.label))
+                .groupBy(_._1)
+                .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
+
+            // Check for missing categorical variables and putting them last in the sorted list.
+            val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
+            for (i <- 0 until maxFeatureValue) {
+              if (centroidForCategories.contains(i)) {
+                fullCentroidForCategories(i) = centroidForCategories(i)
+              } else {
+                fullCentroidForCategories(i) = Double.MaxValue
+              }
+            }
+
+            // bins sorted by centroids
+            val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
+
+            logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+
+            var categoriesForSplit = List[Double]()
+            categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
+              case ((key, value), index) =>
+                categoriesForSplit = key :: categoriesForSplit
+                splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
+                  categoriesForSplit)
+                bins(featureIndex)(index) = {
+                  if (index == 0) {
+                    new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
+                      splits(featureIndex)(0), Categorical, key)
+                  } else {
+                    new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+                      Categorical, key)
+                  }
+                }
+            }
+          }
+          featureIndex += 1
+        }
+
+        // Find all bins.
+        featureIndex = 0
+        while (featureIndex < numFeatures) {
+          val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+          if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
+            bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
+              splits(featureIndex)(0), Continuous, Double.MinValue)
+            for (index <- 1 until numBins - 1){
+              val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+                Continuous, Double.MinValue)
+              bins(featureIndex)(index) = bin
+            }
+            bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),
+              new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
+          }
+          featureIndex += 1
+        }
+        (splits,bins)
+      case MinMax =>
+        throw new UnsupportedOperationException("minmax not supported yet.")
+      case ApproxHist =>
+        throw new UnsupportedOperationException("approximate histogram not supported yet.")
+    }
+  }
+
+  val usage = """
+    Usage: DecisionTreeRunner <master>[slices] --algo <Classification,
+    Regression> --trainDataDir path --testDataDir path --maxDepth num [--impurity <Gini,Entropy,
+    Variance>] [--maxBins num]
+              """
+
+  def main(args: Array[String]) {
+
+    if (args.length < 2) {
+      System.err.println(usage)
+      System.exit(1)
+    }
+
+    val sc = new SparkContext(args(0), "DecisionTree")
+
+    val argList = args.toList.drop(1)
+    type OptionMap = Map[Symbol, Any]
+
+    def nextOption(map : OptionMap, list: List[String]): OptionMap = {
+      list match {
+        case Nil => map
+        case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail)
+        case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
+        case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
+        case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
+        case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string)
+          , tail)
+        case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string),
+          tail)
+        case string :: Nil =>  nextOption(map ++ Map('infile -> string), list.tail)
+        case option :: tail => logError("Unknown option " + option)
+          sys.exit(1)
+      }
+    }
+    val options = nextOption(Map(), argList)
+    logDebug(options.toString())
+
+    // Load training data.
+    val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
+
+    // Identify the type of algorithm.
+    val algoStr =  options.get('algo).get.toString
+    val algo = algoStr match {
+      case "Classification" => Classification
+      case "Regression" => Regression
+    }
+
+    // Identify the type of impurity.
+    val impurityStr = options.getOrElse('impurity,
+      if (algo == Classification) "Gini" else "Variance").toString
+    val impurity = impurityStr match {
+      case "Gini" => Gini
+      case "Entropy" => Entropy
+      case "Variance" => Variance
+    }
+
+    val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
+    val maxBins = options.getOrElse('maxBins, "100").toString.toInt
+
+    val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
+    val model = DecisionTree.train(trainData, strategy)
+
+    // Load test data.
+    val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
+
+    // Measure algorithm accuracy
+    if (algo == Classification) {
+      val accuracy = accuracyScore(model, testData)
+      logDebug("accuracy = " + accuracy)
+    }
+
+    if (algo == Regression) {
+      val mse = meanSquaredError(model, testData)
+      logDebug("mean square error = " + mse)
+    }
+
+    sc.stop()
+  }
+
+  /**
+   * Load labeled data from a file. The data format used here is
+   * <L>, <f1> <f2> ...,
+   * where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
+   *
+   * @param sc SparkContext
+   * @param dir Directory to the input data files.
+   * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
+   *         the label, and the second element represents the feature values (an array of Double).
+   */
+  def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
+    sc.textFile(dir).map { line =>
+      val parts = line.trim().split(",")
+      val label = parts(0).toDouble
+      val features = parts.slice(1,parts.length).map(_.toDouble)
+      LabeledPoint(label, features)
+    }
+  }
+
+  // TODO: Port this method to a generic metrics package.
+  /**
+   * Calculates the classifier accuracy.
+   */
+  private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
+      threshold: Double = 0.5): Double = {
+    def predictedValue(features: Array[Double]) = {
+      if (model.predict(features) < threshold) 0.0 else 1.0
+    }
+    val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+    val count = data.count()
+    logDebug("correct prediction count = " +  correctCount)
+    logDebug("data count = " + count)
+    correctCount.toDouble / count
+  }
+
+  // TODO: Port this method to a generic metrics package
+  /**
+   * Calculates the mean squared error for regression.
+   */
+  private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
+    data.map { y =>
+      val err = tree.predict(y.features) - y.label
+      err * err
+    }.mean()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
new file mode 100644
index 0000000..0fd71aa
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
@@ -0,0 +1,17 @@
+This package contains the default implementation of the decision tree algorithm.
+
+The decision tree algorithm supports:
++ Binary classification
++ Regression
++ Information loss calculation with entropy and gini for classification and variance for regression
++ Both continuous and categorical features
+
+# Tree improvements
++ Node model pruning
++ Printing to dot files
+
+# Future Ensemble Extensions
+
++ Random forests
++ Boosting
++ Extremely randomized trees

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
new file mode 100644
index 0000000..2dd1f0f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to select the algorithm for the decision tree
+ */
+object Algo extends Enumeration {
+  type Algo = Value
+  val Classification, Regression = Value
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
new file mode 100644
index 0000000..09ee058
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to describe whether a feature is "continuous" or "categorical"
+ */
+object FeatureType extends Enumeration {
+  type FeatureType = Value
+  val Continuous, Categorical = Value
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
new file mode 100644
index 0000000..2457a48
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum for selecting the quantile calculation strategy
+ */
+object QuantileStrategy extends Enumeration {
+  type QuantileStrategy = Value
+  val Sort, MinMax, ApproxHist = Value
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
new file mode 100644
index 0000000..df565f3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+
+/**
+ * Stores all the configuration options for tree construction
+ * @param algo classification or regression
+ * @param impurity criterion used for information gain calculation
+ * @param maxDepth maximum depth of the tree
+ * @param maxBins maximum number of bins used for splitting features
+ * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
+ *                                number of discrete values they take. For example, an entry (n ->
+ *                                k) implies the feature n is categorical with k categories 0,
+ *                                1, 2, ... , k-1. It's important to note that features are
+ *                                zero-indexed.
+ */
+class Strategy (
+    val algo: Algo,
+    val impurity: Impurity,
+    val maxDepth: Int,
+    val maxBins: Int = 100,
+    val quantileCalculationStrategy: QuantileStrategy = Sort,
+    val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
new file mode 100644
index 0000000..b93995f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
+ * binary classification.
+ */
+object Entropy extends Impurity {
+
+   def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
+
+  /**
+   * entropy calculation
+   * @param c0 count of instances with label 0
+   * @param c1 count of instances with label 1
+   * @return entropy value
+   */
+   def calculate(c0: Double, c1: Double): Double = {
+     if (c0 == 0 || c1 == 0) {
+       0
+     } else {
+       val total = c0 + c1
+       val f0 = c0 / total
+       val f1 = c1 / total
+       -(f0 * log2(f0)) - (f1 * log2(f1))
+     }
+   }
+
+  def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+    throw new UnsupportedOperationException("Entropy.calculate")
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
new file mode 100644
index 0000000..c040755
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating the
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
+ * during binary classification.
+ */
+object Gini extends Impurity {
+
+  /**
+   * Gini coefficient calculation
+   * @param c0 count of instances with label 0
+   * @param c1 count of instances with label 1
+   * @return Gini coefficient value
+   */
+  override def calculate(c0: Double, c1: Double): Double = {
+    if (c0 == 0 || c1 == 0) {
+      0
+    } else {
+      val total = c0 + c1
+      val f0 = c0 / total
+      val f1 = c1 / total
+      1 - f0 * f0 - f1 * f1
+    }
+  }
+
+  def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+    throw new UnsupportedOperationException("Gini.calculate")
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
new file mode 100644
index 0000000..a406906
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Trait for calculating information gain.
+ */
+trait Impurity extends Serializable {
+
+  /**
+   * information calculation for binary classification
+   * @param c0 count of instances with label 0
+   * @param c1 count of instances with label 1
+   * @return information value
+   */
+  def calculate(c0 : Double, c1 : Double): Double
+
+  /**
+   * information calculation for regression
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   * @return information value
+   */
+  def calculate(count: Double, sum: Double, sumSquares: Double): Double
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
new file mode 100644
index 0000000..b74577d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating variance during regression
+ */
+object Variance extends Impurity {
+   override def calculate(c0: Double, c1: Double): Double =
+     throw new UnsupportedOperationException("Variance.calculate")
+
+  /**
+   * variance calculation
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   */
+  override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
+    val squaredLoss = sumSquares - (sum * sum) / count
+    squaredLoss / count
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
new file mode 100644
index 0000000..a57faa1
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+/**
+ * Used for "binning" the features bins for faster best split calculation. For a continuous
+ * feature, a bin is determined by a low and a high "split". For a categorical feature,
+ * the a bin is determined using a single label value (category).
+ * @param lowSplit signifying the lower threshold for the continuous feature to be
+ *                 accepted in the bin
+ * @param highSplit signifying the upper threshold for the continuous feature to be
+ *                 accepted in the bin
+ * @param featureType type of feature -- categorical or continuous
+ * @param category categorical label value accepted in the bin
+ */
+case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
new file mode 100644
index 0000000..a8bbf21
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Model to store the decision tree parameters
+ * @param topNode root node
+ * @param algo algorithm type -- classification or regression
+ */
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+
+  /**
+   * Predict values for a single data point using the model trained.
+   *
+   * @param features array representing a single data point
+   * @return Double prediction from the trained model
+   */
+  def predict(features: Array[Double]): Double = {
+    topNode.predictIfLeaf(features)
+  }
+
+  /**
+   * Predict values for the given data set using the model trained.
+   *
+   * @param features RDD representing data points to be predicted
+   * @return RDD[Int] where each entry contains the corresponding prediction
+   */
+  def predict(features: RDD[Array[Double]]): RDD[Double] = {
+    features.map(x => predict(x))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
new file mode 100644
index 0000000..ebc9595
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+/**
+ * Filter specifying a split and type of comparison to be applied on features
+ * @param split split specifying the feature index, type and threshold
+ * @param comparison integer specifying <,=,>
+ */
+case class Filter(split: Split, comparison: Int) {
+  // Comparison -1,0,1 signifies <.=,>
+  override def toString = " split = " + split + "comparison = " + comparison
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
new file mode 100644
index 0000000..99bf79c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+/**
+ * Information gain statistics for each split
+ * @param gain information gain value
+ * @param impurity current node impurity
+ * @param leftImpurity left node impurity
+ * @param rightImpurity right node impurity
+ * @param predict predicted value
+ */
+class InformationGainStats(
+    val gain: Double,
+    val impurity: Double,
+    val leftImpurity: Double,
+    val rightImpurity: Double,
+    val predict: Double) extends Serializable {
+
+  override def toString = {
+    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
+      .format(gain, impurity, leftImpurity, rightImpurity, predict)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
new file mode 100644
index 0000000..ea4693c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.Logging
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+/**
+ * Node in a decision tree
+ * @param id integer node id
+ * @param predict predicted value at the node
+ * @param isLeaf whether the leaf is a node
+ * @param split split to calculate left and right nodes
+ * @param leftNode  left child
+ * @param rightNode right child
+ * @param stats information gain stats
+ */
+class Node (
+    val id: Int,
+    val predict: Double,
+    val isLeaf: Boolean,
+    val split: Option[Split],
+    var leftNode: Option[Node],
+    var rightNode: Option[Node],
+    val stats: Option[InformationGainStats]) extends Serializable with Logging {
+
+  override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
+    "split = " + split + ", stats = " + stats
+
+  /**
+   * build the left node and right nodes if not leaf
+   * @param nodes array of nodes
+   */
+  def build(nodes: Array[Node]): Unit = {
+
+    logDebug("building node " + id + " at level " +
+      (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+    logDebug("id = " + id + ", split = " + split)
+    logDebug("stats = " + stats)
+    logDebug("predict = " + predict)
+    if (!isLeaf) {
+      val leftNodeIndex = id*2 + 1
+      val rightNodeIndex = id*2 + 2
+      leftNode = Some(nodes(leftNodeIndex))
+      rightNode = Some(nodes(rightNodeIndex))
+      leftNode.get.build(nodes)
+      rightNode.get.build(nodes)
+    }
+  }
+
+  /**
+   * predict value if node is not leaf
+   * @param feature feature value
+   * @return predicted value
+   */
+  def predictIfLeaf(feature: Array[Double]) : Double = {
+    if (isLeaf) {
+      predict
+    } else{
+      if (split.get.featureType == Continuous) {
+        if (feature(split.get.feature) <= split.get.threshold) {
+          leftNode.get.predictIfLeaf(feature)
+        } else {
+          rightNode.get.predictIfLeaf(feature)
+        }
+      } else {
+        if (split.get.categories.contains(feature(split.get.feature))) {
+          leftNode.get.predictIfLeaf(feature)
+        } else {
+          rightNode.get.predictIfLeaf(feature)
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8b3045ce/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
new file mode 100644
index 0000000..4e64a81
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+
+/**
+ * Split applied to a feature
+ * @param feature feature index
+ * @param threshold threshold for continuous feature
+ * @param featureType type of feature -- categorical or continuous
+ * @param categories accepted values for categorical variables
+ */
+case class Split(
+    feature: Int,
+    threshold: Double,
+    featureType: FeatureType,
+    categories: List[Double]){
+
+  override def toString =
+    "Feature = " + feature + ", threshold = " + threshold + ", featureType =  " + featureType +
+      ", categories = " + categories
+}
+
+/**
+ * Split with minimum threshold for continuous features. Helps with the smallest bin creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyLowSplit(feature: Int, featureType: FeatureType)
+  extends Split(feature, Double.MinValue, featureType, List())
+
+/**
+ * Split with maximum threshold for continuous features. Helps with the highest bin creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyHighSplit(feature: Int, featureType: FeatureType)
+  extends Split(feature, Double.MaxValue, featureType, List())
+
+/**
+ * Split with no acceptable feature values for categorical features. Helps with the first bin
+ * creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
+  extends Split(feature, Double.MaxValue, featureType, List())
+


Mime
View raw message