spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel
Date Wed, 11 Nov 2015 02:46:02 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 b34819c7b -> bc3bfa04d


[SPARK-6726][ML] Import/export for spark.ml LogisticRegressionModel

This PR adds model save/load for spark.ml's LogisticRegressionModel.  It also does minor refactoring
of the default save/load classes to reuse code.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9606 from jkbradley/logreg-io2.

(cherry picked from commit 6e101d2e9d6e08a6a63f7065c1e87a5338f763ea)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc3bfa04
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc3bfa04
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc3bfa04

Branch: refs/heads/branch-1.6
Commit: bc3bfa04d39992b7451c0a15c2f1cb9cc22c15e7
Parents: b34819c
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Tue Nov 10 18:45:48 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Nov 10 18:45:57 2015 -0800

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 68 +++++++++++++++++-
 .../org/apache/spark/ml/util/ReadWrite.scala    | 74 ++++++++++++++++++--
 .../LogisticRegressionSuite.scala               | 17 ++++-
 .../spark/ml/util/DefaultReadWriteTest.scala    |  4 +-
 4 files changed, 152 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc3bfa04/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index f5fca68..a88f526 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -21,13 +21,14 @@ import scala.collection.mutable
 
 import breeze.linalg.{DenseVector => BDV}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN
=> BreezeOWLQN}
+import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
@@ -396,7 +397,7 @@ class LogisticRegressionModel private[ml] (
     val coefficients: Vector,
     val intercept: Double)
   extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
-  with LogisticRegressionParams {
+  with LogisticRegressionParams with Writable {
 
   @deprecated("Use coefficients instead.", "1.6.0")
   def weights: Vector = coefficients
@@ -510,8 +511,71 @@ class LogisticRegressionModel private[ml] (
     // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
     if (probability(1) > getThreshold) 1 else 0
   }
+
+  /**
+   * Returns a [[Writer]] instance for this ML instance.
+   *
+   * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]].
+   * An option to save [[summary]] may be added in the future.
+   */
+  override def write: Writer = new LogisticRegressionWriter(this)
+}
+
+
+/** [[Writer]] instance for [[LogisticRegressionModel]] */
+private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel)
+  extends Writer with Logging {
+
+  private case class Data(
+      numClasses: Int,
+      numFeatures: Int,
+      intercept: Double,
+      coefficients: Vector)
+
+  override protected def saveImpl(path: String): Unit = {
+    // Save metadata and Params
+    DefaultParamsWriter.saveMetadata(instance, path, sc)
+    // Save model data: numClasses, numFeatures, intercept, coefficients
+    val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
+      instance.coefficients)
+    val dataPath = new Path(path, "data").toString
+    sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+  }
+}
+
+
+object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
+
+  override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader
+
+  override def load(path: String): LogisticRegressionModel = read.load(path)
 }
 
