spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5692] [MLlib] Word2Vec save/load
Date Tue, 31 Mar 2015 23:01:12 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2036bc599 -> 0e00f12d3


[SPARK-5692] [MLlib] Word2Vec save/load

Word2Vec model now supports saving and loading.

a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and
"numWords"
b] The data stored in Parquet file format consists of an Array of rows with each row consisting
of 2 columns, first being the word: String and the second, an Array of Floats.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #5291 from MechCoder/spark-5692 and squashes the following commits:

1142f3a [MechCoder] Add numWords to metaData
bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e00f12d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e00f12d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e00f12d

Branch: refs/heads/master
Commit: 0e00f12d33d28d064c166262b14e012a1aeaa7b0
Parents: 2036bc5
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Authored: Tue Mar 31 16:01:08 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Mar 31 16:01:08 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/feature/Word2Vec.scala   | 87 +++++++++++++++++++-
 .../spark/mllib/feature/Word2VecSuite.scala     | 26 ++++++
 2 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0e00f12d/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 59a79e5..9ee7e4a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.Logging
+import org.apache.spark.SparkContext
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.sql.{SQLContext, Row}
 
 /**
  *  Entry in vocabulary 
@@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging {
  */
 @Experimental
 class Word2VecModel private[mllib] (
-    private val model: Map[String, Array[Float]]) extends Serializable {
+    private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
 
   private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
     require(v1.length == v2.length, "Vectors should have the same length")
@@ -432,7 +439,13 @@ class Word2VecModel private[mllib] (
     if (norm1 == 0 || norm2 == 0) return 0.0
     blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
   }
-  
+
+  override protected def formatVersion = "1.0"
+
+  def save(sc: SparkContext, path: String): Unit = {
+    Word2VecModel.SaveLoadV1_0.save(sc, path, model)
+  }
+
   /**
    * Transforms a word to its vector representation
    * @param word a word 
@@ -475,7 +488,7 @@ class Word2VecModel private[mllib] (
       .tail
       .toArray
   }
-  
+
   /**
    * Returns a map of words to their vector representations.
    */
@@ -483,3 +496,71 @@ class Word2VecModel private[mllib] (
     model
   }
 }
+
+@Experimental
+object Word2VecModel extends Loader[Word2VecModel] {
+
+  private object SaveLoadV1_0 {
+
+    val formatVersionV1_0 = "1.0"
+
+    val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"
+
+    case class Data(word: String, vector: Array[Float])
+
+    def load(sc: SparkContext, path: String): Word2VecModel = {
+      val dataPath = Loader.dataPath(path)
+      val sqlContext = new SQLContext(sc)
+      val dataFrame = sqlContext.parquetFile(dataPath)
+
+      val dataArray = dataFrame.select("word", "vector").collect()
+
+      // Check schema explicitly since erasure makes it hard to use match-case for checking.
+      Loader.checkSchema[Data](dataFrame.schema)
+
+      val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
+      new Word2VecModel(word2VecMap)
+    }
+
+    def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = {
+
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+
+      val vectorSize = model.values.head.size
+      val numWords = model.size
+      val metadata = compact(render
+        (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
+         ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
+      sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
+    }
+  }
+
+  override def load(sc: SparkContext, path: String): Word2VecModel = {
+
+    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+    implicit val formats = DefaultFormats
+    val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
+    val expectedNumWords = (metadata \ "numWords").extract[Int]
+    val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+    (loadedClassName, loadedVersion) match {
+      case (classNameV1_0, "1.0") =>
+        val model = SaveLoadV1_0.load(sc, path)
+        val vectorSize = model.getVectors.values.head.size
+        val numWords = model.getVectors.size
+        require(expectedVectorSize == vectorSize,
+          s"Word2VecModel requires each word to be mapped to a vector of size " +
+          s"$expectedVectorSize, got vector of size $vectorSize")
+        require(expectedNumWords == numWords,
+          s"Word2VecModel requires $expectedNumWords words, but got $numWords")
+        model
+      case _ => throw new Exception(
+        s"Word2VecModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $loadedVersion).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0e00f12d/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 5227869..98a98a7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -21,6 +21,9 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
 class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
 
   // TODO: add more tests
@@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
     assert(syms(0)._1 == "taiwan")
     assert(syms(1)._1 == "japan")
   }
+
+  test("model load / save") {
+
+    val word2VecMap = Map(
+      ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+      ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+      ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+      ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+    )
+    val model = new Word2VecModel(word2VecMap)
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    try {
+      model.save(sc, path)
+      val sameModel = Word2VecModel.load(sc, path)
+      assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message