spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6790][ML] Add spark.ml LinearRegression import/export
Date Wed, 18 Nov 2015 21:06:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master 09ad9533d -> 045a4f045


[SPARK-6790][ML] Add spark.ml LinearRegression import/export

This replaces [https://github.com/apache/spark/pull/9656] with updates.

fayeshine should be the main author when this PR is committed.

CC: mengxr fayeshine

Author: Wenjian Huang <nextrush@163.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9814 from jkbradley/fayeshine-patch-6790.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/045a4f04
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/045a4f04
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/045a4f04

Branch: refs/heads/master
Commit: 045a4f045821dcf60442f0600c2df1b79bddb536
Parents: 09ad953
Author: Wenjian Huang <nextrush@163.com>
Authored: Wed Nov 18 13:06:25 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Nov 18 13:06:25 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  | 77 +++++++++++++++++++-
 .../ml/regression/LinearRegressionSuite.scala   | 34 ++++++++-
 2 files changed, 106 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/045a4f04/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 913140e..ca55d59 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
 import breeze.linalg.{DenseVector => BDV}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN
=> BreezeOWLQN}
 import breeze.stats.distributions.StudentsT
+import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.ml.feature.Instance
@@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
@@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
 @Experimental
 class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String)
   extends Regressor[Vector, LinearRegression, LinearRegressionModel]
-  with LinearRegressionParams with Logging {
+  with LinearRegressionParams with Writable with Logging {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("linReg"))
@@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val
uid: String
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object LinearRegression extends Readable[LinearRegression] {
+
+  @Since("1.6.0")
+  override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression]
+
+  @Since("1.6.0")
+  override def load(path: String): LinearRegression = read.load(path)
 }
 
 /**
@@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] (
     val coefficients: Vector,
     val intercept: Double)
   extends RegressionModel[Vector, LinearRegressionModel]
-  with LinearRegressionParams {
+  with LinearRegressionParams with Writable {
 
   private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
 
@@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] (
     if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
     newModel.setParent(parent)
   }
+
+  /**
+   * Returns a [[Writer]] instance for this ML instance.
+   *
+   * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]].
+   * An option to save [[summary]] may be added in the future.
+   *
+   * This also does not save the [[parent]] currently.
+   */
+  @Since("1.6.0")
+  override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this)
+}
+
+@Since("1.6.0")
+object LinearRegressionModel extends Readable[LinearRegressionModel] {
+
+  @Since("1.6.0")
+  override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): LinearRegressionModel = read.load(path)
+
+  /** [[Writer]] instance for [[LinearRegressionModel]] */
+  private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
+    extends Writer with Logging {
+
+    private case class Data(intercept: Double, coefficients: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: intercept, coefficients
+      val data = Data(instance.intercept, instance.coefficients)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+    }
+  }
+
+  private class LinearRegressionModelReader extends Reader[LinearRegressionModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = "org.apache.spark.ml.regression.LinearRegressionModel"
+
+    override def load(path: String): LinearRegressionModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.format("parquet").load(dataPath)
+        .select("intercept", "coefficients").head()
+      val intercept = data.getDouble(0)
+      val coefficients = data.getAs[Vector](1)
+      val model = new LinearRegressionModel(metadata.uid, coefficients, intercept)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/045a4f04/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index a1d86fe..2bdc0e1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,14 +22,15 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   private val seed: Int = 42
   @transient var datasetWithDenseFeature: DataFrame = _
@@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext
{
     model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3)
}
     model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3)
}
   }
+
+  test("read/write") {
+    def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit
= {
+      assert(model.intercept === model2.intercept)
+      assert(model.coefficients === model2.coefficients)
+    }
+    val lr = new LinearRegression()
+    testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
+      checkModelData)
+  }
+}
+
+object LinearRegressionSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPrediction",
+    "regParam" -> 0.01,
+    "elasticNetParam" -> 0.1,
+    "maxIter" -> 2,  // intentionally small
+    "fitIntercept" -> true,
+    "tol" -> 0.8,
+    "standardization" -> false,
+    "solver" -> "l-bfgs"
+  )
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message