spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM
Date Tue, 03 Feb 2015 07:58:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0cc7b88c9 -> 980764f3c


[SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM

**This PR introduces an API + simple implementation for Latent Dirichlet Allocation (LDA).**

The [design doc for this PR](https://docs.google.com/document/d/1kSsDqTeZMEB94Bs4GTd0mvdAmduvZSSkpoSfn-seAzo) has been updated since I initially posted it.  In particular, see the API and Planning for the Future sections.

* Settle on a public API which may eventually include:
  * more inference algorithms
  * more options / functionality
* Have an initial easy-to-understand implementation which others may improve.
* This is NOT intended to support every topic model out there.  However, if there are suggestions for making this extensible or pluggable in the future, that could be nice, as long as it does not complicate the API or implementation too much.
* This may not be very scalable currently.  It will be important to check and improve accuracy.  For correctness of the implementation, please check against the Asuncion et al. (2009) paper in the design doc.

**Dependency: This makes MLlib depend on GraphX.**

Files and classes:
* LDA.scala (441 lines):
  * class LDA (main estimator class)
  * LDA.Document  (text + document ID)
* LDAModel.scala (266 lines)
  * abstract class LDAModel
  * class LocalLDAModel
  * class DistributedLDAModel
* LDAExample.scala (245 lines): script to run LDA + a simple (private) Tokenizer
* LDASuite.scala (144 lines)

Data/model representation and algorithm:
* Data/model: Uses GraphX, with term vertices + document vertices
* Algorithm: EM, following [Asuncion, Welling, Smyth, and Teh.  "On Smoothing and Inference for Topic Models."  UAI, 2009.](http://arxiv-web3.library.cornell.edu/abs/1205.2662v1)
* For more details, please see the description in the “DEVELOPERS NOTE” in LDA.scala

Please refer to the JIRA for more discussion + the [design doc for this PR](https://docs.google.com/document/d/1kSsDqTeZMEB94Bs4GTd0mvdAmduvZSSkpoSfn-seAzo)

Here, I list the main changes AFTER the design doc was posted.

Design decisions:
* logLikelihood() computes the log likelihood of the data and the current point estimate of parameters.  This is different from the likelihood of the data given the hyperparameters, which would be harder to compute.  I’d describe the current approach as more frequentist, whereas the harder approach would be more Bayesian.
* The current API takes Documents as token count vectors.  I believe there should be an extended API taking RDD[String] or RDD[Array[String]] in a future PR.  I have sketched this out in the design doc (as well as handier versions of getTopics returning Strings).
* Hyperparameters should be set differently for different inference/learning algorithms.  See Asuncion et al. (2009) in the design doc for a good demonstration.  I encourage good behavior via defaults and warning messages.

Items planned for future PRs:
* perplexity
* API taking Strings

* Should LDA be called LatentDirichletAllocation (and LDAModel be LatentDirichletAllocationModel)?
  * Pro: We may someday want LinearDiscriminantAnalysis.
  * Con: Very long names

* Should LDA reside in clustering?  Or do we want a sub-package?
  * mllib.topicmodel
  * mllib.clustering.topicmodel

* Does the API seem reasonable and extensible?

* Unit tests:
  * Should there be a test which checks a clustering results?  E.g., train on a small, fake dataset with 2 very distinct topics/clusters, and ensure LDA finds those 2 topics/clusters.  Does that sound useful or too flaky?

This has not been tested much for scaling.  I have run it on a laptop for 200 iterations on a 5MB dataset with 1000 terms and 5 topics.  Running it for 500 iterations made it fail because of GC problems.  I'm running larger scale tests & will put results here, but future PRs may need to improve the scaling.

* dlwh  for the initial implementation
  * + jegonzal  for some code in the initial implementation
* The many contributors towards topic model implementations in Spark which were referenced as a basis for this PR: akopich witgo yinxusen dlwh EntilZha jegonzal  IlyaKozlov
  * Note: The plan is to include this full list in the authors if this PR gets merged.  Please notify me if you prefer otherwise.

CC: mengxr

Authors:
  Joseph K. Bradley <joseph@databricks.com>
  Joseph Gonzalez <joseph.e.gonzalez@gmail.com>
  David Hall <david.lw.hall@gmail.com>
  Guoqiang Li <witgo@qq.com>
  Xiangrui Meng <meng@databricks.com>
  Pedro Rodriguez <pedro@snowgeek.org>
  Avanesov Valeriy <acopich@gmail.com>
  Xusen Yin <yinxusen@gmail.com>

Closes #2388
Closes #4047 from jkbradley/davidhall-lda and squashes the following commits:

77e8814 [Joseph K. Bradley] small doc fix
5c74345 [Joseph K. Bradley] cleaned up doc based on code review
589728b [Joseph K. Bradley] Updates per code review.  Main change was in LDAExample for faster vocab computation.  Also updated PeriodicGraphCheckpointerSuite.scala to clean up checkpoint files at end
e3980d2 [Joseph K. Bradley] cleaned up PeriodicGraphCheckpointerSuite.scala
74487e5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into davidhall-lda
4ae2a7d [Joseph K. Bradley] removed duplicate graphx dependency in mllib/pom.xml
e391474 [Joseph K. Bradley] Removed LDATiming.  Added PeriodicGraphCheckpointerSuite.scala.  Small LDA cleanups.
e8d8acf [Joseph K. Bradley] Added catch for BreakIterator exception.  Improved preprocessing to reduce passes over data
1a231b4 [Joseph K. Bradley] fixed scalastyle
91aadfe [Joseph K. Bradley] Added Java-friendly run method to LDA. Added Java test suite for LDA. Changed LDAModel.describeTopics to return Java-friendly type
b75472d [Joseph K. Bradley] merged improvements from LDATiming into LDAExample.  Will remove LDATiming after done testing
993ca56 [Joseph K. Bradley] * Removed Document type in favor of (Long, Vector) * Changed doc ID restriction to be: id must be nonnegative and unique in the doc (instead of 0,1,2,...) * Add checks for valid ranges of eta, alpha * Rename “LearningState” to “EMOptimizer” * Renamed params: termSmoothing -> topicConcentration, topicSmoothing -> docConcentration   * Also added aliases alpha, beta
cb5a319 [Joseph K. Bradley] Added checkpointing to LDA * new class PeriodicGraphCheckpointer * params checkpointDir, checkpointInterval to LDA
43c1c40 [Joseph K. Bradley] small cleanup
0b90393 [Joseph K. Bradley] renamed LDA LearningState.collectTopicTotals to globalTopicTotals
77a2c85 [Joseph K. Bradley] Moved auto term,topic smoothing computation to get*Smoothing methods.  Changed word to term in some places.  Updated LDAExample to use default smoothing amounts.
fb1e7b5 [Xiangrui Meng] minor
08d59a3 [Xiangrui Meng] reset spacing
9fe0b95 [Xiangrui Meng] optimize aggregateMessages
cec0a9c [Xiangrui Meng] * -> *=
6cb11b0 [Xiangrui Meng] optimize computePTopic
9eb3d02 [Xiangrui Meng] + -> +=
892530c [Xiangrui Meng] use axpy
45cc7f2 [Xiangrui Meng] mapPart -> flatMap
ce53be9 [Joseph K. Bradley] fixed example name
75749e7 [Joseph K. Bradley] scala style fix
9f2a492 [Joseph K. Bradley] Unit tests and fixes for LDA, now ready for PR
377ebd9 [Joseph K. Bradley] separated LDA models into own file.  more cleanups before PR
2d40006 [Joseph K. Bradley] cleanups before PR
2891e89 [Joseph K. Bradley] Prepped LDA main class for PR, but some cleanups remain
0cb7187 [Joseph K. Bradley] Added 3 files from dlwh LDA implementation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/980764f3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/980764f3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/980764f3

Branch: refs/heads/master
Commit: 980764f3c0c065cc32454a036e8d0ead5a92037b
Parents: 0cc7b88
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Mon Feb 2 23:57:35 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Feb 2 23:57:37 2015 -0800

----------------------------------------------------------------------
 .../spark/examples/mllib/LDAExample.scala       | 283 ++++++++++
 .../org/apache/spark/mllib/clustering/LDA.scala | 519 +++++++++++++++++++
 .../spark/mllib/clustering/LDAModel.scala       | 351 +++++++++++++
 .../mllib/impl/PeriodicGraphCheckpointer.scala  | 179 +++++++
 .../spark/mllib/clustering/JavaLDASuite.java    | 119 +++++
 .../spark/mllib/clustering/LDASuite.scala       | 153 ++++++
 .../impl/PeriodicGraphCheckpointerSuite.scala   | 187 +++++++
 7 files changed, 1791 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
new file mode 100644
index 0000000..f4c545a
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.text.BreakIterator
+
+import scala.collection.mutable
+
+import scopt.OptionParser
+
+import org.apache.log4j.{Level, Logger}
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * An example Latent Dirichlet Allocation (LDA) app. Run with
+ * {{{
+ * ./bin/run-example mllib.LDAExample [options] <input>
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LDAExample {
+
+  private case class Params(
+      input: Seq[String] = Seq.empty,
+      k: Int = 20,
+      maxIterations: Int = 10,
+      docConcentration: Double = -1,
+      topicConcentration: Double = -1,
+      vocabSize: Int = 10000,
+      stopwordFile: String = "",
+      checkpointDir: Option[String] = None,
+      checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+  def main(args: Array[String]) {
+    val defaultParams = Params()
+
+    val parser = new OptionParser[Params]("LDAExample") {
+      head("LDAExample: an example LDA app for plain text data.")
+      opt[Int]("k")
+        .text(s"number of topics. default: ${defaultParams.k}")
+        .action((x, c) => c.copy(k = x))
+      opt[Int]("maxIterations")
+        .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
+        .action((x, c) => c.copy(maxIterations = x))
+      opt[Double]("docConcentration")
+        .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." +
+        s"  default: ${defaultParams.docConcentration}")
+        .action((x, c) => c.copy(docConcentration = x))
+      opt[Double]("topicConcentration")
+        .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." +
+        s"  default: ${defaultParams.topicConcentration}")
+        .action((x, c) => c.copy(topicConcentration = x))
+      opt[Int]("vocabSize")
+        .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" +
+          s"  default: ${defaultParams.vocabSize}")
+        .action((x, c) => c.copy(vocabSize = x))
+      opt[String]("stopwordFile")
+        .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
+        s"  default: ${defaultParams.stopwordFile}")
+        .action((x, c) => c.copy(stopwordFile = x))
+      opt[String]("checkpointDir")
+        .text(s"Directory for checkpointing intermediate results." +
+        s"  Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
+        s"  default: ${defaultParams.checkpointDir}")
+        .action((x, c) => c.copy(checkpointDir = Some(x)))
+      opt[Int]("checkpointInterval")
+        .text(s"Iterations between each checkpoint.  Only used if checkpointDir is set." +
+        s" default: ${defaultParams.checkpointInterval}")
+        .action((x, c) => c.copy(checkpointInterval = x))
+      arg[String]("<input>...")
+        .text("input paths (directories) to plain text corpora." +
+        "  Each text file line should hold 1 document.")
+        .unbounded()
+        .required()
+        .action((x, c) => c.copy(input = c.input :+ x))
+    }
+
+    parser.parse(args, defaultParams).map { params =>
+      run(params)
+    }.getOrElse {
+      parser.showUsageAsError
+      sys.exit(1)
+    }
+  }
+
+  private def run(params: Params) {
+    val conf = new SparkConf().setAppName(s"LDAExample with $params")
+    val sc = new SparkContext(conf)
+
+    Logger.getRootLogger.setLevel(Level.WARN)
+
+    // Load documents, and prepare them for LDA.
+    val preprocessStart = System.nanoTime()
+    val (corpus, vocabArray, actualNumTokens) =
+      preprocess(sc, params.input, params.vocabSize, params.stopwordFile)
+    corpus.cache()
+    val actualCorpusSize = corpus.count()
+    val actualVocabSize = vocabArray.size
+    val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9
+
+    println()
+    println(s"Corpus summary:")
+    println(s"\t Training set size: $actualCorpusSize documents")
+    println(s"\t Vocabulary size: $actualVocabSize terms")
+    println(s"\t Training set size: $actualNumTokens tokens")
+    println(s"\t Preprocessing time: $preprocessElapsed sec")
+    println()
+
+    // Run LDA.
+    val lda = new LDA()
+    lda.setK(params.k)
+      .setMaxIterations(params.maxIterations)
+      .setDocConcentration(params.docConcentration)
+      .setTopicConcentration(params.topicConcentration)
+      .setCheckpointInterval(params.checkpointInterval)
+    if (params.checkpointDir.nonEmpty) {
+      lda.setCheckpointDir(params.checkpointDir.get)
+    }
+    val startTime = System.nanoTime()
+    val ldaModel = lda.run(corpus)
+    val elapsed = (System.nanoTime() - startTime) / 1e9
+
+    println(s"Finished training LDA model.  Summary:")
+    println(s"\t Training time: $elapsed sec")
+    val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
+    println(s"\t Training data average log likelihood: $avgLogLikelihood")
+    println()
+
+    // Print the topics, showing the top-weighted terms for each topic.
+    val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
+    val topics = topicIndices.map { case (terms, termWeights) =>
+      terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) }
+    }
+    println(s"${params.k} topics:")
+    topics.zipWithIndex.foreach { case (topic, i) =>
+      println(s"TOPIC $i")
+      topic.foreach { case (term, weight) =>
+        println(s"$term\t$weight")
+      }
+      println()
+    }
+
+  }
+
+  /**
+   * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors.
+   * @return (corpus, vocabulary as array, total token count in corpus)
+   */
+  private def preprocess(
+      sc: SparkContext,
+      paths: Seq[String],
+      vocabSize: Int,
+      stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
+
+    // Get dataset of document texts
+    // One document per line in each text file.
+    val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
+
+    // Split text into words
+    val tokenizer = new SimpleTokenizer(sc, stopwordFile)
+    val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
+      id -> tokenizer.getWords(text)
+    }
+    tokenized.cache()
+
+    // Counts words: RDD[(word, wordCount)]
+    val wordCounts: RDD[(String, Long)] = tokenized
+      .flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
+      .reduceByKey(_ + _)
+    wordCounts.cache()
+    val fullVocabSize = wordCounts.count()
+    // Select vocab
+    //  (vocab: Map[word -> id], total tokens after selecting vocab)
+    val (vocab: Map[String, Int], selectedTokenCount: Long) = {
+      val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
+        // Use all terms
+        wordCounts.collect().sortBy(-_._2)
+      } else {
+        // Sort terms to select vocab
+        wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
+      }
+      (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
+    }
+
+    val documents = tokenized.map { case (id, tokens) =>
+      // Filter tokens by vocabulary, and create word count vector representation of document.
+      val wc = new mutable.HashMap[Int, Int]()
+      tokens.foreach { term =>
+        if (vocab.contains(term)) {
+          val termIndex = vocab(term)
+          wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
+        }
+      }
+      val indices = wc.keys.toArray.sorted
+      val values = indices.map(i => wc(i).toDouble)
+
+      val sb = Vectors.sparse(vocab.size, indices, values)
+      (id, sb)
+    }
+
+    val vocabArray = new Array[String](vocab.size)
+    vocab.foreach { case (term, i) => vocabArray(i) = term }
+
+    (documents, vocabArray, selectedTokenCount)
+  }
+}
+
+/**
+ * Simple Tokenizer.
+ *
+ * TODO: Formalize the interface, and make this a public class in mllib.feature
+ */
+private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
+
+  private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
+    Set.empty[String]
+  } else {
+    val stopwordText = sc.textFile(stopwordFile).collect()
+    stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
+  }
+
+  // Matches sequences of Unicode letters
+  private val allWordRegex = "^(\\p{L}*)$".r
+
+  // Ignore words shorter than this length.
+  private val minWordLength = 3
+
+  def getWords(text: String): IndexedSeq[String] = {
+
+    val words = new mutable.ArrayBuffer[String]()
+
+    // Use Java BreakIterator to tokenize text into words.
+    val wb = BreakIterator.getWordInstance
+    wb.setText(text)
+
+    // current,end index start,end of each word
+    var current = wb.first()
+    var end = wb.next()
+    while (end != BreakIterator.DONE) {
+      // Convert to lowercase
+      val word: String = text.substring(current, end).toLowerCase
+      // Remove short words and strings that aren't only letters
+      word match {
+        case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
+          words += w
+        case _ =>
+      }
+
+      current = end
+      try {
+        end = wb.next()
+      } catch {
+        case e: Exception =>
+          // Ignore remaining text in line.
+          // This is a known bug in BreakIterator (for some Java versions),
+          // which fails when it sees certain characters.
+          end = BreakIterator.DONE
+      }
+    }
+    words
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
new file mode 100644
index 0000000..d8f8286
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -0,0 +1,519 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl
+import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+
+/**
+ * :: Experimental ::
+ *
+ * Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
+ *
+ * Terminology:
+ *  - "word" = "term": an element of the vocabulary
+ *  - "token": instance of a term appearing in a document
+ *  - "topic": multinomial distribution over words representing some concept
+ *
+ * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
+ * according to the Asuncion et al. (2009) paper referenced below.
+ *
+ * References:
+ *  - Original LDA paper (journal version):
+ *    Blei, Ng, and Jordan.  "Latent Dirichlet Allocation."  JMLR, 2003.
+ *     - This class implements their "smoothed" LDA model.
+ *  - Paper which clearly explains several algorithms, including EM:
+ *    Asuncion, Welling, Smyth, and Teh.
+ *    "On Smoothing and Inference for Topic Models."  UAI, 2009.
+ */
+@Experimental
+class LDA private (
+    private var k: Int,
+    private var maxIterations: Int,
+    private var docConcentration: Double,
+    private var topicConcentration: Double,
+    private var seed: Long,
+    private var checkpointDir: Option[String],
+    private var checkpointInterval: Int) extends Logging {
+
+  def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
+    seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10)
+
+  /**
+   * Number of topics to infer.  I.e., the number of soft cluster centers.
+   */
+  def getK: Int = k
+
+  /**
+   * Number of topics to infer.  I.e., the number of soft cluster centers.
+   * (default = 10)
+   */
+  def setK(k: Int): this.type = {
+    require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k")
+    this.k = k
+    this
+  }
+
+  /**
+   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+   * distributions over topics ("theta").
+   *
+   * This is the parameter to a symmetric Dirichlet distribution.
+   */
+  def getDocConcentration: Double = {
+    if (this.docConcentration == -1) {
+      (50.0 / k) + 1.0
+    } else {
+      this.docConcentration
+    }
+  }
+
+  /**
+   * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+   * distributions over topics ("theta").
+   *
+   * This is the parameter to a symmetric Dirichlet distribution.
+   *
+   * This value should be > 1.0, where larger values mean more smoothing (more regularization).
+   * If set to -1, then docConcentration is set automatically.
+   *  (default = -1 = automatic)
+   *
+   * Automatic setting of parameter:
+   *  - For EM: default = (50 / k) + 1.
+   *     - The 50/k is common in LDA libraries.
+   *     - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+   *
+   * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
+   *       but values in (0,1) are not yet supported.
+   */
+  def setDocConcentration(docConcentration: Double): this.type = {
+    require(docConcentration > 1.0 || docConcentration == -1.0,
+      s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration")
+    this.docConcentration = docConcentration
+    this
+  }
+
+  /** Alias for [[getDocConcentration]] */
+  def getAlpha: Double = getDocConcentration
+
+  /** Alias for [[setDocConcentration()]] */
+  def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
+
+  /**
+   * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+   * distributions over terms.
+   *
+   * This is the parameter to a symmetric Dirichlet distribution.
+   *
+   * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+   * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+   */
+  def getTopicConcentration: Double = {
+    if (this.topicConcentration == -1) {
+      1.1
+    } else {
+      this.topicConcentration
+    }
+  }
+
+  /**
+   * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+   * distributions over terms.
+   *
+   * This is the parameter to a symmetric Dirichlet distribution.
+   *
+   * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+   * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+   *
+   * This value should be > 0.0.
+   * If set to -1, then topicConcentration is set automatically.
+   *  (default = -1 = automatic)
+   *
+   * Automatic setting of parameter:
+   *  - For EM: default = 0.1 + 1.
+   *     - The 0.1 gives a small amount of smoothing.
+   *     - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+   *
+   * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
+   *       but values in (0,1) are not yet supported.
+   */
+  def setTopicConcentration(topicConcentration: Double): this.type = {
+    require(topicConcentration > 1.0 || topicConcentration == -1.0,
+      s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration")
+    this.topicConcentration = topicConcentration
+    this
+  }
+
+  /** Alias for [[getTopicConcentration]] */
+  def getBeta: Double = getTopicConcentration
+
+  /** Alias for [[setTopicConcentration()]] */
+  def setBeta(beta: Double): this.type = setBeta(beta)
+
+  /**
+   * Maximum number of iterations for learning.
+   */
+  def getMaxIterations: Int = maxIterations
+
+  /**
+   * Maximum number of iterations for learning.
+   * (default = 20)
+   */
+  def setMaxIterations(maxIterations: Int): this.type = {
+    this.maxIterations = maxIterations
+    this
+  }
+
+  /** Random seed */
+  def getSeed: Long = seed
+
+  /** Random seed */
+  def setSeed(seed: Long): this.type = {
+    this.seed = seed
+    this
+  }
+
+  /**
+   * Directory for storing checkpoint files during learning.
+   * This is not necessary, but checkpointing helps with recovery (when nodes fail).
+   * It also helps with eliminating temporary shuffle files on disk, which can be important when
+   * LDA is run for many iterations.
+   */
+  def getCheckpointDir: Option[String] = checkpointDir
+
+  /**
+   * Directory for storing checkpoint files during learning.
+   * This is not necessary, but checkpointing helps with recovery (when nodes fail).
+   * It also helps with eliminating temporary shuffle files on disk, which can be important when
+   * LDA is run for many iterations.
+   *
+   * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value
+   *       given to LDA is ignored, and the existing directory is kept.
+   *
+   * (default = None)
+   */
+  def setCheckpointDir(checkpointDir: String): this.type = {
+    this.checkpointDir = Some(checkpointDir)
+    this
+  }
+
+  /**
+   * Clear the directory for storing checkpoint files during learning.
+   * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still
+   * occur; otherwise, no checkpointing will be used.
+   */
+  def clearCheckpointDir(): this.type = {
+    this.checkpointDir = None
+    this
+  }
+
+  /**
+   * Period (in iterations) between checkpoints.
+   * @see [[getCheckpointDir]]
+   */
+  def getCheckpointInterval: Int = checkpointInterval
+
+  /**
+   * Period (in iterations) between checkpoints.
+   * (default = 10)
+   * @see [[getCheckpointDir]]
+   */
+  def setCheckpointInterval(checkpointInterval: Int): this.type = {
+    this.checkpointInterval = checkpointInterval
+    this
+  }
+
+  /**
+   * Learn an LDA model using the given dataset.
+   *
+   * @param documents  RDD of documents, which are term (word) count vectors paired with IDs.
+   *                   The term count vectors are "bags of words" with a fixed-size vocabulary
+   *                   (where the vocabulary size is the length of the vector).
+   *                   Document IDs must be unique and >= 0.
+   * @return  Inferred LDA model
+   */
+  def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
+    val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
+      checkpointDir, checkpointInterval)
+    var iter = 0
+    val iterationTimes = Array.fill[Double](maxIterations)(0)
+    while (iter < maxIterations) {
+      val start = System.nanoTime()
+      state.next()
+      val elapsedSeconds = (System.nanoTime() - start) / 1e9
+      iterationTimes(iter) = elapsedSeconds
+      iter += 1
+    }
+    state.graphCheckpointer.deleteAllCheckpoints()
+    new DistributedLDAModel(state, iterationTimes)
+  }
+
+  /** Java-friendly version of [[run()]] */
+  def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
+    run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
+  }
+}
+
+
+private[clustering] object LDA {
+
+  /*
+    DEVELOPERS NOTE:
+
+    This implementation uses GraphX, where the graph is bipartite with 2 types of vertices:
+     - Document vertices
+        - indexed with unique indices >= 0
+        - Store vectors of length k (# topics).
+     - Term vertices
+        - indexed {-1, -2, ..., -vocabSize}
+        - Store vectors of length k (# topics).
+     - Edges correspond to terms appearing in documents.
+        - Edges are directed Document -> Term.
+        - Edges are partitioned by documents.
+
+    Info on EM implementation.
+     - We follow Section 2.2 from Asuncion et al., 2009.  We use some of their notation.
+     - In this implementation, there is one edge for every unique term appearing in a document,
+       i.e., for every unique (document, term) pair.
+     - Notation:
+        - N_{wkj} = count of tokens of term w currently assigned to topic k in document j
+        - N_{*} where * is missing a subscript w/k/j is the count summed over missing subscript(s)
+        - gamma_{wjk} = P(z_i = k | x_i = w, d_i = j),
+          the probability of term x_i in document d_i having topic z_i.
+     - Data graph
+        - Document vertices store N_{kj}
+        - Term vertices store N_{wk}
+        - Edges store N_{wj}.
+        - Global data N_k
+     - Algorithm
+        - Initial state:
+           - Document and term vertices store random counts N_{wk}, N_{kj}.
+        - E-step: For each (document,term) pair i, compute P(z_i | x_i, d_i).
+           - Aggregate N_k from term vertices.
+           - Compute gamma_{wjk} for each possible topic k, from each triplet.
+             using inputs N_{wk}, N_{kj}, N_k.
+        - M-step: Compute sufficient statistics for hidden parameters phi and theta
+          (counts N_{wk}, N_{kj}, N_k).
+           - Document update:
+              - N_{kj} <- sum_w N_{wj} gamma_{wjk}
+              - N_j <- sum_k N_{kj}  (only needed to output predictions)
+           - Term update:
+              - N_{wk} <- sum_j N_{wj} gamma_{wjk}
+              - N_k <- sum_w N_{wk}
+
+    TODO: Add simplex constraints to allow alpha in (0,1).
+          See: Vorontsov and Potapenko. "Tutorial on Probabilistic Topic Modeling : Additive
+               Regularization for Stochastic Matrix Factorization." 2014.
+   */
+
+  /**
+   * Vector over topics (length k) of token counts.
+   * The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
+   */
+  type TopicCounts = BDV[Double]
+
+  type TokenCount = Double
+
+  /** Term vertex IDs are {-1, -2, ..., -vocabSize} */
+  def term2index(term: Int): Long = -(1 + term.toLong)
+
+  def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
+
+  def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
+
+  def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
+
+  /**
+   * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
+   *
+   * @param graph  EM graph, storing current parameter estimates in vertex descriptors and
+   *               data (token counts) in edge descriptors.
+   * @param k  Number of topics
+   * @param vocabSize  Number of unique terms
+   * @param docConcentration  "alpha"
+   * @param topicConcentration  "beta" or "eta"
+   */
+  class EMOptimizer(
+      var graph: Graph[TopicCounts, TokenCount],
+      val k: Int,
+      val vocabSize: Int,
+      val docConcentration: Double,
+      val topicConcentration: Double,
+      checkpointDir: Option[String],
+      checkpointInterval: Int) {
+
+    private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
+      graph, checkpointDir, checkpointInterval)
+
+    def next(): EMOptimizer = {
+      val eta = topicConcentration
+      val W = vocabSize
+      val alpha = docConcentration
+
+      val N_k = globalTopicTotals
+      val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
+        (edgeContext) => {
+          // Compute N_{wj} gamma_{wjk}
+          val N_wj = edgeContext.attr
+          // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
+          // N_{wj}.
+          val scaledTopicDistribution: TopicCounts =
+            computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
+          edgeContext.sendToDst((false, scaledTopicDistribution))
+          edgeContext.sendToSrc((false, scaledTopicDistribution))
+        }
+      // This is a hack to detect whether we could modify the values in-place.
+      // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
+      val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
+        (m0, m1) => {
+          val sum =
+            if (m0._1) {
+              m0._2 += m1._2
+            } else if (m1._1) {
+              m1._2 += m0._2
+            } else {
+              m0._2 + m1._2
+            }
+          (true, sum)
+        }
+      // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
+      val docTopicDistributions: VertexRDD[TopicCounts] =
+        graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
+          .mapValues(_._2)
+      // Update the vertex descriptors with the new counts.
+      val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
+      graph = newGraph
+      graphCheckpointer.updateGraph(newGraph)
+      globalTopicTotals = computeGlobalTopicTotals()
+      this
+    }
+
+    /**
+     * Aggregate distributions over topics from all term vertices.
+     *
+     * Note: This executes an action on the graph RDDs.
+     */
+    var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()
+
+    private def computeGlobalTopicTotals(): TopicCounts = {
+      val numTopics = k
+      graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
+    }
+
+  }
+
+  /**
+   * Compute gamma_{wjk}, a distribution over topics k.
+   */
+  private def computePTopic(
+      docTopicCounts: TopicCounts,
+      termTopicCounts: TopicCounts,
+      totalTopicCounts: TopicCounts,
+      vocabSize: Int,
+      eta: Double,
+      alpha: Double): TopicCounts = {
+    val K = docTopicCounts.length
+    val N_j = docTopicCounts.data
+    val N_w = termTopicCounts.data
+    val N = totalTopicCounts.data
+    val eta1 = eta - 1.0
+    val alpha1 = alpha - 1.0
+    val Weta1 = vocabSize * eta1
+    var sum = 0.0
+    val gamma_wj = new Array[Double](K)
+    var k = 0
+    while (k < K) {
+      val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1)
+      gamma_wj(k) = gamma_wjk
+      sum += gamma_wjk
+      k += 1
+    }
+    // normalize
+    BDV(gamma_wj) /= sum
+  }
+
+  /**
+   * Compute bipartite term/doc graph.
+   */
+  private def initialState(
+      docs: RDD[(Long, Vector)],
+      k: Int,
+      docConcentration: Double,
+      topicConcentration: Double,
+      randomSeed: Long,
+      checkpointDir: Option[String],
+      checkpointInterval: Int): EMOptimizer = {
+    // For each document, create an edge (Document -> Term) for each unique term in the document.
+    val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
+      // Add edges for terms with non-zero counts.
+      termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
+        Edge(docID, term2index(term), cnt)
+      }
+    }
+
+    val vocabSize = docs.take(1).head._2.size
+
+    // Create vertices.
+    // Initially, we use random soft assignments of tokens to topics (random gamma).
+    val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] =
+      edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
+        val random = new Random(partIndex + randomSeed)
+        partEdges.map { edge =>
+          // Create a random gamma_{wjk}
+          (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0))
+        }
+      }
+    def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
+      val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
+        edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
+          (sendToWhere(edge), (edge.attr, gamma))
+        }
+      verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
+        (sum, t) => {
+          brzAxpy(t._1, t._2, sum)
+          sum
+        },
+        (sum0, sum1) => {
+          sum0 += sum1
+        }
+      )
+    }
+    val docVertices = createVertices(_.srcId)
+    val termVertices = createVertices(_.dstId)
+
+    // Partition such that edges are grouped by document
+    val graph = Graph(docVertices ++ termVertices, edges)
+      .partitionBy(PartitionStrategy.EdgePartition1D)
+
+    new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir,
+      checkpointInterval)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
new file mode 100644
index 0000000..19e8aab
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
+import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.BoundedPriorityQueue
+
+/**
+ * :: Experimental ::
+ *
+ * Latent Dirichlet Allocation (LDA) model.
+ *
+ * This abstraction permits for different underlying representations,
+ * including local and distributed data structures.
+ */
+@Experimental
+abstract class LDAModel private[clustering] {
+
+  /** Number of topics */
+  def k: Int
+
+  /** Vocabulary size (number of terms or terms in the vocabulary) */
+  def vocabSize: Int
+
+  /**
+   * Inferred topics, where each topic is represented by a distribution over terms.
+   * This is a matrix of size vocabSize x k, where each column is a topic.
+   * No guarantees are given about the ordering of the topics.
+   */
+  def topicsMatrix: Matrix
+
+  /**
+   * Return the topics described by weighted terms.
+   *
+   * This limits the number of terms per topic.
+   * This is approximate; it may not return exactly the top-weighted terms for each topic.
+   * To get a more precise set of top terms, increase maxTermsPerTopic.
+   *
+   * @param maxTermsPerTopic  Maximum number of terms to collect for each topic.
+   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
+   *          (term indices, term weights in topic).
+   *          Each topic's terms are sorted in order of decreasing weight.
+   */
+  def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])]
+
+  /**
+   * Return the topics described by weighted terms.
+   *
+   * WARNING: If vocabSize and k are large, this can return a large object!
+   *
+   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
+   *          (term indices, term weights in topic).
+   *          Each topic's terms are sorted in order of decreasing weight.
+   */
+  def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize)
+
+  /* TODO (once LDA can be trained with Strings or given a dictionary)
+   * Return the topics described by weighted terms.
+   *
+   * This is similar to [[describeTopics()]] but returns String values for terms.
+   * If this model was trained using Strings or was given a dictionary, then this method returns
+   * terms as text.  Otherwise, this method returns terms as term indices.
+   *
+   * This limits the number of terms per topic.
+   * This is approximate; it may not return exactly the top-weighted terms for each topic.
+   * To get a more precise set of top terms, increase maxTermsPerTopic.
+   *
+   * @param maxTermsPerTopic  Maximum number of terms to collect for each topic.
+   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
+   *          (terms, term weights in topic) where terms are either the actual term text
+   *          (if available) or the term indices.
+   *          Each topic's terms are sorted in order of decreasing weight.
+   */
+  // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])]
+
+  /* TODO (once LDA can be trained with Strings or given a dictionary)
+   * Return the topics described by weighted terms.
+   *
+   * This is similar to [[describeTopics()]] but returns String values for terms.
+   * If this model was trained using Strings or was given a dictionary, then this method returns
+   * terms as text.  Otherwise, this method returns terms as term indices.
+   *
+   * WARNING: If vocabSize and k are large, this can return a large object!
+   *
+   * @return  Array over topics.  Each topic is represented as a pair of matching arrays:
+   *          (terms, term weights in topic) where terms are either the actual term text
+   *          (if available) or the term indices.
+   *          Each topic's terms are sorted in order of decreasing weight.
+   */
+  // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] =
+  //  describeTopicsAsStrings(vocabSize)
+
+  /* TODO
+   * Compute the log likelihood of the observed tokens, given the current parameter estimates:
+   *  log P(docs | topics, topic distributions for docs, alpha, eta)
+   *
+   * Note:
+   *  - This excludes the prior.
+   *  - Even with the prior, this is NOT the same as the data log likelihood given the
+   *    hyperparameters.
+   *
+   * @param documents  RDD of documents, which are term (word) count vectors paired with IDs.
+   *                   The term count vectors are "bags of words" with a fixed-size vocabulary
+   *                   (where the vocabulary size is the length of the vector).
+   *                   This must use the same vocabulary (ordering of term counts) as in training.
+   *                   Document IDs must be unique and >= 0.
+   * @return  Estimated log likelihood of the data under this model
+   */
+  // def logLikelihood(documents: RDD[(Long, Vector)]): Double
+
+  /* TODO
+   * Compute the estimated topic distribution for each document.
+   * This is often called “theta” in the literature.
+   *
+   * @param documents  RDD of documents, which are term (word) count vectors paired with IDs.
+   *                   The term count vectors are "bags of words" with a fixed-size vocabulary
+   *                   (where the vocabulary size is the length of the vector).
+   *                   This must use the same vocabulary (ordering of term counts) as in training.
+   *                   Document IDs must be unique and >= 0.
+   * @return  Estimated topic distribution for each document.
+   *          The returned RDD may be zipped with the given RDD, where each returned vector
+   *          is a multinomial distribution over topics.
+   */
+  // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)]
+
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Local LDA model.
+ * This model stores only the inferred topics.
+ * It may be used for computing topics for new documents, but it may give less accurate answers
+ * than the [[DistributedLDAModel]].
+ *
+ * @param topics Inferred topics (vocabSize x k matrix).
+ */
+@Experimental
+class LocalLDAModel private[clustering] (
+    private val topics: Matrix) extends LDAModel with Serializable {
+
+  override def k: Int = topics.numCols
+
+  override def vocabSize: Int = topics.numRows
+
+  override def topicsMatrix: Matrix = topics
+
+  override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
+    val brzTopics = topics.toBreeze.toDenseMatrix
+    Range(0, k).map { topicIndex =>
+      val topic = normalize(brzTopics(::, topicIndex), 1.0)
+      val (termWeights, terms) =
+        topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip
+      (terms.toArray, termWeights.toArray)
+    }.toArray
+  }
+
+  // TODO
+  // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
+
+  // TODO:
+  // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Distributed LDA model.
+ * This model stores the inferred topics, the full training dataset, and the topic distributions.
+ * When computing topics for new documents, it may give more accurate answers
+ * than the [[LocalLDAModel]].
+ */
+@Experimental
+class DistributedLDAModel private (
+    private val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
+    private val globalTopicTotals: LDA.TopicCounts,
+    val k: Int,
+    val vocabSize: Int,
+    private val docConcentration: Double,
+    private val topicConcentration: Double,
+    private[spark] val iterationTimes: Array[Double]) extends LDAModel {
+
+  import LDA._
+
+  private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
+    this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
+      state.topicConcentration, iterationTimes)
+  }
+
+  /**
+   * Convert model to a local model.
+   * The local model stores the inferred topics but not the topic distributions for training
+   * documents.
+   */
+  def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix)
+
+  /**
+   * Inferred topics, where each topic is represented by a distribution over terms.
+   * This is a matrix of size vocabSize x k, where each column is a topic.
+   * No guarantees are given about the ordering of the topics.
+   *
+   * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large.
+   */
+  override lazy val topicsMatrix: Matrix = {
+    // Collect row-major topics
+    val termTopicCounts: Array[(Int, TopicCounts)] =
+      graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) =>
+        (index2term(termIndex), cnts)
+      }.collect()
+    // Convert to Matrix
+    val brzTopics = BDM.zeros[Double](vocabSize, k)
+    termTopicCounts.foreach { case (term, cnts) =>
+      var j = 0
+      while (j < k) {
+        brzTopics(term, j) = cnts(j)
+        j += 1
+      }
+    }
+    Matrices.fromBreeze(brzTopics)
+  }
+
+  override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
+    val numTopics = k
+    // Note: N_k is not needed to find the top terms, but it is needed to normalize weights
+    //       to a distribution over terms.
+    val N_k: TopicCounts = globalTopicTotals
+    val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] =
+      graph.vertices.filter(isTermVertex)
+        .mapPartitions { termVertices =>
+        // For this partition, collect the most common terms for each topic in queues:
+        //  queues(topic) = queue of (term weight, term index).
+        // Term weights are N_{wk} / N_k.
+        val queues =
+          Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic))
+        for ((termId, n_wk) <- termVertices) {
+          var topic = 0
+          while (topic < numTopics) {
+            queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt))
+            topic += 1
+          }
+        }
+        Iterator(queues)
+      }.reduce { (q1, q2) =>
+        q1.zip(q2).foreach { case (a, b) => a ++= b}
+        q1
+      }
+    topicsInQueues.map { q =>
+      val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip
+      (terms.toArray, termWeights.toArray)
+    }
+  }
+
+  // TODO
+  // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
+
+  /**
+   * Log likelihood of the observed tokens in the training set,
+   * given the current parameter estimates:
+   *  log P(docs | topics, topic distributions for docs, alpha, eta)
+   *
+   * Note:
+   *  - This excludes the prior; for that, use [[logPrior]].
+   *  - Even with [[logPrior]], this is NOT the same as the data log likelihood given the
+   *    hyperparameters.
+   */
+  lazy val logLikelihood: Double = {
+    val eta = topicConcentration
+    val alpha = docConcentration
+    assert(eta > 1.0)
+    assert(alpha > 1.0)
+    val N_k = globalTopicTotals
+    val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
+    // Edges: Compute token log probability from phi_{wk}, theta_{kj}.
+    val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => {
+      val N_wj = edgeContext.attr
+      val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0)
+      val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0)
+      val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
+      val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
+      val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj))
+      edgeContext.sendToDst(tokenLogLikelihood)
+    }
+    graph.aggregateMessages[Double](sendMsg, _ + _)
+      .map(_._2).fold(0.0)(_ + _)
+  }
+
+  /**
+   * Log probability of the current parameter estimate:
+   *  log P(topics, topic distributions for docs | alpha, eta)
+   */
+  lazy val logPrior: Double = {
+    val eta = topicConcentration
+    val alpha = docConcentration
+    // Term vertices: Compute phi_{wk}.  Use to compute prior log probability.
+    // Doc vertex: Compute theta_{kj}.  Use to compute prior log probability.
+    val N_k = globalTopicTotals
+    val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
+    val seqOp: (Double, (VertexId, TopicCounts)) => Double = {
+      case (sumPrior: Double, vertex: (VertexId, TopicCounts)) =>
+        if (isTermVertex(vertex)) {
+          val N_wk = vertex._2
+          val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
+          val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
+          (eta - 1.0) * brzSum(phi_wk.map(math.log))
+        } else {
+          val N_kj = vertex._2
+          val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
+          val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
+          (alpha - 1.0) * brzSum(theta_kj.map(math.log))
+        }
+    }
+    graph.vertices.aggregate(0.0)(seqOp, _ + _)
+  }
+
+  /**
+   * For each document in the training set, return the distribution over topics for that document
+   * (i.e., "theta_doc").
+   *
+   * @return  RDD of (document ID, topic distribution) pairs
+   */
+  def topicDistributions: RDD[(Long, Vector)] = {
+    graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
+      (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0)))
+    }
+  }
+
+  // TODO:
+  // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
new file mode 100644
index 0000000..76672fe
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.Logging
+import org.apache.spark.graphx.Graph
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This class helps with persisting and checkpointing Graphs.
+ * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
+ * unpersisting and removing checkpoint files.
+ *
+ * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
+ * before the graph has been materialized.  After updating [[PeriodicGraphCheckpointer]], users are
+ * responsible for materializing the graph to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
+ *  - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
+ *  - Unpersist graphs from queue until there are at most 3 persisted graphs.
+ *  - If using checkpointing and the checkpoint interval has been reached,
+ *     - Checkpoint the new graph, and put in a queue of checkpointed graphs.
+ *     - Remove older checkpoints.
+ *
+ * WARNINGS:
+ *  - This class should NOT be copied (since copies may conflict on which Graphs should be
+ *    checkpointed).
+ *  - This class removes checkpoint files once later graphs have been checkpointed.
+ *    However, references to the older graphs will still return isCheckpointed = true.
+ *
+ * Example usage:
+ * {{{
+ *  val (graph1, graph2, graph3, ...) = ...
+ *  val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
+ *  graph1.vertices.count(); graph1.edges.count()
+ *  // persisted: graph1
+ *  cp.updateGraph(graph2)
+ *  graph2.vertices.count(); graph2.edges.count()
+ *  // persisted: graph1, graph2
+ *  // checkpointed: graph2
+ *  cp.updateGraph(graph3)
+ *  graph3.vertices.count(); graph3.edges.count()
+ *  // persisted: graph1, graph2, graph3
+ *  // checkpointed: graph2
+ *  cp.updateGraph(graph4)
+ *  graph4.vertices.count(); graph4.edges.count()
+ *  // persisted: graph2, graph3, graph4
+ *  // checkpointed: graph4
+ *  cp.updateGraph(graph5)
+ *  graph5.vertices.count(); graph5.edges.count()
+ *  // persisted: graph3, graph4, graph5
+ *  // checkpointed: graph4
+ * }}}
+ *
+ * @param currentGraph  Initial graph
+ * @param checkpointDir The directory for storing checkpoint files
+ * @param checkpointInterval Graphs will be checkpointed at this interval
+ * @tparam VD  Vertex descriptor type
+ * @tparam ED  Edge descriptor type
+ *
+ * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
+ */
+private[mllib] class PeriodicGraphCheckpointer[VD, ED](
+    var currentGraph: Graph[VD, ED],
+    val checkpointDir: Option[String],
+    val checkpointInterval: Int) extends Logging {
+
+  /** FIFO queue of past checkpointed RDDs */
+  private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
+
+  /** FIFO queue of past persisted RDDs */
+  private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
+
+  /** Number of times [[updateGraph()]] has been called */
+  private var updateCount = 0
+
+  /**
+   * Spark Context for the Graphs given to this checkpointer.
+   * NOTE: This code assumes that only one SparkContext is used for the given graphs.
+   */
+  private val sc = currentGraph.vertices.sparkContext
+
+  // If a checkpoint directory is given, and there's no prior checkpoint directory,
+  // then set the checkpoint directory with the given one.
+  if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) {
+    sc.setCheckpointDir(checkpointDir.get)
+  }
+
+  updateGraph(currentGraph)
+
+  /**
+   * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
+   * Since this handles persistence and checkpointing, this should be called before the graph
+   * has been materialized.
+   *
+   * @param newGraph  New graph created from previous graphs in the lineage.
+   */
+  def updateGraph(newGraph: Graph[VD, ED]): Unit = {
+    if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
+      newGraph.persist()
+    }
+    persistedQueue.enqueue(newGraph)
+    // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
+    // Users should call [[updateGraph()]] when a new graph has been created,
+    // before the graph has been materialized.
+    while (persistedQueue.size > 3) {
+      val graphToUnpersist = persistedQueue.dequeue()
+      graphToUnpersist.unpersist(blocking = false)
+    }
+    updateCount += 1
+
+    // Handle checkpointing (after persisting)
+    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+      // Add new checkpoint before removing old checkpoints.
+      newGraph.checkpoint()
+      checkpointQueue.enqueue(newGraph)
+      // Remove checkpoints before the latest one.
+      var canDelete = true
+      while (checkpointQueue.size > 1 && canDelete) {
+        // Delete the oldest checkpoint only if the next checkpoint exists.
+        if (checkpointQueue.get(1).get.isCheckpointed) {
+          removeCheckpointFile()
+        } else {
+          canDelete = false
+        }
+      }
+    }
+  }
+
+  /**
+   * Call this at the end to delete any remaining checkpoint files.
+   */
+  def deleteAllCheckpoints(): Unit = {
+    while (checkpointQueue.size > 0) {
+      removeCheckpointFile()
+    }
+  }
+
+  /**
+   * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
+   * This prints a warning but does not fail if the files cannot be removed.
+   */
+  private def removeCheckpointFile(): Unit = {
+    val old = checkpointQueue.dequeue()
+    // Since the old checkpoint is not deleted by Spark, we manually delete it.
+    val fs = FileSystem.get(sc.hadoopConfiguration)
+    old.getCheckpointFiles.foreach { checkpointFile =>
+      try {
+        fs.delete(new Path(checkpointFile), true)
+      } catch {
+        case e: Exception =>
+          logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
+            checkpointFile)
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
new file mode 100644
index 0000000..dc10aa6
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import org.apache.spark.api.java.JavaRDD;
+import scala.Tuple2;
+
+import org.junit.After;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+
+
+public class JavaLDASuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaLDA");
+    ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<Tuple2<Long, Vector>>();
+    for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) {
+      tinyCorpus.add(new Tuple2<Long, Vector>((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(),
+          LDASuite$.MODULE$.tinyCorpus()[i]._2()));
+    }
+    JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
+    corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void localLDAModel() {
+    LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
+
+    // Check: basic parameters
+    assertEquals(model.k(), tinyK);
+    assertEquals(model.vocabSize(), tinyVocabSize);
+    assertEquals(model.topicsMatrix(), tinyTopics);
+
+    // Check: describeTopics() with all terms
+    Tuple2<int[], double[]>[] fullTopicSummary = model.describeTopics();
+    assertEquals(fullTopicSummary.length, tinyK);
+    for (int i = 0; i < fullTopicSummary.length; i++) {
+      assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1());
+      assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5);
+    }
+  }
+
+  @Test
+  public void distributedLDAModel() {
+    int k = 3;
+    double topicSmoothing = 1.2;
+    double termSmoothing = 1.2;
+
+    // Train a model
+    LDA lda = new LDA();
+    lda.setK(k)
+      .setDocConcentration(topicSmoothing)
+      .setTopicConcentration(termSmoothing)
+      .setMaxIterations(5)
+      .setSeed(12345);
+
+    DistributedLDAModel model = lda.run(corpus);
+
+    // Check: basic parameters
+    LocalLDAModel localModel = model.toLocal();
+    assertEquals(model.k(), k);
+    assertEquals(localModel.k(), k);
+    assertEquals(model.vocabSize(), tinyVocabSize);
+    assertEquals(localModel.vocabSize(), tinyVocabSize);
+    assertEquals(model.topicsMatrix(), localModel.topicsMatrix());
+
+    // Check: topic summaries
+    Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
+    assertEquals(roundedTopicSummary.length, k);
+    Tuple2<int[], double[]>[] roundedLocalTopicSummary = localModel.describeTopics();
+    assertEquals(roundedLocalTopicSummary.length, k);
+
+    // Check: log probabilities
+    assert(model.logLikelihood() < 0.0);
+    assert(model.logPrior() < 0.0);
+  }
+
+  private static int tinyK = LDASuite$.MODULE$.tinyK();
+  private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
+  private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
+  private static Tuple2<int[], double[]>[] tinyTopicDescription =
+      LDASuite$.MODULE$.tinyTopicDescription();
+  JavaPairRDD<Long, Vector> corpus;
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
new file mode 100644
index 0000000..302d751
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class LDASuite extends FunSuite with MLlibTestSparkContext {
+
+  import LDASuite._
+
+  test("LocalLDAModel") {
+    val model = new LocalLDAModel(tinyTopics)
+
+    // Check: basic parameters
+    assert(model.k === tinyK)
+    assert(model.vocabSize === tinyVocabSize)
+    assert(model.topicsMatrix === tinyTopics)
+
+    // Check: describeTopics() with all terms
+    val fullTopicSummary = model.describeTopics()
+    assert(fullTopicSummary.size === tinyK)
+    fullTopicSummary.zip(tinyTopicDescription).foreach {
+      case ((algTerms, algTermWeights), (terms, termWeights)) =>
+        assert(algTerms === terms)
+        assert(algTermWeights === termWeights)
+    }
+
+    // Check: describeTopics() with some terms
+    val smallNumTerms = 3
+    val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms)
+    smallTopicSummary.zip(tinyTopicDescription).foreach {
+      case ((algTerms, algTermWeights), (terms, termWeights)) =>
+        assert(algTerms === terms.slice(0, smallNumTerms))
+        assert(algTermWeights === termWeights.slice(0, smallNumTerms))
+    }
+  }
+
+  test("running and DistributedLDAModel") {
+    val k = 3
+    val topicSmoothing = 1.2
+    val termSmoothing = 1.2
+
+    // Train a model
+    val lda = new LDA()
+    lda.setK(k)
+      .setDocConcentration(topicSmoothing)
+      .setTopicConcentration(termSmoothing)
+      .setMaxIterations(5)
+      .setSeed(12345)
+    val corpus = sc.parallelize(tinyCorpus, 2)
+
+    val model: DistributedLDAModel = lda.run(corpus)
+
+    // Check: basic parameters
+    val localModel = model.toLocal
+    assert(model.k === k)
+    assert(localModel.k === k)
+    assert(model.vocabSize === tinyVocabSize)
+    assert(localModel.vocabSize === tinyVocabSize)
+    assert(model.topicsMatrix === localModel.topicsMatrix)
+
+    // Check: topic summaries
+    //  The odd decimal formatting and sorting is a hack to do a robust comparison.
+    val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
+      // cut values to 3 digits after the decimal place
+      terms.zip(termWeights).map { case (term, weight) =>
+        ("%.3f".format(weight).toDouble, term.toInt)
+      }
+    }.sortBy(_.mkString(""))
+    val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
+      // cut values to 3 digits after the decimal place
+      terms.zip(termWeights).map { case (term, weight) =>
+        ("%.3f".format(weight).toDouble, term.toInt)
+      }
+    }.sortBy(_.mkString(""))
+    roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
+      assert(t1 === t2)
+    }
+
+    // Check: per-doc topic distributions
+    val topicDistributions = model.topicDistributions.collect()
+    //  Ensure all documents are covered.
+    assert(topicDistributions.size === tinyCorpus.size)
+    assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
+    //  Ensure we have proper distributions
+    topicDistributions.foreach { case (docId, topicDistribution) =>
+      assert(topicDistribution.size === tinyK)
+      assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
+    }
+
+    // Check: log probabilities
+    assert(model.logLikelihood < 0.0)
+    assert(model.logPrior < 0.0)
+  }
+
+  test("vertex indexing") {
+    // Check vertex ID indexing and conversions.
+    val docIds = Array(0, 1, 2)
+    val docVertexIds = docIds
+    val termIds = Array(0, 1, 2)
+    val termVertexIds = Array(-1, -2, -3)
+    assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0))))
+    assert(termIds.map(LDA.term2index) === termVertexIds)
+    assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds)
+    assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0))))
+  }
+}
+
+private[clustering] object LDASuite {
+
+  def tinyK: Int = 3
+  def tinyVocabSize: Int = 5
+  def tinyTopicsAsArray: Array[Array[Double]] = Array(
+    Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0
+    Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1
+    Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2
+  )
+  def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK,
+    values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _))
+  def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic =>
+    val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip
+    (terms.toArray, termWeights.toArray)
+  }
+
+  def tinyCorpus = Array(
+    Vectors.dense(1, 3, 0, 2, 8),
+    Vectors.dense(0, 2, 1, 0, 4),
+    Vectors.dense(2, 3, 12, 3, 1),
+    Vectors.dense(0, 3, 1, 9, 8),
+    Vectors.dense(1, 1, 4, 2, 6)
+  ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+  assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/980764f3/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
new file mode 100644
index 0000000..dac28a3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.{Edge, Graph}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
+
+  import PeriodicGraphCheckpointerSuite._
+
+  // TODO: Do I need to call count() on the graphs' RDDs?
+
+  test("Persisting") {
+    var graphsToCheck = Seq.empty[GraphToCheck]
+
+    val graph1 = createGraph(sc)
+    val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10)
+    graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+    checkPersistence(graphsToCheck, 1)
+
+    var iteration = 2
+    while (iteration < 9) {
+      val graph = createGraph(sc)
+      checkpointer.updateGraph(graph)
+      graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+      checkPersistence(graphsToCheck, iteration)
+      iteration += 1
+    }
+  }
+
+  test("Checkpointing") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    val checkpointInterval = 2
+    var graphsToCheck = Seq.empty[GraphToCheck]
+
+    val graph1 = createGraph(sc)
+    val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval)
+    graph1.edges.count()
+    graph1.vertices.count()
+    graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+    checkCheckpoint(graphsToCheck, 1, checkpointInterval)
+
+    var iteration = 2
+    while (iteration < 9) {
+      val graph = createGraph(sc)
+      checkpointer.updateGraph(graph)
+      graph.vertices.count()
+      graph.edges.count()
+      graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+      checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
+      iteration += 1
+    }
+
+    checkpointer.deleteAllCheckpoints()
+    graphsToCheck.foreach { graph =>
+      confirmCheckpointRemoved(graph.graph)
+    }
+
+    Utils.deleteRecursively(tempDir)
+  }
+}
+
+private object PeriodicGraphCheckpointerSuite {
+
+  case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
+
+  val edges = Seq(
+    Edge[Double](0, 1, 0),
+    Edge[Double](1, 2, 0),
+    Edge[Double](2, 3, 0),
+    Edge[Double](3, 4, 0))
+
+  def createGraph(sc: SparkContext): Graph[Double, Double] = {
+    Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
+  }
+
+  def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
+    graphs.foreach { g =>
+      checkPersistence(g.graph, g.gIndex, iteration)
+    }
+  }
+
+  /**
+   * Check storage level of graph.
+   * @param gIndex  Index of graph in order inserted into checkpointer (from 1).
+   * @param iteration  Total number of graphs inserted into checkpointer.
+   */
+  def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = {
+    try {
+      if (gIndex + 2 < iteration) {
+        assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
+        assert(graph.edges.getStorageLevel == StorageLevel.NONE)
+      } else {
+        assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
+        assert(graph.edges.getStorageLevel != StorageLevel.NONE)
+      }
+    } catch {
+      case _: AssertionError =>
+        throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" +
+          s"\t gIndex = $gIndex\n" +
+          s"\t iteration = $iteration\n" +
+          s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" +
+          s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n")
+    }
+  }
+
+  def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+    graphs.reverse.foreach { g =>
+      checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval)
+    }
+  }
+
+  def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = {
+    // Note: We cannot check graph.isCheckpointed since that value is never updated.
+    //       Instead, we check for the presence of the checkpoint files.
+    //       This test should continue to work even after this graph.isCheckpointed issue
+    //       is fixed (though it can then be simplified and not look for the files).
+    val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration)
+    graph.getCheckpointFiles.foreach { checkpointFile =>
+      assert(!fs.exists(new Path(checkpointFile)),
+        "Graph checkpoint file should have been removed")
+    }
+  }
+
+  /**
+   * Check checkpointed status of graph.
+   * @param gIndex  Index of graph in order inserted into checkpointer (from 1).
+   * @param iteration  Total number of graphs inserted into checkpointer.
+   */
+  def checkCheckpoint(
+      graph: Graph[_, _],
+      gIndex: Int,
+      iteration: Int,
+      checkpointInterval: Int): Unit = {
+    try {
+      if (gIndex % checkpointInterval == 0) {
+        // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph)
+        // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint.
+        if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+          assert(graph.isCheckpointed, "Graph should be checkpointed")
+          assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files")
+        } else {
+          confirmCheckpointRemoved(graph)
+        }
+      } else {
+        // Graph should never be checkpointed
+        assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
+        assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+      }
+    } catch {
+      case e: AssertionError =>
+        throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" +
+          s"\t gIndex = $gIndex\n" +
+          s"\t iteration = $iteration\n" +
+          s"\t checkpointInterval = $checkpointInterval\n" +
+          s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" +
+          s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" +
+          s"  AssertionError message: ${e.getMessage}")
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message