spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject [2/2] spark git commit: [SPARK-14308][ML][MLLIB] Remove unused mllib tree classes and move private classes to ML
Date Sat, 02 Apr 2016 04:23:59 GMT
[SPARK-14308][ML][MLLIB] Remove unused mllib tree classes and move private classes to ML

## What changes were proposed in this pull request?

Decision tree helper classes will be migrated to ML. This patch moves those internal classes that are not part of the public API and removes ones that are no longer used, after [SPARK-12183](https://github.com/apache/spark/pull/11855). No functional changes are made.

Details:
* Bin.scala is removed as the ML implementation does not require bins
* mllib NodeIdCache is removed. It was only used by the mllib implementation previously, which no longer exists
* mllib TreePoint is removed. It was only used by the mllib implementation previously, which no longer exists
* BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, BaggedPointSuite and TimeTracker are all moved to ML.

## How was this patch tested?

No functional changes are made. Existing unit tests ensure behavior is unchanged.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #12097 from sethah/cleanup_mllib_tree.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4fc35e6f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4fc35e6f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4fc35e6f

Branch: refs/heads/master
Commit: 4fc35e6f5c590feb47cbcb5b1136f2e985677b3f
Parents: 36e8fb8
Author: sethah <seth.hendrickson16@gmail.com>
Authored: Fri Apr 1 21:23:35 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Fri Apr 1 21:23:35 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/tree/impl/BaggedPoint.scala | 125 +++++++++++
 .../spark/ml/tree/impl/DTStatsAggregator.scala  | 181 ++++++++++++++++
 .../ml/tree/impl/DecisionTreeMetadata.scala     | 217 +++++++++++++++++++
 .../ml/tree/impl/GradientBoostedTrees.scala     |   1 -
 .../apache/spark/ml/tree/impl/NodeIdCache.scala |   1 -
 .../spark/ml/tree/impl/RandomForest.scala       |   4 +-
 .../apache/spark/ml/tree/impl/TimeTracker.scala |  70 ++++++
 .../apache/spark/ml/tree/impl/TreePoint.scala   |   1 -
 .../spark/mllib/tree/GradientBoostedTrees.scala |   3 +-
 .../spark/mllib/tree/impl/BaggedPoint.scala     | 125 -----------
 .../mllib/tree/impl/DTStatsAggregator.scala     | 178 ---------------
 .../mllib/tree/impl/DecisionTreeMetadata.scala  | 217 -------------------
 .../spark/mllib/tree/impl/NodeIdCache.scala     | 195 -----------------
 .../spark/mllib/tree/impl/TimeTracker.scala     |  70 ------
 .../spark/mllib/tree/impl/TreePoint.scala       | 150 -------------
 .../spark/mllib/tree/impurity/Entropy.scala     |   2 +-
 .../apache/spark/mllib/tree/impurity/Gini.scala |   2 +-
 .../spark/mllib/tree/impurity/Variance.scala    |   2 +-
 .../org/apache/spark/mllib/tree/model/Bin.scala |  47 ----
 .../spark/ml/tree/impl/BaggedPointSuite.scala   |  99 +++++++++
 .../spark/ml/tree/impl/RandomForestSuite.scala  |   1 -
 .../spark/mllib/tree/DecisionTreeSuite.scala    |   2 +-
 .../mllib/tree/impl/BaggedPointSuite.scala      |  99 ---------
 23 files changed, 699 insertions(+), 1093 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
new file mode 100644
index 0000000..4e37270
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.commons.math3.distribution.PoissonDistribution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Internal representation of a datapoint which belongs to several subsamples of the same dataset,
+ * particularly for bagging (e.g., for random forests).
+ *
+ * This holds one instance, as well as an array of weights which represent the (weighted)
+ * number of times which this instance appears in each subsamplingRate.
+ * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
+ * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
+ *
+ * @param datum  Data instance
+ * @param subsampleWeights  Weight of this instance in each subsampled dataset.
+ *
+ * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
+ *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
+ */
+private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
+  extends Serializable
+
+private[spark] object BaggedPoint {
+
+  /**
+   * Convert an input dataset into its BaggedPoint representation,
+   * choosing subsamplingRate counts for each instance.
+   * Each subsamplingRate has the same number of instances as the original dataset,
+   * and is created by subsampling without replacement.
+   * @param input Input dataset.
+   * @param subsamplingRate Fraction of the training data used for learning decision tree.
+   * @param numSubsamples Number of subsamples of this RDD to take.
+   * @param withReplacement Sampling with/without replacement.
+   * @param seed Random seed.
+   * @return BaggedPoint dataset representation.
+   */
+  def convertToBaggedRDD[Datum] (
+      input: RDD[Datum],
+      subsamplingRate: Double,
+      numSubsamples: Int,
+      withReplacement: Boolean,
+      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
+    if (withReplacement) {
+      convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
+    } else {
+      if (numSubsamples == 1 && subsamplingRate == 1.0) {
+        convertToBaggedRDDWithoutSampling(input)
+      } else {
+        convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
+      }
+    }
+  }
+
+  private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+      input: RDD[Datum],
+      subsamplingRate: Double,
+      numSubsamples: Int,
+      seed: Long): RDD[BaggedPoint[Datum]] = {
+    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+      val rng = new XORShiftRandom
+      rng.setSeed(seed + partitionIndex + 1)
+      instances.map { instance =>
+        val subsampleWeights = new Array[Double](numSubsamples)
+        var subsampleIndex = 0
+        while (subsampleIndex < numSubsamples) {
+          val x = rng.nextDouble()
+          subsampleWeights(subsampleIndex) = {
+            if (x < subsamplingRate) 1.0 else 0.0
+          }
+          subsampleIndex += 1
+        }
+        new BaggedPoint(instance, subsampleWeights)
+      }
+    }
+  }
+
+  private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+      input: RDD[Datum],
+      subsample: Double,
+      numSubsamples: Int,
+      seed: Long): RDD[BaggedPoint[Datum]] = {
+    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+      val poisson = new PoissonDistribution(subsample)
+      poisson.reseedRandomGenerator(seed + partitionIndex + 1)
+      instances.map { instance =>
+        val subsampleWeights = new Array[Double](numSubsamples)
+        var subsampleIndex = 0
+        while (subsampleIndex < numSubsamples) {
+          subsampleWeights(subsampleIndex) = poisson.sample()
+          subsampleIndex += 1
+        }
+        new BaggedPoint(instance, subsampleWeights)
+      }
+    }
+  }
+
+  private def convertToBaggedRDDWithoutSampling[Datum] (
+      input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+    input.map(datum => new BaggedPoint(datum, Array(1.0)))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
new file mode 100644
index 0000000..61091bb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.mllib.tree.impurity._
+
+
+
+/**
+ * DecisionTree statistics aggregator for a node.
+ * This holds a flat array of statistics for a set of (features, bins)
+ * and helps with indexing.
+ * This class is abstract to support learning with and without feature subsampling.
+ */
+private[spark] class DTStatsAggregator(
+    val metadata: DecisionTreeMetadata,
+    featureSubset: Option[Array[Int]]) extends Serializable {
+
+  /**
+   * [[ImpurityAggregator]] instance specifying the impurity type.
+   */
+  val impurityAggregator: ImpurityAggregator = metadata.impurity match {
+    case Gini => new GiniAggregator(metadata.numClasses)
+    case Entropy => new EntropyAggregator(metadata.numClasses)
+    case Variance => new VarianceAggregator()
+    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
+  }
+
+  /**
+   * Number of elements (Double values) used for the sufficient statistics of each bin.
+   */
+  private val statsSize: Int = impurityAggregator.statsSize
+
+  /**
+   * Number of bins for each feature.  This is indexed by the feature index.
+   */
+  private val numBins: Array[Int] = {
+    if (featureSubset.isDefined) {
+      featureSubset.get.map(metadata.numBins(_))
+    } else {
+      metadata.numBins
+    }
+  }
+
+  /**
+   * Offset for each feature for calculating indices into the [[allStats]] array.
+   */
+  private val featureOffsets: Array[Int] = {
+    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+  }
+
+  /**
+   * Total number of elements stored in this aggregator
+   */
+  private val allStatsSize: Int = featureOffsets.last
+
+  /**
+   * Flat array of elements.
+   * Index for start of stats for a (feature, bin) is:
+   *   index = featureOffsets(featureIndex) + binIndex * statsSize
+   */
+  private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+  /**
+   * Array of parent node sufficient stats.
+   *
+   * Note: this is necessary because stats for the parent node are not available
+   *       on the first iteration of tree learning.
+   */
+  private val parentStats: Array[Double] = new Array[Double](statsSize)
+
+  /**
+   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+   *
+   * @param featureOffset  This is a pre-computed (node, feature) offset
+   *                           from [[getFeatureOffset]].
+   */
+  def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
+    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
+  }
+
+  /**
+   * Get an [[ImpurityCalculator]] for the parent node.
+   */
+  def getParentImpurityCalculator(): ImpurityCalculator = {
+    impurityAggregator.getCalculator(parentStats, 0)
+  }
+
+  /**
+   * Update the stats for a given (feature, bin) for ordered features, using the given label.
+   */
+  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+    val i = featureOffsets(featureIndex) + binIndex * statsSize
+    impurityAggregator.update(allStats, i, label, instanceWeight)
+  }
+
+  /**
+   * Update the parent node stats using the given label.
+   */
+  def updateParent(label: Double, instanceWeight: Double): Unit = {
+    impurityAggregator.update(parentStats, 0, label, instanceWeight)
+  }
+
+  /**
+   * Faster version of [[update]].
+   * Update the stats for a given (feature, bin), using the given label.
+   *
+   * @param featureOffset  This is a pre-computed feature offset
+   *                           from [[getFeatureOffset]].
+   */
+  def featureUpdate(
+      featureOffset: Int,
+      binIndex: Int,
+      label: Double,
+      instanceWeight: Double): Unit = {
+    impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
+      label, instanceWeight)
+  }
+
+  /**
+   * Pre-compute feature offset for use with [[featureUpdate]].
+   * For ordered features only.
+   */
+  def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
+
+  /**
+   * For a given feature, merge the stats for two bins.
+   *
+   * @param featureOffset  This is a pre-computed feature offset
+   *                           from [[getFeatureOffset]].
+   * @param binIndex  The other bin is merged into this bin.
+   * @param otherBinIndex  This bin is not modified.
+   */
+  def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+    impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
+      featureOffset + otherBinIndex * statsSize)
+  }
+
+  /**
+   * Merge this aggregator with another, and returns this aggregator.
+   * This method modifies this aggregator in-place.
+   */
+  def merge(other: DTStatsAggregator): DTStatsAggregator = {
+    require(allStatsSize == other.allStatsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+        + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+    var i = 0
+    // TODO: Test BLAS.axpy
+    while (i < allStatsSize) {
+      allStats(i) += other.allStats(i)
+      i += 1
+    }
+
+    require(statsSize == other.statsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
+        s"stats vectors. This aggregator's parent stats are length $statsSize, " +
+        s"but the other is ${other.statsSize}.")
+    var j = 0
+    while (j < statsSize) {
+      parentStats(j) += other.parentStats(j)
+      j += 1
+    }
+
+    this
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
new file mode 100644
index 0000000..df8eb5d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.rdd.RDD
+
+/**
+ * Learning and dataset metadata for DecisionTree.
+ *
+ * @param numClasses    For classification: labels can take values {0, ..., numClasses - 1}.
+ *                      For regression: fixed at 0 (no meaning).
+ * @param maxBins  Maximum number of bins, for all features.
+ * @param featureArity  Map: categorical feature index --> arity.
+ *                      I.e., the feature takes values in {0, ..., arity - 1}.
+ * @param numBins  Number of bins for each feature.
+ */
+private[spark] class DecisionTreeMetadata(
+    val numFeatures: Int,
+    val numExamples: Long,
+    val numClasses: Int,
+    val maxBins: Int,
+    val featureArity: Map[Int, Int],
+    val unorderedFeatures: Set[Int],
+    val numBins: Array[Int],
+    val impurity: Impurity,
+    val quantileStrategy: QuantileStrategy,
+    val maxDepth: Int,
+    val minInstancesPerNode: Int,
+    val minInfoGain: Double,
+    val numTrees: Int,
+    val numFeaturesPerNode: Int) extends Serializable {
+
+  def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
+
+  def isClassification: Boolean = numClasses >= 2
+
+  def isMulticlass: Boolean = numClasses > 2
+
+  def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
+
+  def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
+
+  def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+
+  /**
+   * Number of splits for the given feature.
+   * For unordered features, there is 1 bin per split.
+   * For ordered features, there is 1 more bin than split.
+   */
+  def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
+    numBins(featureIndex)
+  } else {
+    numBins(featureIndex) - 1
+  }
+
+
+  /**
+   * Set number of splits for a continuous feature.
+   * For a continuous feature, number of bins is number of splits plus 1.
+   */
+  def setNumSplits(featureIndex: Int, numSplits: Int) {
+    require(isContinuous(featureIndex),
+      s"Only number of bin for a continuous feature can be set.")
+    numBins(featureIndex) = numSplits + 1
+  }
+
+  /**
+   * Indicates if feature subsampling is being used.
+   */
+  def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode
+
+}
+
+private[spark] object DecisionTreeMetadata extends Logging {
+
+  /**
+   * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
+   * This computes which categorical features will be ordered vs. unordered,
+   * as well as the number of splits and bins for each feature.
+   */
+  def buildMetadata(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String): DecisionTreeMetadata = {
+
+    val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
+      throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
+        s"but was given by empty one.")
+    }
+    val numExamples = input.count()
+    val numClasses = strategy.algo match {
+      case Classification => strategy.numClasses
+      case Regression => 0
+    }
+
+    val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+    if (maxPossibleBins < strategy.maxBins) {
+      logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
+        s" (= number of training instances)")
+    }
+
+    // We check the number of bins here against maxPossibleBins.
+    // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
+    // based on the number of training examples.
+    if (strategy.categoricalFeaturesInfo.nonEmpty) {
+      val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+      val maxCategory =
+        strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
+      require(maxCategoriesPerFeature <= maxPossibleBins,
+        s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
+        s"number of values in each categorical feature, but categorical feature $maxCategory " +
+        s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
+        "features with a large number of values, or add more training examples.")
+    }
+
+    val unorderedFeatures = new mutable.HashSet[Int]()
+    val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
+    if (numClasses > 2) {
+      // Multiclass classification
+      val maxCategoriesForUnorderedFeature =
+        ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
+      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+        // Hack: If a categorical feature has only 1 category, we treat it as continuous.
+        // TODO(SPARK-9957): Handle this properly by filtering out those features.
+        if (numCategories > 1) {
+          // Decide if some categorical features should be treated as unordered features,
+          //  which require 2 * ((1 << numCategories - 1) - 1) bins.
+          // We do this check with log values to prevent overflows in case numCategories is large.
+          // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+          if (numCategories <= maxCategoriesForUnorderedFeature) {
+            unorderedFeatures.add(featureIndex)
+            numBins(featureIndex) = numUnorderedBins(numCategories)
+          } else {
+            numBins(featureIndex) = numCategories
+          }
+        }
+      }
+    } else {
+      // Binary classification or regression
+      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
+        // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
+        if (numCategories > 1) {
+          numBins(featureIndex) = numCategories
+        }
+      }
+    }
+
+    // Set number of features to use per node (for random forests).
+    val _featureSubsetStrategy = featureSubsetStrategy match {
+      case "auto" =>
+        if (numTrees == 1) {
+          "all"
+        } else {
+          if (strategy.algo == Classification) {
+            "sqrt"
+          } else {
+            "onethird"
+          }
+        }
+      case _ => featureSubsetStrategy
+    }
+    val numFeaturesPerNode: Int = _featureSubsetStrategy match {
+      case "all" => numFeatures
+      case "sqrt" => math.sqrt(numFeatures).ceil.toInt
+      case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
+      case "onethird" => (numFeatures / 3.0).ceil.toInt
+    }
+
+    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
+      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+      strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
+      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
+  }
+
+  /**
+   * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
+   */
+  def buildMetadata(
+      input: RDD[LabeledPoint],
+      strategy: Strategy): DecisionTreeMetadata = {
+    buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
+  }
+
+    /**
+   * Given the arity of a categorical feature (arity = number of categories),
+   * return the number of bins for the feature if it is to be treated as an unordered feature.
+   * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
+   * there are math.pow(2, arity - 1) - 1 such splits.
+   * Each split has 2 corresponding bins.
+   */
+  def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index b37f4e8..0749d93 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
-import org.apache.spark.mllib.tree.impl.TimeTracker
 import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
 import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
 import org.apache.spark.rdd.RDD

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index 2c82867..9d697a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.tree.{LearningNode, Split}
-import org.apache.spark.mllib.tree.impl.BaggedPoint
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index cccf052..7b1fd08 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -28,8 +28,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
-  TimeTracker}
 import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
 import org.apache.spark.mllib.tree.model.ImpurityStats
 import org.apache.spark.rdd.RDD
