mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r1100858 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/ core/src/test/java/org/apache/mahout/clustering/ examples/src/main/java/org/apache/mahout/clustering/display/
Date Mon, 09 May 2011 01:49:20 GMT
Author: jeastman
Date: Mon May  9 01:49:20 2011
New Revision: 1100858

URL: http://svn.apache.org/viewvc?rev=1100858&view=rev
Log:
MAHOUT-479: added a new iterate method to ClusterIterator. Method accepts 3
hadoop Paths for input, prior and output information plus number of desired iterations. All
algorithm data is pulled-from/pushed-to SequenceFiles. Added a unit test and improved the
example DisplayKMeans, DisplayFuzzyKMeans and DisplayDirichlet to use the new file-based implementation.
Check out Dirichlet.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java Mon
May  9 01:49:20 2011
@@ -16,10 +16,21 @@
  */
 package org.apache.mahout.clustering;
 
+import java.io.IOException;
 import java.util.Iterator;
 import java.util.List;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 
 /**
  * This is an experimental clustering iterator which works with a
@@ -50,8 +61,7 @@ public class ClusterIterator {
    * @return the posterior ClusterClassifier
    */
   public ClusterClassifier iterate(List<Vector> data,
-                                   ClusterClassifier classifier,
-                                   int numIterations) {
+      ClusterClassifier classifier, int numIterations) {
     for (int iteration = 1; iteration <= numIterations; iteration++) {
       for (Vector vector : data) {
         // classification yields probabilities
@@ -59,7 +69,8 @@ public class ClusterIterator {
         // policy selects weights for models given those probabilities
         Vector weights = policy.select(probabilities);
         // training causes all models to observe data
-        for (Iterator<Vector.Element> it = weights.iterateNonZero(); it.hasNext();)
{
+        for (Iterator<Vector.Element> it = weights.iterateNonZero(); it
+            .hasNext();) {
           int index = it.next().index();
           classifier.train(index, vector, weights.get(index));
         }
@@ -71,4 +82,69 @@ public class ClusterIterator {
     }
     return classifier;
   }
+  
+  /**
+   * Iterate over data using a prior-trained ClusterClassifier, for a number of
+   * iterations
+   * 
+   * @param inPath
+   *          a Path to input VectorWritables
+   * @param priorPath
+   *          a Path to the prior classifier
+   * @param outPath
+   *          a Path of output directory
+   * @param numIterations
+   *          the int number of iterations to perform
+   * @throws IOException
+   */
+  public void iterate(Path inPath, Path priorPath, Path outPath,
+      int numIterations) throws IOException {
+    ClusterClassifier classifier = readClassifier(priorPath);
+    Configuration conf = new Configuration();
+    for (int iteration = 1; iteration <= numIterations; iteration++) {
+      for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
+          inPath, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
+        Vector vector = vw.get();
+        // classification yields probabilities
+        Vector probabilities = classifier.classify(vector);
+        // policy selects weights for models given those probabilities
+        Vector weights = policy.select(probabilities);
+        // training causes all models to observe data
+        for (Iterator<Vector.Element> it = weights.iterateNonZero(); it
+            .hasNext();) {
+          int index = it.next().index();
+          classifier.train(index, vector, weights.get(index));
+        }
+      }
+      // compute the posterior models
+      classifier.close();
+      // update the policy
+      policy.update(classifier);
+      // output the classifier
+      writeClassifier(classifier, new Path(outPath, "classifier-" + iteration),
+          String.valueOf(iteration));
+    }
+  }
+  
+  private void writeClassifier(ClusterClassifier classifier, Path outPath, String k)
+      throws IOException {
+    Configuration config = new Configuration();
+    FileSystem fs = FileSystem.get(outPath.toUri(), config);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, outPath,
+        Text.class, ClusterClassifier.class);
+    Writable key = new Text(k);
+    writer.append(key, classifier);
+    writer.close();
+  }
+  
+  private ClusterClassifier readClassifier(Path inPath) throws IOException {
+    Configuration config = new Configuration();
+    FileSystem fs = FileSystem.get(inPath.toUri(), config);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, inPath, config);
+    Writable key = new Text();
+    ClusterClassifier classifierOut = new ClusterClassifier();
+    reader.next(key, classifierOut);
+    reader.close();
+    return classifierOut;
+  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
Mon May  9 01:49:20 2011
@@ -38,6 +38,7 @@ import org.apache.mahout.common.distance
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 import org.junit.Test;
 
 public final class TestClusterClassifier extends MahoutTestCase {
@@ -94,12 +95,22 @@ public final class TestClusterClassifier
     Configuration config = new Configuration();
     Path path = new Path(getTestTempDirPath(), "output");
     FileSystem fs = FileSystem.get(path.toUri(), config);
+    writeClassifier(classifier, config, path, fs);
+    return readClassifier(config, path, fs);
+  }
+  
+  private void writeClassifier(ClusterClassifier classifier,
+      Configuration config, Path path, FileSystem fs) throws IOException {
     SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, path,
         Text.class, ClusterClassifier.class);
     Writable key = new Text("test");
     writer.append(key, classifier);
     writer.close();
-    
+  }
+  
+  private ClusterClassifier readClassifier(Configuration config, Path path,
+      FileSystem fs) throws IOException {
+    Writable key;
     SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
     key = new Text();
     ClusterClassifier classifierOut = new ClusterClassifier();
@@ -232,11 +243,10 @@ public final class TestClusterClassifier
     ClusterClassifier posterior = iterator.iterate(data, prior, 5);
     assertEquals(3, posterior.getModels().size());
     for (Cluster cluster : posterior.getModels()) {
-      System.out
-          .println(cluster.asFormatString(null));
+      System.out.println(cluster.asFormatString(null));
     }
   }
-
+  
   @Test
   public void testClusterIteratorDirichlet() {
     List<Vector> data = TestKmeansClustering
@@ -247,8 +257,42 @@ public final class TestClusterClassifier
     ClusterClassifier posterior = iterator.iterate(data, prior, 5);
     assertEquals(3, posterior.getModels().size());
     for (Cluster cluster : posterior.getModels()) {
-      System.out
-          .println(cluster.asFormatString(null));
+      System.out.println(cluster.asFormatString(null));
+    }
+  }
+  
+  @Test
+  public void testSeqFileClusterIteratorKMeans() throws IOException {
+    Path pointsPath = getTestTempDirPath("points");
+    Path priorPath = getTestTempDirPath("prior");
+    Path outPath = getTestTempDirPath("output");
+    Configuration conf = new Configuration();
+    FileSystem fs = FileSystem.get(conf);
+    List<VectorWritable> points = TestKmeansClustering
+        .getPointsWritable(TestKmeansClustering.REFERENCE);
+    ClusteringTestUtils.writePointsToFile(points,
+        new Path(pointsPath, "file1"), fs, conf);
+    Path path = new Path(priorPath, "priorClassifier");
+    ClusterClassifier prior = newClusterClassifier();
+    writeClassifier(prior, conf, path, fs);
+    assertEquals(3, prior.getModels().size());
+    System.out.println("Prior");
+    for (Cluster cluster : prior.getModels()) {
+      System.out.println(cluster.asFormatString(null));
+    }
+    ClusteringPolicy policy = new KMeansClusteringPolicy();
+    ClusterIterator iterator = new ClusterIterator(policy);
+    iterator.iterate(pointsPath, path, outPath, 5);
+    
+    for (int i = 1; i <= 5; i++) {
+      System.out.println("Classifier-" + i);
+      ClusterClassifier posterior = readClassifier(conf, new Path(outPath,
+          "classifier-" + i), fs);
+      assertEquals(3, posterior.getModels().size());
+      for (Cluster cluster : posterior.getModels()) {
+        System.out.println(cluster.asFormatString(null));
+      }
+      
     }
   }
 }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
