spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-16240][ML] ML persistence backward compatibility for LDA - 2.0 backport
Date Fri, 23 Sep 2016 05:44:29 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 22216d6bd -> 54d4eee51


[SPARK-16240][ML] ML persistence backward compatibility for LDA - 2.0 backport

## What changes were proposed in this pull request?

Allow Spark 2.x to load instances of LDA, LocalLDAModel, and DistributedLDAModel saved from
Spark 1.6.
Backport of https://github.com/apache/spark/pull/15034 for branch-2.0

## How was this patch tested?

I tested this manually, saving the 3 types from 1.6 and loading them into master (2.x).  In
the future, we can add generic tests for testing backwards compatibility across all ML models
in SPARK-15573.

Author: Gayathri Murali <gayathri.m.softie@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #15205 from jkbradley/lda-backward-2.0.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/54d4eee5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/54d4eee5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/54d4eee5

Branch: refs/heads/branch-2.0
Commit: 54d4eee51eca364d9334141f62e0478343345d06
Parents: 22216d6
Author: Gayathri Murali <gayathri.m.softie@gmail.com>
Authored: Thu Sep 22 22:44:20 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Sep 22 22:44:20 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/LDA.scala    | 86 ++++++++++++++++----
 project/MimaExcludes.scala                      |  3 +
 2 files changed, 72 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/54d4eee5/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 034f2c3..8e23325 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -18,6 +18,9 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.JsonAST.JObject
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
 import org.apache.spark.internal.Logging
@@ -26,19 +29,21 @@ import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter,
HasSeed}
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
   EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
   LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
   OnlineLDAOptimizer => OldOnlineLDAOptimizer}
 import org.apache.spark.mllib.impl.PeriodicCheckpointer
-import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Vector => OldVector,
-  Vectors => OldVectors}
+import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.MatrixImplicits._
 import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.VersionUtils
 
 
 private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
@@ -80,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with
HasM
    *     - Values should be >= 0
    *     - default = uniformly (1.0 / k), following the implementation from
    *       [[https://github.com/Blei-Lab/onlineldavb]].
+   *
    * @group param
    */
   @Since("1.6.0")
@@ -121,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol
with HasM
    *     - Value should be >= 0
    *     - default = (1.0 / k), following the implementation from
    *       [[https://github.com/Blei-Lab/onlineldavb]].
+   *
    * @group param
    */
   @Since("1.6.0")
@@ -354,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol
with HasM
   }
 }
 
+private object LDAParams {
+
+  /**
+   * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
+   * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
+   *
+   * @param model    [[LDA]] or [[LDAModel]] instance.  This instance will be modified with
+   *                 [[Param]] values extracted from metadata.
+   * @param metadata Loaded model metadata
+   */
+  def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = {
+    VersionUtils.majorMinorVersion(metadata.sparkVersion) match {
+      case (1, 6) =>
+        implicit val format = DefaultFormats
+        metadata.params match {
+          case JObject(pairs) =>
+            pairs.foreach { case (paramName, jsonValue) =>
+              val origParam =
+                if (paramName == "topicDistribution") "topicDistributionCol" else paramName
+              val param = model.getParam(origParam)
+              val value = param.jsonDecode(compact(render(jsonValue)))
+              model.set(param, value)
+            }
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
+        }
+      case _ => // 2.0+
+        DefaultParamsReader.getAndSetParams(model, metadata)
+    }
+  }
+}
+
 
 /**
  * :: Experimental ::
@@ -414,11 +454,11 @@ sealed abstract class LDAModel private[ml] (
       val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)
 
       val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML }
-      dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
+      dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
     } else {
       logWarning("LDAModel.transform was called without any output columns. Set an output
column" +
         " such as topicDistributionCol to produce results.")
-      dataset.toDF
+      dataset.toDF()
     }
   }
 
@@ -574,18 +614,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
-        .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
-          "gammaShape")
-        .head()
-      val vocabSize = data.getAs[Int](0)
-      val topicsMatrix = data.getAs[Matrix](1)
-      val docConcentration = data.getAs[Vector](2)
-      val topicConcentration = data.getAs[Double](3)
-      val gammaShape = data.getAs[Double](4)
+      val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration")
+      val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix")
+      val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector,
+          topicConcentration: Double, gammaShape: Double) =
+        matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration",
+          "topicConcentration", "gammaShape").head()
       val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
         gammaShape)
       val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      LDAParams.getAndSetParams(model, metadata)
       model
     }
   }
@@ -731,9 +769,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val modelPath = new Path(path, "oldModel").toString
       val oldModel = OldDistributedLDAModel.load(sc, modelPath)
-      val model = new DistributedLDAModel(
-        metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
+        oldModel, sparkSession, None)
+      LDAParams.getAndSetParams(model, metadata)
       model
     }
   }
@@ -881,7 +919,7 @@ class LDA @Since("1.6.0") (
 }
 
 @Since("2.0.0")
-object LDA extends DefaultParamsReadable[LDA] {
+object LDA extends MLReadable[LDA] {
 
   /** Get dataset for spark.mllib LDA */
   private[clustering] def getOldDataset(
@@ -896,6 +934,20 @@ object LDA extends DefaultParamsReadable[LDA] {
       }
   }
 
+  private class LDAReader extends MLReader[LDA] {
+
+    private val className = classOf[LDA].getName
+
+    override def load(path: String): LDA = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val model = new LDA(metadata.uid)
+      LDAParams.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  override def read: MLReader[LDA] = new LDAReader
+
   @Since("2.0.0")
   override def load(path: String): LDA = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/54d4eee5/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index c38a49a..423cbd4 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -784,6 +784,9 @@ object MimaExcludes {
       // SPARK-17096: Improve exception string reported through the StreamingQueryListener
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"),
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this")
+    ) ++ Seq(
+      // SPARK-16240: ML persistence backward compatibility for LDA
+      ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$")
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message