spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR
Date Fri, 25 Mar 2016 05:29:37 GMT
Repository: spark
Updated Branches:
  refs/heads/master 05f652d6c -> 13cbb2de7


[SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR

## What changes were proposed in this pull request?
This PR continues the work in #11447, we implemented the wrapper of ```AFTSurvivalRegression```
named ```survreg``` in SparkR.

## How was this patch tested?
Test against output from R package survival's survreg.

cc mengxr felixcheung

Close #11447

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #11932 from yanboliang/spark-13010-new.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/13cbb2de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/13cbb2de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/13cbb2de

Branch: refs/heads/master
Commit: 13cbb2de709d0ec2707eebf36c5c97f7d44fb84f
Parents: 05f652d
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Thu Mar 24 22:29:34 2016 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu Mar 24 22:29:34 2016 -0700

----------------------------------------------------------------------
 R/pkg/DESCRIPTION                               |  3 +-
 R/pkg/NAMESPACE                                 |  3 +-
 R/pkg/R/generics.R                              |  4 +
 R/pkg/R/mllib.R                                 | 75 +++++++++++++++
 R/pkg/inst/tests/testthat/test_mllib.R          | 49 ++++++++++
 .../ml/r/AFTSurvivalRegressionWrapper.scala     | 99 ++++++++++++++++++++
 6 files changed, 231 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/DESCRIPTION
----------------------------------------------------------------------
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index e26f9a7..7179438 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -12,7 +12,8 @@ Depends:
     methods,
 Suggests:
     testthat,
-    e1071
+    e1071,
+    survival
 Description: R frontend for Spark
 License: Apache License (== 2.0)
 Collate:

http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5d8a4b1..fa3fb0b 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -16,7 +16,8 @@ exportMethods("glm",
               "summary",
               "kmeans",
               "fitted",
-              "naiveBayes")
+              "naiveBayes",
+              "survreg")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 46b115f..c6990f4 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1179,3 +1179,7 @@ setGeneric("fitted")
 #' @rdname naiveBayes
 #' @export
 setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
+
+#' @rdname survreg
+#' @export
+setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })

http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 2555019..33654d5 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,6 +27,11 @@ setClass("PipelineModel", representation(model = "jobj"))
 #' @export
 setClass("NaiveBayesModel", representation(jobj = "jobj"))
 
+#' @title S4 class that represents a AFTSurvivalRegressionModel
+#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
+#' @export
+setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
+
 #' Fits a generalized linear model
 #'
 #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -273,3 +278,73 @@ setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
                                  formula, data@sdf, laplace)
             return(new("NaiveBayesModel", jobj = jobj))
           })
+
+#' Fit an accelerated failure time (AFT) survival regression model.
+#'
+#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
+#'
+#' @param formula A symbolic description of the model to be fitted. Currently only a few
formula
+#'                operators are supported, including '~', ':', '+', and '-'.
+#'                Note that operator '.' is not supported currently.
+#' @param data DataFrame for training.
+#' @return a fitted AFT survival regression model
+#' @rdname survreg
+#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, ovarian)
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
+#' }
+setMethod("survreg", signature(formula = "formula", data = "DataFrame"),
+          function(formula, data, ...) {
+            formula <- paste(deparse(formula), collapse = "")
+            jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
+                                "fit", formula, data@sdf)
+            return(new("AFTSurvivalRegressionModel", jobj = jobj))
+          })
+
+#' Get the summary of an AFT survival regression model
+#'
+#' Returns the summary of an AFT survival regression model produced by survreg(),
+#' similarly to R's summary().
+#'
+#' @param object a fitted AFT survival regression model
+#' @return coefficients the model's coefficients, intercept and log(scale).
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' summary(model)
+#' }
+setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
+          function(object, ...) {
+            jobj <- object@jobj
+            features <- callJMethod(jobj, "rFeatures")
+            coefficients <- callJMethod(jobj, "rCoefficients")
+            coefficients <- as.matrix(unlist(coefficients))
+            colnames(coefficients) <- c("Value")
+            rownames(coefficients) <- unlist(features)
+            return(list(coefficients = coefficients))
+          })
+
+#' Make predictions from an AFT survival regression model
+#'
+#' Make predictions from a model produced by survreg(), similarly to R package survival's
predict.
+#'
+#' @param object A fitted AFT survival regression model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#' }
+setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
+          function(object, newData) {
+            return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 44b4836..fdb5917 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -200,3 +200,52 @@ test_that("naiveBayes", {
     expect_equal(as.character(predict(m, t1[1, ])), "Yes")
   }
 })
