spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5598][MLLIB] model save/load for ALS
Date Mon, 09 Feb 2015 00:26:34 GMT
Repository: spark
Updated Branches:
  refs/heads/master 804949d51 -> 5c299c58f


[SPARK-5598][MLLIB] model save/load for ALS

following #4233. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #4422 from mengxr/SPARK-5598 and squashes the following commits:

a059394 [Xiangrui Meng] SaveLoad not extending Loader
14b7ea6 [Xiangrui Meng] address comments
f487cb2 [Xiangrui Meng] add unit tests
62fc43c [Xiangrui Meng] implement save/load for MFM


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5c299c58
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5c299c58
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5c299c58

Branch: refs/heads/master
Commit: 5c299c58fb9a5434a40be82150d4725bba805adf
Parents: 804949d
Author: Xiangrui Meng <meng@databricks.com>
Authored: Sun Feb 8 16:26:20 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sun Feb 8 16:26:20 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/mllib/recommendation/ALS.scala |  2 +-
 .../MatrixFactorizationModel.scala              | 82 +++++++++++++++++++-
 .../MatrixFactorizationModelSuite.scala         | 19 +++++
 3 files changed, 100 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5c299c58/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 4bb28d1..caacab9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.mllib.recommendation
 
 import org.apache.spark.Logging
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.ml.recommendation.{ALS => NewALS}
 import org.apache.spark.rdd.RDD

http://git-wip-us.apache.org/repos/asf/spark/blob/5c299c58/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index ed2f8b4..9ff06ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,13 +17,17 @@
 
 package org.apache.spark.mllib.recommendation
 
+import java.io.IOException
 import java.lang.{Integer => JavaInteger}
 
+import org.apache.hadoop.fs.Path
 import org.jblas.DoubleMatrix
 
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkContext}
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
 class MatrixFactorizationModel(
     val rank: Int,
     val userFeatures: RDD[(Int, Array[Double])],
-    val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
+    val productFeatures: RDD[(Int, Array[Double])])
+  extends Saveable with Serializable with Logging {
 
   require(rank > 0)
   validateFeatures("User", userFeatures)
@@ -125,6 +130,12 @@ class MatrixFactorizationModel(
     recommend(productFeatures.lookup(product).head, userFeatures, num)
       .map(t => Rating(t._1, product, t._2))
 
+  protected override val formatVersion: String = "1.0"
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
+  }
+
   private def recommend(
       recommendToFeatures: Array[Double],
       recommendableFeatures: RDD[(Int, Array[Double])],
@@ -136,3 +147,70 @@ class MatrixFactorizationModel(
     scored.top(num)(Ordering.by(_._2))
   }
 }
+
+object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
+
+  import org.apache.spark.mllib.util.Loader._
+
+  override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+    val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    (loadedClassName, formatVersion) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        SaveLoadV1_0.load(sc, path)
+      case _ =>
+        throw new IOException("MatrixFactorizationModel.load did not recognize model with"
+
+          s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
+          s"  ($classNameV1_0, 1.0)")
+    }
+  }
+
+  private[recommendation]
+  object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    private[recommendation]
+    val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
+
+    /**
+     * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users`
and
+     * product features are saved under `data/products`.
+     */
+    def save(model: MatrixFactorizationModel, path: String): Unit = {
+      val sc = model.userFeatures.sparkContext
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits.createDataFrame
+      val metadata = (thisClassName, thisFormatVersion, model.rank)
+      val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version",
"rank")
+      metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+      model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
+      model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+      val sqlContext = new SQLContext(sc)
+      val (className, formatVersion, metadata) = loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+      val rank = metadata.select("rank").first().getInt(0)
+      val userFeatures = sqlContext.parquetFile(userPath(path))
+        .map { case Row(id: Int, features: Seq[Double]) =>
+          (id, features.toArray)
+        }
+      val productFeatures = sqlContext.parquetFile(productPath(path))
+        .map { case Row(id: Int, features: Seq[Double]) =>
+        (id, features.toArray)
+      }
+      new MatrixFactorizationModel(rank, userFeatures, productFeatures)
+    }
+
+    private def userPath(path: String): String = {
+      new Path(dataPath(path), "user").toUri.toString
+    }
+
+    private def productPath(path: String): String = {
+      new Path(dataPath(path), "product").toUri.toString
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5c299c58/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
index b9caecc..9801e87 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.FunSuite
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
 
 class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
       new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
     }
   }
+
+  test("save/load") {
+    val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
+      features.mapValues(_.toSeq).collect().toSet
+    }
+    try {
+      model.save(sc, path)
+      val newModel = MatrixFactorizationModel.load(sc, path)
+      assert(newModel.rank === rank)
+      assert(collect(newModel.userFeatures) === collect(userFeatures))
+      assert(collect(newModel.productFeatures) === collect(prodFeatures))
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message