spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-4369] [MLLib] fix TreeModel.predict() with RDD
Date Wed, 12 Nov 2014 21:56:55 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.2 127c19b44 -> 16da988c5


[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD

Fix  TreeModel.predict() with RDD, added tests for it.

(Also checked that other models don't have this issue)

Author: Davies Liu <davies@databricks.com>

Closes #3230 from davies/predict and squashes the following commits:

81172aa [Davies Liu] fix predict

(cherry picked from commit bd86118c4e980f94916f892c76fb808fd4c8bd85)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/16da988c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/16da988c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/16da988c

Branch: refs/heads/branch-1.2
Commit: 16da988c5cdae935151e307a66a5385bac5167c3
Parents: 127c19b
Author: Davies Liu <davies@databricks.com>
Authored: Wed Nov 12 13:56:41 2014 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Nov 12 13:56:50 2014 -0800

----------------------------------------------------------------------
 .../mllib/tree/model/DecisionTreeModel.scala    | 12 +++++++++
 python/pyspark/mllib/tree.py                    | 26 +++++++++++---------
 2 files changed, 26 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/16da988c/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ec1d99a..ac4d02e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.tree.model
 
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.rdd.RDD
@@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
     features.map(x => predict(x))
   }
 
+
+  /**
+   * Predict values for the given data set using the model trained.
+   *
+   * @param features JavaRDD representing data points to be predicted
+   * @return JavaRDD of predictions for each of the given data points
+   */
+  def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
+    predict(features.rdd)
+  }
+
   /**
    * Get number of nodes in tree, including leaf nodes.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/16da988c/python/pyspark/mllib/tree.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 5d1a3c0..ef0d556 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -124,10 +124,13 @@ class DecisionTree(object):
            Predict: 0.0
           Else (feature 0 > 0.0)
            Predict: 1.0
-        >>> model.predict(array([1.0])) > 0
-        True
-        >>> model.predict(array([0.0])) == 0
-        True
+        >>> model.predict(array([1.0]))
+        1.0
+        >>> model.predict(array([0.0]))
+        0.0
+        >>> rdd = sc.parallelize([[1.0], [0.0]])
+        >>> model.predict(rdd).collect()
+        [1.0, 0.0]
         """
         return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
                                    impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
@@ -170,14 +173,13 @@ class DecisionTree(object):
         ... ]
         >>>
         >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
-        >>> model.predict(array([0.0, 1.0])) == 1
-        True
-        >>> model.predict(array([0.0, 0.0])) == 0
-        True
-        >>> model.predict(SparseVector(2, {1: 1.0})) == 1
-        True
-        >>> model.predict(SparseVector(2, {1: 0.0})) == 0
-        True
+        >>> model.predict(SparseVector(2, {1: 1.0}))
+        1.0
+        >>> model.predict(SparseVector(2, {1: 0.0}))
+        0.0
+        >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
+        >>> model.predict(rdd).collect()
+        [1.0, 0.0]
         """
         return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
                                    impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message