Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id DD299200B85 for ; Thu, 1 Sep 2016 06:39:54 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id D58BB160AB5; Thu, 1 Sep 2016 04:39:39 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 00884160AB4 for ; Thu, 1 Sep 2016 06:39:38 +0200 (CEST) Received: (qmail 4596 invoked by uid 500); 1 Sep 2016 04:39:38 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 4586 invoked by uid 99); 1 Sep 2016 04:39:38 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 01 Sep 2016 04:39:38 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 95040E009E; Thu, 1 Sep 2016 04:39:37 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: shivaram@apache.org To: commits@spark.apache.org Message-Id: <1d961a628b33455cb987d5402a8b96f5@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-17241][SPARKR][MLLIB] SparkR spark.glm should have configurable regularization parameter Date: Thu, 1 Sep 2016 04:39:37 +0000 (UTC) archived-at: Thu, 01 Sep 2016 04:39:55 -0000 Repository: spark Updated Branches: refs/heads/master d008638fb -> 7a5000f39 [SPARK-17241][SPARKR][MLLIB] SparkR spark.glm should have configurable regularization parameter https://issues.apache.org/jira/browse/SPARK-17241 ## What changes were proposed in this pull request? Spark has configurable L2 regularization parameter for generalized linear regression. It is very important to have them in SparkR so that users can run ridge regression. ## How was this patch tested? Test manually on local laptop. Author: Xin Ren Closes #14856 from keypointt/SPARK-17241. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7a5000f3 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7a5000f3 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7a5000f3 Branch: refs/heads/master Commit: 7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2 Parents: d008638 Author: Xin Ren Authored: Wed Aug 31 21:39:31 2016 -0700 Committer: Shivaram Venkataraman Committed: Wed Aug 31 21:39:31 2016 -0700 ---------------------------------------------------------------------- R/pkg/R/mllib.R | 10 +++-- R/pkg/inst/tests/testthat/test_mllib.R | 6 +++ .../r/GeneralizedLinearRegressionWrapper.scala | 4 +- .../GeneralizedLinearRegressionSuite.scala | 40 ++++++++++++++++++++ 4 files changed, 55 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7a5000f3/R/pkg/R/mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 64d19fa..9a53f75 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -138,10 +138,11 @@ predict_internal <- function(object, newData) { #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance -#' weights as 1.0. #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. +#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance +#' weights as 1.0. +#' @param regParam regularization parameter for L2 regularization. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model @@ -171,7 +172,8 @@ predict_internal <- function(object, newData) { #' @note spark.glm since 2.0.0 #' @seealso \link{glm}, \link{read.ml} setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL) { + function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, + regParam = 0.0) { if (is.character(family)) { family <- get(family, mode = "function", envir = parent.frame()) } @@ -190,7 +192,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, family$family, family$link, - tol, as.integer(maxIter), as.character(weightCol)) + tol, as.integer(maxIter), as.character(weightCol), regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) http://git-wip-us.apache.org/repos/asf/spark/blob/7a5000f3/R/pkg/inst/tests/testthat/test_mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1e6da65..825a240 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -148,6 +148,12 @@ test_that("spark.glm summary", { baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) baseSummary <- summary(baseModel) expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) + + # Test spark.glm works with regularization parameter + data <- as.data.frame(cbind(a1, a2, b)) + df <- suppressWarnings(createDataFrame(data)) + regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) + expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result }) test_that("spark.glm save/load", { http://git-wip-us.apache.org/repos/asf/spark/blob/7a5000f3/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 0d3181d..7a6ab61 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper link: String, tol: Double, maxIter: Int, - weightCol: String): GeneralizedLinearRegressionWrapper = { + weightCol: String, + regParam: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) val rFormulaModel = rFormula.fit(data) @@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setTol(tol) .setMaxIter(maxIter) .setWeightCol(weightCol) + .setRegParam(regParam) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) http://git-wip-us.apache.org/repos/asf/spark/blob/7a5000f3/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index a4568e83..d8032c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite .setFamily("gaussian") .fit(datasetGaussianIdentity.as[LabeledPoint]) } + + test("generalized linear regression: regularization parameter") { + /* + R code: + + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + b <- c(1, 0, 1, 0) + data <- as.data.frame(cbind(a1, a2, b)) + df <- suppressWarnings(createDataFrame(data)) + + for (regParam in c(0.0, 0.1, 1.0)) { + model <- spark.glm(df, b ~ a1 + a2, regParam = regParam) + print(as.vector(summary(model)$aic)) + } + + [1] 12.88188 + [1] 12.92681 + [1] 13.32836 + */ + val dataset = spark.createDataFrame(Seq( + LabeledPoint(1, Vectors.dense(5, 0)), + LabeledPoint(0, Vectors.dense(2, 1)), + LabeledPoint(1, Vectors.dense(1, 2)), + LabeledPoint(0, Vectors.dense(3, 3)) + )) + val expected = Seq(12.88188, 12.92681, 13.32836) + + var idx = 0 + for (regParam <- Seq(0.0, 0.1, 1.0)) { + val trainer = new GeneralizedLinearRegression() + .setRegParam(regParam) + .setLabelCol("label") + .setFeaturesCol("features") + val model = trainer.fit(dataset) + val actual = model.summary.aic + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.") + idx += 1 + } + } } object GeneralizedLinearRegressionSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org