spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-14814][MLLIB] API: Java compatibility, docs
Date Mon, 09 May 2016 08:09:12 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 238b7b416 -> eb0db9090


[SPARK-14814][MLLIB] API: Java compatibility, docs

## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-14814
fix a java compatibility function in mllib DecisionTreeModel. As synced in jira, other compatibility
issues don't need fixes.

## How was this patch tested?

existing ut

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #12971 from hhbyyh/javacompatibility.

(cherry picked from commit 68abc1b4e9afbb6c2a87689221a46b835dded102)
Signed-off-by: Sean Owen <sowen@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eb0db909
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eb0db909
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eb0db909

Branch: refs/heads/branch-2.0
Commit: eb0db909009afd9289d24fd5a59eb060b8aafc5f
Parents: 238b7b4
Author: Yuhao Yang <hhbyyh@gmail.com>
Authored: Mon May 9 09:08:54 2016 +0100
Committer: Sean Owen <sowen@cloudera.com>
Committed: Mon May 9 09:09:07 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/model/DecisionTreeModel.scala |  4 ++--
 .../apache/spark/mllib/tree/JavaDecisionTreeSuite.java    | 10 ++++++++++
 2 files changed, 12 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eb0db909/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a87f8a6..c13b9a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -75,8 +75,8 @@ class DecisionTreeModel @Since("1.0.0") (
    * @return JavaRDD of predictions for each of the given data points
    */
   @Since("1.2.0")
-  def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
-    predict(features.rdd)
+  def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+    predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/eb0db909/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 8dd2906..60585d2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -28,6 +28,8 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.mllib.tree.configuration.Algo;
 import org.apache.spark.mllib.tree.configuration.Strategy;
@@ -95,6 +97,14 @@ public class JavaDecisionTreeSuite implements Serializable {
 
     DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
 
+    // java compatibility test
+    JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint,
Vector>() {
+      @Override
+      public Vector call(LabeledPoint v1) {
+        return v1.features();
+      }
+    }));
+
     int numCorrect = validatePrediction(arr, model);
     Assert.assertTrue(numCorrect == rdd.count());
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message