spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5987] [MLlib] Save/load for GaussianMixtureModels
Date Wed, 25 Mar 2015 21:45:27 GMT
Repository: spark
Updated Branches:
  refs/heads/master 435337381 -> 4fc4d0369


[SPARK-5987] [MLlib] Save/load for GaussianMixtureModels

Should be self explanatory.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #4986 from MechCoder/spark-5987 and squashes the following commits:

7d2cd56 [MechCoder] Iterate over dataframe in a better way
e7a14cb [MechCoder] Minor
33c84f9 [MechCoder] Store as Array[Data] instead of Data[Array]
505bd57 [MechCoder] Rebased over master and used MatrixUDT
7422bb4 [MechCoder] Store sigmas as Array[Double] instead of Array[Array[Double]]
b9794e4 [MechCoder] Minor
cb77095 [MechCoder] [SPARK-5987] Save/load for GaussianMixtureModels


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4fc4d036
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4fc4d036
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4fc4d036

Branch: refs/heads/master
Commit: 4fc4d0369e8240defe0ee83252426402f1a28a36
Parents: 4353373
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Authored: Wed Mar 25 14:45:23 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Mar 25 14:45:23 2015 -0700

----------------------------------------------------------------------
 docs/mllib-clustering.md                        |  8 ++
 .../mllib/clustering/GaussianMixtureModel.scala | 96 +++++++++++++++++++-
 .../mllib/clustering/GaussianMixtureSuite.scala | 52 ++++++++---
 3 files changed, 136 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4fc4d036/docs/mllib-clustering.md
----------------------------------------------------------------------
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 0b6db4f..f5aa15b 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model.
 
 {% highlight scala %}
 import org.apache.spark.mllib.clustering.GaussianMixture
+import org.apache.spark.mllib.clustering.GaussianMixtureModel
 import org.apache.spark.mllib.linalg.Vectors
 
 // Load and parse the data
@@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
 // Cluster the data into two classes using GaussianMixture
 val gmm = new GaussianMixture().setK(2).run(parsedData)
 
+// Save and load model
+gmm.save(sc, "myGMMModel")
+val sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
+
 // output parameters of max-likelihood model
 for (i <- 0 until gmm.k) {
   println("weight=%f\nmu=%s\nsigma=\n%s\n" format
@@ -231,6 +236,9 @@ public class GaussianMixtureExample {
     // Cluster the data into two classes using GaussianMixture
     GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
 
+    // Save and load GaussianMixtureModel
+    gmm.save(sc, "myGMMModel")
+    GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
     // Output the parameters of the mixture model
     for(int j=0; j<gmm.k(); j++) {
         System.out.println("weight=%f\nmu=%s\nsigma=\n%s\n",

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc4d036/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index af6f83c..ec65a3d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering
 
 import breeze.linalg.{DenseVector => BreezeVector}
 
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
 import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
 
 /**
  * :: Experimental ::
@@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD
 @Experimental
 class GaussianMixtureModel(
   val weights: Array[Double], 
-  val gaussians: Array[MultivariateGaussian]) extends Serializable {
+  val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
   
   require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must
match")
-  
+
+  override protected def formatVersion = "1.0"
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians)
+  }
+
   /** Number of gaussians in mixture */
   def k: Int = weights.length
 
@@ -83,5 +95,79 @@ class GaussianMixtureModel(
       p(i) /= pSum
     }
     p
-  }  
+  }
+}
+
+@Experimental
+object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
+
+  private object SaveLoadV1_0 {
+
+    case class Data(weight: Double, mu: Vector, sigma: Matrix)
+
+    val formatVersionV1_0 = "1.0"
+
+    val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel"
+
+    def save(
+        sc: SparkContext,
+        path: String,
+        weights: Array[Double],
+        gaussians: Array[MultivariateGaussian]): Unit = {
+
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+
+      // Create JSON metadata.
+      val metadata = compact(render
+        (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" ->
weights.length)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      // Create Parquet data.
+      val dataArray = Array.tabulate(weights.length) { i =>
+        Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
+      }
+      sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): GaussianMixtureModel = {
+      val dataPath = Loader.dataPath(path)
+      val sqlContext = new SQLContext(sc)
+      val dataFrame = sqlContext.parquetFile(dataPath)
+      val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
+
+      // Check schema explicitly since erasure makes it hard to use match-case for checking.
+      Loader.checkSchema[Data](dataFrame.schema)
+
+      val (weights, gaussians) = dataArray.map {
+        case Row(weight: Double, mu: Vector, sigma: Matrix) =>
+          (weight, new MultivariateGaussian(mu, sigma))
+      }.unzip
+
+      return new GaussianMixtureModel(weights.toArray, gaussians.toArray)
+    }
+  }
+
+  override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    implicit val formats = DefaultFormats
+    val k = (metadata \ "k").extract[Int]
+    val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+    (loadedClassName, version) match {
+      case (classNameV1_0, "1.0") => {
+        val model = SaveLoadV1_0.load(sc, path)
+        require(model.weights.length == k,
+          s"GaussianMixtureModel requires weights of length $k " +
+          s"got weights of length ${model.weights.length}")
+        require(model.gaussians.length == k,
+          s"GaussianMixtureModel requires gaussians of length $k" +
+          s"got gaussians of length ${model.gaussians.length}")
+        model
+      }
+      case _ => throw new Exception(
+        s"GaussianMixtureModel.load did not recognize model with (className, format version):"
+
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4fc4d036/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index 1b46a40..f356ffa 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices}
 import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
 
 class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
   test("single cluster") {
@@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext
{
   }
   
   test("two clusters") {
-    val data = sc.parallelize(Array(
-      Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
-      Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
-      Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
-      Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
-      Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
-    ))
+    val data = sc.parallelize(GaussianTestData.data)
 
     // we set an initial gaussian to induce expected results
     val initialGmm = new GaussianMixtureModel(
@@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext
{
   }
 
   test("two clusters with sparse data") {
-    val data = sc.parallelize(Array(
-      Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
-      Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
-      Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
-      Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
-      Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
-    ))
-
+    val data = sc.parallelize(GaussianTestData.data)
     val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
     // we set an initial gaussian to induce expected results
     val initialGmm = new GaussianMixtureModel(
@@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext
{
     assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
     assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
   }
+
+  test("model save / load") {
+    val data = sc.parallelize(GaussianTestData.data)
+
+    val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    try {
+      gmm.save(sc, path)
+
+      // TODO: GaussianMixtureModel should implement equals/hashcode directly.
+      val sameModel = GaussianMixtureModel.load(sc, path)
+      assert(sameModel.k === gmm.k)
+      (0 until sameModel.k).foreach { i =>
+        assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu)
+        assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma)
+      }
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
+  object GaussianTestData {
+
+    val data = Array(
+      Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+      Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+      Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+      Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+      Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+    )
+
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message