Mon May  9 01:49:20 2011
@@ -39,8 +39,10 @@ import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
 import org.apache.mahout.clustering.AbstractCluster;
 import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusterClassifier;
 import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
@@ -297,4 +299,26 @@ public class DisplayClustering extends F
   protected static boolean isSignificant(Cluster cluster) {
     return (double) cluster.getNumPoints() / SAMPLE_DATA.size() > significance;
   }
+
+  protected static ClusterClassifier readClassifier(Configuration config, Path path)
+      throws IOException {
+        Writable key;
+        SequenceFile.Reader reader = new SequenceFile.Reader(
+            FileSystem.get(config), path, config);
+        key = new Text();
+        ClusterClassifier classifierOut = new ClusterClassifier();
+        reader.next(key, classifierOut);
+        reader.close();
+        return classifierOut;
+      }
+
+  protected static void writeClassifier(ClusterClassifier classifier, Configuration config,
Path path)
+      throws IOException {
+        SequenceFile.Writer writer = new SequenceFile.Writer(
+            FileSystem.get(config), config, path, Text.class,
+            ClusterClassifier.class);
+        Writable key = new Text("test");
+        writer.append(key, classifier);
+        writer.close();
+      }
 }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
Mon May  9 01:49:20 2011
@@ -19,10 +19,12 @@ package org.apache.mahout.clustering.dis
 
 import java.awt.Graphics;
 import java.awt.Graphics2D;
