Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 3EA2D19594 for ; Wed, 9 Mar 2016 22:44:57 +0000 (UTC) Received: (qmail 57821 invoked by uid 500); 9 Mar 2016 22:44:57 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 57788 invoked by uid 500); 9 Mar 2016 22:44:57 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 57779 invoked by uid 99); 9 Mar 2016 22:44:57 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 09 Mar 2016 22:44:57 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 00FB4DFB8A; Wed, 9 Mar 2016 22:44:56 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: jkbradley@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-11861][ML] Add feature importances for decision trees Date: Wed, 9 Mar 2016 22:44:57 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master c6aa356cd -> e1772d3f1 [SPARK-11861][ML] Add feature importances for decision trees This patch adds an API entry point for single decision tree feature importances. Author: sethah Closes #9912 from sethah/SPARK-11861. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1772d3f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1772d3f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1772d3f Branch: refs/heads/master Commit: e1772d3f19bed7e69a80de7900ed22d3eeb05300 Parents: c6aa356 Author: sethah Authored: Wed Mar 9 14:44:51 2016 -0800 Committer: Joseph K. Bradley Committed: Wed Mar 9 14:44:51 2016 -0800 ---------------------------------------------------------------------- .../classification/DecisionTreeClassifier.scala | 19 +++++++++++++ .../classification/RandomForestClassifier.scala | 4 +-- .../ml/regression/DecisionTreeRegressor.scala | 19 +++++++++++++ .../ml/regression/RandomForestRegressor.scala | 4 +-- .../spark/ml/tree/impl/RandomForest.scala | 30 ++++++++++++++++---- .../DecisionTreeClassifierSuite.scala | 21 ++++++++++++++ .../RandomForestClassifierSuite.scala | 10 ++----- .../org/apache/spark/ml/impl/TreeTests.scala | 13 +++++++++ .../regression/DecisionTreeRegressorSuite.scala | 20 +++++++++++++ .../regression/RandomForestRegressorSuite.scala | 13 ++------- 10 files changed, 126 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8c4cec1..7f0397f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -169,6 +169,25 @@ final class DecisionTreeClassificationModel private[ml] ( s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * Note: Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. + */ + @Since("2.0.0") + lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f7d662d..5da04d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -230,10 +230,10 @@ final class RandomForestClassificationModel private[ml] ( * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree based on total number of training instances used - * to build tree. + * - Normalize importances for tree to sum to 1. * - Normalize feature importance vector to sum to 1. */ + @Since("1.5.0") lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 18c94f3..897b233 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -169,6 +169,25 @@ final class DecisionTreeRegressionModel private[ml] ( s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } + /** + * Estimate of the importance of each feature. + * + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * Note: Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestRegressor]] + * to determine feature importance instead. + */ + @Since("2.0.0") + lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures) + /** Convert to a model in the old API */ private[ml] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 71e40b5..798947b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -189,10 +189,10 @@ final class RandomForestRegressionModel private[ml] ( * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree based on total number of training instances used - * to build tree. + * - Normalize importances for tree to sum to 1. * - Normalize feature importance vector to sum to 1. */ + @Since("1.5.0") lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures) /** (private[ml]) Convert to a model in the old API */ http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index ea733d5..f994c25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -1089,12 +1089,9 @@ private[ml] object RandomForest extends Logging { * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree based on total number of training instances used - * to build tree. + * - Normalize importances for tree to sum to 1. * - Normalize feature importance vector to sum to 1. * - * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for - * independently trained trees. * @param trees Unweighted forest of trees * @param numFeatures Number of features in model (even if not all are explicitly used by * the model). @@ -1128,14 +1125,35 @@ private[ml] object RandomForest extends Logging { maxFeatureIndex + 1 } if (d == 0) { - assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" + - s" importance: No splits in forest, but some non-zero importances.") + assert(totalImportances.size == 0, s"Unknown error in computing feature" + + s" importance: No splits found, but some non-zero importances.") } val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip Vectors.sparse(d, indices.toArray, values.toArray) } /** + * Given a Decision Tree model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * @param tree Decision tree to compute importances for. + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = { + featureImportances(Array(tree), numFeatures) + } + + /** * Recursive method for computing feature importances for one tree. * This walks down the tree, adding to the importance of 1 feature at each node. * @param node Current node in recursion http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 9169bcd..6d68364 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -313,6 +313,27 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte } } + test("Feature importance with toy data") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val numFeatures = data.first().features.size + val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap + val df = TreeTests.setMetadata(data, categoricalFeatures, 2) + + val model = dt.fit(df) + + val importances = model.featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index deb8ec7..6b810ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -167,19 +167,15 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setSeed(123) // In this data, feature 1 is very important. - val data: RDD[LabeledPoint] = sc.parallelize(Seq( - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) - )) + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val importances = rf.fit(df).featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) } ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index a808177..5561f6f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.impl import scala.collection.JavaConverters._ +import org.apache.spark.SparkContext import org.apache.spark.SparkFunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} @@ -141,4 +143,15 @@ private[ml] object TreeTests extends SparkFunSuite { val pred = parentImp.predict new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) } + + /** + * Create some toy data for testing feature importances. + */ + def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) } http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 13165f6..56b335a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -96,6 +96,26 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex } } + test("Feature importance with toy data") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + + val importances = model.featureImportances + val mostImportantFeature = importances.argmax + assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/e1772d3f/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 7e751e4..efb117f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -82,23 +81,17 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setSeed(123) // In this data, feature 1 is very important. - val data: RDD[LabeledPoint] = sc.parallelize(Seq( - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), - new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), - new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) - )) + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) val categoricalFeatures = Map.empty[Int, Int] val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) val model = rf.fit(df) - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) val importances = model.featureImportances val mostImportantFeature = importances.argmax assert(mostImportantFeature === 1) + assert(importances.toArray.sum === 1.0) + assert(importances.toArray.forall(_ >= 0.0)) } ///////////////////////////////////////////////////////////////////////////// --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org