+
+test_that("survreg", {
+  # R code to reproduce the result.
+  #
+  #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
+  #'               x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
+  #' library(survival)
+  #' model <- survreg(Surv(time, status) ~ x + sex, rData)
+  #' summary(model)
+  #' predict(model, data)
+  #
+  # -- output of 'summary(model)'
+  #
+  #              Value Std. Error     z        p
+  # (Intercept)  1.315      0.270  4.88 1.07e-06
+  # x           -0.190      0.173 -1.10 2.72e-01
+  # sex         -0.253      0.329 -0.77 4.42e-01
+  # Log(scale)  -1.160      0.396 -2.93 3.41e-03
+  #
+  # -- output of 'predict(model, data)'
+  #
+  #        1        2        3        4        5        6        7
+  # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
+  #
+  data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
+          list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
+  df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
+  model <- survreg(Surv(time, status) ~ x + sex, df)
+  stats <- summary(model)
+  coefs <- as.vector(stats$coefficients[, 1])
+  rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
+  expect_equal(coefs, rCoefs, tolerance = 1e-4)
+  expect_true(all(
+    rownames(stats$coefficients) ==
+    c("(Intercept)", "x", "sex", "Log(scale)")))
+  p <- collect(select(predict(model, df), "prediction"))
+  expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
+               2.390146, 2.891269, 2.891269), tolerance = 1e-4)
+
+  # Test survival::survreg
+  if (requireNamespace("survival", quietly = TRUE)) {
+    rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
+                 x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
+    expect_that(
+      model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data
= rData),
+      not(throws_error()))
+    expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
+  }
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
new file mode 100644
index 0000000..40590e7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
+import org.apache.spark.sql.DataFrame
+
+private[r] class AFTSurvivalRegressionWrapper private (
+    pipeline: PipelineModel,
+    features: Array[String]) {
+
+  private val aftModel: AFTSurvivalRegressionModel =
+    pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]
+
+  lazy val rCoefficients: Array[Double] = if (aftModel.getFitIntercept) {
+    Array(aftModel.intercept) ++ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
+  } else {
+    aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
+  }
+
+  lazy val rFeatures: Array[String] = if (aftModel.getFitIntercept) {
+    Array("(Intercept)") ++ features ++ Array("Log(scale)")
+  } else {
+    features ++ Array("Log(scale)")
+  }
+
+  def transform(dataset: DataFrame): DataFrame = {
+    pipeline.transform(dataset)
+  }
+}
+
+private[r] object AFTSurvivalRegressionWrapper {
+
+  private def formulaRewrite(formula: String): (String, String) = {
+    var rewritedFormula: String = null
+    var censorCol: String = null
+
+    val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r
+    try {
+      val regex(label, censor, features) = formula
+      // TODO: Support dot operator.
+      if (features.contains(".")) {
+        throw new UnsupportedOperationException(
+          "Terms of survreg formula can not support dot operator.")
+      }
+      rewritedFormula = label.trim + "~" + features.trim
+      censorCol = censor.trim
+    } catch {
+      case e: MatchError =>
+        throw new SparkException(s"Could not parse formula: $formula")
+    }
+
+    (rewritedFormula, censorCol)
+  }
+
+
+  def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
+
+    val (rewritedFormula, censorCol) = formulaRewrite(formula)
+
+    val rFormula = new RFormula().setFormula(rewritedFormula)
+    val rFormulaModel = rFormula.fit(data)
+
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+
+    val aft = new AFTSurvivalRegression()
+      .setCensorCol(censorCol)
+      .setFitIntercept(rFormula.hasIntercept)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, aft))
+      .fit(data)
+
+    new AFTSurvivalRegressionWrapper(pipeline, features)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message