spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-13615][ML] GeneralizedLinearRegression supports save/load
Date Wed, 09 Mar 2016 19:59:27 GMT
Repository: spark
Updated Branches:
  refs/heads/master cad29a40b -> 0dd06485c


[SPARK-13615][ML] GeneralizedLinearRegression supports save/load

## What changes were proposed in this pull request?
```GeneralizedLinearRegression``` supports ```save/load```.
cc mengxr
## How was this patch tested?
unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #11465 from yanboliang/spark-13615.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0dd06485
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0dd06485
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0dd06485

Branch: refs/heads/master
Commit: 0dd06485c4222a896c0d1ee6a04d30043de3626c
Parents: cad29a4
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Wed Mar 9 11:59:22 2016 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Mar 9 11:59:22 2016 -0800

----------------------------------------------------------------------
 .../GeneralizedLinearRegression.scala           | 74 +++++++++++++++++---
 .../GeneralizedLinearRegressionSuite.scala      | 32 ++++++++-
 2 files changed, 96 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0dd06485/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index a850dfe..de1dff9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.regression
 
 import breeze.stats.distributions.{Gaussian => GD}
+import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.annotation.{Experimental, Since}
@@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.optim._
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{BLAS, Vector}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
@@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
 @Since("2.0.0")
 class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
   extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
-  with GeneralizedLinearRegressionBase with Logging {
+  with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {
 
   import GeneralizedLinearRegression._
 
@@ -236,10 +237,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override
val
 }
 
 @Since("2.0.0")
-private[ml] object GeneralizedLinearRegression {
+object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression]
{
+
+  @Since("2.0.0")
+  override def load(path: String): GeneralizedLinearRegression = super.load(path)
 
   /** Set of family and link pairs that GeneralizedLinearRegression supports. */
-  lazy val supportedFamilyAndLinkPairs = Set(
+  private[ml] lazy val supportedFamilyAndLinkPairs = Set(
     Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
     Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
     Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
@@ -247,12 +251,12 @@ private[ml] object GeneralizedLinearRegression {
   )
 
   /** Set of family names that GeneralizedLinearRegression supports. */
-  lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
+  private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
 
   /** Set of link names that GeneralizedLinearRegression supports. */
-  lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
+  private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
 
-  val epsilon: Double = 1E-16
+  private[ml] val epsilon: Double = 1E-16
 
   /**
    * Wrapper of family and link combination used in the model.
@@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] (
     @Since("2.0.0") val coefficients: Vector,
     @Since("2.0.0") val intercept: Double)
   extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
-  with GeneralizedLinearRegressionBase {
+  with GeneralizedLinearRegressionBase with MLWritable {
 
   import GeneralizedLinearRegression._
 
@@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] (
     copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
       .setParent(parent)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter =
+    new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this)
+}
+
+@Since("2.0.0")
+object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel]
{
+
+  @Since("2.0.0")
+  override def read: MLReader[GeneralizedLinearRegressionModel] =
+    new GeneralizedLinearRegressionModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): GeneralizedLinearRegressionModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */
+  private[GeneralizedLinearRegressionModel]
+  class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel)
+    extends MLWriter with Logging {
+
+    private case class Data(intercept: Double, coefficients: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: intercept, coefficients
+      val data = Data(instance.intercept, instance.coefficients)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class GeneralizedLinearRegressionModelReader
+    extends MLReader[GeneralizedLinearRegressionModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[GeneralizedLinearRegressionModel].getName
+
+    override def load(path: String): GeneralizedLinearRegressionModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("intercept", "coefficients").head()
+      val intercept = data.getDouble(0)
+      val coefficients = data.getAs[Vector](1)
+
+      val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0dd06485/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 8bfa985..618304a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors}
 import org.apache.spark.mllib.random._
@@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
 
-class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GeneralizedLinearRegressionSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   private val seed: Int = 42
   @transient var datasetGaussianIdentity: DataFrame = _
@@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark
       }
     }
   }
+
+  test("read/write") {
+    def checkModelData(
+        model: GeneralizedLinearRegressionModel,
+        model2: GeneralizedLinearRegressionModel): Unit = {
+      assert(model.intercept === model2.intercept)
+      assert(model.coefficients.toArray === model2.coefficients.toArray)
+    }
+
+    val glr = new GeneralizedLinearRegression()
+    testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+      GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
+  }
 }
 
 object GeneralizedLinearRegressionSuite {
 
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "family" -> "poisson",
+    "link" -> "log",
+    "fitIntercept" -> true,
+    "maxIter" -> 2,  // intentionally small
+    "tol" -> 0.8,
+    "regParam" -> 0.01,
+    "predictionCol" -> "myPrediction")
+
   def generateGeneralizedLinearRegressionInput(
       intercept: Double,
       coefficients: Array[Double],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message