spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training
Date Tue, 11 Oct 2016 00:04:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master 29f186bfd -> 03c40202f


[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training

## What changes were proposed in this pull request?

A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision
trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally,
an assertion to check that the number of found splits is `> 0` is removed, and instead
features with zero possible splits are ignored.

## How was this patch tested?

A unit test was added to check that finding splits for a constant feature produces an empty
array.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #12374 from sethah/SPARK-14610.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/03c40202
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/03c40202
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/03c40202

Branch: refs/heads/master
Commit: 03c40202f36ea9fc93071b79fed21ed3f2190ba1
Parents: 29f186b
Author: sethah <seth.hendrickson16@gmail.com>
Authored: Mon Oct 10 17:04:11 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Mon Oct 10 17:04:11 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/tree/impl/RandomForest.scala       | 31 +++++++-------
 .../spark/ml/tree/impl/RandomForestSuite.scala  | 44 ++++++++++++++++----
 2 files changed, 52 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/03c40202/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 0b7ad92..b504f41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging {
       node.stats
     }
 
+    val validFeatureSplits =
+      Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
+        featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+          .getOrElse((featureIndexIdx, featureIndexIdx))
+      }.withFilter { case (_, featureIndex) =>
+        binAggregates.metadata.numSplits(featureIndex) != 0
+      }
+
     // For each (feature, split), calculate the gain, and select the best (feature, split).
     val (bestSplit, bestSplitStats) =
-      Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
-        val featureIndex = if (featuresForNode.nonEmpty) {
-          featuresForNode.get.apply(featureIndexIdx)
-        } else {
-          featureIndexIdx
-        }
+      validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
         val numSplits = binAggregates.metadata.numSplits(featureIndex)
         if (binAggregates.metadata.isContinuous(featureIndex)) {
           // Cumulative sum (scanLeft) of bin statistics.
@@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging {
    *                 NOTE: `metadata.numbins` will be changed accordingly
    *                       if there are not enough splits to be found
    * @param featureIndex feature index to find splits
-   * @return array of splits
+   * @return array of split thresholds
    */
   private[tree] def findSplitsForContinuousFeature(
       featureSamples: Iterable[Double],
@@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging {
     require(metadata.isContinuous(featureIndex),
       "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
 
-    val splits = {
+    val splits = if (featureSamples.isEmpty) {
+      Array.empty[Double]
+    } else {
       val numSplits = metadata.numSplits(featureIndex)
 
       // get count for each distinct value
@@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging {
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
 
       // if possible splits is not enough or just enough, just return all possible splits
-      val possibleSplits = valueCounts.length
+      val possibleSplits = valueCounts.length - 1
       if (possibleSplits <= numSplits) {
-        valueCounts.map(_._1)
+        valueCounts.map(_._1).init
       } else {
         // stride between splits
         val stride: Double = numSamples.toDouble / (numSplits + 1)
@@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging {
         splitsBuilder.result()
       }
     }
-
-    // TODO: Do not fail; just ignore the useless feature.
-    assert(splits.length > 0,
-      s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value."
+
-        "  Please remove this feature and then try again.")
-
     splits
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/03c40202/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 79b19ea..499d386 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -115,7 +115,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext
{
       )
       val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata,
0)
-      assert(splits.length === 3)
+      assert(splits === Array(1.0, 2.0))
       // check returned splits are distinct
       assert(splits.distinct.length === splits.length)
     }
@@ -129,23 +129,53 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext
{
       )
       val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata,
0)
-      assert(splits.length === 2)
-      assert(splits(0) === 2.0)
-      assert(splits(1) === 3.0)
+      assert(splits === Array(2.0, 3.0))
     }
 
     // find splits when most samples close to the maximum
     {
       val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
         Map(), Set(),
-        Array(3), Gini, QuantileStrategy.Sort,
+        Array(2), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
       val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata,
0)
-      assert(splits.length === 1)
-      assert(splits(0) === 1.0)
+      assert(splits === Array(1.0))
     }
+
+    // find splits for constant feature
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(0, 0, 0).map(_.toDouble)
+      val featureSamplesEmpty = Array.empty[Double]
+      val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata,
0)
+      assert(splits === Array[Double]())
+      val splitsEmpty =
+        RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
+      assert(splitsEmpty === Array[Double]())
+    }
+  }
+
+  test("train with constant features") {
+    val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
+    val data = Array.fill(5)(lp)
+    val rdd = sc.parallelize(data)
+    val strategy = new OldStrategy(
+          OldAlgo.Classification,
+          Gini,
+          maxDepth = 2,
+          numClasses = 2,
+          maxBins = 100,
+          categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
+    val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
+    assert(tree.rootNode.impurity === -1.0)
+    assert(tree.depth === 0)
+    assert(tree.rootNode.prediction === lp.label)
   }
 
   test("Multiclass classification with unordered categorical features: split calculations")
{


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message