+import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Iterator;
 import java.util.List;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.ClusterClassifier;
 import org.apache.mahout.clustering.ClusterIterator;
@@ -34,7 +36,6 @@ import org.apache.mahout.clustering.diri
 import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -77,44 +78,63 @@ public class DisplayDirichlet extends Di
   
   protected static void generateResults(
       ModelDistribution<VectorWritable> modelDist, int numClusters,
-      int numIterations, double alpha0, int thin, int burnin) {
-    boolean b = false;
-    if (b) {
-      DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
-          alpha0, numClusters, thin, burnin);
-      List<Cluster[]> result = dc.cluster(numIterations);
-      printModels(result, burnin);
-      for (Cluster[] models : result) {
-        List<Cluster> clusters = new ArrayList<Cluster>();
-        for (Cluster cluster : models) {
-          if (isSignificant(cluster)) {
-            clusters.add(cluster);
-          }
-        }
-        CLUSTERS.add(clusters);
-      }
+      int numIterations, double alpha0, int thin, int burnin)
+      throws IOException {
+    boolean runClusterer = false;
+    if (runClusterer) {
+      runSequentialDirichletClusterer(modelDist, numClusters, numIterations, alpha0,
+          thin, burnin);
     } else {
-      List<Vector> points = new ArrayList<Vector>();
-      for (VectorWritable sample : SAMPLE_DATA) {
-        points.add(sample.get());
-      }
-      ClusteringPolicy policy = new DirichletClusteringPolicy(numClusters,
-          numIterations);
-      List<Cluster> models = new ArrayList<Cluster>();
-      for (Model<VectorWritable> cluster : modelDist
-          .sampleFromPrior(numClusters)) {
-        models.add((Cluster) cluster);
+      runSequentialDirichletClassifier(modelDist, numClusters, numIterations);
+    }
+  }
+  
+  private static void runSequentialDirichletClassifier(
+      ModelDistribution<VectorWritable> modelDist, int numClusters,
+      int numIterations) throws IOException {
+    List<Cluster> models = new ArrayList<Cluster>();
+    for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) {
+      models.add((Cluster) cluster);
+    }
+    ClusterClassifier prior = new ClusterClassifier(models);
+    Path samples = new Path("samples");
+    Path output = new Path("output");
+    Path priorClassifier = new Path(output, "clusters-0");
+    Configuration conf = new Configuration();
+    writeClassifier(prior, conf, priorClassifier);
+    
+    ClusteringPolicy policy = new DirichletClusteringPolicy(numClusters,
+        numIterations);
+    new ClusterIterator(policy).iterate(samples, priorClassifier, output,
+        numIterations);
+    for (int i = 1; i <= numIterations; i++) {
+      ClusterClassifier posterior = readClassifier(conf, new Path(output,
+          "classifier-" + i));
+      List<Cluster> clusters = new ArrayList<Cluster>();    
+      for (Cluster cluster : posterior.getModels()) {
+        if (isSignificant(cluster)) {
+          clusters.add(cluster);
+        }
       }
-      ClusterClassifier prior = new ClusterClassifier(models);
-      ClusterIterator iterator = new ClusterIterator(policy);
-      ClusterClassifier posterior = iterator.iterate(points, prior, 5);
-      List<Cluster> models2 = posterior.getModels();
-      for (Iterator<Cluster> it = models2.iterator(); it.hasNext();) {
-        if (!isSignificant(it.next())) {
-          it.remove();
+      CLUSTERS.add(clusters);
+    }
+  }
+  
+  private static void runSequentialDirichletClusterer(
+      ModelDistribution<VectorWritable> modelDist, int numClusters,
+      int numIterations, double alpha0, int thin, int burnin) {
+    DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
+        alpha0, numClusters, thin, burnin);
+    List<Cluster[]> result = dc.cluster(numIterations);
+    printModels(result, burnin);
+    for (Cluster[] models : result) {
+      List<Cluster> clusters = new ArrayList<Cluster>();
+      for (Cluster cluster : models) {
+        if (isSignificant(cluster)) {
+          clusters.add(cluster);
         }
       }
-      CLUSTERS.add(models2);
+      CLUSTERS.add(clusters);
     }
   }
   

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
Mon May  9 01:49:20 2011
@@ -19,6 +19,7 @@ package org.apache.mahout.clustering.dis
 
 import java.awt.Graphics;
 import java.awt.Graphics2D;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -37,7 +38,6 @@ import org.apache.mahout.common.RandomUt
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
 
 class DisplayFuzzyKMeans extends DisplayClustering {
   
@@ -59,45 +59,61 @@ class DisplayFuzzyKMeans extends Display
     
     Path samples = new Path("samples");
     Path output = new Path("output");
+    int numClusters = 3;
+    int maxIterations = 10;
     Configuration conf = new Configuration();
     HadoopUtil.delete(conf, samples);
     HadoopUtil.delete(conf, output);
     RandomUtils.useTestSeed();
     DisplayClustering.generateSamples();
-    boolean b = false;
-    if (b) {
-      writeSampleData(samples);
-      Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
-          output, "clusters-0"), 3, measure);
-      double threshold = 0.001;
-      int numIterations = 10;
-      int m = 3;
-      FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold,
-          numIterations, m, true, true, threshold, true);
-      
-      loadClusters(output);
+    writeSampleData(samples);
+    boolean runClusterer = false;
+    if (runClusterer) {
+      runSequentialFuzzyKClusterer(conf, samples, output, measure, numClusters,
+          maxIterations);
     } else {
-      List<Vector> points = new ArrayList<Vector>();
-      for (VectorWritable sample : SAMPLE_DATA) {
-        points.add(sample.get());
-      }
-      List<Cluster> initialClusters = new ArrayList<Cluster>();
-      int id = 0;
-      int numClusters = 4;
-      for (Vector point : points) {
-        if (initialClusters.size() < Math.min(numClusters, points.size())) {
-          initialClusters.add(new SoftCluster(point, id++, measure));
-        } else {
-          break;
-        }
-      }
-      
-      ClusterClassifier prior = new ClusterClassifier(initialClusters);
-      ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy();
-      ClusterClassifier posterior = new ClusterIterator(policy).iterate(points,
-          prior, 10);
-      CLUSTERS.add(posterior.getModels());
+      runSequentialFuzzyKClassifier(conf, samples, output, measure,
+          numClusters, maxIterations);
     }
     new DisplayFuzzyKMeans();
   }