@@ -330,7 +328,7 @@ private[spark] object RandomForest extends Logging {
   /**
    * Given a group of nodes, this finds the best split for each node.
    *
-   * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+   * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
    * @param metadata Learning and dataset metadata
    * @param topNodes Root node for each tree.  Used for matching instances with nodes.
    * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
new file mode 100644
index 0000000..4cc250a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import scala.collection.mutable.{HashMap => MutableHashMap}
+
+/**
+ * Time tracker implementation which holds labeled timers.
+ */
+private[spark] class TimeTracker extends Serializable {
+
+  private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
+
+  private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
+
+  /**
+   * Starts a new timer, or re-starts a stopped timer.
+   */
+  def start(timerLabel: String): Unit = {
+    val currentTime = System.nanoTime()
+    if (starts.contains(timerLabel)) {
+      throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
+        s" timerLabel = $timerLabel before that timer was stopped.")
+    }
+    starts(timerLabel) = currentTime
+  }
+
+  /**
+   * Stops a timer and returns the elapsed time in seconds.
+   */
+  def stop(timerLabel: String): Double = {
+    val currentTime = System.nanoTime()
+    if (!starts.contains(timerLabel)) {
+      throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
+        s" timerLabel = $timerLabel, but that timer was not started.")
+    }
+    val elapsed = currentTime - starts(timerLabel)
+    starts.remove(timerLabel)
+    if (totals.contains(timerLabel)) {
+      totals(timerLabel) += elapsed
+    } else {
+      totals(timerLabel) = elapsed
+    }
+    elapsed / 1e9
+  }
+
+  /**
+   * Print all timing results in seconds.
+   */
+  override def toString: String = {
+    totals.map { case (label, elapsed) =>
+        s"  $label: ${elapsed / 1e9}"
+      }.mkString("\n")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index 9fa27e5..3a2bf3c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
 
 import org.apache.spark.ml.tree.{ContinuousSplit, Split}
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
 import org.apache.spark.rdd.RDD
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index d166dc7..0f0c6b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,11 +20,11 @@ package org.apache.spark.mllib.tree
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml.tree.impl.TimeTracker
 import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.impl.TimeTracker
 import org.apache.spark.mllib.tree.impurity.Variance
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
 import org.apache.spark.rdd.RDD
@@ -165,6 +165,7 @@ object GradientBoostedTrees extends Logging {
 
   /**
    * Internal method for performing regression using trees as base learners.
+   *
    * @param input Training dataset.
    * @param validationInput Validation dataset, ignored if validate is set to false.
    * @param boostingStrategy Boosting parameters.

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
deleted file mode 100644
index 572815d..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ /dev/null
@@ -1,125 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import org.apache.commons.math3.distribution.PoissonDistribution
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
-import org.apache.spark.util.random.XORShiftRandom
-
-/**
- * Internal representation of a datapoint which belongs to several subsamples of the same dataset,
- * particularly for bagging (e.g., for random forests).
- *
- * This holds one instance, as well as an array of weights which represent the (weighted)
- * number of times which this instance appears in each subsamplingRate.
- * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
- * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
- *
- * @param datum  Data instance
- * @param subsampleWeights  Weight of this instance in each subsampled dataset.
- *
- * TODO: This does not currently support (Double) weighted instances.  Once MLlib has weighted
- *       dataset support, update.  (We store subsampleWeights as Double for this future extension.)
- */
-private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
-  extends Serializable
-
-private[spark] object BaggedPoint {
-
-  /**
-   * Convert an input dataset into its BaggedPoint representation,
-   * choosing subsamplingRate counts for each instance.
-   * Each subsamplingRate has the same number of instances as the original dataset,
-   * and is created by subsampling without replacement.
-   * @param input Input dataset.
-   * @param subsamplingRate Fraction of the training data used for learning decision tree.
-   * @param numSubsamples Number of subsamples of this RDD to take.
-   * @param withReplacement Sampling with/without replacement.
-   * @param seed Random seed.
-   * @return BaggedPoint dataset representation.
-   */
-  def convertToBaggedRDD[Datum] (
-      input: RDD[Datum],
-      subsamplingRate: Double,
-      numSubsamples: Int,
-      withReplacement: Boolean,
-      seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
-    if (withReplacement) {
-      convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
-    } else {
-      if (numSubsamples == 1 && subsamplingRate == 1.0) {
-        convertToBaggedRDDWithoutSampling(input)
-      } else {
-        convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
-      }
-    }
-  }
-
-  private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
-      input: RDD[Datum],
-      subsamplingRate: Double,
-      numSubsamples: Int,
-      seed: Long): RDD[BaggedPoint[Datum]] = {
-    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
-      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
-      val rng = new XORShiftRandom
-      rng.setSeed(seed + partitionIndex + 1)
-      instances.map { instance =>
-        val subsampleWeights = new Array[Double](numSubsamples)
-        var subsampleIndex = 0
-        while (subsampleIndex < numSubsamples) {
-          val x = rng.nextDouble()
-          subsampleWeights(subsampleIndex) = {
-            if (x < subsamplingRate) 1.0 else 0.0
-          }
-          subsampleIndex += 1
-        }
-        new BaggedPoint(instance, subsampleWeights)
-      }
-    }
-  }
-
-  private def convertToBaggedRDDSamplingWithReplacement[Datum] (
-      input: RDD[Datum],
-      subsample: Double,
-      numSubsamples: Int,
-      seed: Long): RDD[BaggedPoint[Datum]] = {
-    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
-      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
-      val poisson = new PoissonDistribution(subsample)
-      poisson.reseedRandomGenerator(seed + partitionIndex + 1)
-      instances.map { instance =>
-        val subsampleWeights = new Array[Double](numSubsamples)
-        var subsampleIndex = 0
-        while (subsampleIndex < numSubsamples) {
-          subsampleWeights(subsampleIndex) = poisson.sample()
-          subsampleIndex += 1
-        }
-        new BaggedPoint(instance, subsampleWeights)
-      }
-    }
-  }
-
-  private def convertToBaggedRDDWithoutSampling[Datum] (
-      input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
-    input.map(datum => new BaggedPoint(datum, Array(1.0)))
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
deleted file mode 100644
index c745e9f..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ /dev/null
@@ -1,178 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import org.apache.spark.mllib.tree.impurity._
-
-
-
-/**
- * DecisionTree statistics aggregator for a node.
- * This holds a flat array of statistics for a set of (features, bins)
- * and helps with indexing.
- * This class is abstract to support learning with and without feature subsampling.
- */
-private[spark] class DTStatsAggregator(
-    val metadata: DecisionTreeMetadata,
-    featureSubset: Option[Array[Int]]) extends Serializable {
-
-  /**
-   * [[ImpurityAggregator]] instance specifying the impurity type.
-   */
-  val impurityAggregator: ImpurityAggregator = metadata.impurity match {
-    case Gini => new GiniAggregator(metadata.numClasses)
-    case Entropy => new EntropyAggregator(metadata.numClasses)
-    case Variance => new VarianceAggregator()
-    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
-  }
-
-  /**
-   * Number of elements (Double values) used for the sufficient statistics of each bin.
-   */
-  private val statsSize: Int = impurityAggregator.statsSize
-
-  /**
-   * Number of bins for each feature.  This is indexed by the feature index.
-   */
-  private val numBins: Array[Int] = {
-    if (featureSubset.isDefined) {
-      featureSubset.get.map(metadata.numBins(_))
-    } else {
-      metadata.numBins
-    }
-  }
-
-  /**
-   * Offset for each feature for calculating indices into the [[allStats]] array.
-   */
-  private val featureOffsets: Array[Int] = {
-    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
-  }
-
-  /**
-   * Total number of elements stored in this aggregator
-   */
-  private val allStatsSize: Int = featureOffsets.last
-
-  /**
-   * Flat array of elements.
-   * Index for start of stats for a (feature, bin) is:
-   *   index = featureOffsets(featureIndex) + binIndex * statsSize
-   */
-  private val allStats: Array[Double] = new Array[Double](allStatsSize)
-
-  /**
-   * Array of parent node sufficient stats.
-   *
-   * Note: this is necessary because stats for the parent node are not available
-   *       on the first iteration of tree learning.
-   */
-  private val parentStats: Array[Double] = new Array[Double](statsSize)
-
-  /**
-   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
-   * @param featureOffset  This is a pre-computed (node, feature) offset
-   *                           from [[getFeatureOffset]].
-   */
-  def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
-    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
-  }
-
-  /**
-   * Get an [[ImpurityCalculator]] for the parent node.
-   */
-  def getParentImpurityCalculator(): ImpurityCalculator = {
-    impurityAggregator.getCalculator(parentStats, 0)
-  }
-
-  /**
-   * Update the stats for a given (feature, bin) for ordered features, using the given label.
-   */
-  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
-    val i = featureOffsets(featureIndex) + binIndex * statsSize
-    impurityAggregator.update(allStats, i, label, instanceWeight)
-  }
-
-  /**
-   * Update the parent node stats using the given label.
-   */
-  def updateParent(label: Double, instanceWeight: Double): Unit = {
-    impurityAggregator.update(parentStats, 0, label, instanceWeight)
-  }
-
-  /**
-   * Faster version of [[update]].
-   * Update the stats for a given (feature, bin), using the given label.
-   * @param featureOffset  This is a pre-computed feature offset
-   *                           from [[getFeatureOffset]].
-   */
-  def featureUpdate(
-      featureOffset: Int,
-      binIndex: Int,
-      label: Double,
-      instanceWeight: Double): Unit = {
-    impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
-      label, instanceWeight)
-  }
-
-  /**
-   * Pre-compute feature offset for use with [[featureUpdate]].
-   * For ordered features only.
-   */
-  def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
-
-  /**
-   * For a given feature, merge the stats for two bins.
-   * @param featureOffset  This is a pre-computed feature offset
-   *                           from [[getFeatureOffset]].
-   * @param binIndex  The other bin is merged into this bin.
-   * @param otherBinIndex  This bin is not modified.
-   */
-  def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
-    impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
-      featureOffset + otherBinIndex * statsSize)
-  }
-
-  /**
-   * Merge this aggregator with another, and returns this aggregator.
-   * This method modifies this aggregator in-place.
-   */
-  def merge(other: DTStatsAggregator): DTStatsAggregator = {
-    require(allStatsSize == other.allStatsSize,
-      s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
-        + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
-    var i = 0
-    // TODO: Test BLAS.axpy
-    while (i < allStatsSize) {
-      allStats(i) += other.allStats(i)
-      i += 1
-    }
-
-    require(statsSize == other.statsSize,
-      s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
-        s"stats vectors. This aggregator's parent stats are length $statsSize, " +
-        s"but the other is ${other.statsSize}.")
-    var j = 0
-    while (j < statsSize) {
-      parentStats(j) += other.parentStats(j)
-      j += 1
-    }
-
-    this
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
deleted file mode 100644
index 4f27dc4..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ /dev/null
@@ -1,217 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impurity.Impurity
-import org.apache.spark.rdd.RDD
-
-/**
- * Learning and dataset metadata for DecisionTree.
- *
- * @param numClasses    For classification: labels can take values {0, ..., numClasses - 1}.
- *                      For regression: fixed at 0 (no meaning).
- * @param maxBins  Maximum number of bins, for all features.
- * @param featureArity  Map: categorical feature index --> arity.
- *                      I.e., the feature takes values in {0, ..., arity - 1}.
- * @param numBins  Number of bins for each feature.
- */
-private[spark] class DecisionTreeMetadata(
-    val numFeatures: Int,
-    val numExamples: Long,
-    val numClasses: Int,
-    val maxBins: Int,
-    val featureArity: Map[Int, Int],
-    val unorderedFeatures: Set[Int],
-    val numBins: Array[Int],
-    val impurity: Impurity,
-    val quantileStrategy: QuantileStrategy,
-    val maxDepth: Int,
-    val minInstancesPerNode: Int,
-    val minInfoGain: Double,
-    val numTrees: Int,
-    val numFeaturesPerNode: Int) extends Serializable {
-
-  def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
-
-  def isClassification: Boolean = numClasses >= 2
-
-  def isMulticlass: Boolean = numClasses > 2
-
-  def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0)
-
-  def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex)
-
-  def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
-
-  /**
-   * Number of splits for the given feature.
-   * For unordered features, there is 1 bin per split.
-   * For ordered features, there is 1 more bin than split.
-   */
-  def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
-    numBins(featureIndex)
-  } else {
-    numBins(featureIndex) - 1
-  }
-
-
-  /**
-   * Set number of splits for a continuous feature.
-   * For a continuous feature, number of bins is number of splits plus 1.
-   */
-  def setNumSplits(featureIndex: Int, numSplits: Int) {
-    require(isContinuous(featureIndex),
-      s"Only number of bin for a continuous feature can be set.")
-    numBins(featureIndex) = numSplits + 1
-  }
-
-  /**
-   * Indicates if feature subsampling is being used.
-   */
-  def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode
-
-}
-
-private[spark] object DecisionTreeMetadata extends Logging {
-
-  /**
-   * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
-   * This computes which categorical features will be ordered vs. unordered,
-   * as well as the number of splits and bins for each feature.
-   */
-  def buildMetadata(
-      input: RDD[LabeledPoint],
-      strategy: Strategy,
-      numTrees: Int,
-      featureSubsetStrategy: String): DecisionTreeMetadata = {
-
-    val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
-      throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
-        s"but was given by empty one.")
-    }
-    val numExamples = input.count()
-    val numClasses = strategy.algo match {
-      case Classification => strategy.numClasses
-      case Regression => 0
-    }
-
-    val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
-    if (maxPossibleBins < strategy.maxBins) {
-      logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
-        s" (= number of training instances)")
-    }
-
-    // We check the number of bins here against maxPossibleBins.
-    // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
-    // based on the number of training examples.
-    if (strategy.categoricalFeaturesInfo.nonEmpty) {
-      val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
-      val maxCategory =
-        strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
-      require(maxCategoriesPerFeature <= maxPossibleBins,
-        s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
-        s"number of values in each categorical feature, but categorical feature $maxCategory " +
-        s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
-        "features with a large number of values, or add more training examples.")
-    }
-
-    val unorderedFeatures = new mutable.HashSet[Int]()
-    val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
-    if (numClasses > 2) {
-      // Multiclass classification
-      val maxCategoriesForUnorderedFeature =
-        ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
-      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
-        // Hack: If a categorical feature has only 1 category, we treat it as continuous.
-        // TODO(SPARK-9957): Handle this properly by filtering out those features.
-        if (numCategories > 1) {
-          // Decide if some categorical features should be treated as unordered features,
-          //  which require 2 * ((1 << numCategories - 1) - 1) bins.
-          // We do this check with log values to prevent overflows in case numCategories is large.
-          // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
-          if (numCategories <= maxCategoriesForUnorderedFeature) {
-            unorderedFeatures.add(featureIndex)
-            numBins(featureIndex) = numUnorderedBins(numCategories)
-          } else {
-            numBins(featureIndex) = numCategories
-          }
-        }
-      }
-    } else {
-      // Binary classification or regression
-      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
-        // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
-        if (numCategories > 1) {
-          numBins(featureIndex) = numCategories
-        }
-      }
-    }
-
-    // Set number of features to use per node (for random forests).
-    val _featureSubsetStrategy = featureSubsetStrategy match {
-      case "auto" =>
-        if (numTrees == 1) {
-          "all"
-        } else {
-          if (strategy.algo == Classification) {
-            "sqrt"
-          } else {
-            "onethird"
-          }
-        }
-      case _ => featureSubsetStrategy
-    }
-    val numFeaturesPerNode: Int = _featureSubsetStrategy match {
-      case "all" => numFeatures
-      case "sqrt" => math.sqrt(numFeatures).ceil.toInt
-      case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
-      case "onethird" => (numFeatures / 3.0).ceil.toInt
-    }
-
-    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
-      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
-      strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
-      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
-  }
-
-  /**
-   * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
-   */
-  def buildMetadata(
-      input: RDD[LabeledPoint],
-      strategy: Strategy): DecisionTreeMetadata = {
-    buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
-  }
-
-    /**
-   * Given the arity of a categorical feature (arity = number of categories),
-   * return the number of bins for the feature if it is to be treated as an unordered feature.
-   * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
-   * there are math.pow(2, arity - 1) - 1 such splits.
-   * Each split has 2 corresponding bins.
-   */
-  def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
deleted file mode 100644
index dc7e969..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: DeveloperApi ::
- * This is used by the node id cache to find the child id that a data point would belong to.
- * @param split Split information.
- * @param nodeIndex The current node index of a data point that this will update.
- */
-@DeveloperApi
-private[tree] case class NodeIndexUpdater(
-    split: Split,
-    nodeIndex: Int) {
-  /**
-   * Determine a child node index based on the feature value and the split.
-   * @param binnedFeatures Binned feature values.
-   * @param bins Bin information to convert the bin indices to approximate feature values.
-   * @return Child node index to update to.
-   */
-  def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
-    if (split.featureType == Continuous) {
-      val featureIndex = split.feature
-      val binIndex = binnedFeatures(featureIndex)
-      val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
-      if (featureValueUpperBound <= split.threshold) {
-        Node.leftChildIndex(nodeIndex)
-      } else {
-        Node.rightChildIndex(nodeIndex)
-      }
-    } else {
-      if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
-        Node.leftChildIndex(nodeIndex)
-      } else {
-        Node.rightChildIndex(nodeIndex)
-      }
-    }
-  }
-}
-
-/**
- * :: DeveloperApi ::
- * A given TreePoint would belong to a particular node per tree.
- * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
- * in each tree. Initially, values should all be 1 for root node.
- * The nodeIdsForInstances RDD needs to be updated at each iteration.
- * @param nodeIdsForInstances The initial values in the cache
- *            (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointInterval The checkpointing interval
- *                           (how often should the cache be checkpointed.).
- */
-@DeveloperApi
-private[spark] class NodeIdCache(
-  var nodeIdsForInstances: RDD[Array[Int]],
-  val checkpointInterval: Int) {
-
-  // Keep a reference to a previous node Ids for instances.
-  // Because we will keep on re-persisting updated node Ids,
-  // we want to unpersist the previous RDD.
-  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
-
-  // To keep track of the past checkpointed RDDs.
-  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
-  private var rddUpdateCount = 0
-
-  /**
-   * Update the node index values in the cache.
-   * This updates the RDD and its lineage.
-   * TODO: Passing bin information to executors seems unnecessary and costly.
-   * @param data The RDD of training rows.
-   * @param nodeIdUpdaters A map of node index updaters.
-   *                       The key is the indices of nodes that we want to update.
-   * @param bins Bin information needed to find child node indices.
-   */
-  def updateNodeIndices(
-      data: RDD[BaggedPoint[TreePoint]],
-      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
-      bins: Array[Array[Bin]]): Unit = {
-    if (prevNodeIdsForInstances != null) {
-      // Unpersist the previous one if one exists.
-      prevNodeIdsForInstances.unpersist()
-    }
-
-    prevNodeIdsForInstances = nodeIdsForInstances
-    nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
-      case (point, node) => {
-        var treeId = 0
-        while (treeId < nodeIdUpdaters.length) {
-          val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
-          if (nodeIdUpdater != null) {
-            val newNodeIndex = nodeIdUpdater.updateNodeIndex(
-              binnedFeatures = point.datum.binnedFeatures,
-              bins = bins)
-            node(treeId) = newNodeIndex
-          }
-
-          treeId += 1
-        }
-
-        node
-      }
-    }
-
-    // Keep on persisting new ones.
-    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
-    rddUpdateCount += 1
-
-    // Handle checkpointing if the directory is not None.
-    if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
-      (rddUpdateCount % checkpointInterval) == 0) {
-      // Let's see if we can delete previous checkpoints.
-      var canDelete = true
-      while (checkpointQueue.size > 1 && canDelete) {
-        // We can delete the oldest checkpoint iff
-        // the next checkpoint actually exists in the file system.
-        if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
-          val old = checkpointQueue.dequeue()
-
-          // Since the old checkpoint is not deleted by Spark,
-          // we'll manually delete it here.
-          val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
-          fs.delete(new Path(old.getCheckpointFile.get), true)
-        } else {
-          canDelete = false
-        }
-      }
-
-      nodeIdsForInstances.checkpoint()
-      checkpointQueue.enqueue(nodeIdsForInstances)
-    }
-  }
-
-  /**
-   * Call this after training is finished to delete any remaining checkpoints.
-   */
-  def deleteAllCheckpoints(): Unit = {
-    while (checkpointQueue.nonEmpty) {
-      val old = checkpointQueue.dequeue()
-      for (checkpointFile <- old.getCheckpointFile) {
-        val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
-        fs.delete(new Path(checkpointFile), true)
-      }
-    }
-    if (prevNodeIdsForInstances != null) {
-      // Unpersist the previous one if one exists.
-      prevNodeIdsForInstances.unpersist()
-    }
-  }
-}
-
-private[spark] object NodeIdCache {
-  /**
-   * Initialize the node Id cache with initial node Id values.
-   * @param data The RDD of training rows.
-   * @param numTrees The number of trees that we want to create cache for.
-   * @param checkpointInterval The checkpointing interval
-   *                           (how often should the cache be checkpointed.).
-   * @param initVal The initial values in the cache.
-   * @return A node Id cache containing an RDD of initial root node Indices.
-   */
-  def init(
-      data: RDD[BaggedPoint[TreePoint]],
-      numTrees: Int,
-      checkpointInterval: Int,
-      initVal: Int = 1): NodeIdCache = {
-    new NodeIdCache(
-      data.map(_ => Array.fill[Int](numTrees)(initVal)),
-      checkpointInterval)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
deleted file mode 100644
index 70afaa1..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable.{HashMap => MutableHashMap}
-
-/**
- * Time tracker implementation which holds labeled timers.
- */
-private[spark] class TimeTracker extends Serializable {
-
-  private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
-
-  private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
-
-  /**
-   * Starts a new timer, or re-starts a stopped timer.
-   */
-  def start(timerLabel: String): Unit = {
-    val currentTime = System.nanoTime()
-    if (starts.contains(timerLabel)) {
-      throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
-        s" timerLabel = $timerLabel before that timer was stopped.")
-    }
-    starts(timerLabel) = currentTime
-  }
-
-  /**
-   * Stops a timer and returns the elapsed time in seconds.
-   */
-  def stop(timerLabel: String): Double = {
-    val currentTime = System.nanoTime()
-    if (!starts.contains(timerLabel)) {
-      throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
-        s" timerLabel = $timerLabel, but that timer was not started.")
-    }
-    val elapsed = currentTime - starts(timerLabel)
-    starts.remove(timerLabel)
-    if (totals.contains(timerLabel)) {
-      totals(timerLabel) += elapsed
-    } else {
-      totals(timerLabel) = elapsed
-    }
-    elapsed / 1e9
-  }
-
-  /**
-   * Print all timing results in seconds.
-   */
-  override def toString: String = {
-    totals.map { case (label, elapsed) =>
-        s"  $label: ${elapsed / 1e9}"
-      }.mkString("\n")
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
deleted file mode 100644
index 21919d6..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.Bin
-import org.apache.spark.rdd.RDD
-
-
-/**
- * Internal representation of LabeledPoint for DecisionTree.
- * This bins feature values based on a subsampled of data as follows:
- *  (a) Continuous features are binned into ranges.
- *  (b) Unordered categorical features are binned based on subsets of feature values.
- *      "Unordered categorical features" are categorical features with low arity used in
- *      multiclass classification.
- *  (c) Ordered categorical features are binned based on feature values.
- *      "Ordered categorical features" are categorical features with high arity,
- *      or any categorical feature used in regression or binary classification.
- *
- * @param label  Label from LabeledPoint
- * @param binnedFeatures  Binned feature values.
- *                        Same length as LabeledPoint.features, but values are bin indices.
- */
-private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
-  extends Serializable {
-}
-
-private[spark] object TreePoint {
-
-  /**
-   * Convert an input dataset into its TreePoint representation,
-   * binning feature values in preparation for DecisionTree training.
-   * @param input     Input dataset.
-   * @param bins      Bins for features, of size (numFeatures, numBins).
-   * @param metadata  Learning and dataset metadata
-   * @return  TreePoint dataset representation
-   */
-  def convertToTreeRDD(
-      input: RDD[LabeledPoint],
-      bins: Array[Array[Bin]],
-      metadata: DecisionTreeMetadata): RDD[TreePoint] = {
-    // Construct arrays for featureArity for efficiency in the inner loop.
-    val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
-    var featureIndex = 0
-    while (featureIndex < metadata.numFeatures) {
-      featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
-      featureIndex += 1
-    }
-    input.map { x =>
-      TreePoint.labeledPointToTreePoint(x, bins, featureArity)
-    }
-  }
-
-  /**
-   * Convert one LabeledPoint into its TreePoint representation.
-   * @param bins      Bins for features, of size (numFeatures, numBins).
-   * @param featureArity  Array indexed by feature, with value 0 for continuous and numCategories
-   *                      for categorical features.
-   */
-  private def labeledPointToTreePoint(
-      labeledPoint: LabeledPoint,
-      bins: Array[Array[Bin]],
-      featureArity: Array[Int]): TreePoint = {
-    val numFeatures = labeledPoint.features.size
-    val arr = new Array[Int](numFeatures)
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
-      arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
-        bins)
-      featureIndex += 1
-    }
-    new TreePoint(labeledPoint.label, arr)
-  }
-
-  /**
-   * Find bin for one (labeledPoint, feature).
-   *
-   * @param featureArity  0 for continuous features; number of categories for categorical features.
-   * @param bins   Bins for features, of size (numFeatures, numBins).
-   */
-  private def findBin(
-      featureIndex: Int,
-      labeledPoint: LabeledPoint,
-      featureArity: Int,
-      bins: Array[Array[Bin]]): Int = {
-
-    /**
-     * Binary search helper method for continuous feature.
-     */
-    def binarySearchForBins(): Int = {
-      val binForFeatures = bins(featureIndex)
-      val feature = labeledPoint.features(featureIndex)
-      var left = 0
-      var right = binForFeatures.length - 1
-      while (left <= right) {
-        val mid = left + (right - left) / 2
-        val bin = binForFeatures(mid)
-        val lowThreshold = bin.lowSplit.threshold
-        val highThreshold = bin.highSplit.threshold
-        if ((lowThreshold < feature) && (highThreshold >= feature)) {
-          return mid
-        } else if (lowThreshold >= feature) {
-          right = mid - 1
-        } else {
-          left = mid + 1
-        }
-      }
-      -1
-    }
-
-    if (featureArity == 0) {
-      // Perform binary search for finding bin for continuous features.
-      val binIndex = binarySearchForBins()
-      if (binIndex == -1) {
-        throw new RuntimeException("No bin was found for continuous feature." +
-          " This error can occur when given invalid data values (such as NaN)." +
-          s" Feature index: $featureIndex.  Feature value: ${labeledPoint.features(featureIndex)}")
-      }
-      binIndex
-    } else {
-      // Categorical feature bins are indexed by feature values.
-      val featureValue = labeledPoint.features(featureIndex)
-      if (featureValue < 0 || featureValue >= featureArity) {
-        throw new IllegalArgumentException(
-          s"DecisionTree given invalid data:" +
-            s" Feature $featureIndex is categorical with values in" +
-            s" {0,...,${featureArity - 1}," +
-            s" but a data point gives it value $featureValue.\n" +
-            "  Bad data point: " + labeledPoint.toString)
-      }
-      featureValue.toInt
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 13aff11..ff7700d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -85,7 +85,7 @@ object Entropy extends Impurity {
  * Note: Instances of this class do not hold the data; they operate on views of the data.
  * @param numClasses  Number of classes for label.
  */
-private[tree] class EntropyAggregator(numClasses: Int)
+private[spark] class EntropyAggregator(numClasses: Int)
   extends ImpurityAggregator(numClasses) with Serializable {
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 39c7f9c..58dc79b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -81,7 +81,7 @@ object Gini extends Impurity {
  * Note: Instances of this class do not hold the data; they operate on views of the data.
  * @param numClasses  Number of classes for label.
  */
-private[tree] class GiniAggregator(numClasses: Int)
+private[spark] class GiniAggregator(numClasses: Int)
   extends ImpurityAggregator(numClasses) with Serializable {
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 92d74a1..2423516 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -71,7 +71,7 @@ object Variance extends Impurity {
  * in order to compute impurity from a sample.
  * Note: Instances of this class do not hold the data; they operate on views of the data.
  */
-private[tree] class VarianceAggregator()
+private[spark] class VarianceAggregator()
   extends ImpurityAggregator(statsSize = 3) with Serializable {
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
deleted file mode 100644
index 0cad473..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-
-/**
- * Used for "binning" the feature values for faster best split calculation.
- *
- * For a continuous feature, the bin is determined by a low and a high split,
- *  where an example with featureValue falls into the bin s.t.
- *  lowSplit.threshold < featureValue <= highSplit.threshold.
- *
- * For ordered categorical features, there is a 1-1-1 correspondence between
- *  bins, splits, and feature values.  The bin is determined by category/feature value.
- *  However, the bins are not necessarily ordered by feature value;
- *  they are ordered using impurity.
- *
- * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
- *  where bins and splits correspond to subsets of feature values (in highSplit.categories).
- *  An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
- *  partitionings of categories into 2 disjoint, non-empty sets.
- *
- * @param lowSplit signifying the lower threshold for the continuous feature to be
- *                 accepted in the bin
- * @param highSplit signifying the upper threshold for the continuous feature to be
- *                  accepted in the bin
- * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for ordered features
- */
-private[tree]
-case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
new file mode 100644
index 0000000..77ab3d8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.tree.EnsembleTestHelper
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suite for [[BaggedPoint]].
+ */
+class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext  {
+
+  test("BaggedPoint RDD: without subsampling") {
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val rdd = sc.parallelize(arr)
+    val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+    baggedRDD.collect().foreach { baggedPoint =>
+      assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+    }
+  }
+
+  test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") {
+    val numSubsamples = 100
+    val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+    val seeds = Array(123, 5354, 230, 349867, 23987)
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val rdd = sc.parallelize(arr)
+    seeds.foreach { seed =>
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
+      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+        expectedStddev, epsilon = 0.01)
+    }
+  }
+
+  test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") {
+    val numSubsamples = 100
+    val subsample = 0.5
+    val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample))
+
+    val seeds = Array(123, 5354, 230, 349867, 23987)
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val rdd = sc.parallelize(arr)
+    seeds.foreach { seed =>
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
+      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+        expectedStddev, epsilon = 0.01)
+    }
+  }
+
+  test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") {
+    val numSubsamples = 100
+    val (expectedMean, expectedStddev) = (1.0, 0)
+
+    val seeds = Array(123, 5354, 230, 349867, 23987)
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val rdd = sc.parallelize(arr)
+    seeds.foreach { seed =>
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
+      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+        expectedStddev, epsilon = 0.01)
+    }
+  }
+
+  test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") {
+    val numSubsamples = 100
+    val subsample = 0.5
+    val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample)))
+
+    val seeds = Array(123, 5354, 230, 349867, 23987)
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+    val rdd = sc.parallelize(arr)
+    seeds.foreach { seed =>
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
+      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+      EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+        expectedStddev, epsilon = 0.01)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 441338e..e64551f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc35e6f/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bb1041b..49cb7e1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree
 import scala.collection.JavaConverters._
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.tree.impl.DecisionTreeMetadata
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.mllib.util.MLlibTestSparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message