spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier
Date Mon, 03 Aug 2015 17:46:39 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8be198c86 -> 69f5a7c93


[SPARK-9528] [ML] Changed RandomForestClassifier to extend ProbabilisticClassifier

RandomForestClassifier now outputs rawPrediction based on tree probabilities, plus probability
column computed from normalized rawPrediction.

CC: holdenk

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7859 from jkbradley/rf-prob and squashes the following commits:

6c28f51 [Joseph K. Bradley] Changed RandomForestClassifier to extend ProbabilisticClassifier


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69f5a7c9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69f5a7c9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69f5a7c9

Branch: refs/heads/master
Commit: 69f5a7c934ac553ed52c00679b800bcffe83c1d6
Parents: 8be198c
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Mon Aug 3 10:46:34 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Mon Aug 3 10:46:34 2015 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |  8 +----
 .../ProbabilisticClassifier.scala               | 27 +++++++++++++-
 .../classification/RandomForestClassifier.scala | 37 ++++++++++++++------
 .../RandomForestClassifierSuite.scala           | 36 ++++++++++++++-----
 4 files changed, 81 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/69f5a7c9/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index f27cfd0..f2b992f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -131,13 +131,7 @@ final class DecisionTreeClassificationModel private[ml] (
   override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
     rawPrediction match {
       case dv: DenseVector =>
-        var i = 0
-        val size = dv.size
-        val sum = dv.values.sum
-        while (i < size) {
-          dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
-          i += 1
-        }
+        ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
         dv
       case sv: SparseVector =>
         throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:"
+

http://git-wip-us.apache.org/repos/asf/spark/blob/69f5a7c9/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index dad4511..f9c9c23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
@@ -175,3 +175,28 @@ private[spark] abstract class ProbabilisticClassificationModel[
    */
   protected def probability2prediction(probability: Vector): Double = probability.argmax
 }
+
+private[ml] object ProbabilisticClassificationModel {
+
+  /**
+   * Normalize a vector of raw predictions to be a multinomial probability vector, in place.
+   *
+   * The input raw predictions should be >= 0.
+   * The output vector sums to 1, unless the input vector is all-0 (in which case the output
is
+   * all-0 too).
+   *
+   * NOTE: This is NOT applicable to all models, only ones which effectively use class
+   *       instance counts for raw predictions.
+   */
+  def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = {
+    val sum = v.values.sum
+    if (sum != 0) {
+      var i = 0
+      val size = v.size
+      while (i < size) {
+        v.values(i) /= sum
+        i += 1
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/69f5a7c9/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 0c7eb4a..56e80cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,22 +17,19 @@
 
 package org.apache.spark.ml.classification
 
-import scala.collection.mutable
-
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams,
TreeEnsembleModel}
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
 
 /**
  * :: Experimental ::
@@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType
  */
 @Experimental
 final class RandomForestClassifier(override val uid: String)
-  extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
+  extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
   with RandomForestParams with TreeClassifierParams {
 
   def this() = this(Identifiable.randomUID("rfc"))
@@ -127,7 +124,7 @@ final class RandomForestClassificationModel private[ml] (
     override val uid: String,
     private val _trees: Array[DecisionTreeClassificationModel],
     override val numClasses: Int)
-  extends ClassificationModel[Vector, RandomForestClassificationModel]
+  extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
   with TreeEnsembleModel with Serializable {
 
   require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -157,15 +154,33 @@ final class RandomForestClassificationModel private[ml] (
   override protected def predictRaw(features: Vector): Vector = {
     // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
     // Classifies using majority votes.
-    // Ignore the weights since all are 1.0 for now.
-    val votes = new Array[Double](numClasses)
+    // Ignore the tree weights since all are 1.0 for now.
+    val votes = Array.fill[Double](numClasses)(0.0)
     _trees.view.foreach { tree =>
-      val prediction = tree.rootNode.predictImpl(features).prediction.toInt
-      votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
+      val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
+      val total = classCounts.sum
+      if (total != 0) {
+        var i = 0
+        while (i < numClasses) {
+          votes(i) += classCounts(i) / total
+          i += 1
+        }
+      }
     }
     Vectors.dense(votes)
   }
 
+  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+    rawPrediction match {
+      case dv: DenseVector =>
+        ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+        dv
+      case sv: SparseVector =>
+        throw new RuntimeException("Unexpected error in RandomForestClassificationModel:"
+
+          " raw2probabilityInPlace encountered SparseVector")
+    }
+  }
+
   override def copy(extra: ParamMap): RandomForestClassificationModel = {
     copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/69f5a7c9/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index dbb2577..edf848b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 
@@ -121,6 +122,33 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
     compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
   }
 
+  test("predictRaw and predictProbability") {
+    val rdd = orderedLabeledPoints5_20
+    val rf = new RandomForestClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(3)
+      .setNumTrees(3)
+      .setSeed(123)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val numClasses = 2
+
+    val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+    val model = rf.fit(df)
+
+    val predictions = model.transform(df)
+      .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
+      .collect()
+
+    predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+      assert(pred === rawPred.argmax,
+        s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+      val sum = rawPred.toArray.sum
+      assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+        "probability prediction mismatch")
+      assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
+    }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
@@ -173,13 +201,5 @@ private object RandomForestClassifierSuite {
     assert(newModel.hasParent)
     assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
     assert(newModel.numClasses == numClasses)
-    val results = newModel.transform(newData)
-    results.select("rawPrediction", "prediction").collect().foreach {
-      case Row(raw: Vector, prediction: Double) => {
-        assert(raw.size == numClasses)
-        val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
-        assert(predFromRaw == prediction)
-      }
-    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message