spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-11626][ML] ml.feature.Word2Vec.transform() function very slow
Date Wed, 11 Nov 2015 17:43:31 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1510c527b -> 27524a3a9


[SPARK-11626][ML] ml.feature.Word2Vec.transform() function very slow

org.apache.spark.ml.feature.Word2Vec.transform() very slow. we should not read broadcast every
sentence.

Author: Yuming Wang <q79969786@gmail.com>
Author: yuming.wang <q79969786@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #9592 from 979969786/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27524a3a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27524a3a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27524a3a

Branch: refs/heads/master
Commit: 27524a3a9ccee6fbe56149180ebfb3f74e0957e7
Parents: 1510c52
Author: Yuming Wang <q79969786@gmail.com>
Authored: Wed Nov 11 09:43:26 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Nov 11 09:43:26 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 34 +++++++++-----------
 1 file changed, 16 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/27524a3a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 9edab3a..5c64cb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,18 +17,16 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
 import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.types._
 
 /**
@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
 @Experimental
 class Word2VecModel private[ml] (
     override val uid: String,
-    wordVectors: feature.Word2VecModel)
+    @transient wordVectors: feature.Word2VecModel)
   extends Model[Word2VecModel] with Word2VecBase {
 
-
   /**
    * Returns a dataframe with two fields, "word" and "vector", with "word" being a String
and
    * and the vector the DenseVector that it is mapped to.
@@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
    */
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
+    val vectors = wordVectors.getVectors
+      .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
+      .map(identity) // mapValues doesn't return a serializable map (SI-7005)
+    val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
+    val d = $(vectorSize)
     val word2Vec = udf { sentence: Seq[String] =>
       if (sentence.size == 0) {
-        Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
+        Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
       } else {
-        val cum = Vectors.zeros($(vectorSize))
-        val model = bWordVectors.value.getVectors
-        for (word <- sentence) {
-          if (model.contains(word)) {
-            axpy(1.0, bWordVectors.value.transform(word), cum)
-          } else {
-            // pass words which not belong to model
+        val sum = Vectors.zeros(d)
+        sentence.foreach { word =>
+          bVectors.value.get(word).foreach { v =>
+            BLAS.axpy(1.0, v, sum)
           }
         }
-        scal(1.0 / sentence.size, cum)
-        cum
+        BLAS.scal(1.0 / sentence.size, sum)
+        sum
       }
     }
     dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message