spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject [1/2] [SPARK-1545] [mllib] Add Random Forests
Date Mon, 29 Sep 2014 04:44:57 GMT
Repository: spark
Updated Branches:
  refs/heads/master f350cd307 -> 0dc2b6361


http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index b6d49e5..212dce2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -48,7 +48,9 @@ private[tree] class DecisionTreeMetadata(
     val quantileStrategy: QuantileStrategy,
     val maxDepth: Int,
     val minInstancesPerNode: Int,
-    val minInfoGain: Double) extends Serializable {
+    val minInfoGain: Double,
+    val numTrees: Int,
+    val numFeaturesPerNode: Int) extends Serializable {
 
   def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
 
@@ -73,6 +75,11 @@ private[tree] class DecisionTreeMetadata(
     numBins(featureIndex) - 1
   }
 
+  /**
+   * Indicates if feature subsampling is being used.
+   */
+  def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode
+
 }
 
 private[tree] object DecisionTreeMetadata {
@@ -82,7 +89,11 @@ private[tree] object DecisionTreeMetadata {
    * This computes which categorical features will be ordered vs. unordered,
    * as well as the number of splits and bins for each feature.
    */
-  def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata =
{
+  def buildMetadata(
+      input: RDD[LabeledPoint],
+      strategy: Strategy,
+      numTrees: Int,
+      featureSubsetStrategy: String): DecisionTreeMetadata = {
 
     val numFeatures = input.take(1)(0).features.size
     val numExamples = input.count()
@@ -128,13 +139,43 @@ private[tree] object DecisionTreeMetadata {
       }
     }
 
+    // Set number of features to use per node (for random forests).
+    val _featureSubsetStrategy = featureSubsetStrategy match {
+      case "auto" =>
+        if (numTrees == 1) {
+          "all"
+        } else {
+          if (strategy.algo == Classification) {
+            "sqrt"
+          } else {
+            "onethird"
+          }
+        }
+      case _ => featureSubsetStrategy
+    }
+    val numFeaturesPerNode: Int = _featureSubsetStrategy match {
+      case "all" => numFeatures
+      case "sqrt" => math.sqrt(numFeatures).ceil.toInt
+      case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
+      case "onethird" => (numFeatures / 3.0).ceil.toInt
+    }
+
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
       strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
-      strategy.minInstancesPerNode, strategy.minInfoGain)
+      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
   }
 
   /**
+   * Version of [[buildMetadata()]] for DecisionTree.
+   */
+  def buildMetadata(
+      input: RDD[LabeledPoint],
+      strategy: Strategy): DecisionTreeMetadata = {
+    buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
+  }
+
+    /**
    * Given the arity of a categorical feature (arity = number of categories),
    * return the number of bins for the feature if it is to be treated as an unordered feature.
    * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 1c8afc2..0e02345 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -89,12 +89,12 @@ private[tree] class EntropyAggregator(numClasses: Int)
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double):
Unit = {
     if (label >= statsSize) {
       throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
         s" but requires label < numClasses (= $statsSize).")
     }
-    allStats(offset + label.toInt) += 1
+    allStats(offset + label.toInt) += instanceWeight
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 5cfdf34..7c83cd4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -85,12 +85,12 @@ private[tree] class GiniAggregator(numClasses: Int)
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double):
Unit = {
     if (label >= statsSize) {
       throw new IllegalArgumentException(s"GiniAggregator given label $label" +
         s" but requires label < numClasses (= $statsSize).")
     }
-    allStats(offset + label.toInt) += 1
+    allStats(offset + label.toInt) += instanceWeight
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 5a047d6..60e2ab2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -78,7 +78,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends
Seri
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double):
Unit
 
   /**
    * Get an [[ImpurityCalculator]] for a (node, feature, bin).

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index e9ccecb..df9eafa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -75,10 +75,10 @@ private[tree] class VarianceAggregator()
    * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
    * @param offset    Start index of stats for this (node, feature, bin).
    */
-  def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
-    allStats(offset) += 1
-    allStats(offset + 1) += label
-    allStats(offset + 2) += label * label
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double):
Unit = {
+    allStats(offset) += instanceWeight
+    allStats(offset + 1) += instanceWeight * label
+    allStats(offset + 2) += instanceWeight * label * label
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 5f0095d..56c3e25 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -41,12 +41,12 @@ import org.apache.spark.mllib.linalg.Vector
 @DeveloperApi
 class Node (
     val id: Int,
-    val predict: Double,
-    val isLeaf: Boolean,
-    val split: Option[Split],
+    var predict: Double,
+    var isLeaf: Boolean,
+    var split: Option[Split],
     var leftNode: Option[Node],
     var rightNode: Option[Node],
-    val stats: Option[InformationGainStats]) extends Serializable with Logging {
+    var stats: Option[InformationGainStats]) extends Serializable with Logging {
 
   override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict
+ ", " +
     "split = " + split + ", stats = " + stats
@@ -168,6 +168,11 @@ class Node (
 private[tree] object Node {
 
   /**
+   * Return a node with the given node id (but nothing else set).
+   */
+  def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+
+  /**
    * Return the index of the left child of this node.
    */
   def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
new file mode 100644
index 0000000..538c0e2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Random forest model for classification or regression.
+ * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
+ * aggregate predictions.
+ * @param trees Trees which make up this forest.  This cannot be empty.
+ * @param algo algorithm type -- classification or regression
+ */
+@Experimental
+class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable
{
+
+  require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
+
+  /**
+   * Predict values for a single data point.
+   *
+   * @param features array representing a single data point
+   * @return Double prediction from the trained model
+   */
+  def predict(features: Vector): Double = {
+    algo match {
+      case Classification =>
+        val predictionToCount = new mutable.HashMap[Int, Int]()
+        trees.foreach { tree =>
+          val prediction = tree.predict(features).toInt
+          predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+        }
+        predictionToCount.maxBy(_._2)._1
+      case Regression =>
+        trees.map(_.predict(features)).sum / trees.size
+    }
+  }
+
+  /**
+   * Predict values for the given data set.
+   *
+   * @param features RDD representing data points to be predicted
+   * @return RDD[Double] where each entry contains the corresponding prediction
+   */
+  def predict(features: RDD[Vector]): RDD[Double] = {
+    features.map(x => predict(x))
+  }
+
+  /**
+   * Get number of trees in forest.
+   */
+  def numTrees: Int = trees.size
+
+  /**
+   * Print full model.
+   */
+  override def toString: String = {
+    val header = algo match {
+      case Classification =>
+        s"RandomForestModel classifier with $numTrees trees\n"
+      case Regression =>
+        s"RandomForestModel regressor with $numTrees trees\n"
+      case _ => throw new IllegalArgumentException(
+        s"RandomForestModel given unknown algo parameter: $algo.")
+    }
+    header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+      s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+    }.fold("")(_ + _)
+  }
+
+}
+
+private[tree] object RandomForestModel {
+
+  def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
+    require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
+    val algo: Algo = trees(0).algo
+    require(trees.forall(_.algo == algo),
+      "RandomForestModel cannot combine trees which have different output types" +
+      " (classification/regression).")
+    new RandomForestModel(trees, algo)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2b2e579..a48ed71 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.scalatest.FunSuite
 
@@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
 import org.apache.spark.mllib.util.LocalSparkContext
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
-  def validateClassifier(
-      model: DecisionTreeModel,
-      input: Seq[LabeledPoint],
-      requiredAccuracy: Double) {
-    val predictions = input.map(x => model.predict(x.features))
-    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
-      prediction != expected.label
-    }
-    val accuracy = (input.length - numOffPredictions).toDouble / input.length
-    assert(accuracy >= requiredAccuracy,
-      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
-  }
-
-  def validateRegressor(
-      model: DecisionTreeModel,
-      input: Seq[LabeledPoint],
-      requiredMSE: Double) {
-    val predictions = input.map(x => model.predict(x.features))
-    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
-      val err = prediction - expected.label
-      err * err
-    }.sum
-    val mse = squaredError / input.length
-    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
-  }
-
   test("Binary classification with continuous features: split and bin calculation") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
@@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 100,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
-    // 2^10 - 1 > 100, so categorical features will be ordered
+    // 2^(10-1) - 1 > 100, so categorical features will be ordered
 
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
     assert(!metadata.isUnordered(featureIndex = 0))
@@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(splits(0).length === 0)
     assert(bins(0).length === 0)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode: Node, doneTraining: Boolean) =
-      DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.categories === List(1.0))
@@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.categories.length === 1)
@@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(!metadata.isUnordered(featureIndex = 1))
 
     val model = DecisionTree.train(rdd, strategy)