+  
+  private static void runSequentialFuzzyKClassifier(Configuration conf,
+      Path samples, Path output, DistanceMeasure measure, int numClusters,
+      int maxIterations) throws IOException {
+    List<Vector> points = new ArrayList<Vector>();
+    for (int i = 0; i < numClusters; i++) {
+      points.add(SAMPLE_DATA.get(i).get());
+    }
+    List<Cluster> initialClusters = new ArrayList<Cluster>();
+    int id = 0;
+    for (Vector point : points) {
+      initialClusters.add(new SoftCluster(point, id++, measure));
+    }
+    ClusterClassifier prior = new ClusterClassifier(initialClusters);
+    Path priorClassifier = new Path(output, "classifier-0");
+    writeClassifier(prior, conf, priorClassifier);
+    
+    ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy();
+    new ClusterIterator(policy).iterate(samples, priorClassifier, output,
+        maxIterations);
+    for (int i = 1; i <= maxIterations; i++) {
+      ClusterClassifier posterior = readClassifier(conf, new Path(output,
+          "classifier-" + i));
+      CLUSTERS.add(posterior.getModels());
+    }
+  }
+  
+  private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples,
+      Path output, DistanceMeasure measure, int numClusters, int maxIterations)
+      throws IOException, ClassNotFoundException, InterruptedException {
+    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
+        output, "clusters-0"), 3, measure);
+    double threshold = 0.001;
+    int m = 3;
+    FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold,
+        maxIterations, m, true, true, threshold, true);
+    
+    loadClusters(output);
+  }
 }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
