spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [SPARK-19155][ML] MLlib GeneralizedLinearRegression family and link should case insensitive
Date Sun, 22 Jan 2017 05:16:48 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 4c2065d0a -> 886f73737


[SPARK-19155][ML] MLlib GeneralizedLinearRegression family and link should case insensitive

## What changes were proposed in this pull request?
MLlib ```GeneralizedLinearRegression``` ```family``` and ```link``` should be case insensitive.
This is consistent with some other MLlib params such as [```featureSubsetStrategy```](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala#L415).

## How was this patch tested?
Update corresponding tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #16516 from yanboliang/spark-19133.

(cherry picked from commit 3dcad9fab17297f9966026f29fefb5c726965a13)
Signed-off-by: Yanbo Liang <ybliang8@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/886f7373
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/886f7373
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/886f7373

Branch: refs/heads/branch-2.0
Commit: 886f73737b2aa2e8202d78104f33e894c7f81578
Parents: 4c2065d
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Sat Jan 21 21:15:57 2017 -0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Sat Jan 21 21:16:42 2017 -0800

----------------------------------------------------------------------
 .../spark/ml/regression/GeneralizedLinearRegression.scala    | 8 ++++----
 .../ml/regression/GeneralizedLinearRegressionSuite.scala     | 4 ++--
 2 files changed, 6 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/886f7373/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 7f88c12..537738e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
   final val family: Param[String] = new Param(this, "family",
     "The name of family which is a description of the error distribution to be used in the
" +
       s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
-    ParamValidators.inArray[String](supportedFamilyNames.toArray))
+    (value: String) => supportedFamilyNames.contains(value.toLowerCase))
 
   /** @group getParam */
   @Since("2.0.0")
@@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
   final val link: Param[String] = new Param(this, "link", "The name of link function " +
     "which provides the relationship between the linear predictor and the mean of the " +
     s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
-    ParamValidators.inArray[String](supportedLinkNames.toArray))
+    (value: String) => supportedLinkNames.contains(value.toLowerCase))
 
   /** @group getParam */
   @Since("2.0.0")
@@ -401,7 +401,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
      * @param name family name: "gaussian", "binomial", "poisson" or "gamma".
      */
     def fromName(name: String): Family = {
-      name match {
+      name.toLowerCase match {
         case Gaussian.name => Gaussian
         case Binomial.name => Binomial
         case Poisson.name => Poisson
@@ -601,7 +601,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
      *             "inverse", "probit", "cloglog" or "sqrt".
      */
     def fromName(name: String): Link = {
-      name match {
+      name.toLowerCase match {
         case Identity.name => Identity
         case Logit.name => Logit
         case Log.name => Log

http://git-wip-us.apache.org/repos/asf/spark/blob/886f7373/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 9d10215..0ce7d11 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -507,7 +507,7 @@ class GeneralizedLinearRegressionSuite
     for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
       ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
       for (fitIntercept <- Seq(false, true)) {
-        val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link)
+        val trainer = new GeneralizedLinearRegression().setFamily("Gamma").setLink(link)
           .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
         val model = trainer.fit(dataset)
         val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
@@ -943,7 +943,7 @@ class GeneralizedLinearRegressionSuite
        -0.6344390  0.3172195  0.2114797 -0.1586097
      */
     val trainer = new GeneralizedLinearRegression()
-      .setFamily("gamma")
+      .setFamily("Gamma")
       .setWeightCol("weight")
 
     val model = trainer.fit(datasetWithWeight)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message