-    validateRegressor(model, arr, 0.0)
+    DecisionTreeSuite.validateRegressor(model, arr, 0.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
   }
@@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins.length === 2)
     assert(bins(0).length === 100)
 
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
-
-    val split = rootNode.split.get
-    assert(split.feature === 0)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val stats = rootNode.stats.get
     assert(stats.gain === 0)
@@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
       numClassesForClassification = 2, maxBins = 100)
     val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
-    val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
+    val rootNode1 = modelOneNode.topNode.deepCopy()
+    val rootNode2 = modelOneNode.topNode.deepCopy()
+    assert(rootNode1.leftNode.nonEmpty)
+    assert(rootNode1.rightNode.nonEmpty)
 
-    // Single group second level tree construction.
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
-      rootNodeCopy1, splits, bins, 10)
-    assert(rootNode.leftNode.nonEmpty)
-    assert(rootNode.rightNode.nonEmpty)
+    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+    // Single group second level tree construction.
+    val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+    val treeToNodeToIndexInfo = Map((0, Map(
+      (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+      (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+    val nodeQueue = new mutable.Queue[(Int, Node)]()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
     val children1 = new Array[Node](2)
-    children1(0) = rootNode.leftNode.get
-    children1(1) = rootNode.rightNode.get
-
-    // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
-    // level tree construction.
-    val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
-      rootNodeCopy2, splits, bins, 0)
-    assert(rootNode2.leftNode.nonEmpty)
-    assert(rootNode2.rightNode.nonEmpty)
+    children1(0) = rootNode1.leftNode.get
+    children1(1) = rootNode1.rightNode.get
+
+    // Train one second-level node at a time.
+    val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+    val treeToNodeToIndexInfoA = Map((0, Map(
+      (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+    nodeQueue.clear()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+      nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+    val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+    val treeToNodeToIndexInfoB = Map((0, Map(
+      (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+    nodeQueue.clear()
+    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+      nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
     val children2 = new Array[Node](2)
     children2(0) = rootNode2.leftNode.get
     children2(1) = rootNode2.rightNode.get
@@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(metadata.isUnordered(featureIndex = 0))
     assert(metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.feature === 0)
@@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 1.0)
+    DecisionTreeSuite.validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
   }
@@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       numClassesForClassification = 2)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 1.0)
+    DecisionTreeSuite.validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
     assert(model.topNode.split.get.feature === 1)
@@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(metadata.isUnordered(featureIndex = 1))
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 1.0)
+    DecisionTreeSuite.validateClassifier(model, arr, 1.0)
     assert(model.numNodes === 3)
     assert(model.depth === 1)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = model.topNode
 
     val split = rootNode.split.get
     assert(split.feature === 0)
@@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 0.9)
+    DecisionTreeSuite.validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = model.topNode
 
     val split = rootNode.split.get
     assert(split.feature === 1)
@@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(metadata.isUnordered(featureIndex = 0))
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 0.9)
+    DecisionTreeSuite.validateClassifier(model, arr, 0.9)
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = model.topNode
 
     val split = rootNode.split.get
     assert(split.feature === 1)
@@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(!metadata.isUnordered(featureIndex = 0))
     assert(!metadata.isUnordered(featureIndex = 1))
 
-    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     assert(split.feature === 0)
@@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(strategy.isMulticlassClassification)
 
     val model = DecisionTree.train(rdd, strategy)
-    validateClassifier(model, arr, 0.6)
+    DecisionTreeSuite.validateClassifier(model, arr, 0.6)
   }
 
   test("split must satisfy min instances per node requirements") {
@@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
     arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
 
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini,
       maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
 
-    val model = DecisionTree.train(input, strategy)
+    val model = DecisionTree.train(rdd, strategy)
     assert(model.topNode.isLeaf)
     assert(model.topNode.predict == 0.0)
-    val predicts = input.map(p => model.predict(p.features)).collect()
+    val predicts = rdd.map(p => model.predict(p.features)).collect()
     predicts.foreach { predict =>
       assert(predict == 0.0)
     }
 
-    // test for findBestSplits when no valid split can be found
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    // test when no valid split can be found
+    val rootNode = model.topNode
 
     val gain = rootNode.stats.get
     assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
     arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
 
-    val input = sc.parallelize(arr)
+    val rdd = sc.parallelize(arr)
     val strategy = new Strategy(algo = Classification, impurity = Gini,
       maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
       numClassesForClassification = 2, minInstancesPerNode = 2)
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+
+    val rootNode = DecisionTree.train(rdd, strategy).topNode
 
     val split = rootNode.split.get
     val gain = rootNode.stats.get
@@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(predict == 0.0)
     }
 
-    // test for findBestSplits when no valid split can be found
-    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
-    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
-      null, splits, bins, 10)
+    // test when no valid split can be found
+    val rootNode = model.topNode
 
     val gain = rootNode.stats.get
     assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
 object DecisionTreeSuite {
 
+  def validateClassifier(
+      model: DecisionTreeModel,
+      input: Seq[LabeledPoint],
+      requiredAccuracy: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+      prediction != expected.label
+    }
+    val accuracy = (input.length - numOffPredictions).toDouble / input.length
+    assert(accuracy >= requiredAccuracy,
+      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+  }
+
+  def validateRegressor(
+      model: DecisionTreeModel,
+      input: Seq[LabeledPoint],
+      requiredMSE: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+      val err = prediction - expected.label
+      err * err
+    }.sum
+    val mse = squaredError / input.length
+    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+  }
+
   def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](1000)
     for (i <- 0 until 1000) {

http://git-wip-us.apache.org/repos/asf/spark/blob/0dc2b636/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
new file mode 100644
index 0000000..30669fc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.util.StatCounter
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends FunSuite with LocalSparkContext {
+
+  test("BaggedPoint RDD: without subsampling") {
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+    val rdd = sc.parallelize(arr)
+    val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
+    baggedRDD.collect().foreach { baggedPoint =>
+      assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0)
== 1)
+    }
+  }
+
+  test("BaggedPoint RDD: with subsampling") {
+    val numSubsamples = 100
+    val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+    val seeds = Array(123, 5354, 230, 349867, 23987)
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+    val rdd = sc.parallelize(arr)
+    seeds.foreach { seed =>
+      val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
+      val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+      RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+        expectedStddev, epsilon = 0.01)
+    }
+  }
+
+  test("Binary classification with continuous features:" +
+      " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+    val rdd = sc.parallelize(arr)
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val numTrees = 1
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+    val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
+      featureSubsetStrategy = "auto", seed = 123)
+    assert(rf.trees.size === 1)
+    val rfTree = rf.trees(0)
+
+    val dt = DecisionTree.train(rdd, strategy)
+
+    RandomForestSuite.validateClassifier(rf, arr, 0.9)
+    DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
+
+    // Make sure trees are the same.
+    assert(rfTree.toString == dt.toString)
+  }
+
+  test("Regression with continuous features:" +
+    " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+    val rdd = sc.parallelize(arr)
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+    val numTrees = 1
+
+    val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+    val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
+      featureSubsetStrategy = "auto", seed = 123)
+    assert(rf.trees.size === 1)
+    val rfTree = rf.trees(0)
+
+    val dt = DecisionTree.train(rdd, strategy)
+
+    RandomForestSuite.validateRegressor(rf, arr, 0.01)
+    DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
+
+    // Make sure trees are the same.
+    assert(rfTree.toString == dt.toString)
+  }
+
+  test("Binary classification with continuous features: subsampling features") {
+    val numFeatures = 50
+    val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
+    val rdd = sc.parallelize(arr)
+    val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+    // Select feature subset for top nodes.  Return true if OK.
+    def checkFeatureSubsetStrategy(
+        numTrees: Int,
+        featureSubsetStrategy: String,
+        numFeaturesPerNode: Int): Unit = {
+      val seeds = Array(123, 5354, 230, 349867, 23987)
+      val maxMemoryUsage: Long = 128 * 1024L * 1024L
+      val metadata =
+        DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
+      seeds.foreach { seed =>
+        val failString = s"Failed on test with:" +
+          s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
+          s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+        val nodeQueue = new mutable.Queue[(Int, Node)]()
+        val topNodes: Array[Node] = new Array[Node](numTrees)
+        Range(0, numTrees).foreach { treeIndex =>
+          topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
+          nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+        }
+        val rng = new scala.util.Random(seed = seed)
+        val (nodesForGroup: Map[Int, Array[Node]],
+            treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
+          RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+
+        assert(nodesForGroup.size === numTrees, failString)
+        assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+        if (numFeaturesPerNode == numFeatures) {
+          // featureSubset values should all be None
+          assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+            failString)
+        } else {
+          // Check number of features.
+          assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+            _.featureSubset.get.size === numFeaturesPerNode)), failString)
+        }
+      }
+    }
+
+    checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 1, "log2",
+      (math.log(numFeatures) / math.log(2)).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+    checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+    checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "log2",
+      (math.log(numFeatures) / math.log(2)).ceil.toInt)
+    checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+  }
+
+}
+
+object RandomForestSuite {
+
+  /**
+   * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+   * epsilon of the expected values.
+   * @param data  Every element of the data should be an i.i.d. sample from some distribution.
+   */
+  def testRandomArrays(
+      data: Array[Array[Double]],
+      numCols: Int,
+      expectedMean: Double,
+      expectedStddev: Double,
+      epsilon: Double) {
+    val values = new mutable.ArrayBuffer[Double]()
+    data.foreach { row =>
+      assert(row.size == numCols)
+      values ++= row
+    }
+    val stats = new StatCounter(values)
+    assert(math.abs(stats.mean - expectedMean) < epsilon)
+    assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+  }
+
+  def validateClassifier(
+      model: RandomForestModel,
+      input: Seq[LabeledPoint],
+      requiredAccuracy: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+      prediction != expected.label
+    }
+    val accuracy = (input.length - numOffPredictions).toDouble / input.length
+    assert(accuracy >= requiredAccuracy,
+      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+  }
+
+  def validateRegressor(
+      model: RandomForestModel,
+      input: Seq[LabeledPoint],
+      requiredMSE: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+      val err = prediction - expected.label
+      err * err
+    }.sum
+    val mse = squaredError / input.length
+    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+  }
+
+  def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
+    val numInstances = 1000
+    val arr = new Array[LabeledPoint](numInstances)
+    for (i <- 0 until numInstances) {
+      val label = if (i < numInstances / 10) {
+        0.0
+      } else if (i < numInstances / 2) {
+        1.0
+      } else if (i < numInstances * 0.9) {
+        0.0
+      } else {
+        1.0
+      }
+      val features = Array.fill[Double](numFeatures)(i.toDouble)
+      arr(i) = new LabeledPoint(label, Vectors.dense(features))
+    }
+    arr
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message