Mon May  9 01:49:20 2011
@@ -19,16 +19,17 @@ package org.apache.mahout.clustering.dis
 
 import java.awt.Graphics;
 import java.awt.Graphics2D;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.ClusterClassifier;
 import org.apache.mahout.clustering.ClusterIterator;
 import org.apache.mahout.clustering.ClusteringPolicy;
 import org.apache.mahout.clustering.KMeansClusteringPolicy;
-import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.kmeans.KMeansDriver;
 import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
 import org.apache.mahout.common.HadoopUtil;
@@ -36,7 +37,6 @@ import org.apache.mahout.common.RandomUt
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
 
 class DisplayKMeans extends DisplayClustering {
   
@@ -53,47 +53,64 @@ class DisplayKMeans extends DisplayClust
     Path samples = new Path("samples");
     Path output = new Path("output");
     Configuration conf = new Configuration();
+    int numClusters = 3;
+    int maxIterations = 10;
     HadoopUtil.delete(conf, samples);
     HadoopUtil.delete(conf, output);
     
     RandomUtils.useTestSeed();
     DisplayClustering.generateSamples();
     writeSampleData(samples);
-    boolean b = false;
-    if (b) {
-      Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
-          output, "clusters-0"), 3, measure);
-      int maxIter = 10;
-      double distanceThreshold = 0.001;
-      KMeansDriver.run(samples, clusters, output, measure, distanceThreshold,
-          maxIter, true, true);
-      loadClusters(output);
+    boolean runClusterer = false;
+    if (runClusterer) {
+      runSequentialKMeansClusterer(conf, samples, output, measure, numClusters,
+          maxIterations);
     } else {
-      List<Vector> points = new ArrayList<Vector>();
-      for (VectorWritable sample : SAMPLE_DATA) {
-        points.add(sample.get());
-      }
-      List<Cluster> initialClusters = new ArrayList<Cluster>();
-      int id = 0;
-      int numClusters = 4;
-      for (Vector point : points) {
-        if (initialClusters.size() < Math.min(numClusters, points.size())) {
-          initialClusters.add(new org.apache.mahout.clustering.kmeans.Cluster(
-              point, id++, measure));
-        } else {
-          break;
-        }
-      }
-      
-      ClusterClassifier prior = new ClusterClassifier(initialClusters);
-      ClusteringPolicy policy = new KMeansClusteringPolicy();
-      ClusterClassifier posterior = new ClusterIterator(policy).iterate(points,
-          prior, 10);
-      CLUSTERS.add(posterior.getModels());
+      runSequentialKMeansClassifier(conf, samples, output, measure,
+          numClusters, maxIterations);
     }
     new DisplayKMeans();
   }
   
+  private static void runSequentialKMeansClassifier(Configuration conf,
+      Path samples, Path output, DistanceMeasure measure, int numClusters,
+      int maxIterations) throws IOException {
+    List<Vector> points = new ArrayList<Vector>();
+    for (int i = 0; i < numClusters; i++) {
+      points.add(SAMPLE_DATA.get(i).get());
+    }
+    List<Cluster> initialClusters = new ArrayList<Cluster>();
+    int id = 0;
+    for (Vector point : points) {
+      initialClusters.add(new org.apache.mahout.clustering.kmeans.Cluster(
+          point, id++, measure));
+    }
+    ClusterClassifier prior = new ClusterClassifier(initialClusters);
+    Path priorClassifier = new Path(output, "clusters-0");
+    writeClassifier(prior, conf, priorClassifier);
+    
+    int maxIter = 10;
+    ClusteringPolicy policy = new KMeansClusteringPolicy();
+    new ClusterIterator(policy).iterate(samples, priorClassifier, output,
+        maxIter);
+    for (int i = 1; i <= maxIter; i++) {
+      ClusterClassifier posterior = readClassifier(conf, new Path(output,
+          "classifier-" + i));
+      CLUSTERS.add(posterior.getModels());
+    }
+  }
+  
+  private static void runSequentialKMeansClusterer(Configuration conf, Path samples,
+      Path output, DistanceMeasure measure, int numClusters, int maxIterations)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
+        output, "clusters-0"), 3, measure);
+    double distanceThreshold = 0.001;
+    KMeansDriver.run(samples, clusters, output, measure, distanceThreshold,
+        maxIterations, true, true);
+    loadClusters(output);
+  }
+  
   // Override the paint() method
   @Override
   public void paint(Graphics g) {



Mime
View raw message