+
+private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel]
{
+
+  /** Checked against metadata when loading model */
+  private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
+
+  override def load(path: String): LogisticRegressionModel = {
+    val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+    val dataPath = new Path(path, "data").toString
+    val data = sqlContext.read.format("parquet").load(dataPath)
+      .select("numClasses", "numFeatures", "intercept", "coefficients").head()
+    // We will need numClasses, numFeatures in the future for multinomial logreg support.
+    // val numClasses = data.getInt(0)
+    // val numFeatures = data.getInt(1)
+    val intercept = data.getDouble(2)
+    val coefficients = data.getAs[Vector](3)
+    val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)
+
+    DefaultParamsReader.getAndSetParams(model, metadata)
+    model
+  }
+}
+
+
 /**
  * MultiClassSummarizer computes the number of distinct labels and corresponding counts,
  * and validates the data to see if the labels used for k class multi-label classification

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3bfa04/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index cbdf913..85f888c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -175,6 +175,21 @@ trait Readable[T] {
 private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
 
   override protected def saveImpl(path: String): Unit = {
+    DefaultParamsWriter.saveMetadata(instance, path, sc)
+  }
+}
+
+private[ml] object DefaultParamsWriter {
+
+  /**
+   * Saves metadata + Params to: path + "/metadata"
+   *  - class
+   *  - timestamp
+   *  - sparkVersion
+   *  - uid
+   *  - paramMap
+   */
+  def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
     val uid = instance.uid
     val cls = instance.getClass.getName
     val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -201,14 +216,61 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer
{
 private[ml] class DefaultParamsReader[T] extends Reader[T] {
 
   override def load(path: String): T = {
-    implicit val format = DefaultFormats
+    val metadata = DefaultParamsReader.loadMetadata(path, sc)
+    val cls = Utils.classForName(metadata.className)
+    val instance =
+      cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
+    DefaultParamsReader.getAndSetParams(instance, metadata)
+    instance.asInstanceOf[T]
+  }
+}
+
+private[ml] object DefaultParamsReader {
+
+  /**
+   * All info from metadata file.
+   * @param params  paramMap, as a [[JValue]]
+   * @param metadataStr  Full metadata file String (for debugging)
+   */
+  case class Metadata(
+      className: String,
+      uid: String,
+      timestamp: Long,
+      sparkVersion: String,
+      params: JValue,
+      metadataStr: String)
+
+  /**
+   * Load metadata from file.
+   * @param expectedClassName  If non empty, this is checked against the loaded metadata.
+   * @throws IllegalArgumentException if expectedClassName is specified and does not match
metadata
+   */
+  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata
= {
     val metadataPath = new Path(path, "metadata").toString
     val metadataStr = sc.textFile(metadataPath, 1).first()
     val metadata = parse(metadataStr)
-    val cls = Utils.classForName((metadata \ "class").extract[String])
+
+    implicit val format = DefaultFormats
+    val className = (metadata \ "class").extract[String]
     val uid = (metadata \ "uid").extract[String]
-    val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
-    (metadata \ "paramMap") match {
+    val timestamp = (metadata \ "timestamp").extract[Long]
+    val sparkVersion = (metadata \ "sparkVersion").extract[String]
+    val params = metadata \ "paramMap"
+    if (expectedClassName.nonEmpty) {
+      require(className == expectedClassName, s"Error loading metadata: Expected class name"
+
+        s" $expectedClassName but found class name $className")
+    }
+
+    Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
+  }
+
+  /**
+   * Extract Params from metadata, and set them in the instance.
+   * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
+   */
+  def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
+    implicit val format = DefaultFormats
+    metadata.params match {
       case JObject(pairs) =>
         pairs.foreach { case (paramName, jsonValue) =>
           val param = instance.getParam(paramName)
@@ -216,8 +278,8 @@ private[ml] class DefaultParamsReader[T] extends Reader[T] {
           instance.set(param, value)
         }
       case _ =>
-        throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
+        throw new IllegalArgumentException(
+          s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
     }
-    instance.asInstanceOf[T]
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3bfa04/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 325faf3..51b06b7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -23,7 +23,7 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -31,7 +31,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class LogisticRegressionSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   @transient var dataset: DataFrame = _
   @transient var binaryDataset: DataFrame = _
@@ -869,6 +870,18 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext
{
     assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
     assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3)
     assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+  }
 
+  test("read/write") {
+    // Set some Params to make sure set Params are serialized.
+    val lr = new LogisticRegression()
+      .setElasticNetParam(0.1)
+      .setMaxIter(2)
+      .fit(dataset)
+    val lr2 = testDefaultReadWrite(lr)
+    assert(lr.intercept === lr2.intercept)
+    assert(lr.coefficients.toArray === lr2.coefficients.toArray)
+    assert(lr.numClasses === lr2.numClasses)
+    assert(lr.numFeatures === lr2.numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bc3bfa04/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 4545b0f..cac4bd9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -31,8 +31,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
    * Checks "overwrite" option and params.
    * @param instance ML instance to test saving/loading
    * @tparam T ML instance type
+   * @return  Instance loaded from file
    */
-  def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
+  def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
     val uid = instance.uid
     val path = new File(tempDir, uid).getPath
 
@@ -61,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
     val load = instance.getClass.getMethod("load", classOf[String])
     val another = load.invoke(instance, path).asInstanceOf[T]
     assert(another.uid === instance.uid)
+    another
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message