spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-10809][MLLIB] Single-document topicDistributions method for LocalLDAModel
Date Mon, 11 Jan 2016 22:55:49 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4f8eefa36 -> bbea88852


[SPARK-10809][MLLIB] Single-document topicDistributions method for LocalLDAModel

jira: https://issues.apache.org/jira/browse/SPARK-10809

We could provide a single-document topicDistributions method for LocalLDAModel to allow for
quick queries which avoid RDD operations. Currently, the user must use an RDD of documents.

add some missing assert too.

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #9484 from hhbyyh/ldaTopicPre.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bbea8885
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bbea8885
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bbea8885

Branch: refs/heads/master
Commit: bbea88852ce6a3127d071ca40dbca2d042f9fbcf
Parents: 4f8eefa
Author: Yuhao Yang <hhbyyh@gmail.com>
Authored: Mon Jan 11 14:55:44 2016 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Mon Jan 11 14:55:44 2016 -0800

----------------------------------------------------------------------
 .../spark/mllib/clustering/LDAModel.scala       | 26 ++++++++++++++++++++
 .../spark/mllib/clustering/LDASuite.scala       | 15 ++++++++---
 2 files changed, 38 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bbea8885/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 2fce3ff..b30ecb8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -388,6 +388,32 @@ class LocalLDAModel private[spark] (
   }
 
   /**
+   * Predicts the topic mixture distribution for a document (often called "theta" in the
+   * literature).  Returns a vector of zeros for an empty document.
+   *
+   * Note this means to allow quick query for single document. For batch documents, please
refer
+   * to [[topicDistributions()]] to avoid overhead.
+   *
+   * @param document document to predict topic mixture distributions for
+   * @return topic mixture distribution for the document
+   */
+  @Since("2.0.0")
+  def topicDistribution(document: Vector): Vector = {
+    val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+    if (document.numNonzeros == 0) {
+      Vectors.zeros(this.k)
+    } else {
+      val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
+        document,
+        expElogbeta,
+        this.docConcentration.toBreeze,
+        gammaShape,
+        this.k)
+      Vectors.dense(normalize(gamma, 1.0).toArray)
+    }
+  }
+
+  /**
    * Java-friendly version of [[topicDistributions]]
    */
   @Since("1.4.1")

http://git-wip-us.apache.org/repos/asf/spark/blob/bbea8885/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index faef60e..ea23196 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -366,7 +366,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       (0, 0.99504), (1, 0.99504),
       (1, 0.99504), (1, 0.99504))
 
-    val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) =>
+    val actualPredictions = ldaModel.topicDistributions(docs).cache()
+    val topTopics = actualPredictions.map { case (id, topics) =>
         // convert results to expectedPredictions format, which only has highest probability
topic
         val topicsBz = topics.toBreeze.toDenseVector
         (id, (argmax(topicsBz), max(topicsBz)))
@@ -374,9 +375,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
       .values
       .collect()
 
-    expectedPredictions.zip(actualPredictions).forall { case (expected, actual) =>
-      expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D)
+    expectedPredictions.zip(topTopics).foreach { case (expected, actual) =>
+      assert(expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D))
     }
+
+    docs.collect()
+      .map(doc => ldaModel.topicDistribution(doc._2))
+      .zip(actualPredictions.map(_._2).collect())
+      .foreach { case (single, batch) =>
+        assert(single ~== batch relTol 1E-3D)
+      }
+    actualPredictions.unpersist()
   }
 
   test("OnlineLDAOptimizer with asymmetric prior") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message