spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [mllib] DecisionTree: treeAggregate + Python example bug fix
Date Mon, 18 Aug 2014 21:40:28 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.1 25cabd7ee -> 98778fffd


[mllib] DecisionTree: treeAggregate + Python example bug fix

Small DecisionTree updates:
* Changed main DecisionTree aggregate to treeAggregate.
* Fixed bug in python example decision_tree_runner.py with missing argument (since categoricalFeaturesInfo
is no longer an optional argument for trainClassifier).
* Fixed same bug in python doc tests, and added tree.py to doc tests.

CC: mengxr

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #2015 from jkbradley/dt-opt2 and squashes the following commits:

b5114fa [Joseph K. Bradley] Fixed python tree.py doc test (extra newline)
8e4665d [Joseph K. Bradley] Added tree.py to python doc tests.  Fixed bug from missing categoricalFeaturesInfo
argument.
b7b2922 [Joseph K. Bradley] Fixed bug in python example decision_tree_runner.py with missing
argument.  Changed main DecisionTree aggregate to treeAggregate.
85bbc1f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2
66d076f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2
a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata.  Small doc updates.
3726d20 [Joseph K. Bradley] Small code improvements based on code review.
ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using <<
instead of math.pow.
db0d773 [Joseph K. Bradley] scala style fix
6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code
931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2
797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level.  Needed
to update treePointToNodeIndex with groupShift.
f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint
6b5651e [Joseph K. Bradley] Updates based on code review.  1 major change: persisting to memory
+ disk, not just memory.
2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used.  Removed
debugging println calls in DecisionTree.scala.
356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2
430d782 [Joseph K. Bradley] Added more debug info on binning error.  Added some docs.
d036089 [Joseph K. Bradley] Print timing info to logDebug.
e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private
8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it
up.  Removed debugging println calls from DecisionTree.  Made TreePoint extend Serialiable
a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1
c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit
to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification
b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes
b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt
0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree
3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid
calling findBin multiple times. * (not working yet, but debugging)
f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing
a95bc22 [Joseph K. Bradley] timing for DecisionTree internals

(cherry picked from commit 115eeb30dd9c9dd10685a71f2c23ca23794d3142)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/98778fff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/98778fff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/98778fff

Branch: refs/heads/branch-1.1
Commit: 98778fffdb4e11593149eb7770071a0728653f19
Parents: 25cabd7
Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>
Authored: Mon Aug 18 14:40:05 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Aug 18 14:40:21 2014 -0700

----------------------------------------------------------------------
 .../src/main/python/mllib/decision_tree_runner.py     |  4 +++-
 .../org/apache/spark/mllib/tree/DecisionTree.scala    |  3 ++-
 python/pyspark/mllib/tree.py                          | 14 ++++++++------
 python/run-tests                                      |  1 +
 4 files changed, 14 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/98778fff/examples/src/main/python/mllib/decision_tree_runner.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py
index 8efadb5..db96a7c 100755
--- a/examples/src/main/python/mllib/decision_tree_runner.py
+++ b/examples/src/main/python/mllib/decision_tree_runner.py
@@ -124,7 +124,9 @@ if __name__ == "__main__":
     (reindexedData, origToNewLabels) = reindexClassLabels(points)
 
     # Train a classifier.
-    model = DecisionTree.trainClassifier(reindexedData, numClasses=2)
+    categoricalFeaturesInfo={} # no categorical features
+    model = DecisionTree.trainClassifier(reindexedData, numClasses=2,
+                                         categoricalFeaturesInfo=categoricalFeaturesInfo)
     # Print learned tree and stats.
     print "Trained DecisionTree for classification:"
     print "  Model numNodes: %d\n" % model.numNodes()

http://git-wip-us.apache.org/repos/asf/spark/blob/98778fff/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 6b9a8f7..5cdd258 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.Logging
+import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.configuration.Algo._
@@ -826,7 +827,7 @@ object DecisionTree extends Serializable with Logging {
     // Calculate bin aggregates.
     timer.start("aggregation")
     val binAggregates = {
-      input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
+      input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp)
     }
     timer.stop("aggregation")
     logDebug("binAggregates.length = " + binAggregates.length)

http://git-wip-us.apache.org/repos/asf/spark/blob/98778fff/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index e1a4671..e9d778d 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -88,7 +88,8 @@ class DecisionTree(object):
                   It will probably be modified for Spark v1.2.
 
     Example usage:
-    >>> from numpy import array, ndarray
+    >>> from numpy import array
+    >>> import sys
     >>> from pyspark.mllib.regression import LabeledPoint
     >>> from pyspark.mllib.tree import DecisionTree
     >>> from pyspark.mllib.linalg import SparseVector
@@ -99,15 +100,15 @@ class DecisionTree(object):
     ...     LabeledPoint(1.0, [2.0]),
     ...     LabeledPoint(1.0, [3.0])
     ... ]
-    >>>
-    >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
-    >>> print(model)
+    >>> categoricalFeaturesInfo = {} # no categorical features
+    >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
+    ...                                      categoricalFeaturesInfo=categoricalFeaturesInfo)
+    >>> sys.stdout.write(model)
     DecisionTreeModel classifier
       If (feature 0 <= 0.5)
        Predict: 0.0
       Else (feature 0 > 0.5)
        Predict: 1.0
-
     >>> model.predict(array([1.0])) > 0
     True
     >>> model.predict(array([0.0])) == 0
@@ -119,7 +120,8 @@ class DecisionTree(object):
     ...     LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
     ... ]
     >>>
-    >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
+    >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
+    ...                                     categoricalFeaturesInfo=categoricalFeaturesInfo)
     >>> model.predict(array([0.0, 1.0])) == 1
     True
     >>> model.predict(array([0.0, 0.0])) == 0

http://git-wip-us.apache.org/repos/asf/spark/blob/98778fff/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index 1218edc..a6271e0 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py"
 run_test "pyspark/mllib/recommendation.py"
 run_test "pyspark/mllib/regression.py"
 run_test "pyspark/mllib/tests.py"
+run_test "pyspark/mllib/tree.py"
 run_test "pyspark/mllib/util.py"
 
 if [[ $FAILED == 0 ]]; then


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message