mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r953164 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/kmeans/ core/src/test/java/org/apache/mahout/clustering/kmeans/ examples/src/main/java/org/apache/mahout/...
Date Wed, 09 Jun 2010 21:18:39 GMT
Author: jeastman
Date: Wed Jun  9 21:18:37 2010
New Revision: 953164

URL: http://svn.apache.org/viewvc?rev=953164&view=rev
Log:
MAHOUT-167: Converted k-Means to 0.20.2

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java Wed Jun  9 21:18:37 2010
@@ -22,8 +22,8 @@ import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapreduce.Mapper.Context;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.dirichlet.models.Model;
 import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
@@ -124,12 +124,8 @@ public class DirichletClusterer<O> {
    * @param burnin
    *          the int burnin interval, used to suppress early iterations
    */
-  public DirichletClusterer(List<O> sampleData,
-                            ModelDistribution<O> modelFactory,
-                            double alpha0,
-                            int numClusters,
-                            int thin,
-                            int burnin) {
+  public DirichletClusterer(List<O> sampleData, ModelDistribution<O> modelFactory, double alpha0, int numClusters, int thin,
+      int burnin) {
     this.sampleData = sampleData;
     this.modelFactory = modelFactory;
     this.thin = thin;
@@ -224,8 +220,9 @@ public class DirichletClusterer<O> {
     return pi;
   }
 
-  public void emitPointToClusters(VectorWritable vector, List<DirichletCluster<VectorWritable>> clusters, Context context)
-      throws IOException, InterruptedException {
+  public void emitPointToClusters(VectorWritable vector, List<DirichletCluster<VectorWritable>> clusters,
+      Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context) throws IOException,
+      InterruptedException {
     Vector pi = new DenseVector(clusters.size());
     for (int i = 0; i < clusters.size(); i++) {
       pi.set(i, clusters.get(i).getModel().pdf(vector));
@@ -247,7 +244,8 @@ public class DirichletClusterer<O> {
    * @throws InterruptedException 
    */
   private void emitMostLikelyCluster(VectorWritable point, List<DirichletCluster<VectorWritable>> clusters, Vector pi,
-      Context context) throws IOException, InterruptedException {
+      Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context) throws IOException,
+      InterruptedException {
     int clusterId = -1;
     double clusterPdf = 0;
     for (int i = 0; i < clusters.size(); i++) {
@@ -262,7 +260,8 @@ public class DirichletClusterer<O> {
   }
 
   private void emitAllClusters(VectorWritable point, List<DirichletCluster<VectorWritable>> clusters, Vector pi,
-      Context context) throws IOException, InterruptedException {
+      Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context) throws IOException,
+      InterruptedException {
     for (int i = 0; i < clusters.size(); i++) {
       double pdf = pi.get(i);
       if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
@@ -290,15 +289,9 @@ public class DirichletClusterer<O> {
    * @param numIterations
    *          number of iterations to be performed
    */
-  public static List<Model<Vector>[]> clusterPoints(List<Vector> points, 
-                                                    ModelDistribution<Vector> modelFactory,
-                                                    double alpha0,
-                                                    int numClusters,
-                                                    int thin,
-                                                    int burnin,
-                                                    int numIterations) {
-    DirichletClusterer<Vector> clusterer =
-        new DirichletClusterer<Vector>(points, modelFactory, alpha0, numClusters, thin, burnin);
+  public static List<Model<Vector>[]> clusterPoints(List<Vector> points, ModelDistribution<Vector> modelFactory, double alpha0,
+      int numClusters, int thin, int burnin, int numIterations) {
+    DirichletClusterer<Vector> clusterer = new DirichletClusterer<Vector>(points, modelFactory, alpha0, numClusters, thin, burnin);
     return clusterer.cluster(numIterations);
 
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Wed Jun  9 21:18:37 2010
@@ -204,7 +204,7 @@ public final class DirichletDriver {
     int protoSize = 0;
     for (FileStatus s : status) {
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, s.getPath(), conf);
-      WritableComparable key = (WritableComparable) reader.getKeyClass().newInstance();
+      WritableComparable<?> key = (WritableComparable<?>) reader.getKeyClass().newInstance();
       VectorWritable value = new VectorWritable();
       if (reader.next(key, value)) {
         protoSize = value.get().size();

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java Wed Jun  9 21:18:37 2010
@@ -21,50 +21,49 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.WritableComparable;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.Mapper;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 
-public class KMeansClusterMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable> {
+public class KMeansClusterMapper extends Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable> {
   
   private final List<Cluster> clusters = new ArrayList<Cluster>();
   private KMeansClusterer clusterer;
   
+  
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Mapper#map(java.lang.Object, java.lang.Object, org.apache.hadoop.mapreduce.Mapper.Context)
+   */
   @Override
-  public void map(WritableComparable<?> key,
-                  VectorWritable point,
-                  OutputCollector<IntWritable,WeightedVectorWritable> output,
-                  Reporter reporter) throws IOException {
-    clusterer.outputPointWithClusterInfo(point.get(), clusters, output);
+  protected void map(WritableComparable<?> key, VectorWritable point, Context context) throws IOException, InterruptedException {
+    clusterer.outputPointWithClusterInfo(point.get(), clusters, context);
   }
-  
+
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context)
+   */
   @Override
-  public void configure(JobConf job) {
-    super.configure(job);
-    
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
     try {
       ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-      Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+      Class<?> cl = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
       DistanceMeasure measure = (DistanceMeasure) cl.newInstance();
-      measure.configure(job);
+      measure.configure(conf);
       
-      String clusterPath = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
+      String clusterPath = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
       if ((clusterPath != null) && (clusterPath.length() > 0)) {
         KMeansUtil.configureWithClusterInfo(new Path(clusterPath), clusters);
         if (clusters.isEmpty()) {
           throw new IllegalStateException("Cluster is empty!");
         }
-      }
-      
+      }  
       this.clusterer = new KMeansClusterer(measure);
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
@@ -74,5 +73,4 @@ public class KMeansClusterMapper extends
       throw new IllegalStateException(e);
     }
   }
-  
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Wed Jun  9 21:18:37 2010
@@ -22,7 +22,8 @@ import java.util.List;
 
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.Vector;
@@ -60,10 +61,11 @@ public class KMeansClusterer {
    *          a point to find a cluster for.
    * @param clusters
    *          a List<Cluster> to test.
+   * @throws InterruptedException 
+   * @throws IOException 
    */
-  public void emitPointToNearestCluster(Vector point,
-                                        Iterable<Cluster> clusters,
-                                        OutputCollector<Text, KMeansInfo> output) throws IOException {
+  public void emitPointToNearestCluster(Vector point, List<Cluster> clusters,
+      Mapper<WritableComparable<?>, VectorWritable, Text, KMeansInfo>.Context context) throws IOException, InterruptedException {
     Cluster nearestCluster = null;
     double nearestDistance = Double.MAX_VALUE;
     for (Cluster cluster : clusters) {
@@ -77,12 +79,11 @@ public class KMeansClusterer {
         nearestDistance = distance;
       }
     }
-    // emit only clusterID
-    output.collect(new Text(nearestCluster.getIdentifier()), new KMeansInfo(1, point));
+    context.write(new Text(nearestCluster.getIdentifier()), new KMeansInfo(1, point));
   }
 
-  public void outputPointWithClusterInfo(Vector vector, Iterable<Cluster> clusters,
-      OutputCollector<IntWritable, WeightedVectorWritable> output) throws IOException {
+  public void outputPointWithClusterInfo(Vector vector, List<Cluster> clusters,
+      Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException {
     Cluster nearestCluster = null;
     double nearestDistance = Double.MAX_VALUE;
     for (Cluster cluster : clusters) {
@@ -93,8 +94,7 @@ public class KMeansClusterer {
         nearestDistance = distance;
       }
     }
-
-    output.collect(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, new VectorWritable(vector)));
+    context.write(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, new VectorWritable(vector)));
   }
 
   /**

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java Wed Jun  9 21:18:37 2010
@@ -20,24 +20,23 @@ import java.io.IOException;
 import java.util.Iterator;
 
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapreduce.Reducer;
 
-public class KMeansCombiner extends MapReduceBase implements Reducer<Text,KMeansInfo,Text,KMeansInfo> {
-  
+public class KMeansCombiner extends Reducer<Text, KMeansInfo, Text, KMeansInfo> {
+
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Reducer#reduce(java.lang.Object, java.lang.Iterable, org.apache.hadoop.mapreduce.Reducer.Context)
+   */
   @Override
-  public void reduce(Text key,
-                     Iterator<KMeansInfo> values,
-                     OutputCollector<Text,KMeansInfo> output,
-                     Reporter reporter) throws IOException {
+  protected void reduce(Text key, Iterable<KMeansInfo> values, Context context) throws IOException, InterruptedException {
+
     Cluster cluster = new Cluster(key.toString());
-    while (values.hasNext()) {
-      KMeansInfo next = values.next();
+    Iterator<KMeansInfo> it = values.iterator();
+    while (it.hasNext()) {
+      KMeansInfo next = it.next();
       cluster.addPoints(next.getPoints(), next.getPointTotal());
     }
-    output.collect(key, new KMeansInfo(cluster.getNumPoints(), cluster.getPointTotal()));
+    context.write(key, new KMeansInfo(cluster.getNumPoints(), cluster.getPointTotal()));
   }
-  
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Wed Jun  9 21:18:37 2010
@@ -24,6 +24,7 @@ import org.apache.commons.cli2.Option;
 import org.apache.commons.cli2.OptionException;
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
@@ -31,12 +32,11 @@ import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.Writable;
-import org.apache.hadoop.mapred.FileInputFormat;
-import org.apache.hadoop.mapred.FileOutputFormat;
-import org.apache.hadoop.mapred.JobClient;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.SequenceFileInputFormat;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.HadoopUtil;
@@ -56,8 +56,7 @@ public final class KMeansDriver {
     Option inputOpt = DefaultOptionCreator.inputOption().create();
     Option clustersOpt = DefaultOptionCreator.clustersInOption().withDescription(
         "The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
-            + "If k is also specified, then a random set of vectors will be selected"
-            + " and written out to this path first")
+            + "If k is also specified, then a random set of vectors will be selected" + " and written out to this path first")
         .create();
     Option kOpt = DefaultOptionCreator.kOption().withDescription(
         "The k in k-Means.  If specified, then a random selection of k Vectors will be chosen"
@@ -71,9 +70,8 @@ public final class KMeansDriver {
     Option clusteringOpt = DefaultOptionCreator.clusteringOption().create();
     Option helpOpt = DefaultOptionCreator.helpOption();
 
-    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(clustersOpt)
-        .withOption(outputOpt).withOption(measureClassOpt).withOption(convergenceDeltaOpt)
-        .withOption(maxIterationsOpt).withOption(numReduceTasksOpt)
+    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt)
+        .withOption(measureClassOpt).withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt)
         .withOption(kOpt).withOption(overwriteOutput).withOption(helpOpt).withOption(clusteringOpt).create();
     try {
       Parser parser = new Parser();
@@ -96,8 +94,7 @@ public final class KMeansDriver {
         HadoopUtil.overwriteOutput(output);
       }
       if (cmdLine.hasOption(kOpt)) {
-        clusters = RandomSeedGenerator.buildRandom(input, clusters,
-                                                   Integer.parseInt(cmdLine.getValue(kOpt).toString()));
+        clusters = RandomSeedGenerator.buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString()));
       }
       runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations, numReduceTasks, cmdLine
           .hasOption(clusteringOpt));
@@ -126,16 +123,18 @@ public final class KMeansDriver {
    *          the number of reducers
    * @param runClustering 
    *          true if points are to be clustered after iterations are completed
+   * @throws ClassNotFoundException 
+   * @throws InterruptedException 
    */
   public static void runJob(Path input, Path clustersIn, Path output, String measureClass, double convergenceDelta,
-      int maxIterations, int numReduceTasks, boolean runClustering) throws IOException {
+      int maxIterations, int numReduceTasks, boolean runClustering) throws IOException, InterruptedException,
+      ClassNotFoundException {
     // iterate until the clusters converge
     String delta = Double.toString(convergenceDelta);
     if (log.isInfoEnabled()) {
-      log.info("Input: {} Clusters In: {} Out: {} Distance: {}",
-               new Object[] {input, clustersIn, output, measureClass});
-      log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}",
-               new Object[] {convergenceDelta, maxIterations, numReduceTasks, VectorWritable.class.getName()});
+      log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] { input, clustersIn, output, measureClass });
+      log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}", new Object[] { convergenceDelta,
+          maxIterations, numReduceTasks, VectorWritable.class.getName() });
     }
     boolean converged = false;
     int iteration = 1;
@@ -171,40 +170,38 @@ public final class KMeansDriver {
    * @param numReduceTasks
    *          the number of reducer tasks
    * @return true if the iteration successfully runs
+   * @throws ClassNotFoundException 
+   * @throws InterruptedException 
    */
-  private static boolean runIteration(Path input,
-                                      Path clustersIn,
-                                      Path clustersOut,
-                                      String measureClass,
-                                      String convergenceDelta,
-                                      int numReduceTasks) throws IOException {
-    JobConf conf = new JobConf(KMeansDriver.class);
-    conf.setMapOutputKeyClass(Text.class);
-    conf.setMapOutputValueClass(KMeansInfo.class);
-    conf.setOutputKeyClass(Text.class);
-    conf.setOutputValueClass(Cluster.class);
-
-    FileInputFormat.setInputPaths(conf, input);
-    FileOutputFormat.setOutputPath(conf, clustersOut);
-    HadoopUtil.overwriteOutput(clustersOut);
-    conf.setInputFormat(SequenceFileInputFormat.class);
-    conf.setOutputFormat(SequenceFileOutputFormat.class);
-    conf.setMapperClass(KMeansMapper.class);
-    conf.setCombinerClass(KMeansCombiner.class);
-    conf.setReducerClass(KMeansReducer.class);
-    conf.setNumReduceTasks(numReduceTasks);
+  private static boolean runIteration(Path input, Path clustersIn, Path clustersOut, String measureClass, String convergenceDelta,
+      int numReduceTasks) throws IOException, InterruptedException, ClassNotFoundException {
+    Configuration conf = new Configuration();
     conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn.toString());
     conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measureClass);
     conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
 
-    try {
-      JobClient.runJob(conf);
-      FileSystem fs = FileSystem.get(clustersOut.toUri(), conf);
-      return isConverged(clustersOut, conf, fs);
-    } catch (IOException e) {
-      log.warn(e.toString(), e);
-      return true;
-    }
+    Job job = new Job(conf);
+
+    job.setMapOutputKeyClass(Text.class);
+    job.setMapOutputValueClass(KMeansInfo.class);
+    job.setOutputKeyClass(Text.class);
+    job.setOutputValueClass(Cluster.class);
+
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setMapperClass(KMeansMapper.class);
+    job.setCombinerClass(KMeansCombiner.class);
+    job.setReducerClass(KMeansReducer.class);
+    job.setNumReduceTasks(numReduceTasks);
+
+    FileInputFormat.addInputPath(job, input);
+    FileOutputFormat.setOutputPath(job, clustersOut);
+
+    HadoopUtil.overwriteOutput(clustersOut);
+    job.waitForCompletion(true);
+    FileSystem fs = FileSystem.get(clustersOut.toUri(), conf);
+
+    return isConverged(clustersOut, conf, fs);
   }
 
   /**
@@ -220,40 +217,35 @@ public final class KMeansDriver {
    *          the classname of the DistanceMeasure
    * @param convergenceDelta
    *          the convergence delta value
+   * @throws ClassNotFoundException 
+   * @throws InterruptedException 
    */
-  private static void runClustering(Path input,
-                                    Path clustersIn,
-                                    Path output,
-                                    String measureClass,
-                                    String convergenceDelta) throws IOException {
+  private static void runClustering(Path input, Path clustersIn, Path output, String measureClass, String convergenceDelta)
+      throws IOException, InterruptedException, ClassNotFoundException {
     if (log.isInfoEnabled()) {
       log.info("Running Clustering");
-      log.info("Input: {} Clusters In: {} Out: {} Distance: {}",
-               new Object[] {input, clustersIn, output, measureClass});
+      log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] { input, clustersIn, output, measureClass });
       log.info("convergence: {} Input Vectors: {}", convergenceDelta, VectorWritable.class.getName());
     }
-    JobConf conf = new JobConf(KMeansDriver.class);
-    conf.setInputFormat(SequenceFileInputFormat.class);
-    conf.setOutputFormat(SequenceFileOutputFormat.class);
+    Configuration conf = new Configuration();
+    conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn.toString());
+    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measureClass);
+    conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
 
-    conf.setOutputKeyClass(IntWritable.class);
-    conf.setOutputValueClass(WeightedVectorWritable.class);
+    Job job = new Job(conf);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(WeightedVectorWritable.class);
 
-    FileInputFormat.setInputPaths(conf, input);
+    FileInputFormat.setInputPaths(job, input);
     HadoopUtil.overwriteOutput(output);
-    FileOutputFormat.setOutputPath(conf, output);
+    FileOutputFormat.setOutputPath(job, output);
 
-    conf.setMapperClass(KMeansClusterMapper.class);
-    conf.setNumReduceTasks(0);
-    conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn.toString());
-    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measureClass);
-    conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
+    job.setMapperClass(KMeansClusterMapper.class);
+    job.setNumReduceTasks(0);
 
-    try {
-      JobClient.runJob(conf);
-    } catch (IOException e) {
-      log.warn(e.toString(), e);
-    }
+    job.waitForCompletion(true);
   }
 
   /**
@@ -269,7 +261,7 @@ public final class KMeansDriver {
    * @throws IOException
    *           if there was an IO error
    */
-  private static boolean isConverged(Path filePath, JobConf conf, FileSystem fs) throws IOException {
+  private static boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException {
     FileStatus[] parts = fs.listStatus(filePath);
     for (FileStatus part : parts) {
       String name = part.getPath().getName();

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java Wed Jun  9 21:18:37 2010
@@ -20,61 +20,50 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparable;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.Mapper;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 
-public class KMeansMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>,VectorWritable,Text,KMeansInfo> {
-  
+public class KMeansMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, KMeansInfo> {
+
   private KMeansClusterer clusterer;
+
   private final List<Cluster> clusters = new ArrayList<Cluster>();
-  
+
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Mapper#map(java.lang.Object, java.lang.Object, org.apache.hadoop.mapreduce.Mapper.Context)
+   */
   @Override
-  public void map(WritableComparable<?> key,
-                  VectorWritable point,
-                  OutputCollector<Text,KMeansInfo> output,
-                  Reporter reporter) throws IOException {
-    this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, output);
+  protected void map(WritableComparable<?> key, VectorWritable point, Context context) throws IOException, InterruptedException {
+    this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, context);
   }
-  
-  /**
-   * Configure the mapper by providing its clusters. Used by unit tests.
-   * 
-   * @param clusters
-   *          a List<Cluster>
+
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context)
    */
-  void config(List<Cluster> clusters) {
-    this.clusters.clear();
-    this.clusters.addAll(clusters);
-  }
-  
   @Override
-  public void configure(JobConf job) {
-    super.configure(job);
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
     try {
       ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-      Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+      Class<?> cl = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
       DistanceMeasure measure = (DistanceMeasure) cl.newInstance();
-      measure.configure(job);
-      
+      measure.configure(conf);
+
       this.clusterer = new KMeansClusterer(measure);
-      
-      String clusterPath = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
+
+      String clusterPath = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
       if ((clusterPath != null) && (clusterPath.length() > 0)) {
         KMeansUtil.configureWithClusterInfo(new Path(clusterPath), clusters);
         if (clusters.isEmpty()) {
           throw new IllegalStateException("Cluster is empty!");
         }
       }
-      
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
     } catch (IllegalAccessException e) {
@@ -83,4 +72,17 @@ public class KMeansMapper extends MapRed
       throw new IllegalStateException(e);
     }
   }
+
+  /**
+   * Configure the mapper by providing its clusters. Used by unit tests.
+   * 
+   * @param clusters
+   *          a List<Cluster>
+   * @param measure TODO
+   */
+  void setup(List<Cluster> clusters, DistanceMeasure measure) {
+    this.clusters.clear();
+    this.clusters.addAll(clusters);
+    this.clusterer = new KMeansClusterer(measure);
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Wed Jun  9 21:18:37 2010
@@ -23,55 +23,57 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapreduce.Reducer;
 import org.apache.mahout.common.distance.DistanceMeasure;
 
-public class KMeansReducer extends MapReduceBase implements Reducer<Text,KMeansInfo,Text,Cluster> {
-  
-  private Map<String,Cluster> clusterMap;
+public class KMeansReducer extends Reducer<Text, KMeansInfo, Text, Cluster> {
+
+  private Map<String, Cluster> clusterMap;
+
   private double convergenceDelta;
+
   private DistanceMeasure measure;
-  
+
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Reducer#reduce(java.lang.Object, java.lang.Iterable, org.apache.hadoop.mapreduce.Reducer.Context)
+   */
   @Override
-  public void reduce(Text key,
-                     Iterator<KMeansInfo> values,
-                     OutputCollector<Text,Cluster> output,
-                     Reporter reporter) throws IOException {
+  protected void reduce(Text key, Iterable<KMeansInfo> values, Context context) throws IOException, InterruptedException {
     Cluster cluster = clusterMap.get(key.toString());
-    
-    while (values.hasNext()) {
-      KMeansInfo delta = values.next();
+    Iterator<KMeansInfo> it = values.iterator();
+    while (it.hasNext()) {
+      KMeansInfo delta = it.next();
       cluster.addPoints(delta.getPoints(), delta.getPointTotal());
     }
     // force convergence calculation
     boolean converged = cluster.computeConvergence(this.measure, this.convergenceDelta);
     if (converged) {
-      reporter.incrCounter("Clustering", "Converged Clusters", 1);
+      //context.getCounter("Clustering", "Converged Clusters").increment(1);
     }
-    output.collect(new Text(cluster.getIdentifier()), cluster);
+    context.write(new Text(cluster.getIdentifier()), cluster);
   }
-  
+
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.mapreduce.Reducer#setup(org.apache.hadoop.mapreduce.Reducer.Context)
+   */
   @Override
-  public void configure(JobConf job) {
-    
-    super.configure(job);
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
     try {
       ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-      Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+      Class<?> cl = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
       this.measure = (DistanceMeasure) cl.newInstance();
-      this.measure.configure(job);
-      
-      this.convergenceDelta = Double.parseDouble(job.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
-      
-      this.clusterMap = new HashMap<String,Cluster>();
-      
-      String path = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
+      this.measure.configure(conf);
+
+      this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
+
+      this.clusterMap = new HashMap<String, Cluster>();
+
+      String path = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
       if (path.length() > 0) {
         List<Cluster> clusters = new ArrayList<Cluster>();
         KMeansUtil.configureWithClusterInfo(new Path(path), clusters);
@@ -88,18 +90,18 @@ public class KMeansReducer extends MapRe
       throw new IllegalStateException(e);
     }
   }
-  
+
   private void setClusterMap(List<Cluster> clusters) {
-    clusterMap = new HashMap<String,Cluster>();
+    clusterMap = new HashMap<String, Cluster>();
     for (Cluster cluster : clusters) {
       clusterMap.put(cluster.getIdentifier(), cluster);
     }
     clusters.clear();
   }
-  
-  public void config(List<Cluster> clusters) {
+
+  public void setup(List<Cluster> clusters, DistanceMeasure measure) {
     setClusterMap(clusters);
-    
+    this.measure = measure;
   }
-  
+
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Wed Jun  9 21:18:37 2010
@@ -29,12 +29,13 @@ import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapreduce.Job;
 import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.MockMapperContext;
+import org.apache.mahout.clustering.MockReducerContext;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.canopy.CanopyDriver;
 import org.apache.mahout.common.DummyOutputCollector;
-import org.apache.mahout.common.DummyReporter;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
@@ -47,15 +48,11 @@ import org.apache.mahout.math.VectorWrit
 
 public class TestKmeansClustering extends MahoutTestCase {
 
-  public static final double[][] reference =
-      { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 },
-        { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
-
-  private static final int[][] expectedNumPoints = {
-      { 9 }, { 4, 5 }, { 4, 4, 1 },
-      { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 },
-      { 1, 1, 1, 1, 1, 4 }, { 1, 1, 1, 1, 1, 2, 2 },
-      { 1, 1, 1, 1, 1, 1, 2, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
+  public static final double[][] reference = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 },
+      { 5, 5 } };
+
+  private static final int[][] expectedNumPoints = { { 9 }, { 4, 5 }, { 4, 4, 1 }, { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 },
+      { 1, 1, 1, 1, 1, 4 }, { 1, 1, 1, 1, 1, 2, 2 }, { 1, 1, 1, 1, 1, 1, 2, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
 
   private FileSystem fs;
 
@@ -149,15 +146,16 @@ public class TestKmeansClustering extend
   /** Story: test that the mapper will map input points to the nearest cluster */
   public void testKMeansMapper() throws Exception {
     KMeansMapper mapper = new KMeansMapper();
-    JobConf conf = new JobConf();
-    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
+    EuclideanDistanceMeasure measure = new EuclideanDistanceMeasure();
+    Configuration conf = new Configuration();
+    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measure.getClass().getName());
     conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, "0.001");
     conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, "");
-    mapper.configure(conf);
     List<VectorWritable> points = getPointsWritable(reference);
     for (int k = 0; k < points.size(); k++) {
       // pick k initial cluster centers at random
       DummyOutputCollector<Text, KMeansInfo> collector = new DummyOutputCollector<Text, KMeansInfo>();
+      MockMapperContext<Text, KMeansInfo> context = new MockMapperContext<Text, KMeansInfo>(mapper, conf, collector);
       List<Cluster> clusters = new ArrayList<Cluster>();
 
       for (int i = 0; i < k + 1; i++) {
@@ -166,15 +164,15 @@ public class TestKmeansClustering extend
         cluster.addPoint(cluster.getCenter());
         clusters.add(cluster);
       }
-      mapper.config(clusters);
+      mapper.setup(clusters, measure);
 
       // map the data
       for (VectorWritable point : points) {
-        mapper.map(new Text(), point, collector, null);
+        mapper.map(new Text(), point, context);
       }
       assertEquals("Number of map results", k + 1, collector.getData().size());
       // now verify that all points are correctly allocated
-      EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
+      EuclideanDistanceMeasure euclideanDistanceMeasure = measure;
       Map<String, Cluster> clusterMap = loadClusterMap(clusters);
       for (Text key : collector.getKeys()) {
         Cluster cluster = clusterMap.get(key.toString());
@@ -195,15 +193,16 @@ public class TestKmeansClustering extend
    */
   public void testKMeansCombiner() throws Exception {
     KMeansMapper mapper = new KMeansMapper();
-    JobConf conf = new JobConf();
-    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
+    EuclideanDistanceMeasure measure = new EuclideanDistanceMeasure();
+    Configuration conf = new Configuration();
+    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measure.getClass().getName());
     conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, "0.001");
     conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, "");
-    mapper.configure(conf);
     List<VectorWritable> points = getPointsWritable(reference);
     for (int k = 0; k < points.size(); k++) {
       // pick k initial cluster centers at random
-      DummyOutputCollector<Text, KMeansInfo> collector = new DummyOutputCollector<Text, KMeansInfo>();
+      DummyOutputCollector<Text, KMeansInfo> mapCollector = new DummyOutputCollector<Text, KMeansInfo>();
+      MockMapperContext<Text, KMeansInfo> mapContext = new MockMapperContext<Text, KMeansInfo>(mapper, conf, mapCollector);
       List<Cluster> clusters = new ArrayList<Cluster>();
       for (int i = 0; i < k + 1; i++) {
         Vector vec = points.get(i).get();
@@ -213,24 +212,26 @@ public class TestKmeansClustering extend
         cluster.addPoint(cluster.getCenter());
         clusters.add(cluster);
       }
-      mapper.config(clusters);
+      mapper.setup(clusters, measure);
       // map the data
       for (VectorWritable point : points) {
-        mapper.map(new Text(), point, collector, null);
+        mapper.map(new Text(), point, mapContext);
       }
       // now combine the data
       KMeansCombiner combiner = new KMeansCombiner();
-      DummyOutputCollector<Text, KMeansInfo> collector2 = new DummyOutputCollector<Text, KMeansInfo>();
-      for (Text key : collector.getKeys()) {
-        combiner.reduce(new Text(key), collector.getValue(key).iterator(), collector2, null);
+      DummyOutputCollector<Text, KMeansInfo> combineCollector = new DummyOutputCollector<Text, KMeansInfo>();
+      MockReducerContext<Text, KMeansInfo> combineContext = new MockReducerContext<Text, KMeansInfo>(combiner, conf, combineCollector,
+          Text.class, KMeansInfo.class);
+      for (Text key : mapCollector.getKeys()) {
+        combiner.reduce(new Text(key), mapCollector.getValue(key), combineContext);
       }
 
-      assertEquals("Number of map results", k + 1, collector2.getData().size());
+      assertEquals("Number of map results", k + 1, combineCollector.getData().size());
       // now verify that all points are accounted for
       int count = 0;
       Vector total = new DenseVector(2);
-      for (Text key : collector2.getKeys()) {
-        List<KMeansInfo> values = collector2.getValue(key);
+      for (Text key : combineCollector.getKeys()) {
+        List<KMeansInfo> values = combineCollector.getValue(key);
         assertEquals("too many values", 1, values.size());
         // String value = values.get(0).toString();
         KMeansInfo info = values.get(0);
@@ -250,17 +251,17 @@ public class TestKmeansClustering extend
    */
   public void testKMeansReducer() throws Exception {
     KMeansMapper mapper = new KMeansMapper();
-    EuclideanDistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
-    JobConf conf = new JobConf();
-    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
+    EuclideanDistanceMeasure measure = new EuclideanDistanceMeasure();
+    Configuration conf = new Configuration();
+    conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measure.getClass().getName());
     conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, "0.001");
     conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, "");
-    mapper.configure(conf);
     List<VectorWritable> points = getPointsWritable(reference);
     for (int k = 0; k < points.size(); k++) {
       System.out.println("K = " + k);
       // pick k initial cluster centers at random
-      DummyOutputCollector<Text, KMeansInfo> collector = new DummyOutputCollector<Text, KMeansInfo>();
+      DummyOutputCollector<Text, KMeansInfo> mapCollector = new DummyOutputCollector<Text, KMeansInfo>();
+      MockMapperContext<Text, KMeansInfo> mapContext = new MockMapperContext<Text, KMeansInfo>(mapper, conf, mapCollector);
       List<Cluster> clusters = new ArrayList<Cluster>();
       for (int i = 0; i < k + 1; i++) {
         Vector vec = points.get(i).get();
@@ -269,28 +270,31 @@ public class TestKmeansClustering extend
         // cluster.addPoint(cluster.getCenter());
         clusters.add(cluster);
       }
-      mapper.config(clusters);
+      mapper.setup(clusters, new EuclideanDistanceMeasure());
       // map the data
       for (VectorWritable point : points) {
-        mapper.map(new Text(), point, collector, null);
+        mapper.map(new Text(), point, mapContext);
       }
       // now combine the data
       KMeansCombiner combiner = new KMeansCombiner();
-      DummyOutputCollector<Text, KMeansInfo> collector2 = new DummyOutputCollector<Text, KMeansInfo>();
-      for (Text key : collector.getKeys()) {
-        combiner.reduce(new Text(key), collector.getValue(key).iterator(), collector2, null);
+      DummyOutputCollector<Text, KMeansInfo> combineCollector = new DummyOutputCollector<Text, KMeansInfo>();
+      MockReducerContext<Text, KMeansInfo> combineContext = new MockReducerContext<Text, KMeansInfo>(combiner, conf,
+          combineCollector, Text.class, KMeansInfo.class);
+      for (Text key : mapCollector.getKeys()) {
+        combiner.reduce(new Text(key), mapCollector.getValue(key), combineContext);
       }
 
       // now reduce the data
       KMeansReducer reducer = new KMeansReducer();
-      reducer.configure(conf);
-      reducer.config(clusters);
-      DummyOutputCollector<Text, Cluster> collector3 = new DummyOutputCollector<Text, Cluster>();
-      for (Text key : collector2.getKeys()) {
-        reducer.reduce(new Text(key), collector2.getValue(key).iterator(), collector3, new DummyReporter());
+      reducer.setup(clusters, measure);
+      DummyOutputCollector<Text, Cluster> reduceCollector = new DummyOutputCollector<Text, Cluster>();
+      MockReducerContext<Text, Cluster> reduceContext = new MockReducerContext<Text, Cluster>(reducer, conf, reduceCollector,
+          Text.class, Cluster.class);
+      for (Text key : combineCollector.getKeys()) {
+        reducer.reduce(new Text(key), combineCollector.getValue(key), reduceContext);
       }
 
-      assertEquals("Number of map results", k + 1, collector3.getData().size());
+      assertEquals("Number of map results", k + 1, reduceCollector.getData().size());
 
       // compute the reference result after one iteration and compare
       List<Cluster> reference = new ArrayList<Cluster>();
@@ -302,7 +306,7 @@ public class TestKmeansClustering extend
       for (VectorWritable point : points) {
         pointsVectors.add((Vector) point.get());
       }
-      boolean converged = KMeansClusterer.runKMeansIteration(pointsVectors, reference, euclideanDistanceMeasure, 0.001);
+      boolean converged = KMeansClusterer.runKMeansIteration(pointsVectors, reference, measure, 0.001);
       if (k == 8) {
         assertTrue("not converged? " + k, converged);
       } else {
@@ -313,7 +317,7 @@ public class TestKmeansClustering extend
       converged = true;
       for (Cluster ref : reference) {
         String key = ref.getIdentifier();
-        List<Cluster> values = collector3.getValue(new Text(key));
+        List<Cluster> values = reduceCollector.getValue(new Text(key));
         Cluster cluster = values.get(0);
         converged = converged && cluster.isConverged();
         // Since we aren't roundtripping through Writable, we need to compare the reference center with the
@@ -341,10 +345,9 @@ public class TestKmeansClustering extend
     for (int k = 1; k < points.size(); k++) {
       System.out.println("testKMeansMRJob k= " + k);
       // pick k initial cluster centers at random
-      JobConf job = new JobConf(KMeansDriver.class);
       Path path = new Path(clustersPath, "part-00000");
-      FileSystem fs = FileSystem.get(path.toUri(), job);
-      SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, Text.class, Cluster.class);
+      FileSystem fs = FileSystem.get(path.toUri(), conf);
+      SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, Text.class, Cluster.class);
 
       for (int i = 0; i < k + 1; i++) {
         Vector vec = points.get(i).get();
@@ -357,12 +360,11 @@ public class TestKmeansClustering extend
       writer.close();
       // now run the Job
       Path outputPath = getTestTempDirPath("output");
-      KMeansDriver.runJob(pointsPath, clustersPath, outputPath, EuclideanDistanceMeasure.class.getName(), 0.001, 10,
-          k + 1, true);
+      KMeansDriver.runJob(pointsPath, clustersPath, outputPath, EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, true);
       // now compare the expected clusters with actual
       Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
       // assertEquals("output dir files?", 4, outFiles.length);
-      SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-00000"), conf);
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-00000"), conf);
       int[] expect = expectedNumPoints[k];
       DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
       // The key is the clusterId
@@ -376,7 +378,7 @@ public class TestKmeansClustering extend
       }
       reader.close();
       if (k == 2) {
-      // cluster 3 is empty so won't appear in output
+        // cluster 3 is empty so won't appear in output
         assertEquals("clusters[" + k + ']', expect.length - 1, collector.getKeys().size());
       } else {
         assertEquals("clusters[" + k + ']', expect.length, collector.getKeys().size());
@@ -398,15 +400,15 @@ public class TestKmeansClustering extend
     CanopyDriver.runJob(pointsPath, outputPath, ManhattanDistanceMeasure.class.getName(), 3.1, 2.1, false);
 
     // now run the KMeans job
-    KMeansDriver.runJob(pointsPath, new Path(outputPath, "clusters-0"), outputPath,
-                        EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, true);
+    KMeansDriver.runJob(pointsPath, new Path(outputPath, "clusters-0"), outputPath, EuclideanDistanceMeasure.class.getName(),
+        0.001, 10, 1, true);
 
     // now compare the expected clusters with actual
     Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
     //String[] outFiles = outDir.list();
     //assertEquals("output dir files?", 4, outFiles.length);
     DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
-    SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-00000"), conf);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-00000"), conf);
 
     // The key is the clusterId
     IntWritable clusterId = new IntWritable(0);

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=953164&r1=953163&r2=953164&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Wed Jun  9 21:18:37 2010
@@ -28,8 +28,6 @@ import org.apache.commons.cli2.builder.D
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.mapred.JobClient;
-import org.apache.hadoop.mapred.JobConf;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.canopy.CanopyDriver;
 import org.apache.mahout.clustering.kmeans.KMeansDriver;
@@ -63,21 +61,19 @@ public final class Job {
         "The Distance Measure to use.  Default is SquaredEuclidean").withShortName("m").create();
 
     Option t1Opt = obuilder.withLongName("t1").withRequired(false).withArgument(
-        abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).withDescription("The t1 value to use.")
-        .withShortName("m").create();
+        abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).withDescription("The t1 value to use.").withShortName("m")
+        .create();
     Option t2Opt = obuilder.withLongName("t2").withRequired(false).withArgument(
-        abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).withDescription("The t2 value to use.")
-        .withShortName("m").create();
+        abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).withDescription("The t2 value to use.").withShortName("m")
+        .create();
     Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
         abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Vector implementation class name.  Default is RandomAccessSparseVector.class")
-        .withShortName("v").create();
+        "The Vector implementation class name.  Default is RandomAccessSparseVector.class").withShortName("v").create();
 
     Option helpOpt = DefaultOptionCreator.helpOption();
 
-    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt)
-        .withOption(measureClassOpt).withOption(convergenceDeltaOpt).withOption(maxIterationsOpt)
-        .withOption(vectorClassOpt).withOption(t1Opt).withOption(t2Opt)
+    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(measureClassOpt).withOption(
+        convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(vectorClassOpt).withOption(t1Opt).withOption(t2Opt)
         .withOption(helpOpt).create();
     try {
       Parser parser = new Parser();
@@ -90,8 +86,8 @@ public final class Job {
       }
       Path input = new Path(cmdLine.getValue(inputOpt, "testdata").toString());
       Path output = new Path(cmdLine.getValue(outputOpt, "output").toString());
-      String measureClass = cmdLine.getValue(measureClassOpt,
-                                             "org.apache.mahout.common.distance.EuclideanDistanceMeasure").toString();
+      String measureClass = cmdLine.getValue(measureClassOpt, "org.apache.mahout.common.distance.EuclideanDistanceMeasure")
+          .toString();
       double t1 = Double.parseDouble(cmdLine.getValue(t1Opt, "80").toString());
       double t2 = Double.parseDouble(cmdLine.getValue(t2Opt, "55").toString());
       double convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt, "0.5").toString());
@@ -132,11 +128,8 @@ public final class Job {
    * @throws InterruptedException 
    */
   private static void runJob(Path input, Path output, String measureClass, double t1, double t2, double convergenceDelta,
-      int maxIterations) throws IOException, InstantiationException, IllegalAccessException, InterruptedException, ClassNotFoundException {
-    JobClient client = new JobClient();
-    JobConf conf = new JobConf(Job.class);
-
-    client.setConf(conf);
+      int maxIterations) throws IOException, InstantiationException, IllegalAccessException, InterruptedException,
+      ClassNotFoundException {
     HadoopUtil.overwriteOutput(output);
 
     Path directoryContainingConvertedInput = new Path(output, Constants.DIRECTORY_CONTAINING_CONVERTED_INPUT);
@@ -145,13 +138,7 @@ public final class Job {
     log.info("Running Canopy to get initial clusters");
     CanopyDriver.runJob(directoryContainingConvertedInput, output, measureClass, t1, t2, false);
     log.info("Running KMeans");
-    KMeansDriver.runJob(directoryContainingConvertedInput,
-                        new Path(output, Cluster.INITIAL_CLUSTERS_DIR),
-                        output,
-                        measureClass,
-                        convergenceDelta,
-                        maxIterations,
-                        1,
-                        true);
+    KMeansDriver.runJob(directoryContainingConvertedInput, new Path(output, Cluster.INITIAL_CLUSTERS_DIR), output, measureClass,
+        convergenceDelta, maxIterations, 1, true);
   }
 }



Mime
View raw message