mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r1292563 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/ core/src/main/java/org/apache/mahout/clustering/classify/ core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/ core/src/test/java/org/apache/mahout/cluste...
Date Wed, 22 Feb 2012 22:46:00 GMT
Author: jeastman
Date: Wed Feb 22 22:45:59 2012
New Revision: 1292563

URL: http://svn.apache.org/viewvc?rev=1292563&view=rev
Log:
MAHOUT-933: Refactored actual classification out of ClusterClassifier and into ClusteringPolicies. This
allows classifier to be completely generic as to the algorithm and gives policies correct use of e.g. fuzzyK 'm'
Introduced Canopy and MeanShift clustering policies for classification though not used by cluster iterator
Modified serialization of ClusterClassifiers to include ClusteringPolicy
Added ClusterClassifier serialization methods to exploded sequenceFile representation needed for MR
Updated Display examples and unit tests. All run

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CanopyClusteringPolicy.java   (with props)
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/MeanShiftClusteringPolicy.java   (with props)
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java Wed Feb 22 22:45:59 2012
@@ -28,9 +28,9 @@ public class CIMapper extends Mapper<Wri
   @Override
   protected void setup(Context context) throws IOException, InterruptedException {
     String priorClustersPath = context.getConfiguration().get(ClusterIterator.PRIOR_PATH_KEY);
-    String policyPath = context.getConfiguration().get(ClusterIterator.POLICY_PATH_KEY);
-    classifier = ClusterIterator.readClassifier(new Path(priorClustersPath));
-    policy = ClusterIterator.readPolicy(new Path(policyPath));
+    classifier = new ClusterClassifier();
+    classifier.readFromSeqFiles(new Path(priorClustersPath));
+    policy = classifier.getPolicy();
     policy.update(classifier);
     super.setup(context);
   }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CanopyClusteringPolicy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CanopyClusteringPolicy.java?rev=1292563&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CanopyClusteringPolicy.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CanopyClusteringPolicy.java Wed Feb 22 22:45:59 2012
@@ -0,0 +1,102 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.TimesFunction;
+
+/**
+ * This is a simple maximum likelihood clustering policy, suitable for k-means
+ * clustering
+ * 
+ */
+public class CanopyClusteringPolicy implements ClusteringPolicy {
+  
+  public CanopyClusteringPolicy() {
+    super();
+  }
+  
+  private double t1, t2;
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see
+   * org.apache.mahout.clustering.ClusteringPolicy#select(org.apache.mahout.
+   * math.Vector)
+   */
+  @Override
+  public Vector select(Vector probabilities) {
+    int maxValueIndex = probabilities.maxValueIndex();
+    Vector weights = new SequentialAccessSparseVector(probabilities.size());
+    weights.set(maxValueIndex, 1.0);
+    return weights;
+  }
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see
+   * org.apache.mahout.clustering.ClusteringPolicy#update(org.apache.mahout.
+   * clustering.ClusterClassifier)
+   */
+  @Override
+  public void update(ClusterClassifier posterior) {
+    // nothing to do here
+  }
+  
+  @Override
+  public Vector classify(Vector data, List<Cluster> models) {
+    int i = 0;
+    Vector pdfs = new DenseVector(models.size());
+    for (Cluster model : models) {
+      pdfs.set(i++, model.pdf(new VectorWritable(data)));
+    }
+    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
+  }
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.hadoop.io.Writable#write(java.io.DataOutput)
+   */
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeDouble(t1);
+    out.writeDouble(t2);
+  }
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.hadoop.io.Writable#readFields(java.io.DataInput)
+   */
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    this.t1 = in.readDouble();
+    this.t2 = in.readDouble();
+  }
+  
+}

Propchange: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CanopyClusteringPolicy.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java Wed Feb 22 22:45:59 2012
@@ -18,20 +18,27 @@ package org.apache.mahout.clustering;
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
-import java.util.Collection;
 import java.util.List;
+import java.util.Locale;
 
-import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.Writable;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.classifier.OnlineLearner;
-import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
-import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
 import org.apache.mahout.common.ClassUtils;
-import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.function.TimesFunction;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
 
 /**
  * This classifier works with any clustering Cluster. It is initialized with a
@@ -49,39 +56,33 @@ public class ClusterClassifier extends A
   
   private String modelClass;
   
+  private ClusteringPolicy policy;
+  
   /**
    * The public constructor accepts a list of clusters to become the models
    * 
    * @param models
    *          a List<Cluster>
+   * @param policy
+   *          a ClusteringPolicy
    */
