mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r907842 - in /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet: DirichletMapper.java DirichletState.java UncommonDistributions.java
Date Mon, 08 Feb 2010 23:11:51 GMT
Author: jeastman
Date: Mon Feb  8 23:11:51 2010
New Revision: 907842

URL: http://svn.apache.org/viewvc?rev=907842&view=rev
Log:
MAHOUT-276

- added alpha_0 parameter to rDirichlet and incorporated into rBeta arguments
- passed alpha_0 argument in DirichletMapper and DirichletState calls to rDirichlet
- removed totalCount = alpha_0/k initialization in DirichletState

all tests still run and seem to produce reasonable outputs

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=907842&r1=907841&r2=907842&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
Mon Feb  8 23:11:51 2010
@@ -86,9 +86,10 @@
     String alpha_0 = job.get(DirichletDriver.ALPHA_0_KEY);
 
     try {
+      double alpha = Double.parseDouble(alpha_0);
       DirichletState<VectorWritable> state = DirichletDriver.createState(
           modelFactory, modelPrototype, Integer.parseInt(prototypeSize),
-          Integer.parseInt(numClusters), Double.parseDouble(alpha_0));
+          Integer.parseInt(numClusters), alpha);
       Path path = new Path(statePath);
       FileSystem fs = FileSystem.get(path.toUri(), job);
       FileStatus[] status = fs.listStatus(path, new OutputLogFilter());
@@ -108,7 +109,7 @@
         }
       }
       // TODO: with more than one mapper, they will all have different mixtures. Will this
matter?
-      state.setMixture(UncommonDistributions.rDirichlet(state.totalCounts()));
+      state.setMixture(UncommonDistributions.rDirichlet(state.totalCounts(), alpha));
       return state;
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java?rev=907842&r1=907841&r2=907842&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
Mon Feb  8 23:11:51 2010
@@ -35,21 +35,20 @@
 
   private Vector mixture; // the mixture vector
 
-  private double offset; // alpha_0 / numClusters
+  private double alpha_0; // alpha_0
 
   public DirichletState(ModelDistribution<O> modelFactory,
                         int numClusters, double alpha_0, int thin, int burnin) {
     this.numClusters = numClusters;
     this.modelFactory = modelFactory;
-    // initialize totalCounts
-    offset = alpha_0 / numClusters;
+    this.alpha_0 = alpha_0;
     // sample initial prior models
     clusters = new ArrayList<DirichletCluster<O>>();
     for (Model<O> m : modelFactory.sampleFromPrior(numClusters)) {
-      clusters.add(new DirichletCluster<O>(m, offset));
+      clusters.add(new DirichletCluster<O>(m, 0.0));
     }
     // sample the mixture parameters from a Dirichlet distribution on the totalCounts 
-    mixture = UncommonDistributions.rDirichlet(totalCounts());
+    mixture = UncommonDistributions.rDirichlet(totalCounts(), alpha_0);
   }
 
   public DirichletState() {
@@ -87,14 +86,6 @@
     this.mixture = mixture;
   }
 
-  public double getOffset() {
-    return offset;
-  }
-
-  public void setOffset(double offset) {
-    this.offset = offset;
-  }
-
   public Vector totalCounts() {
     Vector result = new DenseVector(numClusters);
     for (int i = 0; i < numClusters; i++) {
@@ -115,7 +106,7 @@
       clusters.get(i).setModel(newModels[i]);
     }
     // update the mixture
-    mixture = UncommonDistributions.rDirichlet(totalCounts());
+    mixture = UncommonDistributions.rDirichlet(totalCounts(), alpha_0);
   }
 
   /**
@@ -131,6 +122,7 @@
     return mix * pdf;
   }
 
+  @SuppressWarnings("unchecked")
   public Model<O>[] getModels() {
     Model<O>[] result = (Model<O>[]) new Model[numClusters];
     for (int i = 0; i < numClusters; i++) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java?rev=907842&r1=907841&r2=907842&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java
Mon Feb  8 23:11:51 2010
@@ -228,25 +228,26 @@
   }
 
   /**
-   * Sample from a Dirichlet distribution over the given alpha, returning a vector of probabilities
using a
+   * Sample from a Dirichlet distribution, returning a vector of probabilities using a
    * stick-breaking algorithm
    *
-   * @param alpha an unnormalized count Vector
+   * @param totalCounts an unnormalized count Vector
+   * @param alpha_0 a double
    * @return a Vector of probabilities
    */
-  public static Vector rDirichlet(Vector alpha) {
-    Vector r = alpha.like();
-    double total = alpha.zSum();
-    double remainder = 1;
-    for (int i = 0; i < r.size(); i++) {
-      double a = alpha.get(i);
-      total -= a;
-      double beta = rBeta(a, Math.max(0, total));
-      double p = beta * remainder;
-      r.set(i, p);
-      remainder -= p;
+  public static Vector rDirichlet(Vector totalCounts, double alpha_0) {
+    Vector pi = totalCounts.like();
+    double total = totalCounts.zSum();
+    double remainder = 1.0;
+    for (int k = 0; k < pi.size(); k++) {
+      double countK = totalCounts.get(k);
+      total -= countK;
+      double betaK = rBeta(1 + countK, Math.max(0, alpha_0 + total));
+      double piK = betaK * remainder;
+      pi.set(k, piK);
+      remainder -= piK;
     }
-    return r;
+    return pi;
   }
 
 }



Mime
View raw message