spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-7475] [MLLIB] adjust ldaExample for online LDA
Date Sat, 09 May 2015 22:40:59 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.4 5110f3efe -> e96fc8630


[SPARK-7475] [MLLIB] adjust ldaExample for online LDA

jira: https://issues.apache.org/jira/browse/SPARK-7475

Add a new argument to specify the algorithm applied to LDA, to exhibit the basic usage of
LDAOptimizer.

cc jkbradley

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #6000 from hhbyyh/ldaExample and squashes the following commits:

0a7e2bc [Yuhao Yang] fix according to comments
5810b0f [Yuhao Yang] adjust ldaExample for online LDA

(cherry picked from commit b13162b364aeff35e3bdeea9c9a31e5ce66f8c9a)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e96fc863
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e96fc863
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e96fc863

Branch: refs/heads/branch-1.4
Commit: e96fc8630ebe85bee51fbe7795773419fdf174b9
Parents: 5110f3e
Author: Yuhao Yang <hhbyyh@gmail.com>
Authored: Sat May 9 15:40:46 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Sat May 9 15:40:54 2015 -0700

----------------------------------------------------------------------
 .../spark/examples/mllib/LDAExample.scala       | 31 ++++++++++++++++----
 1 file changed, 25 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e96fc863/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index a185039..31d629f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -26,7 +26,7 @@ import scopt.OptionParser
 import org.apache.log4j.{Level, Logger}
 
 import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
+import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel,
LDA}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.rdd.RDD
 
@@ -48,6 +48,7 @@ object LDAExample {
       topicConcentration: Double = -1,
       vocabSize: Int = 10000,
       stopwordFile: String = "",
+      algorithm: String = "em",
       checkpointDir: Option[String] = None,
       checkpointInterval: Int = 10) extends AbstractParams[Params]
 
@@ -78,6 +79,10 @@ object LDAExample {
         .text(s"filepath for a list of stopwords. Note: This must fit on a single machine."
+
         s"  default: ${defaultParams.stopwordFile}")
         .action((x, c) => c.copy(stopwordFile = x))
+      opt[String]("algorithm")
+        .text(s"inference algorithm to use. em and online are supported." +
+        s" default: ${defaultParams.algorithm}")
+        .action((x, c) => c.copy(algorithm = x))
       opt[String]("checkpointDir")
         .text(s"Directory for checkpointing intermediate results." +
         s"  Checkpointing helps with recovery and eliminates temporary shuffle files on disk."
+
@@ -128,7 +133,17 @@ object LDAExample {
 
     // Run LDA.
     val lda = new LDA()
-    lda.setK(params.k)
+
+    val optimizer = params.algorithm.toLowerCase match {
+      case "em" => new EMLDAOptimizer
+      // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
+      case "online" => new OnlineLDAOptimizer().setMiniBatchFraction(0.05 + 1.0 / actualCorpusSize)
+      case _ => throw new IllegalArgumentException(
+        s"Only em, online are supported but got ${params.algorithm}.")
+    }
+
+    lda.setOptimizer(optimizer)
+      .setK(params.k)
       .setMaxIterations(params.maxIterations)
       .setDocConcentration(params.docConcentration)
       .setTopicConcentration(params.topicConcentration)
@@ -137,14 +152,18 @@ object LDAExample {
       sc.setCheckpointDir(params.checkpointDir.get)
     }
     val startTime = System.nanoTime()
-    val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
+    val ldaModel = lda.run(corpus)
     val elapsed = (System.nanoTime() - startTime) / 1e9
 
     println(s"Finished training LDA model.  Summary:")
     println(s"\t Training time: $elapsed sec")
-    val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
-    println(s"\t Training data average log likelihood: $avgLogLikelihood")
-    println()
+
+    if (ldaModel.isInstanceOf[DistributedLDAModel]) {
+      val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
+      val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
+      println(s"\t Training data average log likelihood: $avgLogLikelihood")
+      println()
+    }
 
     // Print the topics, showing the top-weighted terms for each topic.
     val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message