-  public ClusterClassifier(List<Cluster> models) {
+  public ClusterClassifier(List<Cluster> models, ClusteringPolicy policy) {
     this.models = models;
     modelClass = models.get(0).getClass().getName();
+    this.policy = policy;
   }
   
   // needed for serialization/deserialization
   public ClusterClassifier() {}
   
+  // only used by MR ClusterIterator
+  protected ClusterClassifier(ClusteringPolicy policy) {
+    this.policy = policy;
+  }
+  
   @Override
   public Vector classify(Vector instance) {
-    if (models.get(0) instanceof SoftCluster) {
-      Collection<SoftCluster> clusters = Lists.newArrayList();
-      List<Double> distances = Lists.newArrayList();
-      for (Cluster model : models) {
-        SoftCluster sc = (SoftCluster) model;
-        clusters.add(sc);
-        distances.add(sc.getMeasure().distance(instance, sc.getCenter()));
-      }
-      return new FuzzyKMeansClusterer().computePi(clusters, distances);
-    } else {
-      int i = 0;
-      Vector pdfs = new DenseVector(models.size());
-      for (Cluster model : models) {
-        pdfs.set(i++, model.pdf(new VectorWritable(instance)));
-      }
-      return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
-    }
+    return policy.classify(instance, models);
   }
   
   @Override
@@ -103,6 +104,7 @@ public class ClusterClassifier extends A
   public void write(DataOutput out) throws IOException {
     out.writeInt(models.size());
     out.writeUTF(modelClass);
+    new ClusteringPolicyWritable(policy).write(out);
     for (Cluster cluster : models) {
       cluster.write(out);
     }
@@ -113,6 +115,9 @@ public class ClusterClassifier extends A
     int size = in.readInt();
     modelClass = in.readUTF();
     models = Lists.newArrayList();
+    ClusteringPolicyWritable clusteringPolicyWritable = new ClusteringPolicyWritable();
+    clusteringPolicyWritable.readFields(in);
+    policy = clusteringPolicyWritable.getValue();
     for (int i = 0; i < size; i++) {
       Cluster element = ClassUtils.instantiateAs(modelClass, Cluster.class);
       element.readFields(in);
@@ -159,4 +164,62 @@ public class ClusterClassifier extends A
   public List<Cluster> getModels() {
     return models;
   }
+  
+  public ClusteringPolicy getPolicy() {
+    return policy;
+  }
+  
+  public void writeToSeqFiles(Path path) throws IOException {
+    writePolicy(path);
+    Configuration config = new Configuration();
+    FileSystem fs = FileSystem.get(path.toUri(), config);
+    SequenceFile.Writer writer = null;
+    ClusterWritable cw = new ClusterWritable();
+    for (int i = 0; i < models.size(); i++) {
+      try {
+        Cluster cluster = models.get(i);
+        cw.setValue(cluster);
+        writer = new SequenceFile.Writer(fs, config,
+            new Path(path, "part-" + String.format(Locale.ENGLISH, "%05d", i)), IntWritable.class,
+            ClusterWritable.class);
+        Writable key = new IntWritable(i);
+        writer.append(key, cw);
+      } finally {
+        Closeables.closeQuietly(writer);
+      }
+    }
+  }
+  
+  public void readFromSeqFiles(Path path) throws IOException {
+    Configuration config = new Configuration();
+    List<Cluster> clusters = Lists.newArrayList();
+    for (ClusterWritable cw : new SequenceFileDirValueIterable<ClusterWritable>(path, PathType.LIST,
+        PathFilters.logsCRCFilter(), config)) {
+      clusters.add(cw.getValue());
+    }
+    this.models = clusters;
+    modelClass = models.get(0).getClass().getName();
+    this.policy = readPolicy(path);
+  }
+  
+  private ClusteringPolicy readPolicy(Path path) throws IOException {
+    Path policyPath = new Path(path, "_policy");
+    Configuration config = new Configuration();
+    FileSystem fs = FileSystem.get(policyPath.toUri(), config);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, policyPath, config);
+    Text key = new Text();
+    ClusteringPolicyWritable cpw = new ClusteringPolicyWritable();
+    reader.next(key, cpw);
+    return cpw.getValue();
+  }
+  
+  protected void writePolicy(Path path) throws IOException {
+    Path policyPath = new Path(path, "_policy");
+    Configuration config = new Configuration();
+    FileSystem fs = FileSystem.get(policyPath.toUri(), config);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, policyPath, Text.class,
+        ClusteringPolicyWritable.class);
+    writer.append(new Text(), new ClusteringPolicyWritable(policy));
+    writer.close();
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java Wed Feb 22 22:45:59 2012
@@ -18,17 +18,12 @@ package org.apache.mahout.clustering;
 
 import java.io.IOException;
 import java.util.Iterator;
-import java.util.List;
-import java.util.Locale;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Job;
 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
@@ -42,7 +37,6 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
-import com.google.common.collect.Lists;
 import com.google.common.io.Closeables;
 
 /**
@@ -55,8 +49,6 @@ import com.google.common.io.Closeables;
 public class ClusterIterator {
   
   public static final String PRIOR_PATH_KEY = "org.apache.mahout.clustering.prior.path";
-  public static final String POLICY_PATH_KEY = "org.apache.mahout.clustering.policy.path";
-  
   public ClusterIterator(ClusteringPolicy policy) {
     this.policy = policy;
   }
@@ -111,7 +103,8 @@ public class ClusterIterator {
    * @throws IOException
    */
   public void iterateSeq(Path inPath, Path priorPath, Path outPath, int numIterations) throws IOException {
-    ClusterClassifier classifier = readClassifier(priorPath);
+    ClusterClassifier classifier = new ClusterClassifier();
+    classifier.readFromSeqFiles(priorPath);
     Configuration conf = new Configuration();
     for (int iteration = 1; iteration <= numIterations; iteration++) {
       for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(inPath, PathType.LIST,
@@ -132,7 +125,7 @@ public class ClusterIterator {
       // update the policy
       policy.update(classifier);
       // output the classifier
-      writeClassifier(classifier, new Path(outPath, "classifier-" + iteration));
+      classifier.writeToSeqFiles(new Path(outPath, "classifier-" + iteration));
     }
   }
   
@@ -153,9 +146,7 @@ public class ClusterIterator {
       InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     HadoopUtil.delete(conf, outPath);
-    Path policyPath = new Path(outPath, "policy.seq");
-    writePolicy(policy, policyPath);
-    conf.set(POLICY_PATH_KEY, policyPath.toString());
+    ClusterClassifier classifier = new ClusterClassifier(policy);
     for (int iteration = 1; iteration <= numIterations; iteration++) {
       conf.set(PRIOR_PATH_KEY, priorPath.toString());
       
@@ -181,6 +172,7 @@ public class ClusterIterator {
       if (!job.waitForCompletion(true)) {
         throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath);
       }
+      classifier.writePolicy(clustersOut);
       FileSystem fs = FileSystem.get(outPath.toUri(), conf);
       if (isConverged(clustersOut, conf, fs)) {
         break;
@@ -212,53 +204,4 @@ public class ClusterIterator {
     }
     return true;
   }
-  
-  public static void writeClassifier(ClusterClassifier classifier, Path outPath) throws IOException {
-    Configuration config = new Configuration();
-    FileSystem fs = FileSystem.get(outPath.toUri(), config);
-    SequenceFile.Writer writer = null;
-    ClusterWritable cw = new ClusterWritable();
-    for (int i = 0; i < classifier.getModels().size(); i++) {
-      try {
-        Cluster cluster = classifier.getModels().get(i);
-        cw.setValue(cluster);
-        writer = new SequenceFile.Writer(fs, config, new Path(outPath, "part-"
-            + String.format(Locale.ENGLISH, "%05d", i)), IntWritable.class, ClusterWritable.class);
-        Writable key = new IntWritable(i);
-        writer.append(key, cw);
-      } finally {
-        Closeables.closeQuietly(writer);
-      }
-    }
-  }
-  
-  public static ClusterClassifier readClassifier(Path inPath) throws IOException {
-    Configuration config = new Configuration();
-    List<Cluster> clusters = Lists.newArrayList();
-    for (ClusterWritable cw : new SequenceFileDirValueIterable<ClusterWritable>(inPath, PathType.LIST,
-        PathFilters.logsCRCFilter(), config)) {
-      clusters.add(cw.getValue());
-    }
-    ClusterClassifier classifierOut = new ClusterClassifier(clusters);
-    return classifierOut;
-  }
-  
-  public static ClusteringPolicy readPolicy(Path policyPath) throws IOException {
-    Configuration config = new Configuration();
-    FileSystem fs = FileSystem.get(policyPath.toUri(), config);
-    SequenceFile.Reader reader = new SequenceFile.Reader(fs, policyPath, config);
-    Text key = new Text();
-    ClusteringPolicyWritable cpw = new ClusteringPolicyWritable();
-    reader.next(key, cpw);
-    return cpw.getValue();
-  }
-  
-  public static void writePolicy(ClusteringPolicy policy, Path policyPath) throws IOException {
-    Configuration config = new Configuration();
-    FileSystem fs = FileSystem.get(policyPath.toUri(), config);
-    SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, policyPath, Text.class,
-        ClusteringPolicyWritable.class);
-    writer.append(new Text(), new ClusteringPolicyWritable(policy));
-    writer.close();
-  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java Wed Feb 22 22:45:59 2012
@@ -16,6 +16,8 @@
  */
 package org.apache.mahout.clustering;
 
+import java.util.List;
+
 import org.apache.hadoop.io.Writable;
 import org.apache.mahout.math.Vector;
 
@@ -23,10 +25,10 @@ import org.apache.mahout.math.Vector;
  * A ClusteringPolicy captures the semantics of assignment of points to clusters
  * 
  */
-public interface ClusteringPolicy extends Writable{
+public interface ClusteringPolicy extends Writable {
   
   /**
-   * Return the index of the most appropriate model
+   * Return a vector of weights for each of the models given those probabilities
    * 
    * @param probabilities
    *          a Vector of pdfs
@@ -42,4 +44,14 @@ public interface ClusteringPolicy extend
    */
   void update(ClusterClassifier posterior);
   
+  /**
+   * @param data
+   *          a data Vector
+   * @param models
+   *          a list of Cluster models
+   * @return a Vector of probabilities that the data is described by each of the
+   *         models
+   */
+  Vector classify(Vector data, List<Cluster> models);
+  
 }
\ No newline at end of file

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java Wed Feb 22 22:45:59 2012
@@ -19,19 +19,21 @@ package org.apache.mahout.clustering;
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.List;
 
 import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.TimesFunction;
 
 public class DirichletClusteringPolicy implements ClusteringPolicy {
   
   public DirichletClusteringPolicy() {
     super();
   }
-
+  
   public DirichletClusteringPolicy(int k, double alpha0) {
     this.alpha0 = alpha0;
     this.mixture = UncommonDistributions.rDirichlet(new DenseVector(k), alpha0);
@@ -44,8 +46,12 @@ public class DirichletClusteringPolicy i
   // Alpha_0 primes the Dirichlet distribution
   private double alpha0;
   
-  /* (non-Javadoc)
-   * @see org.apache.mahout.clustering.ClusteringPolicy#select(org.apache.mahout.math.Vector)
+  /*
+   * (non-Javadoc)
+   * 
+   * @see
+   * org.apache.mahout.clustering.ClusteringPolicy#select(org.apache.mahout.
+   * math.Vector)
    */
   @Override
   public Vector select(Vector probabilities) {
@@ -56,8 +62,12 @@ public class DirichletClusteringPolicy i
   }
   
   // update the total counts and then the mixture
-  /* (non-Javadoc)
-   * @see org.apache.mahout.clustering.ClusteringPolicy#update(org.apache.mahout.clustering.ClusterClassifier)
+  /*
+   * (non-Javadoc)
+   * 
+   * @see
+   * org.apache.mahout.clustering.ClusteringPolicy#update(org.apache.mahout.
+   * clustering.ClusterClassifier)
    */
   @Override
   public void update(ClusterClassifier prior) {
@@ -67,8 +77,20 @@ public class DirichletClusteringPolicy i
     }
     mixture = UncommonDistributions.rDirichlet(totalCounts, alpha0);
   }
+  
+  @Override
+  public Vector classify(Vector data, List<Cluster> models) {
+    int i = 0;
+    Vector pdfs = new DenseVector(models.size());
+    for (Cluster model : models) {
+      pdfs.set(i++, model.pdf(new VectorWritable(data)));
+    }
+    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
+  }
 
-  /* (non-Javadoc)
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.hadoop.io.Writable#write(java.io.DataOutput)
    */
   @Override
@@ -76,8 +98,10 @@ public class DirichletClusteringPolicy i
     out.writeDouble(alpha0);
     VectorWritable.writeVector(out, mixture);
   }
-
-  /* (non-Javadoc)
+  
+  /*
+   * (non-Javadoc)
+   * 
    * @see org.apache.hadoop.io.Writable#readFields(java.io.DataInput)
    */
   @Override

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java Wed Feb 22 22:45:59 2012
@@ -19,9 +19,15 @@ package org.apache.mahout.clustering;
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
 
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
 import org.apache.mahout.math.Vector;
 
+import com.google.common.collect.Lists;
+
 /**
  * This is a probability-weighted clustering policy, suitable for fuzzy k-means
  * clustering
@@ -33,9 +39,9 @@ public class FuzzyKMeansClusteringPolicy
     super();
   }
   
-  private double m;
+  private double m = 2;
   
-  private double convergenceDelta;
+  private double convergenceDelta = 0.05;
   
   public FuzzyKMeansClusteringPolicy(double m, double convergenceDelta) {
     this.m = m;
@@ -87,4 +93,18 @@ public class FuzzyKMeansClusteringPolicy
     this.convergenceDelta = in.readDouble();
   }
   
+  @Override
+  public Vector classify(Vector data, List<Cluster> models) {
+    Collection<SoftCluster> clusters = Lists.newArrayList();
+    List<Double> distances = Lists.newArrayList();
+    for (Cluster model : models) {
+      SoftCluster sc = (SoftCluster) model;
+      clusters.add(sc);
+      distances.add(sc.getMeasure().distance(data, sc.getCenter()));
+    }
+    FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer();
+    fuzzyKMeansClusterer.setM(m);
+    return fuzzyKMeansClusterer.computePi(clusters, distances);
+  }
+  
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java Wed Feb 22 22:45:59 2012
@@ -19,9 +19,13 @@ package org.apache.mahout.clustering;
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.List;
 
+import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.TimesFunction;
 
 /**
  * This is a simple maximum likelihood clustering policy, suitable for k-means
@@ -39,7 +43,7 @@ public class KMeansClusteringPolicy impl
     this.convergenceDelta = convergenceDelta;
   }
 
-  private double convergenceDelta;
+  private double convergenceDelta = 0.05;
   
   /* (non-Javadoc)
    * @see org.apache.mahout.clustering.ClusteringPolicy#select(org.apache.mahout.math.Vector)
@@ -60,6 +64,16 @@ public class KMeansClusteringPolicy impl
     // nothing to do here
   }
 
+  @Override
+  public Vector classify(Vector data, List<Cluster> models) {
+    int i = 0;
+    Vector pdfs = new DenseVector(models.size());
+    for (Cluster model : models) {
+      pdfs.set(i++, model.pdf(new VectorWritable(data)));
+    }
+    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
+  }
+
   /* (non-Javadoc)
    * @see org.apache.hadoop.io.Writable#write(java.io.DataOutput)
    */

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/MeanShiftClusteringPolicy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/MeanShiftClusteringPolicy.java?rev=1292563&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/MeanShiftClusteringPolicy.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/MeanShiftClusteringPolicy.java Wed Feb 22 22:45:59 2012
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.TimesFunction;
+
+/**
+ * This is a simple maximum likelihood clustering policy, suitable for k-means
+ * clustering
+ * 
+ */
+public class MeanShiftClusteringPolicy implements ClusteringPolicy {
+  
+  public MeanShiftClusteringPolicy() {
+    super();
+  }
+  
+  private double t1, t2, t3, t4;
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see
+   * org.apache.mahout.clustering.ClusteringPolicy#select(org.apache.mahout.
+   * math.Vector)
+   */
+  @Override
+  public Vector select(Vector probabilities) {
+    int maxValueIndex = probabilities.maxValueIndex();
+    Vector weights = new SequentialAccessSparseVector(probabilities.size());
+    weights.set(maxValueIndex, 1.0);
+    return weights;
+  }
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see
+   * org.apache.mahout.clustering.ClusteringPolicy#update(org.apache.mahout.
+   * clustering.ClusterClassifier)
+   */
+  @Override
+  public void update(ClusterClassifier posterior) {
+    // nothing to do here
+  }
+  
+  @Override
+  public Vector classify(Vector data, List<Cluster> models) {
+    int i = 0;
+    Vector pdfs = new DenseVector(models.size());
+    for (Cluster model : models) {
+      pdfs.set(i++, model.pdf(new VectorWritable(data)));
+    }
+    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
+  }
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.hadoop.io.Writable#write(java.io.DataOutput)
+   */
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeDouble(t1);
+    out.writeDouble(t2);
+    out.writeDouble(t3);
+    out.writeDouble(t4);
+  }
+  
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.hadoop.io.Writable#readFields(java.io.DataInput)
+   */
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    this.t1 = in.readDouble();
+    this.t2 = in.readDouble();
+    this.t3 = in.readDouble();
+    this.t4 = in.readDouble();
+  }
+  
+}

Propchange: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/MeanShiftClusteringPolicy.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java Wed Feb 22 22:45:59 2012
@@ -132,7 +132,7 @@ public class ClusterClassificationDriver
 	  
 	  private static void classifyClusterSeq(Path input, Path clusters, Path output, Double clusterClassificationThreshold) throws IOException {
 	    List<Cluster> clusterModels = populateClusterModels(clusters);
-	    ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels);
+	    ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels, null);
       selectCluster(input, clusterModels, clusterClassifier, output, clusterClassificationThreshold);
       
 	  }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java Wed Feb 22 22:45:59 2012
@@ -64,7 +64,7 @@ public class ClusterClassificationMapper
       if (clustersIn != null && !clustersIn.isEmpty()) {
         Path clustersInPath = new Path(clustersIn, "*");
         populateClusterModels(clustersInPath);
-        clusterClassifier = new ClusterClassifier(clusterModels);
+        clusterClassifier = new ClusterClassifier(clusterModels, null);
       }
       threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
       clusterId = new IntWritable();

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Wed Feb 22 22:45:59 2012
@@ -327,4 +327,8 @@ public class FuzzyKMeansClusterer {
     // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null));
     writer.append(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, point));
   }
