mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r1336424 [2/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/classify/ core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ core/src/main/java/org...
Date Wed, 09 May 2012 22:02:52 GMT
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
Wed May  9 22:02:50 2012
@@ -37,7 +37,6 @@ import org.apache.hadoop.conf.Configurat
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
@@ -329,20 +328,6 @@ public class DisplayClustering extends F
     }
   }
   
-  protected static List<Cluster> readClusters(Path clustersIn) {
-    List<Cluster> clusters = Lists.newArrayList();
-    Configuration conf = new Configuration();
-    for (Cluster value : new SequenceFileDirValueIterable<Cluster>(clustersIn, PathType.LIST,
-        PathFilters.logsCRCFilter(), conf)) {
-      log.info(
-          "Reading Cluster:{} center:{} numPoints:{} radius:{}",
-          new Object[] {value.getId(), AbstractCluster.formatVector(value.getCenter(), null),
-              value.getNumObservations(), AbstractCluster.formatVector(value.getRadius(),
null)});
-      clusters.add(value);
-    }
-    return clusters;
-  }
-  
   protected static List<Cluster> readClustersWritable(Path clustersIn) {
     List<Cluster> clusters = Lists.newArrayList();
     Configuration conf = new Configuration();
@@ -358,15 +343,6 @@ public class DisplayClustering extends F
     return clusters;
   }
   
-  protected static void loadClusters(Path output) throws IOException {
-    Configuration conf = new Configuration();
-    FileSystem fs = FileSystem.get(output.toUri(), conf);
-    for (FileStatus s : fs.listStatus(output, new ClustersFilter())) {
-      List<Cluster> clusters = readClusters(s.getPath());
-      CLUSTERS.add(clusters);
-    }
-  }
-  
   protected static void loadClustersWritable(Path output) throws IOException {
     Configuration conf = new Configuration();
     FileSystem fs = FileSystem.get(output.toUri(), conf);
@@ -376,15 +352,6 @@ public class DisplayClustering extends F
     }
   }
   
