spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6093] [MLLIB] Add RegressionMetrics in PySpark/MLlib
Date Thu, 07 May 2015 18:18:42 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.4 3038b26f1 -> ef835dc52


[SPARK-6093] [MLLIB] Add RegressionMetrics in PySpark/MLlib

https://issues.apache.org/jira/browse/SPARK-6093

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #5941 from yanboliang/spark-6093 and squashes the following commits:

6934af3 [Yanbo Liang] change to @property
aac3bc5 [Yanbo Liang] Add RegressionMetrics in PySpark/MLlib

(cherry picked from commit 1712a7c7057bf6dd5da8aea1d7fbecdf96ea4b32)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ef835dc5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ef835dc5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ef835dc5

Branch: refs/heads/branch-1.4
Commit: ef835dc526b685886781d454c46e837644d8f446
Parents: 3038b26
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Thu May 7 11:18:32 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu May 7 11:18:38 2015 -0700

----------------------------------------------------------------------
 .../mllib/evaluation/RegressionMetrics.scala    |  9 +++
 python/pyspark/mllib/evaluation.py              | 78 +++++++++++++++++++-
 2 files changed, 85 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ef835dc5/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 693117d..e577bf8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.Logging
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
+import org.apache.spark.sql.DataFrame
 
 /**
  * :: Experimental ::
@@ -33,6 +34,14 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Multivariate
 class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging
{
 
   /**
+   * An auxiliary constructor taking a DataFrame.
+   * @param predictionAndObservations a DataFrame with two double columns:
+   *                                  prediction and observation
+   */
+  private[mllib] def this(predictionAndObservations: DataFrame) =
+    this(predictionAndObservations.map(r => (r.getDouble(0), r.getDouble(1))))
+
+  /**
    * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and
errors.
    */
   private lazy val summary: MultivariateStatisticalSummary = {

http://git-wip-us.apache.org/repos/asf/spark/blob/ef835dc5/python/pyspark/mllib/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 16cb49c..3e11df0 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -27,9 +27,9 @@ class BinaryClassificationMetrics(JavaModelWrapper):
     >>> scoreAndLabels = sc.parallelize([
     ...     (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8,
1.0)], 2)
     >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
-    >>> metrics.areaUnderROC()
+    >>> metrics.areaUnderROC
     0.70...
-    >>> metrics.areaUnderPR()
+    >>> metrics.areaUnderPR
     0.83...
     >>> metrics.unpersist()
     """
@@ -47,6 +47,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
         java_model = java_class(df._jdf)
         super(BinaryClassificationMetrics, self).__init__(java_model)
 
+    @property
     def areaUnderROC(self):
         """
         Computes the area under the receiver operating characteristic
@@ -54,6 +55,7 @@ class BinaryClassificationMetrics(JavaModelWrapper):
         """
         return self.call("areaUnderROC")
 
+    @property
     def areaUnderPR(self):
         """
         Computes the area under the precision-recall curve.
@@ -67,6 +69,78 @@ class BinaryClassificationMetrics(JavaModelWrapper):
         self.call("unpersist")
 
 
+class RegressionMetrics(JavaModelWrapper):
+    """
+    Evaluator for regression.
+
+    >>> predictionAndObservations = sc.parallelize([
+    ...     (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
+    >>> metrics = RegressionMetrics(predictionAndObservations)
+    >>> metrics.explainedVariance
+    0.95...
+    >>> metrics.meanAbsoluteError
+    0.5...
+    >>> metrics.meanSquaredError
+    0.37...
+    >>> metrics.rootMeanSquaredError
+    0.61...
+    >>> metrics.r2
+    0.94...
+    """
+
+    def __init__(self, predictionAndObservations):
+        """
+        :param predictionAndObservations: an RDD of (prediction, observation) pairs.
+        """
+        sc = predictionAndObservations.ctx
+        sql_ctx = SQLContext(sc)
+        df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
+            StructField("prediction", DoubleType(), nullable=False),
+            StructField("observation", DoubleType(), nullable=False)]))
+        java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
+        java_model = java_class(df._jdf)
+        super(RegressionMetrics, self).__init__(java_model)
+
+    @property
+    def explainedVariance(self):
+        """
+        Returns the explained variance regression score.
+        explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+        """
+        return self.call("explainedVariance")
+
+    @property
+    def meanAbsoluteError(self):
+        """
+        Returns the mean absolute error, which is a risk function corresponding to the
+        expected value of the absolute error loss or l1-norm loss.
+        """
+        return self.call("meanAbsoluteError")
+
+    @property
+    def meanSquaredError(self):
+        """
+        Returns the mean squared error, which is a risk function corresponding to the
+        expected value of the squared error loss or quadratic loss.
+        """
+        return self.call("meanSquaredError")
+
+    @property
+    def rootMeanSquaredError(self):
+        """
+        Returns the root mean squared error, which is defined as the square root of
+        the mean squared error.
+        """
+        return self.call("rootMeanSquaredError")
+
+    @property
+    def r2(self):
+        """
+        Returns R^2^, the coefficient of determination.
+        """
+        return self.call("r2")
+
+
 def _test():
     import doctest
     from pyspark import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message