mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jman...@apache.org
Subject svn commit: r1209794 [2/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/lda/ core/src/main/java/org/apache/mahout/clustering/lda/cvb/ core/src/main/java/org/apache/mahout/common/ core/src/main/java/org/apache/mahout/math/ core/sr...
Date Sat, 03 Dec 2011 00:18:47 GMT
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,502 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configurable;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DistributedRowMatrixWriter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.Sampler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Thin wrapper around a {@link Matrix} of counts of occurrences of (topic, term) pairs.  Dividing
+ * {code topicTermCount.viewRow(topic).get(term)} by the sum over the values for all terms in that
+ * row yields p(term | topic).  Instead dividing it by all topic columns for that term yields
+ * p(topic | term).
+ *
+ * Multithreading is enabled for the {@code update(Matrix)} method: this method is async, and
+ * merely submits the matrix to a work queue.  When all work has been submitted,
+ * {@code awaitTermination()} should be called, which will block until updates have been
+ * accumulated.
+ */
+public class TopicModel implements Configurable, Iterable<MatrixSlice> {
+  private static final Logger log = LoggerFactory.getLogger(TopicModel.class);
+  private final String[] dictionary;
+  private final Matrix topicTermCounts;
+  private final Vector topicSums;
+  private final int numTopics;
+  private final int numTerms;
+  private final double eta;
+  private final double alpha;
+
+  private Configuration conf;
+
+  private Sampler sampler;
+  private final int numThreads;
+  private ThreadPoolExecutor threadPool;
+  private Updater[] updaters;
+
+  public int getNumTerms() {
+    return numTerms;
+  }
+
+  public int getNumTopics() {
+    return numTopics;
+  }
+
+  public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary,
+      double modelWeight) {
+    this(numTopics, numTerms, eta, alpha, null, dictionary, 1, modelWeight);
+  }
+
+  public TopicModel(Configuration conf, double eta, double alpha,
+      String[] dictionary, int numThreads, double modelWeight, Path... modelpath) throws IOException {
+    this(loadModel(conf, modelpath), eta, alpha, dictionary, numThreads, modelWeight);
+  }
+
+  public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary,
+      int numThreads, double modelWeight) {
+    this(new DenseMatrix(numTopics, numTerms), new DenseVector(numTopics), eta, alpha, dictionary,
+        numThreads, modelWeight);
+  }
+
+  public TopicModel(int numTopics, int numTerms, double eta, double alpha, Random random,
+      String[] dictionary, int numThreads, double modelWeight) {
+    this(randomMatrix(numTopics, numTerms, random), eta, alpha, dictionary, numThreads, modelWeight);
+  }
+
+  private TopicModel(Pair<Matrix, Vector> model, double eta, double alpha, String[] dict,
+      int numThreads, double modelWeight) {
+    this(model.getFirst(), model.getSecond(), eta, alpha, dict, numThreads, modelWeight);
+  }
+
+  public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha,
+    String[] dictionary, double modelWeight) {
+    this(topicTermCounts, topicSums, eta, alpha, dictionary, 1, modelWeight);
+  }
+
+  public TopicModel(Matrix topicTermCounts, double eta, double alpha, String[] dictionary,
+      int numThreads, double modelWeight) {
+    this(topicTermCounts, viewRowSums(topicTermCounts),
+        eta, alpha, dictionary, numThreads, modelWeight);
+  }
+
+  public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha,
+    String[] dictionary, int numThreads, double modelWeight) {
+    this.dictionary = dictionary;
+    this.topicTermCounts = topicTermCounts;
+    this.topicSums = topicSums;
+    this.numTopics = topicSums.size();
+    this.numTerms = topicTermCounts.numCols();
+    this.eta = eta;
+    this.alpha = alpha;
+    this.sampler = new Sampler(new Random(1234));
+    this.numThreads = numThreads;
+    if(modelWeight != 1) {
+      topicSums.assign(Functions.mult(modelWeight));
+      for(int x = 0; x < numTopics; x++) {
+        topicTermCounts.viewRow(x).assign(Functions.mult(modelWeight));
+      }
+    }
+    initializeThreadPool();
+  }
+
+  private static Vector viewRowSums(Matrix m) {
+    Vector v = new DenseVector(m.numRows());
+    for(MatrixSlice slice : m) {
+      v.set(slice.index(), slice.vector().norm(1));
+    }
+    return v;
+  }
+
+  private void initializeThreadPool() {
+    threadPool = new ThreadPoolExecutor(numThreads, numThreads, 0, TimeUnit.SECONDS,
+        new ArrayBlockingQueue<Runnable>(numThreads * 10));
+    threadPool.allowCoreThreadTimeOut(false);
+    updaters = new Updater[numThreads];
+    for(int i = 0; i < numThreads; i++) {
+      updaters[i] = new Updater();
+      threadPool.submit(updaters[i]);
+    }
+  }
+
+  Matrix topicTermCounts() {
+    return topicTermCounts;
+  }
+
+  public Iterator<MatrixSlice> iterator() {
+    return topicTermCounts.iterateAll();
+  }
+
+  public Vector topicSums() {
+    return topicSums;
+  }
+
+  private static Pair<Matrix,Vector> randomMatrix(int numTopics, int numTerms, Random random) {
+    Matrix topicTermCounts = new DenseMatrix(numTopics, numTerms);
+    Vector topicSums = new DenseVector(numTopics);
+    if(random != null) {
+      for(int x = 0; x < numTopics; x++) {
+        for(int term = 0; term < numTerms; term++) {
+          topicTermCounts.viewRow(x).set(term, random.nextDouble());
+        }
+      }
+    }
+    for(int x = 0; x < numTopics; x++) {
+      topicSums.set(x, random == null ? 1d : topicTermCounts.viewRow(x).norm(1));
+    }
+    return Pair.of(topicTermCounts, topicSums);
+  }
+
+  public static Pair<Matrix, Vector> loadModel(Configuration conf, Path... modelPaths)
+      throws IOException {
+    int numTopics = -1;
+    int numTerms = -1;
+    List<Pair<Integer, Vector>> rows = Lists.newArrayList();
+    for(Path modelPath : modelPaths) {
+      for(Pair<IntWritable, VectorWritable> row :
+          new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) {
+        rows.add(Pair.of(row.getFirst().get(), row.getSecond().get()));
+        numTopics = Math.max(numTopics, row.getFirst().get());
+        if(numTerms < 0) {
+          numTerms = row.getSecond().get().size();
+        }
+      }
+    }
+    if(rows.isEmpty()) {
+      throw new IOException(modelPaths + " have no vectors in it");
+    }
+    numTopics++;
+    Matrix model = new DenseMatrix(numTopics, numTerms);
+    Vector topicSums = new DenseVector(numTopics);
+    for(Pair<Integer, Vector> pair : rows) {
+      model.viewRow(pair.getFirst()).assign(pair.getSecond());
+      topicSums.set(pair.getFirst(), pair.getSecond().norm(1));
+    }
+    return Pair.of(model, topicSums);
+  }
+
+  public String toString() {
+    String buf = "";
+    for(int x = 0; x < numTopics; x++) {
+      String v = dictionary != null
+          ? vectorToSortedString(topicTermCounts.viewRow(x).normalize(1), dictionary)
+          : topicTermCounts.viewRow(x).asFormatString();
+      buf += v + "\n";
+    }
+    return buf;
+  }
+
+  public int sampleTerm(Vector topicDistribution) {
+    return sampler.sample(topicTermCounts.viewRow(sampler.sample(topicDistribution)));
+  }
+
+  public int sampleTerm(int topic) {
+    return sampler.sample(topicTermCounts.viewRow(topic));
+  }
+
+  public void reset() {
+    for(int x = 0; x < numTopics; x++) {
+      topicTermCounts.assignRow(x, new SequentialAccessSparseVector(numTerms));
+    }
+    topicSums.assign(1d);
+    initializeThreadPool();
+  }
+
+  public void awaitTermination() {
+    for(Updater updater : updaters) {
+      updater.shutdown();
+    }
+  }
+
+  public void renormalize() {
+    for(int x = 0; x < numTopics; x++) {
+      topicTermCounts.assignRow(x, topicTermCounts.viewRow(x).normalize(1));
+      topicSums.assign(1d);
+    }
+  }
+
+  public void trainDocTopicModel(Vector original, Vector topics, Matrix docTopicModel) {
+    // first calculate p(topic|term,document) for all terms in original, and all topics,
+    // using p(term|topic) and p(topic|doc)
+    pTopicGivenTerm(original, topics, docTopicModel);
+    normalizeByTopic(docTopicModel);
+    // now multiply, term-by-term, by the document, to get the weighted distribution of
+    // term-topic pairs from this document.
+    Iterator<Vector.Element> it = original.iterateNonZero();
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      for(int x = 0; x < numTopics; x++) {
+        Vector docTopicModelRow = docTopicModel.viewRow(x);
+        docTopicModelRow.setQuick(e.index(), docTopicModelRow.getQuick(e.index()) * e.get());
+      }
+    }
+    // now recalculate p(topic|doc) by summing contributions from all of pTopicGivenTerm
+    topics.assign(0d);
+    for(int x = 0; x < numTopics; x++) {
+      topics.set(x, docTopicModel.viewRow(x).norm(1));
+    }
+    // now renormalize so that sum_x(p(x|doc)) = 1
+    topics.assign(Functions.mult(1/topics.norm(1)));
+  }
+
+  public Vector infer(Vector original, Vector docTopics) {
+    Vector pTerm = original.like();
+    Iterator<Vector.Element> it = original.iterateNonZero();
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      int term = e.index();
+      // p(a) = sum_x (p(a|x) * p(x|i))
+      double pA = 0;
+      for(int x = 0; x < numTopics; x++) {
+        pA += (topicTermCounts.viewRow(x).get(term) / topicSums.get(x)) * docTopics.get(x);
+      }
+      pTerm.set(term, pA);
+    }
+    return pTerm;
+  }
+
+  public void update(Matrix docTopicCounts) {
+    for(int x = 0; x < numTopics; x++) {
+      updaters[x % updaters.length].update(x, docTopicCounts.viewRow(x));
+    }
+  }
+
+  public void updateTopic(int topic, Vector docTopicCounts) {
+    topicTermCounts.viewRow(topic).assign(docTopicCounts, Functions.PLUS);
+    topicSums.set(topic, topicSums.get(topic) + docTopicCounts.norm(1));
+  }
+
+  public void update(int termId, Vector topicCounts) {
+    for(int x = 0; x < numTopics; x++) {
+      Vector v = topicTermCounts.viewRow(x);
+      v.set(termId, v.get(termId) + topicCounts.get(x));
+    }
+    topicSums.assign(topicCounts, Functions.PLUS);
+  }
+
+  public void persist(Path outputDir, boolean overwrite) throws IOException {
+    FileSystem fs = outputDir.getFileSystem(conf);
+    if(overwrite) {
+      fs.delete(outputDir, true); // CHECK second arg
+    }
+    DistributedRowMatrixWriter.write(outputDir, conf, topicTermCounts);
+  }
+
+  /**
+   * Computes {@code p(topic x|term a, document i)} distributions given input document {@code i}.
+   * {@code pTGT[x][a]} is the (un-normalized) {@code p(x|a,i)}, or if docTopics is {@code null},
+   * {@code p(a|x)} (also un-normalized).
+   *
+   * @param document doc-term vector encoding {@code w(term a|document i)}.
+   * @param docTopics {@code docTopics[x]} is the overall weight of topic {@code x} in given
+   *          document. If {@code null}, a topic weight of {@code 1.0} is used for all topics.
+   * @param termTopicDist storage for output {@code p(x|a,i)} distributions.
+   */
+  private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopicDist) {
+    // for each topic x
+    for(int x = 0; x < numTopics; x++) {
+      // get p(topic x | document i), or 1.0 if docTopics is null
+      double topicWeight = docTopics == null ? 1d : docTopics.get(x);
+      // get w(term a | topic x)
+      Vector topicTermRow = topicTermCounts.viewRow(x);
+      // get \sum_a w(term a | topic x)
+      double topicSum = topicSums.get(x);
+      // get p(topic x | term a) distribution to update
+      Vector termTopicRow = termTopicDist.viewRow(x);
+
+      // for each term a in document i with non-zero weight
+      Iterator<Vector.Element> it = document.iterateNonZero();
+      while(it.hasNext()) {
+        Vector.Element e = it.next();
+        int termIndex = e.index();
+
+        // calc un-normalized p(topic x | term a, document i)
+        double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * (topicWeight + alpha) / (topicSum + eta * numTerms);
+        termTopicRow.set(termIndex, termTopicLikelihood);
+      }
+    }
+  }
+
+  /**
+   * sum_x sum_a (c_ai * log(p(x|i) * p(a|x)))
+   * @param document
+   * @param docTopics
+   * @return
+   */
+  public double perplexity(Vector document, Vector docTopics) {
+    double perplexity = 0;
+    double norm = docTopics.norm(1) + (docTopics.size() * alpha);
+    Iterator<Vector.Element> it = document.iterateNonZero();
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      int term = e.index();
+      double prob = 0;
+      for(int x = 0; x < numTopics; x++) {
+        double d = (docTopics.get(x) + alpha) / norm;
+        double p = d * (topicTermCounts.viewRow(x).get(term) + eta)
+                   / (topicSums.get(x) + eta * numTerms);
+        prob += p;
+      }
+      perplexity += e.get() * Math.log(prob);
+    }
+    return -perplexity;
+  }
+
+  private void normalizeByTopic(Matrix perTopicSparseDistributions) {
+    Iterator<Vector.Element> it = perTopicSparseDistributions.viewRow(0).iterateNonZero();
+    // then make sure that each of these is properly normalized by topic: sum_x(p(x|t,d)) = 1
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      int a = e.index();
+      double sum = 0;
+      for(int x = 0; x < numTopics; x++) {
+        sum += perTopicSparseDistributions.viewRow(x).get(a);
+      }
+      for(int x = 0; x < numTopics; x++) {
+        perTopicSparseDistributions.viewRow(x).set(a,
+            perTopicSparseDistributions.viewRow(x).get(a) / sum);
+      }
+    }
+  }
+
+  public static String vectorToSortedString(Vector vector, String[] dictionary) {
+    List<Pair<String,Double>> vectorValues =
+        new ArrayList<Pair<String, Double>>(vector.getNumNondefaultElements());
+    Iterator<Vector.Element> it = vector.iterateNonZero();
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      vectorValues.add(Pair.of(dictionary != null ? dictionary[e.index()] : String.valueOf(e.index()),
+                               e.get()));
+    }
+    Collections.sort(vectorValues, new Comparator<Pair<String, Double>>() {
+      @Override public int compare(Pair<String, Double> x, Pair<String, Double> y) {
+        return y.getSecond().compareTo(x.getSecond());
+      }
+    });
+    Iterator<Pair<String,Double>> listIt = vectorValues.iterator();
+    StringBuilder bldr = new StringBuilder(2048);
+    bldr.append("{");
+    int i = 0;
+    while(listIt.hasNext() && i < 25) {
+      i++;
+      Pair<String,Double> p = listIt.next();
+      bldr.append(p.getFirst());
+      bldr.append(":");
+      bldr.append(p.getSecond());
+      bldr.append(",");
+    }
+    if(bldr.length() > 1) {
+      bldr.setCharAt(bldr.length() - 1, '}');
+    }
+    return bldr.toString();
+  }
+
+  @Override
+  public void setConf(Configuration configuration) {
+    this.conf = configuration;
+  }
+
+  @Override
+  public Configuration getConf() {
+    return conf;
+  }
+
+  private final class Updater implements Runnable {
+    private ArrayBlockingQueue<Pair<Integer, Vector>> queue =
+        new ArrayBlockingQueue<Pair<Integer, Vector>>(100);
+    private boolean shutdown = false;
+    private boolean shutdownComplete = false;
+
+    public void shutdown() {
+      try {
+        synchronized (this) {
+          while(!shutdownComplete) {
+            shutdown = true;
+            wait();
+          }
+        }
+      } catch (InterruptedException e) {
+        log.warn("Interrupted waiting to shutdown() : ", e);
+      }
+    }
+
+    public boolean update(int topic, Vector v) {
+      if(shutdown) { // maybe don't do this?
+        throw new IllegalStateException("In SHUTDOWN state: cannot submit tasks");
+      }
+      while(true) { // keep trying if interrupted
+        try {
+          // start async operation by submitting to the queue
+          queue.put(Pair.of(topic, v));
+          // return once you got access to the queue
+          return true;
+        } catch (InterruptedException e) {
+          log.warn("Interrupted trying to queue update:", e);
+        }
+      }
+    }
+
+    @Override public void run() {
+      while(!shutdown) {
+        try {
+          Pair<Integer, Vector> pair = queue.poll(1, TimeUnit.SECONDS);
+          if(pair != null) {
+            updateTopic(pair.getFirst(), pair.getSecond());
+          }
+        } catch (InterruptedException e) {
+          log.warn("Interrupted waiting to poll for update", e);
+        }
+      }
+      // in shutdown mode, finish remaining tasks!
+      for(Pair<Integer, Vector> pair : queue) {
+        updateTopic(pair.getFirst(), pair.getSecond());
+      }
+      synchronized (this) {
+        shutdownComplete = true;
+        notifyAll();
+      }
+    }
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/common/MemoryUtil.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/MemoryUtil.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/MemoryUtil.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/MemoryUtil.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.common;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Memory utilities.
+ */
+public class MemoryUtil {
+  private static final Logger log = LoggerFactory.getLogger(MemoryUtil.class);
+
+  /**
+   * Logs current heap memory statistics.
+   *
+   * @see Runtime
+   */
+  public static void logMemoryStatistics() {
+    Runtime runtime = Runtime.getRuntime();
+    long freeBytes = runtime.freeMemory();
+    long maxBytes = runtime.maxMemory();
+    long totalBytes = runtime.totalMemory();
+    long usedBytes = totalBytes - freeBytes;
+    log.info("Memory (bytes): {} used, {} heap, {} max", new Object[] { usedBytes, totalBytes,
+            maxBytes });
+  }
+
+  private static ScheduledExecutorService scheduler;
+
+  /**
+   * Constructs and starts a memory logger thread.
+   *
+   * @param rateInMillis how often memory info should be logged.
+   */
+  public static void startMemoryLogger(long rateInMillis) {
+    stopMemoryLogger();
+    scheduler = Executors.newScheduledThreadPool(1, new ThreadFactory() {
+      private final ThreadFactory delegate = Executors.defaultThreadFactory();
+
+      @Override
+      public Thread newThread(Runnable r) {
+        Thread t = delegate.newThread(r);
+        t.setDaemon(true);
+        return t;
+      }
+    });
+    Runnable memoryLoogerRunnable = new Runnable() {
+      public void run() {
+        logMemoryStatistics();
+      }
+    };
+    scheduler.scheduleAtFixedRate(memoryLoogerRunnable, rateInMillis, rateInMillis,
+        TimeUnit.MILLISECONDS);
+  }
+
+  /**
+   * Constructs and starts a memory logger thread with a logging rate of 1000 milliseconds.
+   */
+  public static void startMemoryLogger() {
+    startMemoryLogger(1000);
+  }
+
+  /**
+   * Stops the memory logger, if any, started via {@link #startMemoryLogger(long)} or
+   * {@link #startMemoryLogger()}.
+   */
+  public static void stopMemoryLogger() {
+    if (scheduler == null) {
+      return;
+    }
+    scheduler.shutdownNow();
+    scheduler = null;
+  }
+
+  /**
+   * Tests {@link MemoryLoggerThread}.
+   *
+   * @param args
+   * @throws InterruptedException
+   */
+  public static void main(String[] args) throws InterruptedException {
+    startMemoryLogger();
+    Thread.sleep(10000);
+  }
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/Pair.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/Pair.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/Pair.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/Pair.java Sat Dec  3 00:18:46 2011
@@ -41,6 +41,10 @@ public final class Pair<A,B> implements 
   public Pair<B, A> swap() {
     return new Pair<B, A>(second, first);
   }
+
+  public static <A,B> Pair<A,B> of(A a, B b) {
+    return new Pair<A, B>(a, b);
+  }
   
   @Override
   public boolean equals(Object obj) {

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+
+import java.io.IOException;
+
+public class DistributedRowMatrixWriter {
+
+  public static void write(Path outputDir, Configuration conf, VectorIterable matrix)
+      throws IOException {
+    FileSystem fs = outputDir.getFileSystem(conf);
+    SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir,
+        IntWritable.class, VectorWritable.class);
+    IntWritable topic = new IntWritable();
+    VectorWritable vector = new VectorWritable();
+    for(MatrixSlice slice : matrix) {
+      topic.set(slice.index());
+      vector.set(slice.vector());
+      writer.append(topic, vector);
+    }
+    writer.close();
+
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixUtils.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixUtils.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixUtils.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import java.io.IOException;
+import java.util.List;
+
+public class MatrixUtils {
+
+  public static void write(Path outputDir, Configuration conf, VectorIterable matrix)
+      throws IOException {
+    FileSystem fs = outputDir.getFileSystem(conf);
+    fs.delete(outputDir, true);
+    SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir,
+        IntWritable.class, VectorWritable.class);
+    IntWritable topic = new IntWritable();
+    VectorWritable vector = new VectorWritable();
+    for(MatrixSlice slice : matrix) {
+      topic.set(slice.index());
+      vector.set(slice.vector());
+      writer.append(topic, vector);
+    }
+    writer.close();
+  }
+
+  public static Matrix read(Configuration conf, Path... modelPaths) throws IOException {
+    int numRows = -1;
+    int numCols = -1;
+    boolean sparse = false;
+    List<Pair<Integer, Vector>> rows = Lists.newArrayList();
+    for(Path modelPath : modelPaths) {
+      for(Pair<IntWritable, VectorWritable> row :
+          new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) {
+        rows.add(Pair.of(row.getFirst().get(), row.getSecond().get()));
+        numRows = Math.max(numRows, row.getFirst().get());
+        sparse = !row.getSecond().get().isDense();
+        if(numCols < 0) {
+          numCols = row.getSecond().get().size();
+        }
+      }
+    }
+    if(rows.isEmpty()) {
+      throw new IOException(modelPaths + " have no vectors in it");
+    }
+    numRows++;
+    Vector[] arrayOfRows = new Vector[numRows];
+    for(Pair<Integer, Vector> pair : rows) {
+      arrayOfRows[pair.getFirst()] = pair.getSecond();
+    }
+    Matrix matrix;
+    if(sparse) {
+      matrix = new SparseRowMatrix(numRows, numCols, arrayOfRows);
+    } else {
+      matrix = new DenseMatrix(numRows, numCols);
+      for(int i = 0; i < numRows; i++) {
+        matrix.assignRow(i, arrayOfRows[i]);
+      }
+    }
+    return matrix;
+  }
+
+  public static OpenObjectIntHashMap<String> readDictionary(Configuration conf, Path... dictPath)
+    throws IOException {
+    OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<String>();
+    for(Path dictionaryFile : dictPath) {
+      for (Pair<Writable, IntWritable> record
+              : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+        dictionary.put(record.getFirst().toString(), record.getSecond().get());
+      }
+    }
+    return dictionary;
+  }
+
+  public static String[] invertDictionary(OpenObjectIntHashMap<String> termIdMap) {
+    int maxTermId = -1;
+    for(String term : termIdMap.keys()) {
+      maxTermId = Math.max(maxTermId, termIdMap.get(term));
+    }
+    maxTermId++;
+    String[] dictionary = new String[maxTermId];
+    for(String term : termIdMap.keys()) {
+      dictionary[termIdMap.get(term)] = term;
+    }
+    return dictionary;
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/Sampler.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/Sampler.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/Sampler.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/Sampler.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,69 @@
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Discrete distribution sampler:
+ *
+ * Samples from a given discrete distribution: you provide a source of randomness and a Vector
+ * (cardinality N) which describes a distribution over [0,N), and calls to sample() sample
+ * from 0 to N using this distribution
+ *
+ */
+public class Sampler {
+
+  private Random random;
+  private double[] sampler;
+
+  public Sampler(Random random) {
+    this.random = random;
+    sampler = null;
+  }
+
+  public Sampler(Random random, double[] sampler) {
+    this.random = random;
+    this.sampler = sampler;
+  }
+
+  public Sampler(Random random, Vector distribution) {
+    this.random = random;
+    this.sampler = samplerFor(distribution);
+  }
+
+  public int sample(Vector distribution) {
+    return sample(samplerFor(distribution));
+  }
+
+  public int sample() {
+    if(sampler == null) {
+      throw new NullPointerException("Sampler must have been constructed with a distribution, or"
+        + " else sample(Vector) should be used to sample");
+    }
+    return sample(sampler);
+  }
+
+  private double[] samplerFor(double[] distribution) {
+    return samplerFor(new DenseVector(distribution));
+  }
+
+  private double[] samplerFor(Vector vectorDistribution) {
+    int size = vectorDistribution.size();
+    double[] partition = new double[size];
+    double norm = vectorDistribution.norm(1);
+    double sum = 0;
+    for(int i = 0; i < size; i++) {
+      sum += (vectorDistribution.get(i) / norm);
+      partition[i] = sum;
+    }
+    return partition;
+  }
+
+  private int sample(double[] sampler) {
+    int index = Arrays.binarySearch(sampler, random.nextDouble());
+    return index < 0 ? -(index+1) : index;
+  }
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java Sat Dec  3 00:18:46 2011
@@ -24,9 +24,17 @@ import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.clustering.lda.LDASampler;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.DoubleFunction;
 
 import java.io.IOException;
+import java.util.Random;
 
 public final class ClusteringTestUtils {
 
@@ -60,4 +68,45 @@ public final class ClusteringTestUtils {
     }
   }
 
+  public static Matrix sampledCorpus(Matrix matrix, Random random,
+      int numDocs, int numSamples, int numTopicsPerDoc) {
+    Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols());
+    LDASampler modelSampler = new LDASampler(matrix, random);
+    Vector topicVector = new DenseVector(matrix.numRows());
+    for(int i = 0; i < numTopicsPerDoc; i++) {
+      int topic = random.nextInt(topicVector.size());
+      topicVector.set(topic, topicVector.get(topic) + 1);
+    }
+    for(int docId = 0; docId < numDocs; docId++) {
+      for(int sample : modelSampler.sample(topicVector, numSamples)) {
+        corpus.set(docId, sample, corpus.get(docId, sample) + 1);
+      }
+    }
+    return corpus;
+  }
+
+  public static Matrix randomStructuredModel(int numTopics, int numTerms) {
+    return randomStructuredModel(numTopics, numTerms, new DoubleFunction() {
+      @Override public double apply(double d) {
+        return 1.0 / (1 + Math.abs(d));
+      }
+    });
+  }
+
+  public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) {
+    Matrix model = new DenseMatrix(numTopics, numTerms);
+    int width = numTerms / numTopics;
+    for(int topic = 0; topic < numTopics; topic++) {
+      int topicCentroid = width * (1+topic);
+      for(int i = 0; i < numTerms; i++) {
+        int distance = Math.abs(topicCentroid - i);
+        if(distance > numTerms / 2) {
+          distance = numTerms - distance;
+        }
+        double v = decay.apply(distance);
+        model.set(topic, i, v);
+      }
+    }
+    return model;
+  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Sat Dec  3 00:18:46 2011
@@ -16,15 +16,13 @@
  */
 package org.apache.mahout.clustering.lda;
 
-import org.apache.commons.math.distribution.IntegerDistribution;
-import org.easymock.EasyMock;
-
-import java.util.Iterator;
-import java.util.Random;
-
+import com.google.common.base.Joiner;
+import com.google.common.collect.Lists;
 import org.apache.commons.math.MathException;
-
+import org.apache.commons.math.distribution.IntegerDistribution;
 import org.apache.commons.math.distribution.PoissonDistributionImpl;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.mahout.common.IntPairWritable;
@@ -32,11 +30,21 @@ import org.apache.mahout.common.MahoutTe
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixUtils;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.easymock.EasyMock;
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.Ignore;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import static org.apache.mahout.clustering.ClusteringTestUtils.*;
 
 public final class TestMapReduce extends MahoutTestCase {
 
@@ -109,6 +117,55 @@ public final class TestMapReduce extends
     }
   }
 
+  @Test
+  @Ignore("MAHOUT-399")
+  public void testEndToEnd() throws Exception {
+    double eta = 0.1;
+    int numGeneratingTopics = 5;
+    int numTerms = 26;
+    Matrix matrix = randomStructuredModel(numGeneratingTopics, numTerms,
+        new DoubleFunction() {
+      @Override public double apply(double d) {
+        return 1d / Math.pow(d+1, 3);
+      }
+    });
+
+    int numDocs = 500;
+    int numSamples = 10;
+    int numTopicsPerDoc = 1;
+
+    Matrix sampledCorpus = sampledCorpus(matrix, new Random(1234),
+        numDocs, numSamples, numTopicsPerDoc);
+
+    Path sampleCorpusPath = getTestTempDirPath("corpus");
+    MatrixUtils.write(sampleCorpusPath, new Configuration(), sampledCorpus);
+
+    int numIterations = 10;
+    List<Double> perplexities = Lists.newArrayList();
+    int startTopic = numGeneratingTopics - 2;
+    int numTestTopics = startTopic;
+    while(numTestTopics < numGeneratingTopics + 3) {
+      LDADriver driver = new LDADriver();
+      driver.setConf(new Configuration());
+      Path outputPath = getTestTempDirPath("output" + numTestTopics);
+      perplexities.add(driver.run(driver.getConf(), sampleCorpusPath, outputPath, numTestTopics,
+          numTerms, eta, numIterations, false));
+      numTestTopics++;
+    }
+    
+    int bestTopic = -1;
+    double lowestPerplexity = Double.MAX_VALUE;
+    for(int t = 0; t < perplexities.size(); t++) {
+      if(perplexities.get(t) < lowestPerplexity) {
+        lowestPerplexity = perplexities.get(t);
+        bestTopic = t + startTopic;
+      }
+    }
+    assertEquals("The optimal number of topics is not that of the generating distribution",
+        bestTopic, numGeneratingTopics);
+    System.out.println("Perplexities: " + Joiner.on(", ").join(perplexities));
+  }
+
   private static int numNonZero(Vector v) {
     int count = 0;
     for(Iterator<Vector.Element> iter = v.iterateNonZero();

Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixUtils;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+import static org.apache.mahout.clustering.ClusteringTestUtils.randomStructuredModel;
+import static org.apache.mahout.clustering.ClusteringTestUtils.sampledCorpus;
+
+public class TestCVBModelTrainer extends MahoutTestCase {
+  private double eta = 0.1;
+  private double alpha = 0.1;
+
+  @Test
+  public void testInMemoryCVB0() throws Exception {
+    int numGeneratingTopics = 5;
+    int numTerms = 26;
+    String[] terms = new String[26];
+    for(int i=0; i<terms.length; i++) {
+      terms[i] = "" + ((char)(i + 97));
+    }
+    Matrix matrix = randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction() {
+      @Override public double apply(double d) {
+        return 1d / Math.pow(d+1, 2);
+      }
+    });
+
+    int numDocs = 100;
+    int numSamples = 20;
+    int numTopicsPerDoc = 1;
+
+    Matrix sampledCorpus = sampledCorpus(matrix, new Random(12345),
+        numDocs, numSamples, numTopicsPerDoc);
+
+    List<Double> perplexities = Lists.newArrayList();
+    int numTrials = 2;
+    for(int numTestTopics = 1; numTestTopics < 2 * numGeneratingTopics; numTestTopics++) {
+      double[] perps = new double[numTrials];
+      for(int trial = 0; trial < numTrials; trial++) {
+        InMemoryCollapsedVariationalBayes0 cvb =
+          new InMemoryCollapsedVariationalBayes0(sampledCorpus, terms, numTestTopics, alpha, eta,
+              2, 1, 0, (trial+1) * 123456L);
+        cvb.setVerbose(true);
+        perps[trial] = cvb.iterateUntilConvergence(0, 20, 0, 0.2);
+        System.out.println(perps[trial]);
+      }
+      Arrays.sort(perps);
+      System.out.println(Arrays.toString(perps));
+      perplexities.add(perps[0]);
+    }
+    System.out.println(Joiner.on(",").join(perplexities));
+  }
+
+  @Test
+  public void testRandomStructuredModelViaMR() throws Exception {
+    int numGeneratingTopics = 3;
+    int numTerms = 9;
+    Matrix matrix = randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction() {
+      @Override public double apply(double d) {
+        return 1d / Math.pow(d+1, 3);
+      }
+    });
+
+    int numDocs = 500;
+    int numSamples = 10;
+    int numTopicsPerDoc = 1;
+
+    Matrix sampledCorpus = sampledCorpus(matrix, new Random(1234),
+        numDocs, numSamples, numTopicsPerDoc);
+
+    Path sampleCorpusPath = getTestTempDirPath("corpus");
+    MatrixUtils.write(sampleCorpusPath, new Configuration(), sampledCorpus);
+    int numIterations = 5;
+    List<Double> perplexities = Lists.newArrayList();
+    int startTopic = numGeneratingTopics - 1;
+    int numTestTopics = startTopic;
+    while(numTestTopics < numGeneratingTopics + 2) {
+      CVB0Driver driver = new CVB0Driver();
+      Path topicModelStateTempPath = getTestTempDirPath("topicTemp" + numTestTopics);
+      Configuration conf = new Configuration();
+      driver.run(conf, sampleCorpusPath, null, numTestTopics, numTerms,
+          alpha, eta, numIterations, 1, 0, null, null, topicModelStateTempPath, 1234, 0.2f, 2,
+          1, 10, 1, false);
+      perplexities.add(lowestPerplexity(conf, topicModelStateTempPath));
+      numTestTopics++;
+    }
+    int bestTopic = -1;
+    double lowestPerplexity = Double.MAX_VALUE;
+    for(int t = 0; t < perplexities.size(); t++) {
+      if(perplexities.get(t) < lowestPerplexity) {
+        lowestPerplexity = perplexities.get(t);
+        bestTopic = t + startTopic;
+      }
+    }
+    assertEquals("The optimal number of topics is not that of the generating distribution",
+        bestTopic, numGeneratingTopics);
+    System.out.println("Perplexities: " + Joiner.on(", ").join(perplexities));
+  }
+
+  private static double lowestPerplexity(Configuration conf, Path topicModelTemp)
+      throws IOException {
+    double lowest = Double.MAX_VALUE;
+    double current;
+    int iteration = 2;
+    while(!Double.isNaN(current = CVB0Driver.readPerplexity(conf, topicModelTemp, iteration))) {
+      lowest = Math.min(current, lowest);
+      iteration++;
+    }
+    return lowest;
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/SamplerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/SamplerTest.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/SamplerTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/SamplerTest.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class SamplerTest extends MahoutTestCase {
+
+  @Test
+  public void testDiscreteSampler() {
+    Vector distribution = new DenseVector(new double[] {1, 0, 2, 3, 5, 0});
+    Sampler sampler = new Sampler(RandomUtils.getRandom(1234), distribution);
+    Vector sampledDistribution = distribution.like();
+    int i = 0;
+    while(i < 10000) {
+      int index = sampler.sample();
+      sampledDistribution.set(index, sampledDistribution.get(index) + 1);
+      i++;
+    }
+    assertTrue("sampled distribution is far from the original",
+        l1Dist(distribution, sampledDistribution) < 1e-2);
+  }
+
+  private double l1Dist(Vector v, Vector w) {
+    return v.normalize(1d).minus(w.normalize(1)).norm(1d);
+  }
+}

Modified: mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java (original)
+++ mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java Sat Dec  3 00:18:46 2011
@@ -30,9 +30,12 @@ import org.apache.commons.cli2.builder.G
 import org.apache.commons.cli2.commandline.Parser;
 import org.apache.commons.cli2.util.HelpFormatter;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.Pair;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
 import org.apache.mahout.math.NamedVector;
@@ -88,21 +91,31 @@ public final class VectorDumper {
     Option namesAsCommentsOpt = obuilder.withLongName("namesAsComments").withRequired(false).withDescription(
             "If using CSV output, optionally add a comment line for each NamedVector (if the vector is one) printing out the name")
             .withShortName("n").create();
+    Option sortVectorsOpt = obuilder.withLongName("sortVectors").withRequired(false).withDescription(
+            "Sort output key/value pairs of the vector entries in abs magnitude descending order")
+            .withShortName("sort").create();
     Option sizeOpt = obuilder.withLongName("sizeOnly").withRequired(false).
             withDescription("Dump only the size of the vector").withShortName("sz").create();
-    Option numItemsOpt = obuilder.withLongName("n").withRequired(false).withArgument(
-            abuilder.withName("numItems").withMinimum(1).withMaximum(1).create()).
-            withDescription("Output at most <n> key value pairs").withShortName("n").create();
+    Option numItemsOpt = obuilder.withLongName("numItems").withRequired(false).withArgument(
+            abuilder.withName("n").withMinimum(1).withMaximum(1).create()).
+            withDescription("Output at most <n> vecors").withShortName("n").create();
+    Option numIndexesPerVectorOpt = obuilder.withLongName("vectorSize").withShortName("vs")
+        .withRequired(false).withArgument(abuilder.withName("vs").withMinimum(1)
+        .withMaximum(1).create())
+        .withDescription("Truncate vectors to <vs> length when dumping (most useful when in"
+                          + " conjunction with -sort").create();
     Option filtersOpt = obuilder.withLongName("filter").withRequired(false).withArgument(
             abuilder.withName("filter").withMinimum(1).withMaximum(100).create()).
-            withDescription("Only dump out those vectors whose name matches the filter.  Multiple items may be specified by repeating the argument.").withShortName("fi").create();
+            withDescription("Only dump out those vectors whose name matches the filter." +
+            "  Multiple items may be specified by repeating the argument.").withShortName("fi").create();
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
             .create();
 
-    Group group = gbuilder.withName("Options").withOption(seqOpt).withOption(outputOpt).withOption(
-            dictTypeOpt).withOption(dictOpt).withOption(csvOpt).withOption(vectorAsKeyOpt).withOption(
-            printKeyOpt).withOption(sizeOpt).withOption(numItemsOpt).withOption(filtersOpt)
-            .withOption(helpOpt).create();
+    Group group = gbuilder.withName("Options").withOption(seqOpt).withOption(outputOpt)
+                          .withOption(dictTypeOpt).withOption(dictOpt).withOption(csvOpt)
+                          .withOption(vectorAsKeyOpt).withOption(printKeyOpt).withOption(sortVectorsOpt)
+                          .withOption(filtersOpt).withOption(helpOpt).withOption(numItemsOpt)
+                          .withOption(sizeOpt).withOption(numIndexesPerVectorOpt).create();
 
     try {
       Parser parser = new Parser();
@@ -110,21 +123,24 @@ public final class VectorDumper {
       CommandLine cmdLine = parser.parse(args);
 
       if (cmdLine.hasOption(helpOpt)) {
-
-        printHelp(group);
+        CommandLineUtil.printHelpWithGenericOptions(group);
         return;
       }
 
       if (cmdLine.hasOption(seqOpt)) {
-        Path path = new Path(cmdLine.getValue(seqOpt).toString());
-        //System.out.println("Input Path: " + path); interferes with output?
         Configuration conf = new Configuration();
+        Path pathPattern = new Path(cmdLine.getValue(seqOpt).toString());
+        FileSystem fs = FileSystem.get(conf);
+        FileStatus[] inputPaths = fs.globStatus(pathPattern);
 
         String dictionaryType = "text";
         if (cmdLine.hasOption(dictTypeOpt)) {
           dictionaryType = cmdLine.getValue(dictTypeOpt).toString();
         }
 
+        boolean sortVectors = cmdLine.hasOption(sortVectorsOpt);
+        log.info("Sort? " + sortVectors);
+
         String[] dictionary = null;
         if (cmdLine.hasOption(dictOpt)) {
           if ("text".equals(dictionaryType)) {
@@ -168,55 +184,70 @@ public final class VectorDumper {
             }
             writer.write('\n');
           }
-          long numItems = Long.MAX_VALUE;
+          Long numItems = null;
           if (cmdLine.hasOption(numItemsOpt)) {
             numItems = Long.parseLong(cmdLine.getValue(numItemsOpt).toString());
             writer.append("#Max Items to dump: ").append(String.valueOf(numItems)).append('\n');
           }
-          SequenceFileIterable<Writable, Writable> iterable = new SequenceFileIterable<Writable, Writable>(path, true, conf);
-          Iterator<Pair<Writable,Writable>> iterator = iterable.iterator();
-          long i = 0;
-          long count = 0;
-          while (iterator.hasNext() && count < numItems) {
-            Pair<Writable, Writable> record = iterator.next();
-            Writable keyWritable = record.getFirst();
-            Writable valueWritable = record.getSecond();
-            if (printKey) {
-              Writable notTheVectorWritable = transposeKeyValue ? valueWritable : keyWritable;
-              writer.write(notTheVectorWritable.toString());
-              writer.write('\t');
+          int maxIndexesPerVector = cmdLine.hasOption(numIndexesPerVectorOpt)
+              ? Integer.parseInt(cmdLine.getValue(numIndexesPerVectorOpt).toString())
+              : Integer.MAX_VALUE;
+          long itemCount = 0;
+          int fileCount = 0;
+          for (FileStatus stat : inputPaths) {
+            if (numItems != null && numItems <= itemCount) {
+              break;
             }
-            VectorWritable vectorWritable = (VectorWritable) (transposeKeyValue ? keyWritable : valueWritable);
-            Vector vector = vectorWritable.get();
-            if (filters != null && (vector instanceof NamedVector && filters.contains(((NamedVector)vector).getName()) == false)){
-              //we are filtering out this item, skip
-              continue;
-            }
-            if (sizeOnly) {
-              if (vector instanceof NamedVector) {
-                writer.write(((NamedVector) vector).getName());
-                writer.write(":");
-              } else {
-                writer.write(String.valueOf(i++));
-                writer.write(":");
+            Path path = stat.getPath();
+            log.info("Processing file '{}' ({}/{})",
+                new Object[]{path, ++fileCount, inputPaths.length});
+            SequenceFileIterable<Writable, Writable> iterable =
+                new SequenceFileIterable<Writable, Writable>(path, true, conf);
+            Iterator<Pair<Writable,Writable>> iterator = iterable.iterator();
+            long i = 0;
+            while (iterator.hasNext() && (numItems == null || itemCount < numItems)) {
+              Pair<Writable, Writable> record = iterator.next();
+              Writable keyWritable = record.getFirst();
+              Writable valueWritable = record.getSecond();
+              if (printKey) {
+                Writable notTheVectorWritable = transposeKeyValue ? valueWritable : keyWritable;
+                writer.write(notTheVectorWritable.toString());
+                writer.write('\t');
               }
-              writer.write(String.valueOf(vector.size()));
-              writer.write('\n');
-            } else {
-              String fmtStr;
-              if (useCSV) {
-                fmtStr = VectorHelper.vectorToCSVString(vector, namesAsComments);
+              VectorWritable vectorWritable =
+                  (VectorWritable) (transposeKeyValue ? keyWritable : valueWritable);
+              Vector vector = vectorWritable.get();
+              if (filters != null
+                  && vector instanceof NamedVector
+                  && filters.contains(((NamedVector)vector).getName()) == false){
+                //we are filtering out this item, skip
+                continue;
+              }
+              if (sizeOnly) {
+                if (vector instanceof NamedVector) {
+                  writer.write(((NamedVector) vector).getName());
+                  writer.write(":");
+                } else {
+                  writer.write(String.valueOf(i++));
+                  writer.write(":");
+                }
+                writer.write(String.valueOf(vector.size()));
+                writer.write('\n');
               } else {
-                fmtStr = vector.asFormatString();
+                String fmtStr;
+                if (useCSV) {
+                  fmtStr = VectorHelper.vectorToCSVString(vector, namesAsComments);
+                } else {
+                  fmtStr = VectorHelper.vectorToJson(vector, dictionary, maxIndexesPerVector,
+                      sortVectors);
+                }
+                writer.write(fmtStr);
+                writer.write('\n');
               }
-              writer.write(fmtStr);
-              writer.write('\n');
+              itemCount++;
             }
-            count++;
           }
-
           writer.flush();
-
         } finally {
           if (shouldClose) {
             Closeables.closeQuietly(writer);

Modified: mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorHelper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorHelper.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorHelper.java (original)
+++ mahout/trunk/integration/src/main/java/org/apache/mahout/utils/vectors/VectorHelper.java Sat Dec  3 00:18:46 2011
@@ -17,18 +17,16 @@
 
 package org.apache.mahout.utils.vectors;
 
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.util.Iterator;
-import java.util.regex.Pattern;
-
+import com.google.common.base.Function;
+import com.google.common.collect.Collections2;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Ordering;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
+import org.apache.lucene.util.PriorityQueue;
 import org.apache.mahout.common.Pair;
 import org.apache.mahout.common.iterator.FileLineIterator;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
@@ -37,6 +35,15 @@ import org.apache.mahout.math.NamedVecto
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.map.OpenObjectIntHashMap;
 
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.regex.Pattern;
+
 public final class VectorHelper {
 
   private static final Pattern TAB_PATTERN = Pattern.compile("\t");
@@ -50,6 +57,77 @@ public final class VectorHelper {
     return bldr.toString();
   }
 
+  public static String buildJson(Iterable<Pair<String,Double>> iterable) {
+    return buildJson(iterable, new StringBuilder(2048));
+  }
+
+  public static String buildJson(Iterable<Pair<String,Double>> iterable, StringBuilder bldr) {
+    bldr.append("{");
+    Iterator<Pair<String, Double>> listIt = iterable.iterator();
+    while(listIt.hasNext()) {
+      Pair<String,Double> p = listIt.next();
+      bldr.append(p.getFirst());
+      bldr.append(":");
+      bldr.append(p.getSecond());
+      bldr.append(",");
+    }
+    if(bldr.length() > 1) {
+      bldr.setCharAt(bldr.length() - 1, '}');
+    }
+    return bldr.toString();
+  }
+
+  public static String vectorToSortedString(Vector vector, String[] dictionary) {
+    return vectorToJson(vector, dictionary, Integer.MAX_VALUE, true);
+  }
+
+  public static List<Pair<Integer, Double>> topEntries(Vector vector, int maxEntries) {
+    PriorityQueue<Pair<Integer,Double>> queue = new TDoublePQ<Integer>(-1, maxEntries);
+    Iterator<Vector.Element> it = vector.iterateNonZero();
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      queue.insertWithOverflow(Pair.of(e.index(), e.get()));
+    }
+    List<Pair<Integer, Double>> entries = Lists.newArrayList();
+    Pair<Integer, Double> pair = null;
+    while((pair = queue.pop()) != null) {
+      if(pair.getFirst() > -1) {
+        entries.add(pair);
+      }
+    }
+    Collections.sort(entries, Ordering.natural().reverse());
+    return entries;
+  }
+
+  public static List<Pair<Integer, Double>> firstEntries(Vector vector, int maxEntries) {
+    List<Pair<Integer, Double>> entries = Lists.newArrayList();
+    Iterator<Vector.Element> it = vector.iterateNonZero();
+    int i = 0;
+    while(it.hasNext() && i++ < maxEntries) {
+      Vector.Element e = it.next();
+      entries.add(Pair.of(e.index(), e.get()));
+    }
+    return entries;
+  }
+
+  public static List<Pair<String, Double>> toWeightedTerms(List<Pair<Integer, Double>> entries,
+      final String[] dictionary) {
+    return Lists.newArrayList(Collections2.transform(entries,
+          new Function<Pair<Integer, Double>, Pair<String, Double>>() {
+            @Override
+            public Pair<String, Double> apply(Pair<Integer, Double> p) {
+              return Pair.of(dictionary[p.getFirst()], p.getSecond());
+            }
+          }));
+  }
+
+  public static String vectorToJson(Vector vector, final String[] dictionary, int maxEntries,
+      boolean sort) {
+    return buildJson(toWeightedTerms(sort
+                                     ? topEntries(vector, maxEntries)
+                                     : firstEntries(vector, maxEntries), dictionary));
+  }
+
   public static void vectorToCSVString(Vector vector,
                                        boolean namesAsComments,
                                        Appendable bldr) throws IOException {
@@ -91,7 +169,8 @@ public final class VectorHelper {
   public static String[] loadTermDictionary(Configuration conf, String filePattern) {
     OpenObjectIntHashMap<String> dict = new OpenObjectIntHashMap<String>();
     for (Pair<Text,IntWritable> record :
-         new SequenceFileDirIterable<Text,IntWritable>(new Path(filePattern), PathType.GLOB, null, null, true, conf)) {
+         new SequenceFileDirIterable<Text,IntWritable>(new Path(filePattern), PathType.GLOB,
+             null, null, true, conf)) {
       dict.put(record.getFirst().toString(), record.getSecond().get());
     }
     String[] dictionary = new String[dict.size()];
@@ -128,4 +207,21 @@ public final class VectorHelper {
     }
     return result;
   }
+
+  private static class TDoublePQ<T> extends PriorityQueue<Pair<T, Double>> {
+    final T sentinel;
+    public TDoublePQ(T sentinel, int size) {
+      initialize(size);
+      this.sentinel = sentinel;
+    }
+    @Override
+    protected boolean lessThan(Pair<T, Double> a,
+        Pair<T, Double> b) {
+      return a.getSecond().compareTo(b.getSecond()) < 0;
+    }
+    @Override
+    protected Pair<T, Double> getSentinelObject() {
+      return Pair.of(sentinel, Double.NEGATIVE_INFINITY);
+    }
+  }
 }

Added: mahout/trunk/integration/src/test/java/org/apache/mahout/utils/vectors/VectorHelperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/utils/vectors/VectorHelperTest.java?rev=1209794&view=auto
==============================================================================
--- mahout/trunk/integration/src/test/java/org/apache/mahout/utils/vectors/VectorHelperTest.java (added)
+++ mahout/trunk/integration/src/test/java/org/apache/mahout/utils/vectors/VectorHelperTest.java Sat Dec  3 00:18:46 2011
@@ -0,0 +1,27 @@
+package org.apache.mahout.utils.vectors;
+
+import junit.framework.TestCase;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+public class VectorHelperTest extends TestCase {
+
+  public void testJsonFormatting() throws Exception {
+    Vector v = new SequentialAccessSparseVector(10);
+    v.set(2, 3.1);
+    v.set(4, 1.0);
+    v.set(6, 8.1);
+    v.set(7, -100);
+    v.set(9, 12.2);
+    String UNUSED = "UNUSED";
+    String[] dictionary = {
+        UNUSED, UNUSED, "two", UNUSED, "four", UNUSED, "six", "seven", UNUSED, "nine"
+    };
+
+    assertEquals("sorted json form incorrect: ", "{nine:12.2,six:8.1,two:3.1}",
+        VectorHelper.vectorToJson(v, dictionary, 3, true));
+    assertEquals("unsorted form incorrect: ", "{two:3.1,four:1.0}",
+        VectorHelper.vectorToJson(v, dictionary, 2, false));
+  }
+
+}

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java Sat Dec  3 00:18:46 2011
@@ -624,12 +624,16 @@ public abstract class AbstractVector imp
 
   @Override
   public String toString() {
+    return toString(null);
+  }
+
+  public String toString(String[] dictionary) {
     StringBuilder result = new StringBuilder();
     result.append('{');
     for (int index = 0; index < size; index++) {
       double value = getQuick(index);
       if (value != 0.0) {
-        result.append(index);
+        result.append(dictionary != null && dictionary.length > index ? dictionary[index] : index);
         result.append(':');
         result.append(value);
         result.append(',');

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/NamedVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/NamedVector.java?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/NamedVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/NamedVector.java Sat Dec  3 00:18:46 2011
@@ -17,11 +17,11 @@
 
 package org.apache.mahout.math;
 
-import java.util.Iterator;
-
 import org.apache.mahout.math.function.DoubleDoubleFunction;
 import org.apache.mahout.math.function.DoubleFunction;
 
+import java.util.Iterator;
+
 public class NamedVector implements Vector {
 
   private Vector delegate;

Modified: mahout/trunk/src/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1209794&r1=1209793&r2=1209794&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Sat Dec  3 00:18:46 2011
@@ -27,6 +27,8 @@ org.apache.mahout.clustering.fuzzykmeans
 org.apache.mahout.clustering.minhash.MinHashDriver = minhash : Run Minhash clustering
 org.apache.mahout.clustering.lda.LDADriver = lda : Latent Dirchlet Allocation
 org.apache.mahout.clustering.lda.LDAPrintTopics = ldatopics : LDA Print Topics
+org.apache.mahout.clustering.lda.cvb.CVB0Driver = cvb : LDA via Collapsed Variation Bayes (0th deriv. approx)
+org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0 = cvb0_local : LDA via Collapsed Variation Bayes, in memory locally.
 org.apache.mahout.clustering.dirichlet.DirichletDriver = dirichlet : Dirichlet Clustering
 org.apache.mahout.clustering.meanshift.MeanShiftCanopyDriver = meanshift : Mean Shift clustering
 org.apache.mahout.clustering.canopy.CanopyDriver = canopy : Canopy clustering
@@ -68,4 +70,4 @@ org.apache.mahout.cf.taste.hadoop.als.Re
 
 #Link Analysis
 org.apache.mahout.graph.linkanalysis.PageRankJob = pagerank : compute the PageRank of a graph
-org.apache.mahout.graph.linkanalysis.RandomWalkWithRestartJob = randomwalkwithrestart : compute all other vertices' proximity to a source vertex in a graph
\ No newline at end of file
+org.apache.mahout.graph.linkanalysis.RandomWalkWithRestartJob = randomwalkwithrestart : compute all other vertices' proximity to a source vertex in a graph



Mime
View raw message