spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression
Date Mon, 09 Nov 2015 16:56:27 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 0f03bd13e -> 029e931da


[SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression

Expose R-like summary statistics in SparkR::glm for linear regression, the output of ```summary```
like
```Java
$DevianceResiduals
 Min        Max
 -0.9509607 0.7291832

$Coefficients
                   Estimate   Std. Error t value   Pr(>|t|)
(Intercept)        1.6765     0.2353597  7.123139  4.456124e-11
Sepal_Length       0.3498801  0.04630128 7.556598  4.187317e-12
Species_versicolor -0.9833885 0.07207471 -13.64402 0
Species_virginica  -1.00751   0.09330565 -10.79796 0
```

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9561 from yanboliang/spark-11494.

(cherry picked from commit 8c0e1b50e960d3e8e51d0618c462eed2bb4936f0)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/029e931d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/029e931d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/029e931d

Branch: refs/heads/branch-1.6
Commit: 029e931dae82b9843ac0fe9348fe6f64ae6556db
Parents: 0f03bd1
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Mon Nov 9 08:56:22 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Nov 9 08:56:32 2015 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 | 22 +++++++--
 R/pkg/inst/tests/test_mllib.R                   | 31 +++++++++---
 .../org/apache/spark/ml/r/SparkRWrappers.scala  | 50 ++++++++++++++++++--
 3 files changed, 88 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/029e931d/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b0d73dd..7ff8597 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -91,12 +91,26 @@ setMethod("predict", signature(object = "PipelineModel"),
 #'}
 setMethod("summary", signature(x = "PipelineModel"),
           function(x, ...) {
+            modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                   "getModelName", x@model)
             features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
                                    "getModelFeatures", x@model)
             coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
                                    "getModelCoefficients", x@model)
-            coefficients <- as.matrix(unlist(coefficients))
-            colnames(coefficients) <- c("Estimate")
-            rownames(coefficients) <- unlist(features)
-            return(list(coefficients = coefficients))
+            if (modelName == "LinearRegressionModel") {
+              devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+                                               "getModelDevianceResiduals", x@model)
+              devianceResiduals <- matrix(devianceResiduals, nrow = 1)
+              colnames(devianceResiduals) <- c("Min", "Max")
+              rownames(devianceResiduals) <- rep("", times = 1)
+              coefficients <- matrix(coefficients, ncol = 4)
+              colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
+              rownames(coefficients) <- unlist(features)
+              return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients))
+            } else {
+              coefficients <- as.matrix(unlist(coefficients))
+              colnames(coefficients) <- c("Estimate")
+              rownames(coefficients) <- unlist(features)
+              return(list(coefficients = coefficients))
+            }
           })

http://git-wip-us.apache.org/repos/asf/spark/blob/029e931d/R/pkg/inst/tests/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 4761e28..2606407 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", {
 
 test_that("summary coefficients match with native glm", {
   training <- createDataFrame(sqlContext, iris)
-  stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver =
"l-bfgs"))
-  coefs <- as.vector(stats$coefficients)
+  stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver =
"normal"))
+  coefs <- unlist(stats$Coefficients)
+  devianceResiduals <- unlist(stats$DevianceResiduals)
+
   rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
-  expect_true(all(abs(rCoefs - coefs) < 1e-6))
+  rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331)
+  rTValue <- c(7.123, 7.557, -13.644, -10.798)
+  rPValue <- c(0.0, 0.0, 0.0, 0.0)
+  rDevianceResiduals <- c(-0.95096, 0.72918)
+
+  expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6))
+  expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5))
+  expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3))
+  expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6))
+  expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
   expect_true(all(
-    as.character(stats$features) ==
+    rownames(stats$Coefficients) ==
     c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
 })
 
@@ -85,14 +96,20 @@ test_that("summary coefficients match with native glm of family 'binomial'",
{
   training <- filter(df, df$Species != "setosa")
   stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
     family = "binomial"))
-  coefs <- as.vector(stats$coefficients)
+  coefs <- as.vector(stats$Coefficients)
 
   rTraining <- iris[iris$Species %in% c("versicolor","virginica"),]
   rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
     family = binomial(link = "logit"))))
+  rStdError <- c(3.0974, 0.5169, 0.8628)
+  rTValue <- c(-4.212, 3.680, 0.469)
+  rPValue <- c(0.000, 0.000, 0.639)
 
-  expect_true(all(abs(rCoefs - coefs) < 1e-4))
+  expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4))
+  expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4))
+  expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3))
+  expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3))
   expect_true(all(
-    as.character(stats$features) ==
+    rownames(stats$Coefficients) ==
     c("(Intercept)", "Sepal_Length", "Sepal_Width")))
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/029e931d/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 5be2f86..4d82b90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -53,10 +53,35 @@ private[r] object SparkRWrappers {
 
   def getModelCoefficients(model: PipelineModel): Array[Double] = {
     model.stages.last match {
+      case m: LinearRegressionModel => {
+        val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last)
++
+          m.summary.coefficientStandardErrors.dropRight(1)
+        val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1)
+        val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1)
+        if (m.getFitIntercept) {
+          Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++
+            tValuesR ++ pValuesR
+        } else {
+          m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR
+        }
+      }
+      case m: LogisticRegressionModel => {
+        if (m.getFitIntercept) {
+          Array(m.intercept) ++ m.coefficients.toArray
+        } else {
+          m.coefficients.toArray
+        }
+      }
+    }
+  }
+
+  def getModelDevianceResiduals(model: PipelineModel): Array[Double] = {
+    model.stages.last match {
       case m: LinearRegressionModel =>
-        Array(m.intercept) ++ m.coefficients.toArray
+        m.summary.devianceResiduals
       case m: LogisticRegressionModel =>
-        Array(m.intercept) ++ m.coefficients.toArray
+        throw new UnsupportedOperationException(
+          "No deviance residuals available for LogisticRegressionModel")
     }
   }
 
@@ -65,11 +90,28 @@ private[r] object SparkRWrappers {
       case m: LinearRegressionModel =>
         val attrs = AttributeGroup.fromStructField(
           m.summary.predictions.schema(m.summary.featuresCol))
-        Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+        if (m.getFitIntercept) {
+          Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+        } else {
+          attrs.attributes.get.map(_.name.get)
+        }
       case m: LogisticRegressionModel =>
         val attrs = AttributeGroup.fromStructField(
           m.summary.predictions.schema(m.summary.featuresCol))
-        Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+        if (m.getFitIntercept) {
+          Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+        } else {
+          attrs.attributes.get.map(_.name.get)
+        }
+    }
+  }
+
+  def getModelName(model: PipelineModel): String = {
+    model.stages.last match {
+      case m: LinearRegressionModel =>
+        "LinearRegressionModel"
+      case m: LogisticRegressionModel =>
+        "LogisticRegressionModel"
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message