+
+  public void setM(double m) {
+    this.m = m;
+  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java Wed Feb 22 22:45:59 2012
@@ -46,16 +46,16 @@ public final class TestClusterClassifier
     models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0, measure));
     models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
     models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2, measure));
-    return new ClusterClassifier(models);
+    return new ClusterClassifier(models, new KMeansClusteringPolicy());
   }
   
-  private static ClusterClassifier newClusterClassifier() {
+  private static ClusterClassifier newKlusterClassifier() {
     List<Cluster> models = Lists.newArrayList();
     DistanceMeasure measure = new ManhattanDistanceMeasure();
     models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
     models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
     models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
-    return new ClusterClassifier(models);
+    return new ClusterClassifier(models, new KMeansClusteringPolicy());
   }
   
   private static ClusterClassifier newSoftClusterClassifier() {
@@ -64,7 +64,7 @@ public final class TestClusterClassifier
     models.add(new SoftCluster(new DenseVector(2).assign(1), 0, measure));
     models.add(new SoftCluster(new DenseVector(2), 1, measure));
     models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
-    return new ClusterClassifier(models);
+    return new ClusterClassifier(models, new FuzzyKMeansClusteringPolicy());
   }
   
   private static ClusterClassifier newGaussianClassifier() {
@@ -72,13 +72,15 @@ public final class TestClusterClassifier
     models.add(new GaussianCluster(new DenseVector(2).assign(1), new DenseVector(2).assign(1), 0));
     models.add(new GaussianCluster(new DenseVector(2), new DenseVector(2).assign(1), 1));
     models.add(new GaussianCluster(new DenseVector(2).assign(-1), new DenseVector(2).assign(1), 2));
-    return new ClusterClassifier(models);
+    return new ClusterClassifier(models, new DirichletClusteringPolicy(3, 1.0));
   }
   
   private ClusterClassifier writeAndRead(ClusterClassifier classifier) throws IOException {
     Path path = new Path(getTestTempDirPath(), "output");
-    ClusterIterator.writeClassifier(classifier, path);
-    return ClusterIterator.readClassifier(path);
+    classifier.writeToSeqFiles(path);
+    ClusterClassifier newClassifier = new ClusterClassifier();
+    newClassifier.readFromSeqFiles(path);
+    return newClassifier;
   }
   
   @Test
@@ -97,7 +99,7 @@ public final class TestClusterClassifier
     models.add(new Canopy(new DenseVector(2).assign(1), 0, measure));
     models.add(new Canopy(new DenseVector(2), 1, measure));
     models.add(new Canopy(new DenseVector(2).assign(-1), 2, measure));
-    ClusterClassifier classifier = new ClusterClassifier(models);
+    ClusterClassifier classifier = new ClusterClassifier(models, new CanopyClusteringPolicy());
     Vector pdf = classifier.classify(new DenseVector(2));
     assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(pdf, null));
     pdf = classifier.classify(new DenseVector(2).assign(2));
