spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-3097][MLlib] Word2Vec performance improvement
Date Mon, 18 Aug 2014 06:29:56 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.1 0506539b0 -> 708cde99a


[SPARK-3097][MLlib] Word2Vec performance improvement

mengxr Please review the code. Adding weights in reduceByKey soon.

Only output model entry for words appeared in the partition before merging and use reduceByKey
to combine model. In general, this implementation is 30s or so faster than implementation
using big array.

Author: Liquan Pei <liquanpei@gmail.com>

Closes #1932 from Ishiihara/Word2Vec-improve2 and squashes the following commits:

d5377a9 [Liquan Pei] use syn0Global and syn1Global to represent model
cad2011 [Liquan Pei] bug fix for synModify array out of bound
083aa66 [Liquan Pei] update synGlobal in place and reduce synOut size
9075e1c [Liquan Pei] combine syn0Global and syn1Global to synGlobal
aa2ab36 [Liquan Pei] use reduceByKey to combine models

(cherry picked from commit 3c8fa505900ac158d57de36f6b0fd6da05f8893b)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/708cde99
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/708cde99
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/708cde99

Branch: refs/heads/branch-1.1
Commit: 708cde99a142c90f5a06c7aa326b622d80022e3d
Parents: 0506539
Author: Liquan Pei <liquanpei@gmail.com>
Authored: Sun Aug 17 23:29:44 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sun Aug 17 23:29:52 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/feature/Word2Vec.scala   | 50 ++++++++++++++------
 1 file changed, 35 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/708cde99/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index ecd49ea..d2ae62b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.rdd._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
 
 /**
  *  Entry in vocabulary 
@@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging {
     var syn0Global =
       Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
     var syn1Global = new Array[Float](vocabSize * vectorSize)
-
     var alpha = startingAlpha
     for (k <- 1 to numIterations) {
       val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
         val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) <<
8))
+        val syn0Modify = new Array[Int](vocabSize)
+        val syn1Modify = new Array[Int](vocabSize)
         val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
           case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
             var lwc = lastWordCount
@@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging {
                     // Hierarchical softmax
                     var d = 0
                     while (d < bcVocab.value(word).codeLen) {
-                      val l2 = bcVocab.value(word).point(d) * vectorSize
+                      val inner = bcVocab.value(word).point(d)
+                      val l2 = inner * vectorSize
                       // Propagate hidden -> output
                       var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
                       if (f > -MAX_EXP && f < MAX_EXP) {
@@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging {
                         val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
                         blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
                         blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
+                        syn1Modify(inner) += 1
                       }
                       d += 1
                     }
                     blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
+                    syn0Modify(lastWord) += 1
                   }
                 }
                 a += 1
@@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging {
             }
             (syn0, syn1, lwc, wc)
         }
-        Iterator(model)
+        val syn0Local = model._1
+        val syn1Local = model._2
+        val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
+        var index = 0
+        while(index < vocabSize) {
+          if (syn0Modify(index) != 0) {
+            synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
+          }
+          if (syn1Modify(index) != 0) {
+            synOut.update(index + vocabSize,
+              syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
+          }
+          index += 1
+        }
+        Iterator(synOut)
       }
-      val (aggSyn0, aggSyn1, _, _) =
-        partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2,
wc_2)) =>
-          val n = syn0_1.length
-          val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
-          val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
-          blas.sscal(n, weight1, syn0_1, 1)
-          blas.sscal(n, weight1, syn1_1, 1)
-          blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
-          blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
-          (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+      val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
+          blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
+          v1
+      }.collect()
+      var i = 0
+      while (i < synAgg.length) {
+        val index = synAgg(i)._1
+        if (index < vocabSize) {
+          Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
+        } else {
+          Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
         }
-      syn0Global = aggSyn0
-      syn1Global = aggSyn1
+        i += 1
+      }
     }
     newSentences.unpersist()
     


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message