spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-11612][ML] Pipeline and PipelineModel persistence
Date Tue, 17 Nov 2015 01:12:53 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 32a69e4c1 -> 505eceef3


[SPARK-11612][ML] Pipeline and PipelineModel persistence

Pipeline and PipelineModel extend Readable and Writable.  Persistence succeeds only when all
stages are Writable.

Note: This PR reinstates tests for other read/write functionality.  It should probably not
get merged until [https://issues.apache.org/jira/browse/SPARK-11672] gets fixed.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9674 from jkbradley/pipeline-io.

(cherry picked from commit 1c5475f1401d2233f4c61f213d1e2c2ee9673067)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/505eceef
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/505eceef
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/505eceef

Branch: refs/heads/branch-1.6
Commit: 505eceef303e3291253b35164fbec7e4390e8252
Parents: 32a69e4
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Mon Nov 16 17:12:39 2015 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Mon Nov 16 17:12:48 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Pipeline.scala    | 175 ++++++++++++++++++-
 .../org/apache/spark/ml/util/ReadWrite.scala    |   4 +-
 .../org/apache/spark/ml/PipelineSuite.scala     | 120 ++++++++++++-
 .../spark/ml/util/DefaultReadWriteTest.scala    |  25 +--
 4 files changed, 306 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/505eceef/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a3e5940..25f0c69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,12 +22,19 @@ import java.{util => ju}
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
 
-import org.apache.spark.Logging
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{SparkContext, Logging}
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.Reader
+import org.apache.spark.ml.util.Writer
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -82,7 +89,7 @@ abstract class PipelineStage extends Params with Logging {
  * an identity transformer.
  */
 @Experimental
-class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
+class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable {
 
   def this() = this(Identifiable.randomUID("pipeline"))
 
@@ -166,6 +173,131 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel]
{
       "Cannot have duplicate components in a pipeline.")
     theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
   }
+
+  override def write: Writer = new Pipeline.PipelineWriter(this)
+}
+
+object Pipeline extends Readable[Pipeline] {
+
+  override def read: Reader[Pipeline] = new PipelineReader
+
+  override def load(path: String): Pipeline = read.load(path)
+
+  private[ml] class PipelineWriter(instance: Pipeline) extends Writer {
+
+    SharedReadWrite.validateStages(instance.getStages)
+
+    override protected def saveImpl(path: String): Unit =
+      SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
+  }
+
+  private[ml] class PipelineReader extends Reader[Pipeline] {
+
+    /** Checked against metadata when loading model */
+    private val className = "org.apache.spark.ml.Pipeline"
+
+    override def load(path: String): Pipeline = {
+      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc,
path)
+      new Pipeline(uid).setStages(stages)
+    }
+  }
+
+  /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]]
*/
+  private[ml] object SharedReadWrite {
+
+    import org.json4s.JsonDSL._
+
+    /** Check that all stages are Writable */
+    def validateStages(stages: Array[PipelineStage]): Unit = {
+      stages.foreach {
+        case stage: Writable => // good
+        case other =>
+          throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline"
+
+            s" because it contains a stage which does not implement Writable. Non-Writable
stage:" +
+            s" ${other.uid} of type ${other.getClass}")
+      }
+    }
+
+    /**
+     * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
+     *  - save metadata to path/metadata
+     *  - save stages to stages/IDX_UID
+     */
+    def saveImpl(
+        instance: Params,
+        stages: Array[PipelineStage],
+        sc: SparkContext,
+        path: String): Unit = {
+      // Copied and edited from DefaultParamsWriter.saveMetadata
+      // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
+      val uid = instance.uid
+      val cls = instance.getClass.getName
+      val stageUids = stages.map(_.uid)
+      val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
+      val metadata = ("class" -> cls) ~
+        ("timestamp" -> System.currentTimeMillis()) ~
+        ("sparkVersion" -> sc.version) ~
+        ("uid" -> uid) ~
+        ("paramMap" -> jsonParams)
+      val metadataPath = new Path(path, "metadata").toString
+      val metadataJson = compact(render(metadata))
+      sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+
+      // Save stages
+      val stagesDir = new Path(path, "stages").toString
+      stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) =>
+        stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir))
+      }
+    }
+
+    /**
+     * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
+     * @return  (UID, list of stages)
+     */
+    def load(
+        expectedClassName: String,
+        sc: SparkContext,
+        path: String): (String, Array[PipelineStage]) = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+      implicit val format = DefaultFormats
+      val stagesDir = new Path(path, "stages").toString
+      val stageUids: Array[String] = metadata.params match {
+        case JObject(pairs) =>
+          if (pairs.length != 1) {
+            // Should not happen unless file is corrupted or we have a bug.
+            throw new RuntimeException(
+              s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
+          }
+          pairs.head match {
+            case ("stageUids", jsonValue) =>
+              jsonValue.extract[Seq[String]].toArray
+            case (paramName, jsonValue) =>
+              // Should not happen unless file is corrupted or we have a bug.
+              throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName"
+
+                s" in metadata: ${metadata.metadataStr}")
+          }
+        case _ =>
+          throw new IllegalArgumentException(
+            s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
+      }
+      val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx)
=>
+        val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
+        val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
+        val cls = Utils.classForName(stageMetadata.className)
+        cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath)
+      }
+      (metadata.uid, stages)
+    }
+
+    /** Get path for saving the given stage. */
+    def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String):
String = {
+      val stageIdxDigits = numStages.toString.length
+      val idxFormat = s"%0${stageIdxDigits}d"
+      val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
+      new Path(stagesDir, stageDir).toString
+    }
+  }
 }
 
 /**
@@ -176,7 +308,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel]
{
 class PipelineModel private[ml] (
     override val uid: String,
     val stages: Array[Transformer])
-  extends Model[PipelineModel] with Logging {
+  extends Model[PipelineModel] with Writable with Logging {
 
   /** A Java/Python-friendly auxiliary constructor. */
   private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