@@ -106,7 +108,7 @@ public final class TestClusterClassifier
   
   @Test
   public void testClusterClassification() {
-    ClusterClassifier classifier = newClusterClassifier();
+    ClusterClassifier classifier = newKlusterClassifier();
     Vector pdf = classifier.classify(new DenseVector(2));
     assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(pdf, null));
     pdf = classifier.classify(new DenseVector(2).assign(2));
@@ -120,7 +122,7 @@ public final class TestClusterClassifier
     models.add(new MeanShiftCanopy(new DenseVector(2).assign(1), 0, measure));
     models.add(new MeanShiftCanopy(new DenseVector(2), 1, measure));
     models.add(new MeanShiftCanopy(new DenseVector(2).assign(-1), 2, measure));
-    ClusterClassifier classifier = new ClusterClassifier(models);
+    ClusterClassifier classifier = new ClusterClassifier(models, new MeanShiftClusteringPolicy());
     classifier.classify(new DenseVector(2));
   }
   
@@ -153,7 +155,7 @@ public final class TestClusterClassifier
   
   @Test
   public void testClusterClassifierSerialization() throws Exception {
-    ClusterClassifier classifier = newClusterClassifier();
+    ClusterClassifier classifier = newKlusterClassifier();
     ClusterClassifier classifierOut = writeAndRead(classifier);
     assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
     assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
@@ -182,7 +184,7 @@ public final class TestClusterClassifier
   public void testClusterIteratorKMeans() {
     List<Vector> data = TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
     ClusteringPolicy policy = new KMeansClusteringPolicy();
-    ClusterClassifier prior = newClusterClassifier();
+    ClusterClassifier prior = newKlusterClassifier();
     ClusterIterator iterator = new ClusterIterator(policy);
     ClusterClassifier posterior = iterator.iterate(data, prior, 5);
     assertEquals(3, posterior.getModels().size());
@@ -195,7 +197,7 @@ public final class TestClusterClassifier
   public void testClusterIteratorDirichlet() {
     List<Vector> data = TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
     ClusteringPolicy policy = new DirichletClusteringPolicy(3, 1);
-    ClusterClassifier prior = newClusterClassifier();
+    ClusterClassifier prior = newKlusterClassifier();
     ClusterIterator iterator = new ClusterIterator(policy);
     ClusterClassifier posterior = iterator.iterate(data, prior, 5);
     assertEquals(3, posterior.getModels().size());
@@ -214,20 +216,20 @@ public final class TestClusterClassifier
     List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
     ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
     Path path = new Path(priorPath, "priorClassifier");
-    ClusterClassifier prior = newClusterClassifier();
-    ClusterIterator.writeClassifier(prior, path);
+    ClusterClassifier prior = newKlusterClassifier();
+    prior.writeToSeqFiles(path);
     assertEquals(3, prior.getModels().size());
     System.out.println("Prior");
     for (Cluster cluster : prior.getModels()) {
       System.out.println(cluster.asFormatString(null));
     }
-    ClusteringPolicy policy = new KMeansClusteringPolicy();
-    ClusterIterator iterator = new ClusterIterator(policy);
+    ClusterIterator iterator = new ClusterIterator(prior.getPolicy());
     iterator.iterateSeq(pointsPath, path, outPath, 5);
     
     for (int i = 1; i <= 5; i++) {
       System.out.println("Classifier-" + i);
-      ClusterClassifier posterior = ClusterIterator.readClassifier(new Path(outPath, "classifier-" + i));
+      ClusterClassifier posterior = new ClusterClassifier();
+      posterior.readFromSeqFiles(new Path(outPath, "classifier-" + i));
       assertEquals(3, posterior.getModels().size());
       for (Cluster cluster : posterior.getModels()) {
         System.out.println(cluster.asFormatString(null));
@@ -246,8 +248,8 @@ public final class TestClusterClassifier
     List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
     ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
     Path path = new Path(priorPath, "priorClassifier");
-    ClusterClassifier prior = newClusterClassifier();
-    ClusterIterator.writeClassifier(prior, path);
+    ClusterClassifier prior = newKlusterClassifier();
+    prior.writeToSeqFiles(path);
     assertEquals(3, prior.getModels().size());
     System.out.println("Prior");
     for (Cluster cluster : prior.getModels()) {
@@ -255,11 +257,12 @@ public final class TestClusterClassifier
     }
     ClusteringPolicy policy = new KMeansClusteringPolicy();
     ClusterIterator iterator = new ClusterIterator(policy);
-    iterator.iterateMR(pointsPath, path, outPath, 5);
+    iterator.iterateMR(pointsPath, path, outPath, 3);
     
-    for (int i = 1; i <= 5; i++) {
+    for (int i = 1; i <= 3; i++) {
       System.out.println("Classifier-" + i);
-      ClusterClassifier posterior = ClusterIterator.readClassifier(new Path(outPath, "clusters-" + i));
+      ClusterClassifier posterior = new ClusterClassifier();
+      posterior.readFromSeqFiles(new Path(outPath, "clusters-" + i));
       assertEquals(3, posterior.getModels().size());
       for (Cluster cluster : posterior.getModels()) {
         System.out.println(cluster.asFormatString(null));

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java Wed Feb 22 22:45:59 2012
@@ -46,8 +46,8 @@ public class DisplayDirichlet extends Di
   
   public DisplayDirichlet() {
     initialize();
-    this.setTitle("Dirichlet Process Clusters - Normal Distribution (>"
-        + (int) (significance * 100) + "% of population)");
+    this.setTitle("Dirichlet Process Clusters - Normal Distribution (>" + (int) (significance * 100)
+        + "% of population)");
   }
   
   // Override the paint() method
@@ -74,37 +74,33 @@ public class DisplayDirichlet extends Di
     log.info(models.toString());
   }
   
-  protected static void generateResults(ModelDistribution<VectorWritable> modelDist,
-                                        int numClusters,
-                                        int numIterations,
-                                        double alpha0,
-                                        int thin,
-                                        int burnin) throws IOException {
+  protected static void generateResults(ModelDistribution<VectorWritable> modelDist, int numClusters,
+      int numIterations, double alpha0, int thin, int burnin) throws IOException {
     boolean runClusterer = false;
     if (runClusterer) {
       runSequentialDirichletClusterer(modelDist, numClusters, numIterations, alpha0, thin, burnin);
     } else {
-      runSequentialDirichletClassifier(modelDist, numClusters, numIterations);
+      runSequentialDirichletClassifier(modelDist, numClusters, numIterations, alpha0);
     }
   }
   
-  private static void runSequentialDirichletClassifier(ModelDistribution<VectorWritable> modelDist,
-                                                       int numClusters,
-                                                       int numIterations) throws IOException {
+  private static void runSequentialDirichletClassifier(ModelDistribution<VectorWritable> modelDist, int numClusters,
+      int numIterations, double alpha0) throws IOException {
     List<Cluster> models = Lists.newArrayList();
     for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) {
       models.add((Cluster) cluster);
     }
-    ClusterClassifier prior = new ClusterClassifier(models);
+    ClusterClassifier prior = new ClusterClassifier(models, new DirichletClusteringPolicy(numClusters, alpha0));
     Path samples = new Path("samples");
     Path output = new Path("output");
-    Path priorClassifier = new Path(output, "clusters-0");
-    ClusterIterator.writeClassifier(prior, priorClassifier);
+    Path priorPath = new Path(output, "clusters-0");
+    prior.writeToSeqFiles(priorPath);
     
     ClusteringPolicy policy = new DirichletClusteringPolicy(numClusters, numIterations);
-    new ClusterIterator(policy).iterateSeq(samples, priorClassifier, output, numIterations);
+    new ClusterIterator(policy).iterateSeq(samples, priorPath, output, numIterations);
     for (int i = 1; i <= numIterations; i++) {
-      ClusterClassifier posterior = ClusterIterator.readClassifier(new Path(output, "classifier-" + i));
+      ClusterClassifier posterior = new ClusterClassifier();
+      posterior.readFromSeqFiles(new Path(output, "classifier-" + i));
       List<Cluster> clusters = Lists.newArrayList();
       for (Cluster cluster : posterior.getModels()) {
         if (isSignificant(cluster)) {
@@ -115,12 +111,8 @@ public class DisplayDirichlet extends Di
     }
   }
   
-  private static void runSequentialDirichletClusterer(ModelDistribution<VectorWritable> modelDist,
-                                                      int numClusters,
-                                                      int numIterations,
-                                                      double alpha0,
-                                                      int thin,
-                                                      int burnin) {
+  private static void runSequentialDirichletClusterer(ModelDistribution<VectorWritable> modelDist, int numClusters,
+      int numIterations, double alpha0, int thin, int burnin) {
     DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist, alpha0, numClusters, thin, burnin);
     List<Cluster[]> result = dc.cluster(numIterations);
     printModels(result, burnin);

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java Wed Feb 22 22:45:59 2012
@@ -44,8 +44,7 @@ public class DisplayFuzzyKMeans extends 
   
   DisplayFuzzyKMeans() {
     initialize();
-    this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100)
-        + "% of population)");
+    this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
   }
   
   // Override the paint() method
@@ -68,21 +67,19 @@ public class DisplayFuzzyKMeans extends 
     writeSampleData(samples);
     boolean runClusterer = false;
     int maxIterations = 10;
+    float threshold = 0.001F;
+    float m = 1.1F;
     if (runClusterer) {
-      runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations);
+      runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations, m, threshold);
     } else {
       int numClusters = 3;
-      runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations);
+      runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations, m, threshold);
     }
     new DisplayFuzzyKMeans();
   }
   
-  private static void runSequentialFuzzyKClassifier(Configuration conf,
-                                                    Path samples,
-                                                    Path output,
-                                                    DistanceMeasure measure,
-                                                    int numClusters,
-                                                    int maxIterations) throws IOException {
+  private static void runSequentialFuzzyKClassifier(Configuration conf, Path samples, Path output,
+      DistanceMeasure measure, int numClusters, int maxIterations, float m, double threshold) throws IOException {
     Collection<Vector> points = Lists.newArrayList();
     for (int i = 0; i < numClusters; i++) {
       points.add(SAMPLE_DATA.get(i).get());
@@ -92,30 +89,24 @@ public class DisplayFuzzyKMeans extends 
     for (Vector point : points) {
       initialClusters.add(new SoftCluster(point, id++, measure));
     }
-    ClusterClassifier prior = new ClusterClassifier(initialClusters);
-    Path priorClassifier = new Path(output, "classifier-0");
-    ClusterIterator.writeClassifier(prior, priorClassifier);
+    ClusterClassifier prior = new ClusterClassifier(initialClusters, new FuzzyKMeansClusteringPolicy(m, threshold));
+    Path priorPath = new Path(output, "classifier-0");
+    prior.writeToSeqFiles(priorPath);
     
     ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy(1.1, 0.001);
-    new ClusterIterator(policy).iterateSeq(samples, priorClassifier, output, maxIterations);
+    new ClusterIterator(policy).iterateSeq(samples, priorPath, output, maxIterations);
     for (int i = 1; i <= maxIterations; i++) {
-      ClusterClassifier posterior = ClusterIterator.readClassifier(new Path(output, "classifier-" + i));
+      ClusterClassifier posterior = new ClusterClassifier();
+      posterior.readFromSeqFiles(new Path(output, "classifier-" + i));
       CLUSTERS.add(posterior.getModels());
     }
   }
   
-  private static void runSequentialFuzzyKClusterer(Configuration conf,
-                                                   Path samples,
-                                                   Path output,
-                                                   DistanceMeasure measure,
-                                                   int maxIterations)
-    throws IOException, ClassNotFoundException, InterruptedException {
-    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
-        output, "clusters-0"), 3, measure);
-    double threshold = 0.001;
-    float m = 1.1F;
-    FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold,
-        maxIterations, m, true, true, threshold, true);
+  private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output,
+      DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException,
+      ClassNotFoundException, InterruptedException {
+    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(output, "clusters-0"), 3, measure);
+    FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold, maxIterations, m, true, true, threshold, true);
     
     loadClusters(output);
   }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=1292563&r1=1292562&r2=1292563&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java Wed Feb 22 22:45:59 2012
@@ -40,7 +40,7 @@ import org.apache.mahout.common.distance
 import org.apache.mahout.math.Vector;
 
 public class DisplayKMeans extends DisplayClustering {
-
+  
   DisplayKMeans() {
     initialize();
     this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)");
@@ -58,21 +58,19 @@ public class DisplayKMeans extends Displ
     DisplayClustering.generateSamples();
     writeSampleData(samples);
     boolean runClusterer = false;
+    double convergenceDelta = 0.001;
     if (runClusterer) {
       int numClusters = 3;
-      runSequentialKMeansClusterer(conf, samples, output, measure, numClusters);
+      runSequentialKMeansClusterer(conf, samples, output, measure, numClusters, convergenceDelta);
     } else {
       int maxIterations = 10;
-      runSequentialKMeansClassifier(conf, samples, output, measure, maxIterations);
+      runSequentialKMeansClassifier(conf, samples, output, measure, maxIterations, convergenceDelta);
     }
     new DisplayKMeans();
   }
   
-  private static void runSequentialKMeansClassifier(Configuration conf,
-                                                    Path samples,
-                                                    Path output,
-                                                    DistanceMeasure measure,
-                                                    int numClusters) throws IOException {
+  private static void runSequentialKMeansClassifier(Configuration conf, Path samples, Path output,
+      DistanceMeasure measure, int numClusters, double convergenceDelta) throws IOException {
     Collection<Vector> points = Lists.newArrayList();
     for (int i = 0; i < numClusters; i++) {
       points.add(SAMPLE_DATA.get(i).get());
@@ -80,33 +78,27 @@ public class DisplayKMeans extends Displ
     List<Cluster> initialClusters = Lists.newArrayList();
     int id = 0;
     for (Vector point : points) {
-      initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(
-          point, id++, measure));
+      initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure));
     }
-    ClusterClassifier prior = new ClusterClassifier(initialClusters);
-    Path priorClassifier = new Path(output, "clusters-0");
-    ClusterIterator.writeClassifier(prior, priorClassifier);
+    ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta));
+    Path priorPath = new Path(output, "clusters-0");
+    prior.writeToSeqFiles(priorPath);
     
     int maxIter = 10;
     ClusteringPolicy policy = new KMeansClusteringPolicy();
