spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mengxr <...@git.apache.org>
Subject [GitHub] spark pull request: [SPARK-6113] [ml] Stabilize DecisionTree API
Date Thu, 16 Apr 2015 06:39:19 GMT
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5530#discussion_r28486987
  
    --- Diff: mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala ---
    @@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext
{
         assert(bins(0).length === 0)
       }
     
    +  test("Avoid aggregation on the last level") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
    +    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
    +    val input = sc.parallelize(arr)
    +
    +    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
    +      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    val topNode = Node.emptyNode(nodeIndex = 1)
    +    assert(topNode.predict.predict === Double.MinValue)
    +    assert(topNode.impurity === -1.0)
    +    assert(topNode.isLeaf === false)
    +
    +    val nodesForGroup = Map((0, Array(topNode)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
    +    )))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +
    +    // don't enqueue leaf nodes into node queue
    +    assert(nodeQueue.isEmpty)
    +
    +    // set impurity and predict for topNode
    +    assert(topNode.predict.predict !== Double.MinValue)
    +    assert(topNode.impurity !== -1.0)
    +
    +    // set impurity and predict for child nodes
    +    assert(topNode.leftNode.get.predict.predict === 0.0)
    +    assert(topNode.rightNode.get.predict.predict === 1.0)
    +    assert(topNode.leftNode.get.impurity === 0.0)
    +    assert(topNode.rightNode.get.impurity === 0.0)
    +  }
    +
    +  test("Avoid aggregation if impurity is 0.0") {
    +    val arr = new Array[LabeledPoint](4)
    +    arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
    +    arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
    +    arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
    +    arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
    +    val input = sc.parallelize(arr)
    +
    +    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
    +      numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    val topNode = Node.emptyNode(nodeIndex = 1)
    +    assert(topNode.predict.predict === Double.MinValue)
    +    assert(topNode.impurity === -1.0)
    +    assert(topNode.isLeaf === false)
    +
    +    val nodesForGroup = Map((0, Array(topNode)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (topNode.id, new RandomForest.NodeIndexInfo(0, None))
    +    )))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +
    +    // don't enqueue a node into node queue if its impurity is 0.0
    +    assert(nodeQueue.isEmpty)
    +
    +    // set impurity and predict for topNode
    +    assert(topNode.predict.predict !== Double.MinValue)
    +    assert(topNode.impurity !== -1.0)
    +
    +    // set impurity and predict for child nodes
    +    assert(topNode.leftNode.get.predict.predict === 0.0)
    +    assert(topNode.rightNode.get.predict.predict === 1.0)
    +    assert(topNode.leftNode.get.impurity === 0.0)
    +    assert(topNode.rightNode.get.impurity === 0.0)
    +  }
    +
    +  test("Second level node building with vs. without groups") {
    +    val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
    +    assert(arr.length === 1000)
    +    val rdd = sc.parallelize(arr)
    +    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
    +    val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
    +    val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
    +    assert(splits.length === 2)
    +    assert(splits(0).length === 99)
    +    assert(bins.length === 2)
    +    assert(bins(0).length === 100)
    +
    +    // Train a 1-node model
    +    val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
    +      numClasses = 2, maxBins = 100)
    +    val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
    +    val rootNode1 = modelOneNode.topNode.deepCopy()
    +    val rootNode2 = modelOneNode.topNode.deepCopy()
    +    assert(rootNode1.leftNode.nonEmpty)
    +    assert(rootNode1.rightNode.nonEmpty)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
    +
    +    // Single group second level tree construction.
    +    val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
    +    val treeToNodeToIndexInfo = Map((0, Map(
    +      (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
    +      (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
    +    val nodeQueue = new mutable.Queue[(Int, Node)]()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
    +    val children1 = new Array[Node](2)
    +    children1(0) = rootNode1.leftNode.get
    +    children1(1) = rootNode1.rightNode.get
    +
    +    // Train one second-level node at a time.
    +    val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
    +    val treeToNodeToIndexInfoA = Map((0, Map(
    +      (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
    +    nodeQueue.clear()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
    +      nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
    +    val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
    +    val treeToNodeToIndexInfoB = Map((0, Map(
    +      (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
    +    nodeQueue.clear()
    +    DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
    +      nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
    +    val children2 = new Array[Node](2)
    +    children2(0) = rootNode2.leftNode.get
    +    children2(1) = rootNode2.rightNode.get
    +
    +    // Verify whether the splits obtained using single group and multiple group level
    +    // construction strategies are the same.
    +    for (i <- 0 until 2) {
    +      assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain >
0)
    +      assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain >
0)
    +      assert(children1(i).split === children2(i).split)
    +      assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
    +      val stats1 = children1(i).stats.get
    +      val stats2 = children2(i).stats.get
    +      assert(stats1.gain === stats2.gain)
    +      assert(stats1.impurity === stats2.impurity)
    +      assert(stats1.leftImpurity === stats2.leftImpurity)
    +      assert(stats1.rightImpurity === stats2.rightImpurity)
    +      assert(children1(i).predict.predict === children2(i).predict.predict)
    +    }
    +  }
    --- End diff --
    
    Could you keep the origin ordering of tests? It helps diff.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message