@@ -200,4 +332,39 @@ class PipelineModel private[ml] (
   override def copy(extra: ParamMap): PipelineModel = {
     new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
   }
+
+  override def write: Writer = new PipelineModel.PipelineModelWriter(this)
+}
+
+object PipelineModel extends Readable[PipelineModel] {
+
+  import Pipeline.SharedReadWrite
+
+  override def read: Reader[PipelineModel] = new PipelineModelReader
+
+  override def load(path: String): PipelineModel = read.load(path)
+
+  private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer {
+
+    SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
+
+    override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
+      instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
+  }
+
+  private[ml] class PipelineModelReader extends Reader[PipelineModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = "org.apache.spark.ml.PipelineModel"
+
+    override def load(path: String): PipelineModel = {
+      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc,
path)
+      val transformers = stages map {
+        case stage: Transformer => stage
+        case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but
found it" +
+          s" was not a Transformer.  Bad stage ${other.uid} of type ${other.getClass}")
+      }
+      new PipelineModel(uid, transformers)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/505eceef/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ca896ed..3169c9e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -164,6 +164,8 @@ trait Readable[T] {
 
   /**
    * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
+   *
+   * Note: Implementing classes should override this to be Java-friendly.
    */
   @Since("1.6.0")
   def load(path: String): T = read.load(path)
@@ -190,7 +192,7 @@ private[ml] object DefaultParamsWriter {
    *  - timestamp
    *  - sparkVersion
    *  - uid
-   *  - paramMap
+   *  - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
    */
   def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
     val uid = instance.uid

http://git-wip-us.apache.org/repos/asf/spark/blob/505eceef/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 1f2c9b7..484026b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,19 +17,25 @@
 
 package org.apache.spark.ml
 
+import java.io.File
+
 import scala.collection.JavaConverters._
 
+import org.apache.hadoop.fs.{FileSystem, Path}
 import org.mockito.Matchers.{any, eq => meq}
 import org.mockito.Mockito.when
 import org.scalatest.mock.MockitoSugar.mock
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.Pipeline.SharedReadWrite
 import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
 
-class PipelineSuite extends SparkFunSuite {
+class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
   abstract class MyModel extends Model[MyModel]
 
@@ -111,4 +117,112 @@ class PipelineSuite extends SparkFunSuite {
     assert(pipelineModel1.uid === "pipeline1")
     assert(pipelineModel1.stages === stages)
   }
+
+  test("Pipeline read/write") {
+    val writableStage = new WritableStage("writableStage").setIntParam(56)
+    val pipeline = new Pipeline().setStages(Array(writableStage))
+
+    val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
+    assert(pipeline2.getStages.length === 1)
+    assert(pipeline2.getStages(0).isInstanceOf[WritableStage])
+    val writableStage2 = pipeline2.getStages(0).asInstanceOf[WritableStage]
+    assert(writableStage.getIntParam === writableStage2.getIntParam)
+  }
+
+  test("Pipeline read/write with non-Writable stage") {
+    val unWritableStage = new UnWritableStage("unwritableStage")
+    val unWritablePipeline = new Pipeline().setStages(Array(unWritableStage))
+    withClue("Pipeline.write should fail when Pipeline contains non-Writable stage") {
+      intercept[UnsupportedOperationException] {
+        unWritablePipeline.write
+      }
+    }
+  }
+
+  test("PipelineModel read/write") {
+    val writableStage = new WritableStage("writableStage").setIntParam(56)
+    val pipeline =
+      new PipelineModel("pipeline_89329327", Array(writableStage.asInstanceOf[Transformer]))
+
+    val pipeline2 = testDefaultReadWrite(pipeline, testParams = false)
+    assert(pipeline2.stages.length === 1)
+    assert(pipeline2.stages(0).isInstanceOf[WritableStage])
+    val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
+    assert(writableStage.getIntParam === writableStage2.getIntParam)
+
+    val path = new File(tempDir, pipeline.uid).getPath
+    val stagesDir = new Path(path, "stages").toString
+    val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 1, stagesDir)
+    assert(FileSystem.get(sc.hadoopConfiguration).exists(new Path(expectedStagePath)),
+      s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with uid ${pipeline.uid}"
+
+        s" to be saved to path: $expectedStagePath")
+  }
+
+  test("PipelineModel read/write: getStagePath") {
+    val stageUid = "myStage"
+    val stagesDir = new Path("pipeline", "stages").toString
+    def testStage(stageIdx: Int, numStages: Int, expectedPrefix: String): Unit = {
+      val path = SharedReadWrite.getStagePath(stageUid, stageIdx, numStages, stagesDir)
+      val expected = new Path(stagesDir, expectedPrefix + "_" + stageUid).toString
+      assert(path === expected)
+    }
+    testStage(0, 1, "0")
+    testStage(0, 9, "0")
+    testStage(0, 10, "00")
+    testStage(1, 10, "01")
+    testStage(12, 999, "012")
+  }
+
+  test("PipelineModel read/write with non-Writable stage") {
+    val unWritableStage = new UnWritableStage("unwritableStage")
+    val unWritablePipeline =
+      new PipelineModel("pipeline_328957", Array(unWritableStage.asInstanceOf[Transformer]))
+    withClue("PipelineModel.write should fail when PipelineModel contains non-Writable stage")
{
+      intercept[UnsupportedOperationException] {
+        unWritablePipeline.write
+      }
+    }
+  }
+}
+
+
+/** Used to test [[Pipeline]] with [[Writable]] stages */
+class WritableStage(override val uid: String) extends Transformer with Writable {
+
+  final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+
+  def getIntParam: Int = $(intParam)
+
+  def setIntParam(value: Int): this.type = set(intParam, value)
+
+  setDefault(intParam -> 0)
+
+  override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)
+
+  override def write: Writer = new DefaultParamsWriter(this)
+
+  override def transform(dataset: DataFrame): DataFrame = dataset
+
+  override def transformSchema(schema: StructType): StructType = schema
+}
+
+object WritableStage extends Readable[WritableStage] {
+
+  override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]
+
+  override def load(path: String): WritableStage = read.load(path)
+}
+
+/** Used to test [[Pipeline]] with non-[[Writable]] stages */
+class UnWritableStage(override val uid: String) extends Transformer {
+
+  final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+
+  setDefault(intParam -> 0)
+
+  override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
+
+  override def transform(dataset: DataFrame): DataFrame = dataset
+
+  override def transformSchema(schema: StructType): StructType = schema
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/505eceef/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index cac4bd9..c37f050 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -30,10 +30,13 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
   /**
    * Checks "overwrite" option and params.
    * @param instance ML instance to test saving/loading
+   * @param testParams  If true, then test values of Params.  Otherwise, just test overwrite
option.
    * @tparam T ML instance type
    * @return  Instance loaded from file
    */
-  def testDefaultReadWrite[T <: Params with Writable](instance: T): T = {
+  def testDefaultReadWrite[T <: Params with Writable](
+      instance: T,
+      testParams: Boolean = true): T = {
     val uid = instance.uid
     val path = new File(tempDir, uid).getPath
 
@@ -46,16 +49,18 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
     val newInstance = loader.load(path)
 
     assert(newInstance.uid === instance.uid)
-    instance.params.foreach { p =>
-      if (instance.isDefined(p)) {
-        (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
-          case (Array(values), Array(newValues)) =>
-            assert(values === newValues, s"Values do not match on param ${p.name}.")
-          case (value, newValue) =>
-            assert(value === newValue, s"Values do not match on param ${p.name}.")
+    if (testParams) {
+      instance.params.foreach { p =>
+        if (instance.isDefined(p)) {
+          (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
+            case (Array(values), Array(newValues)) =>
+              assert(values === newValues, s"Values do not match on param ${p.name}.")
+            case (value, newValue) =>
+              assert(value === newValue, s"Values do not match on param ${p.name}.")
+          }
+        } else {
+          assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
         }
-      } else {
-        assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
       }
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message