Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id EC07A198F4 for ; Fri, 25 Mar 2016 05:29:37 +0000 (UTC) Received: (qmail 66404 invoked by uid 500); 25 Mar 2016 05:29:37 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 66371 invoked by uid 500); 25 Mar 2016 05:29:37 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 66362 invoked by uid 99); 25 Mar 2016 05:29:37 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 25 Mar 2016 05:29:37 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id B9086DFBDE; Fri, 25 Mar 2016 05:29:37 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: meng@apache.org To: commits@spark.apache.org Message-Id: <59728eb02a6547ad85def5c52e22d56f@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR Date: Fri, 25 Mar 2016 05:29:37 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 05f652d6c -> 13cbb2de7 [SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR ## What changes were proposed in this pull request? This PR continues the work in #11447, we implemented the wrapper of ```AFTSurvivalRegression``` named ```survreg``` in SparkR. ## How was this patch tested? Test against output from R package survival's survreg. cc mengxr felixcheung Close #11447 Author: Yanbo Liang Closes #11932 from yanboliang/spark-13010-new. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/13cbb2de Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/13cbb2de Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/13cbb2de Branch: refs/heads/master Commit: 13cbb2de709d0ec2707eebf36c5c97f7d44fb84f Parents: 05f652d Author: Yanbo Liang Authored: Thu Mar 24 22:29:34 2016 -0700 Committer: Xiangrui Meng Committed: Thu Mar 24 22:29:34 2016 -0700 ---------------------------------------------------------------------- R/pkg/DESCRIPTION | 3 +- R/pkg/NAMESPACE | 3 +- R/pkg/R/generics.R | 4 + R/pkg/R/mllib.R | 75 +++++++++++++++ R/pkg/inst/tests/testthat/test_mllib.R | 49 ++++++++++ .../ml/r/AFTSurvivalRegressionWrapper.scala | 99 ++++++++++++++++++++ 6 files changed, 231 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/DESCRIPTION ---------------------------------------------------------------------- diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index e26f9a7..7179438 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -12,7 +12,8 @@ Depends: methods, Suggests: testthat, - e1071 + e1071, + survival Description: R frontend for Spark License: Apache License (== 2.0) Collate: http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/NAMESPACE ---------------------------------------------------------------------- diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5d8a4b1..fa3fb0b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -16,7 +16,8 @@ exportMethods("glm", "summary", "kmeans", "fitted", - "naiveBayes") + "naiveBayes", + "survreg") # Job group lifecycle management methods export("setJobGroup", http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/R/generics.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 46b115f..c6990f4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1179,3 +1179,7 @@ setGeneric("fitted") #' @rdname naiveBayes #' @export setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") }) + +#' @rdname survreg +#' @export +setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") }) http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/R/mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 2555019..33654d5 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,6 +27,11 @@ setClass("PipelineModel", representation(model = "jobj")) #' @export setClass("NaiveBayesModel", representation(jobj = "jobj")) +#' @title S4 class that represents a AFTSurvivalRegressionModel +#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper +#' @export +setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) + #' Fits a generalized linear model #' #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. @@ -273,3 +278,73 @@ setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"), formula, data@sdf, laplace) return(new("NaiveBayesModel", jobj = jobj)) }) + +#' Fit an accelerated failure time (AFT) survival regression model. +#' +#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' Note that operator '.' is not supported currently. +#' @param data DataFrame for training. +#' @return a fitted AFT survival regression model +#' @rdname survreg +#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, ovarian) +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df) +#' } +setMethod("survreg", signature(formula = "formula", data = "DataFrame"), + function(formula, data, ...) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", + "fit", formula, data@sdf) + return(new("AFTSurvivalRegressionModel", jobj = jobj)) + }) + +#' Get the summary of an AFT survival regression model +#' +#' Returns the summary of an AFT survival regression model produced by survreg(), +#' similarly to R's summary(). +#' +#' @param object a fitted AFT survival regression model +#' @return coefficients the model's coefficients, intercept and log(scale). +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData) +#' summary(model) +#' } +setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Value") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) + +#' Make predictions from an AFT survival regression model +#' +#' Make predictions from a model produced by survreg(), similarly to R package survival's predict. +#' +#' @param object A fitted AFT survival regression model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#' } +setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/R/pkg/inst/tests/testthat/test_mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 44b4836..fdb5917 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -200,3 +200,52 @@ test_that("naiveBayes", { expect_equal(as.character(predict(m, t1[1, ])), "Yes") } }) + +test_that("survreg", { + # R code to reproduce the result. + # + #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + #' library(survival) + #' model <- survreg(Surv(time, status) ~ x + sex, rData) + #' summary(model) + #' predict(model, data) + # + # -- output of 'summary(model)' + # + # Value Std. Error z p + # (Intercept) 1.315 0.270 4.88 1.07e-06 + # x -0.190 0.173 -1.10 2.72e-01 + # sex -0.253 0.329 -0.77 4.42e-01 + # Log(scale) -1.160 0.396 -2.93 3.41e-03 + # + # -- output of 'predict(model, data)' + # + # 1 2 3 4 5 6 7 + # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269 + # + data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), + list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) + df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) + model <- survreg(Surv(time, status) ~ x + sex, df) + stats <- summary(model) + coefs <- as.vector(stats$coefficients[, 1]) + rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800) + expect_equal(coefs, rCoefs, tolerance = 1e-4) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "x", "sex", "Log(scale)"))) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, + 2.390146, 2.891269, 2.891269), tolerance = 1e-4) + + # Test survival::survreg + if (requireNamespace("survival", quietly = TRUE)) { + rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + expect_that( + model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), + not(throws_error())) + expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + } +}) http://git-wip-us.apache.org/repos/asf/spark/blob/13cbb2de/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala new file mode 100644 index 0000000..40590e7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.SparkException +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} +import org.apache.spark.sql.DataFrame + +private[r] class AFTSurvivalRegressionWrapper private ( + pipeline: PipelineModel, + features: Array[String]) { + + private val aftModel: AFTSurvivalRegressionModel = + pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel] + + lazy val rCoefficients: Array[Double] = if (aftModel.getFitIntercept) { + Array(aftModel.intercept) ++ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale)) + } else { + aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale)) + } + + lazy val rFeatures: Array[String] = if (aftModel.getFitIntercept) { + Array("(Intercept)") ++ features ++ Array("Log(scale)") + } else { + features ++ Array("Log(scale)") + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset) + } +} + +private[r] object AFTSurvivalRegressionWrapper { + + private def formulaRewrite(formula: String): (String, String) = { + var rewritedFormula: String = null + var censorCol: String = null + + val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r + try { + val regex(label, censor, features) = formula + // TODO: Support dot operator. + if (features.contains(".")) { + throw new UnsupportedOperationException( + "Terms of survreg formula can not support dot operator.") + } + rewritedFormula = label.trim + "~" + features.trim + censorCol = censor.trim + } catch { + case e: MatchError => + throw new SparkException(s"Could not parse formula: $formula") + } + + (rewritedFormula, censorCol) + } + + + def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + + val (rewritedFormula, censorCol) = formulaRewrite(formula) + + val rFormula = new RFormula().setFormula(rewritedFormula) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + val aft = new AFTSurvivalRegression() + .setCensorCol(censorCol) + .setFitIntercept(rFormula.hasIntercept) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, aft)) + .fit(data) + + new AFTSurvivalRegressionWrapper(pipeline, features) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org