Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 1A4D51774A for ; Mon, 9 Feb 2015 04:26:34 +0000 (UTC) Received: (qmail 77083 invoked by uid 500); 9 Feb 2015 00:26:34 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 77048 invoked by uid 500); 9 Feb 2015 00:26:34 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 77039 invoked by uid 99); 9 Feb 2015 00:26:34 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 09 Feb 2015 00:26:34 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 631D0E04AF; Mon, 9 Feb 2015 00:26:34 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: meng@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-5598][MLLIB] model save/load for ALS Date: Mon, 9 Feb 2015 00:26:34 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 804949d51 -> 5c299c58f [SPARK-5598][MLLIB] model save/load for ALS following #4233. jkbradley Author: Xiangrui Meng Closes #4422 from mengxr/SPARK-5598 and squashes the following commits: a059394 [Xiangrui Meng] SaveLoad not extending Loader 14b7ea6 [Xiangrui Meng] address comments f487cb2 [Xiangrui Meng] add unit tests 62fc43c [Xiangrui Meng] implement save/load for MFM Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5c299c58 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5c299c58 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5c299c58 Branch: refs/heads/master Commit: 5c299c58fb9a5434a40be82150d4725bba805adf Parents: 804949d Author: Xiangrui Meng Authored: Sun Feb 8 16:26:20 2015 -0800 Committer: Xiangrui Meng Committed: Sun Feb 8 16:26:20 2015 -0800 ---------------------------------------------------------------------- .../apache/spark/mllib/recommendation/ALS.scala | 2 +- .../MatrixFactorizationModel.scala | 82 +++++++++++++++++++- .../MatrixFactorizationModelSuite.scala | 19 +++++ 3 files changed, 100 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5c299c58/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 4bb28d1..caacab9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD http://git-wip-us.apache.org/repos/asf/spark/blob/5c299c58/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ed2f8b4..9ff06ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,17 @@ package org.apache.spark.mllib.recommendation +import java.io.IOException import java.lang.{Integer => JavaInteger} +import org.apache.hadoop.fs.Path import org.jblas.DoubleMatrix -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.storage.StorageLevel /** @@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + val productFeatures: RDD[(Int, Array[Double])]) + extends Saveable with Serializable with Logging { require(rank > 0) validateFeatures("User", userFeatures) @@ -125,6 +130,12 @@ class MatrixFactorizationModel( recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + protected override val formatVersion: String = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + MatrixFactorizationModel.SaveLoadV1_0.save(this, path) + } + private def recommend( recommendToFeatures: Array[Double], recommendableFeatures: RDD[(Int, Array[Double])], @@ -136,3 +147,70 @@ class MatrixFactorizationModel( scored.top(num)(Ordering.by(_._2)) } } + +object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { + + import org.apache.spark.mllib.util.Loader._ + + override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => + throw new IOException("MatrixFactorizationModel.load did not recognize model with" + + s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private[recommendation] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[recommendation] + val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + + /** + * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ + def save(model: MatrixFactorizationModel, path: String): Unit = { + val sc = model.userFeatures.sparkContext + val sqlContext = new SQLContext(sc) + import sqlContext.implicits.createDataFrame + val metadata = (thisClassName, thisFormatVersion, model.rank) + val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank") + metadataRDD.toJSON.saveAsTextFile(metadataPath(path)) + model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path)) + model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path)) + } + + def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rank = metadata.select("rank").first().getInt(0) + val userFeatures = sqlContext.parquetFile(userPath(path)) + .map { case Row(id: Int, features: Seq[Double]) => + (id, features.toArray) + } + val productFeatures = sqlContext.parquetFile(productPath(path)) + .map { case Row(id: Int, features: Seq[Double]) => + (id, features.toArray) + } + new MatrixFactorizationModel(rank, userFeatures, productFeatures) + } + + private def userPath(path: String): String = { + new Path(dataPath(path), "user").toUri.toString + } + + private def productPath(path: String): String = { + new Path(dataPath(path), "product").toUri.toString + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/5c299c58/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc..9801e87 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org