-    new ClusterIterator(policy).iterateSeq(samples, priorClassifier, output, maxIter);
+    new ClusterIterator(policy).iterateSeq(samples, priorPath, output, maxIter);
     for (int i = 1; i <= maxIter; i++) {
-      ClusterClassifier posterior = ClusterIterator.readClassifier(new Path(output, "classifier-" + i));
+      ClusterClassifier posterior = new ClusterClassifier();
+      posterior.readFromSeqFiles(new Path(output, "classifier-" + i));
       CLUSTERS.add(posterior.getModels());
     }
   }
   
-  private static void runSequentialKMeansClusterer(Configuration conf,
-                                                   Path samples,
-                                                   Path output,
-                                                   DistanceMeasure measure,
-                                                   int maxIterations)
-    throws IOException, InterruptedException, ClassNotFoundException {
-    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
-        output, "clusters-0"), 3, measure);
-    double distanceThreshold = 0.001;
-    KMeansDriver.run(samples, clusters, output, measure, distanceThreshold,
-        maxIterations, true, true);
+  private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output,
+      DistanceMeasure measure, int maxIterations, double convergenceDelta) throws IOException, InterruptedException,
+      ClassNotFoundException {
+    Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(output, "clusters-0"), 3, measure);
+    KMeansDriver.run(samples, clusters, output, measure, convergenceDelta, maxIterations, true, true);
     loadClusters(output);
   }
   



Mime
View raw message