-  protected static void loadClusters(Path output, PathFilter filter) throws IOException {
-    Configuration conf = new Configuration();
-    FileSystem fs = FileSystem.get(output.toUri(), conf);
-    for (FileStatus s : fs.listStatus(output, filter)) {
-      List<Cluster> clusters = readClusters(s.getPath());
-      CLUSTERS.add(clusters);
-    }
-  }
-  
   /**
    * Generate random samples and add them to the sampleData
    * 

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
Wed May  9 22:02:50 2012
@@ -22,28 +22,28 @@ import java.awt.Graphics2D;
 import java.io.IOException;
 import java.util.List;
 
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.Model;
 import org.apache.mahout.clustering.ModelDistribution;
 import org.apache.mahout.clustering.classify.ClusterClassifier;
-import org.apache.mahout.clustering.dirichlet.DirichletClusterer;
+import org.apache.mahout.clustering.dirichlet.DirichletDriver;
+import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
 import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
 import org.apache.mahout.clustering.iterator.ClusterIterator;
-import org.apache.mahout.clustering.iterator.ClusteringPolicy;
 import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy;
+import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.VectorWritable;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import com.google.common.collect.Lists;
 
 public class DisplayDirichlet extends DisplayClustering {
   
-  private static final Logger log = LoggerFactory.getLogger(DisplayDirichlet.class);
-  
   public DisplayDirichlet() {
     initialize();
     this.setTitle("Dirichlet Process Clusters - Normal Distribution (>" + (int) (significance
* 100)
@@ -57,50 +57,19 @@ public class DisplayDirichlet extends Di
     plotClusters((Graphics2D) g);
   }
   
-  protected static void printModels(Iterable<Cluster[]> result, int significant) {
-    int row = 0;
-    StringBuilder models = new StringBuilder(100);
-    for (Cluster[] r : result) {
-      models.append("sample[").append(row++).append("]= ");
-      for (int k = 0; k < r.length; k++) {
-        Cluster model = r[k];
-        if (model.getNumObservations() > significant) {
-          models.append('m').append(k).append(model.asFormatString(null)).append(", ");
-        }
-      }
-      models.append('\n');
-    }
-    models.append('\n');
-    log.info(models.toString());
-  }
-  
-  protected static void generateResults(ModelDistribution<VectorWritable> modelDist,
int numClusters,
-      int numIterations, double alpha0, int thin, int burnin) throws IOException {
-    boolean runClusterer = false;
+  protected static void generateResults(Path input, Path output,
+      ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations,
double alpha0, int thin, int burnin) throws IOException, ClassNotFoundException,
+      InterruptedException {
+    boolean runClusterer = true;
     if (runClusterer) {
-      runSequentialDirichletClusterer(modelDist, numClusters, numIterations, alpha0, thin,
burnin);
+      runSequentialDirichletClusterer(input, output, modelDist, numClusters, numIterations,
alpha0);
     } else {
-      runSequentialDirichletClassifier(modelDist, numClusters, numIterations, alpha0);
+      runSequentialDirichletClassifier(input, output, modelDist, numClusters, numIterations,
alpha0);
     }
-  }
-  
-  private static void runSequentialDirichletClassifier(ModelDistribution<VectorWritable>
modelDist, int numClusters,
-      int numIterations, double alpha0) throws IOException {
-    List<Cluster> models = Lists.newArrayList();
-    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) {
-      models.add((Cluster) cluster);
-    }
-    ClusterClassifier prior = new ClusterClassifier(models, new DirichletClusteringPolicy(numClusters,
alpha0));
-    Path samples = new Path("samples");
-    Path output = new Path("output");
-    Path priorPath = new Path(output, "clusters-0");
-    prior.writeToSeqFiles(priorPath);
-    
-    new ClusterIterator().iterateSeq(samples, priorPath, output, numIterations);
     for (int i = 1; i <= numIterations; i++) {
       ClusterClassifier posterior = new ClusterClassifier();
       String name = i == numIterations ? "clusters-" + i + "-final" : "clusters-" + i;
-      posterior.readFromSeqFiles(new Path(output, name));
+      posterior.readFromSeqFiles(new Configuration(), new Path(output, name));
       List<Cluster> clusters = Lists.newArrayList();
       for (Cluster cluster : posterior.getModels()) {
         if (isSignificant(cluster)) {
@@ -111,33 +80,47 @@ public class DisplayDirichlet extends Di
     }
   }
   
-  private static void runSequentialDirichletClusterer(ModelDistribution<VectorWritable>
modelDist, int numClusters,
-      int numIterations, double alpha0, int thin, int burnin) {
-    DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist, alpha0, numClusters,
thin, burnin);
-    List<Cluster[]> result = dc.cluster(numIterations);
-    printModels(result, burnin);
-    for (Cluster[] models : result) {
-      List<Cluster> clusters = Lists.newArrayList();
-      for (Cluster cluster : models) {
-        if (isSignificant(cluster)) {
-          clusters.add(cluster);
-        }
-      }
-      CLUSTERS.add(clusters);
+  private static void runSequentialDirichletClassifier(Path input, Path output,
+      ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations,
double alpha0)
+      throws IOException {
+    List<Cluster> models = Lists.newArrayList();
+    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) {
+      models.add((Cluster) cluster);
     }
+    ClusterClassifier prior = new ClusterClassifier(models, new DirichletClusteringPolicy(numClusters,
alpha0));
+    Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+    prior.writeToSeqFiles(priorPath);
+    Configuration conf = new Configuration();
+    new ClusterIterator().iterateSeq(conf, input, priorPath, output, numIterations);
+  }
+  
+  private static void runSequentialDirichletClusterer(Path input, Path output,
+      ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations,
double alpha0)
+      throws IOException, ClassNotFoundException, InterruptedException {
+    DistributionDescription description = new DistributionDescription(modelDist.getClass().getName(),
+        RandomAccessSparseVector.class.getName(), ManhattanDistanceMeasure.class.getName(),
2);
+    
+    DirichletDriver.run(new Configuration(), input, output, description, numClusters, numIterations,
alpha0, true,
+        true, 0, false);
   }
   
   public static void main(String[] args) throws Exception {
     VectorWritable modelPrototype = new VectorWritable(new DenseVector(2));
     ModelDistribution<VectorWritable> modelDist = new GaussianClusterDistribution(modelPrototype);
+    Configuration conf = new Configuration();
+    Path output = new Path("output");
+    HadoopUtil.delete(conf, output);
+    Path samples = new Path("samples");
+    HadoopUtil.delete(conf, samples);
     RandomUtils.useTestSeed();
     generateSamples();
+    writeSampleData(samples);
     int numIterations = 20;
     int numClusters = 10;
     int alpha0 = 1;
     int thin = 3;
     int burnin = 5;
-    generateResults(modelDist, numClusters, numIterations, alpha0, thin, burnin);
+    generateResults(samples, output, modelDist, numClusters, numIterations, alpha0, thin,
burnin);
     new DisplayDirichlet();
   }
   

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
Wed May  9 22:02:50 2012
@@ -60,12 +60,12 @@ public class DisplayFuzzyKMeans extends 
     Path samples = new Path("samples");
     Path output = new Path("output");
     Configuration conf = new Configuration();
-    HadoopUtil.delete(conf, samples);
     HadoopUtil.delete(conf, output);
+    HadoopUtil.delete(conf, samples);
     RandomUtils.useTestSeed();
     DisplayClustering.generateSamples();
     writeSampleData(samples);
-    boolean runClusterer = false;
+    boolean runClusterer = true;
     int maxIterations = 10;
     float threshold = 0.001F;
     float m = 1.1F;
@@ -93,16 +93,17 @@ public class DisplayFuzzyKMeans extends 
     Path priorPath = new Path(output, "classifier-0");
     prior.writeToSeqFiles(priorPath);
     
-    new ClusterIterator().iterateSeq(samples, priorPath, output, maxIterations);
+    new ClusterIterator().iterateSeq(conf, samples, priorPath, output, maxIterations);
     loadClustersWritable(output);
   }
   
   private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path
output,
       DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException,
       ClassNotFoundException, InterruptedException {
-    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(output, "clusters-0"),
3, measure);
-    FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold, maxIterations, m,
true, true, threshold, true);
+    Path clustersIn = new Path(output, "random-seeds");
+    RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure);
+    FuzzyKMeansDriver.run(samples, clustersIn, output, measure, threshold, maxIterations,
m, true, true, threshold, true);
     
-    loadClusters(output);
+    loadClustersWritable(output);
   }
 }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
Wed May  9 22:02:50 2012
@@ -55,9 +55,9 @@ public class DisplayKMeans extends Displ
     HadoopUtil.delete(conf, output);
     
     RandomUtils.useTestSeed();
-    DisplayClustering.generateSamples();
+    generateSamples();
     writeSampleData(samples);
-    boolean runClusterer = false;
+    boolean runClusterer = true;
     double convergenceDelta = 0.001;
     if (runClusterer) {
       int numClusters = 3;
@@ -81,20 +81,21 @@ public class DisplayKMeans extends Displ
       initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure));
     }
     ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta));
-    Path priorPath = new Path(output, "clusters-0");
+    Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
     prior.writeToSeqFiles(priorPath);
     
     int maxIter = 10;
-    new ClusterIterator().iterateSeq(samples, priorPath, output, maxIter);
+    new ClusterIterator().iterateSeq(conf, samples, priorPath, output, maxIter);
     loadClustersWritable(output);
   }
   
   private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path
output,
       DistanceMeasure measure, int maxIterations, double convergenceDelta) throws IOException,
InterruptedException,
       ClassNotFoundException {
-    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(output, "clusters-0"),
3, measure);
-    KMeansDriver.run(samples, clusters, output, measure, convergenceDelta, maxIterations,
true, 0.0, true);
-    loadClusters(output);
+    Path clustersIn = new Path(output, "random-seeds");
+    RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure);
+    KMeansDriver.run(samples, clustersIn, output, measure, convergenceDelta, maxIterations,
true, 0.0, true);
+    loadClustersWritable(output);
   }
   
   // Override the paint() method

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java
Wed May  9 22:02:50 2012
@@ -111,7 +111,7 @@ public class DisplayMeanShift extends Di
     // if (b) {
     MeanShiftCanopyDriver.run(conf, samples, output, measure, kernelProfile,
         t1, t2, 0.005, 20, false, true, true);
-    loadClusters(output);
+    loadClustersWritable(output);
     // } else {
     // Collection<Vector> points = new ArrayList<Vector>();
     // for (VectorWritable sample : SAMPLE_DATA) {

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
Wed May  9 22:02:50 2012
@@ -139,16 +139,8 @@ public final class Job extends AbstractJ
           throws Exception{
     Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT);
     InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector");
-    DirichletDriver.run(directoryContainingConvertedInput,
-                        output,
-                        description,
-                        numModels,
-                        maxIterations,
-                        alpha0,
-                        true,
-                        emitMostLikely,
-                        threshold,
-                        false);
+    DirichletDriver.run(new Configuration(), directoryContainingConvertedInput, output, description,
numModels, maxIterations, alpha0, true,
+    emitMostLikely, threshold, false);
     // run ClusterDumper
     ClusterDumper clusterDumper =
         new ClusterDumper(new Path(output, "clusters-" + maxIterations), new Path(output,
"clusteredPoints"));

Modified: mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java
(original)
+++ mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java
Wed May  9 22:02:50 2012
@@ -412,8 +412,8 @@ public final class TestClusterEvaluator 
     DistributionDescription description = new DistributionDescription(
         GaussianClusterDistribution.class.getName(),
         DenseVector.class.getName(), null, 2);
-    DirichletDriver.run(testdata, output, description, 15, 5, 1.0, true, true,
-        0, true);
+    DirichletDriver.run(new Configuration(), testdata, output, description, 15, 5, 1.0, true,
+    true, (double) 0, true);
     int numIterations = 10;
     Configuration conf = new Configuration();
     Path clustersIn = new Path(output, "clusters-5-final");

Modified: mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
(original)
+++ mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
Wed May  9 22:02:50 2012
@@ -427,8 +427,8 @@ public final class TestCDbwEvaluator ext
     DistributionDescription description = new DistributionDescription(
         GaussianClusterDistribution.class.getName(),
         DenseVector.class.getName(), null, 2);
-    DirichletDriver.run(testdata, output, description, 15, 5, 1.0, true, true,
-        0, true);
+    DirichletDriver.run(new Configuration(), testdata, output, description, 15, 5, 1.0, true,
+    true, (double) 0, true);
     int numIterations = 10;
     Path clustersIn = new Path(output, "clusters-0");
     RepresentativePointsDriver.run(conf, clustersIn, new Path(output,

Modified: mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java?rev=1336424&r1=1336423&r2=1336424&view=diff
==============================================================================
--- mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
(original)
+++ mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
Wed May  9 22:02:50 2012
@@ -23,8 +23,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 
-import com.google.common.collect.Lists;
-import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
 import org.apache.lucene.analysis.standard.StandardAnalyzer;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
@@ -35,8 +34,16 @@ import org.apache.lucene.store.RAMDirect
 import org.apache.lucene.util.Version;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.Model;
+import org.apache.mahout.clustering.ModelDistribution;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
 import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
 import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy;
+import org.apache.mahout.common.distance.CosineDistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.utils.MahoutTestCase;
@@ -49,6 +56,9 @@ import org.apache.mahout.vectorizer.TFID
 import org.apache.mahout.vectorizer.Weight;
 import org.junit.Test;
 
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
 public final class TestL1ModelClustering extends MahoutTestCase {
   
   private class MapElement implements Comparable<MapElement> {
@@ -118,7 +128,7 @@ public final class TestL1ModelClustering
       "The robber wore a white fleece jacket and a baseball cap.",
       "The English Springer Spaniel is the best of all dogs."};
   
-  private List<VectorWritable> sampleData;
+  private List<Vector> sampleData;
   
   private void getSampleData(String[] docs2) throws IOException {
     sampleData = Lists.newArrayList();
@@ -148,11 +158,11 @@ public final class TestL1ModelClustering
     for (Vector vector : iterable) {
       assertNotNull(vector);
       System.out.println("Vector[" + i++ + "]=" + formatVector(vector));
-      sampleData.add(new VectorWritable(vector));
+      sampleData.add(vector);
     }
   }
   
-  private static String formatVector(Vector v) {
+  private String formatVector(Vector v) {
     StringBuilder buf = new StringBuilder();
     int nzero = 0;
     Iterator<Vector.Element> iterateNonZero = v.iterateNonZero();
@@ -179,7 +189,7 @@ public final class TestL1ModelClustering
     return buf.toString();
   }
   
-  private static void printSamples(Iterable<Cluster[]> result, int significant) {
+  private void printSamples(Iterable<Cluster[]> result, int significant) {
     int row = 0;
     for (Cluster[] r : result) {
       int sig = 0;
@@ -199,19 +209,19 @@ public final class TestL1ModelClustering
     System.out.println();
   }
   
-  private void printClusters(Model<VectorWritable>[] models, List<VectorWritable>
samples, String[] docs) {
-    for (int m = 0; m < models.length; m++) {
-      Model<VectorWritable> model = models[m];
+  private void printClusters(List<Cluster> models, String[] docs) {
+    for (int m = 0; m < models.size(); m++) {
+      Cluster model = models.get(m);
       long count = model.getNumObservations();
       if (count == 0) {
         continue;
       }
-      System.out.println("Model[" + m + "] had " + count + " hits (!) and " + (samples.size()
- count)
+      System.out.println("Model[" + m + "] had " + count + " hits (!) and " + (sampleData.size()
- count)
           + " misses (? in pdf order) during the last iteration:");
-      MapElement[] map = new MapElement[samples.size()];
+      MapElement[] map = new MapElement[sampleData.size()];
       // sort the samples by pdf
-      for (int i = 0; i < samples.size(); i++) {
-        VectorWritable sample = samples.get(i);
+      for (int i = 0; i < sampleData.size(); i++) {
+        VectorWritable sample = new VectorWritable(sampleData.get(i));
         map[i] = new MapElement(model.pdf(sample), docs[i]);
       }
       Arrays.sort(map);
@@ -230,45 +240,81 @@ public final class TestL1ModelClustering
   @Test
   public void testDocs() throws Exception {
     getSampleData(DOCS);
-    DirichletClusterer dc = new DirichletClusterer(sampleData, new GaussianClusterDistribution(sampleData.get(0)),
1.0,
-        15, 1, 0);
-    List<Cluster[]> result = dc.cluster(10);
-    assertNotNull(result);
-    printSamples(result, 0);
-    printClusters(result.get(result.size() - 1), sampleData, DOCS);
+    DistributionDescription description = new DistributionDescription(GaussianClusterDistribution.class.getName(),
+        RandomAccessSparseVector.class.getName(), ManhattanDistanceMeasure.class.getName(),
sampleData.get(0).size());
+    
+    List<Cluster> models = Lists.newArrayList();
+    ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new
Configuration());
+    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) {
+      models.add((Cluster) cluster);
+    }
+    
+    ClusterIterator iterator = new ClusterIterator();
+    ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15,
1.0));
+    ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10);
+    
+    printClusters(posterior.getModels(), DOCS);
   }
   
   @Test
   public void testDMDocs() throws Exception {
+    
     getSampleData(DOCS);
-    DirichletClusterer dc = new DirichletClusterer(sampleData,
-        new DistanceMeasureClusterDistribution(sampleData.get(0)), 1.0, 15, 1, 0);
-    List<Cluster[]> result = dc.cluster(10);
-    assertNotNull(result);
-    printSamples(result, 0);
-    printClusters(result.get(result.size() - 1), sampleData, DOCS);
+    DistributionDescription description = new DistributionDescription(
+        DistanceMeasureClusterDistribution.class.getName(), RandomAccessSparseVector.class.getName(),
+        CosineDistanceMeasure.class.getName(), sampleData.get(0).size());
+    
+    List<Cluster> models = Lists.newArrayList();
+    ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new
Configuration());
+    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) {
+      models.add((Cluster) cluster);
+    }
+    
+    ClusterIterator iterator = new ClusterIterator();
+    ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15,
1.0));
+    ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10);
+    
+    printClusters(posterior.getModels(), DOCS);
   }
   
   @Test
   public void testDocs2() throws Exception {
     getSampleData(DOCS2);
-    DirichletClusterer dc = new DirichletClusterer(sampleData, new GaussianClusterDistribution(sampleData.get(0)),
1.0,
-        15, 1, 0);
-    List<Cluster[]> result = dc.cluster(10);
-    assertNotNull(result);
-    printSamples(result, 0);
-    printClusters(result.get(result.size() - 1), sampleData, DOCS2);
+    DistributionDescription description = new DistributionDescription(GaussianClusterDistribution.class.getName(),
+        RandomAccessSparseVector.class.getName(), ManhattanDistanceMeasure.class.getName(),
sampleData.get(0).size());
+    
+    List<Cluster> models = Lists.newArrayList();
+    ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new
Configuration());
+    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) {
+      models.add((Cluster) cluster);
+    }
+    
+    ClusterIterator iterator = new ClusterIterator();
+    ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15,
1.0));
+    ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10);
+    
+    printClusters(posterior.getModels(), DOCS2);
   }
   
   @Test
   public void testDMDocs2() throws Exception {
-    getSampleData(DOCS2);
-    DirichletClusterer dc = new DirichletClusterer(sampleData,
-        new DistanceMeasureClusterDistribution(sampleData.get(0)), 1.0, 15, 1, 0);
-    List<Cluster[]> result = dc.cluster(10);
-    assertNotNull(result);
-    printSamples(result, 0);
-    printClusters(result.get(result.size() - 1), sampleData, DOCS2);
+    
+    getSampleData(DOCS);
+    DistributionDescription description = new DistributionDescription(
+        DistanceMeasureClusterDistribution.class.getName(), RandomAccessSparseVector.class.getName(),
+        CosineDistanceMeasure.class.getName(), sampleData.get(0).size());
+    
+    List<Cluster> models = Lists.newArrayList();
+    ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new
Configuration());
+    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) {
+      models.add((Cluster) cluster);
+    }
+    
+    ClusterIterator iterator = new ClusterIterator();
+    ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15,
1.0));
+    ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10);
+    
+    printClusters(posterior.getModels(), DOCS2);
   }
   
 }



Mime
View raw message