mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gsing...@apache.org
Subject svn commit: r755548 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/kmeans/ core/src/test/java/org/apache/mahout/clustering/kmeans/ examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/
Date Wed, 18 Mar 2009 11:07:16 GMT
Author: gsingers
Date: Wed Mar 18 11:07:16 2009
New Revision: 755548

URL: http://svn.apache.org/viewvc?rev=755548&view=rev
Log:
MAHOUT-99: Fix k-means speed issue

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -14,7 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.mahout.clustering.kmeans;
 
 import org.apache.hadoop.io.Text;
@@ -64,6 +63,7 @@
    * Format the cluster for output
    * 
    * @param cluster the Cluster
+   * @return
    */
   public static String formatCluster(Cluster cluster) {
     return cluster.getIdentifier() + ": "
@@ -73,8 +73,7 @@
   /**
    * Decodes and returns a Cluster from the formattedString
    * 
-   * @param formattedString
-   *            a String produced by formatCluster
+   * @param formattedString a String produced by formatCluster
    * @return a new Canopy
    */
   public static Cluster decodeCluster(String formattedString) {
@@ -83,8 +82,8 @@
     String center = formattedString.substring(beginIndex);
     char firstChar = id.charAt(0);
     boolean startsWithV = firstChar == 'V';
-    if (firstChar == 'C' || startsWithV) {
-      int clusterId = Integer.parseInt(formattedString.substring(1, beginIndex - 2));
+     if (firstChar == 'C' || startsWithV) {
+      int clusterId = Integer.parseInt(formattedString.substring(1, beginIndex - 2));    
       Vector clusterCenter = AbstractVector.decodeVector(center);
       Cluster cluster = new Cluster(clusterCenter, clusterId);
       cluster.converged = startsWithV;
@@ -96,12 +95,11 @@
   /**
    * Configure the distance measure from the job
    * 
-   * @param job
-   *            the JobConf for the job
+   * @param job the JobConf for the job
    */
   public static void configure(JobConf job) {
     try {
-      ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+      final ClassLoader ccl = Thread.currentThread().getContextClassLoader();
       Class<?> cl = ccl.loadClass(job.get(DISTANCE_MEASURE_KEY));
       measure = (DistanceMeasure) cl.newInstance();
       measure.configure(job);
@@ -119,10 +117,8 @@
   /**
    * Configure the distance measure directly. Used by unit tests.
    * 
-   * @param aMeasure
-   *            the DistanceMeasure
-   * @param aConvergenceDelta
-   *            the delta value used to define convergence
+   * @param aMeasure the DistanceMeasure
+   * @param aConvergenceDelta the delta value used to define convergence
    */
   public static void config(DistanceMeasure aMeasure, double aConvergenceDelta) {
     measure = aMeasure;
@@ -133,15 +129,11 @@
   /**
    * Emit the point to the nearest cluster center
    * 
-   * @param point
-   *            a point
-   * @param clusters
-   *            a List<Cluster> to test
-   * @param values
-   *            a Writable containing the input point and possible other values
-   *            of interest (payload)
-   * @param output
-   *            the OutputCollector to emit into
+   * @param point a point
+   * @param clusters a List<Cluster> to test
+   * @param values a Writable containing the input point and possible other
+   *        values of interest (payload)
+   * @param output the OutputCollector to emit into
    * @throws IOException
    */
   public static void emitPointToNearestCluster(Vector point,
@@ -156,7 +148,26 @@
         nearestDistance = distance;
       }
     }
-    output.collect(new Text(formatCluster(nearestCluster)), values);
+    // emit only clusterID
+    String outKey = nearestCluster.getIdentifier();
+    String value = "1\t" + values.toString();
+    output.collect(new Text(outKey), new Text(value));
+  }
+
+  public static void outputPointWithClusterInfo(String key, Vector point,
+      List<Cluster> clusters, Text values, OutputCollector<Text, Text> output)
+      throws IOException {
+    Cluster nearestCluster = null;
+    double nearestDistance = Double.MAX_VALUE;
+    for (Cluster cluster : clusters) {
+      double distance = measure.distance(point, cluster.getCenter());
+      if (nearestCluster == null || distance < nearestDistance) {
+        nearestCluster = cluster;
+        nearestDistance = distance;
+      }
+    }
+    output.collect(new Text(key), new Text(Integer
+        .toString(nearestCluster.clusterId)));
   }
 
   /**
@@ -177,10 +188,10 @@
   /**
    * Construct a new cluster with the given point as its center
    * 
-   * @param center
-   *            the center point
+   * @param center the center point
    */
   public Cluster(Vector center) {
+    super();
     this.clusterId = nextClusterId++;
     this.center = center;
     this.numPoints = 0;
@@ -190,16 +201,28 @@
   /**
    * Construct a new cluster with the given point as its center
    * 
-   * @param center
-   *            the center point
+   * @param center the center point
    */
   public Cluster(Vector center, int clusterId) {
+    super();
     this.clusterId = clusterId;
     this.center = center;
     this.numPoints = 0;
     this.pointTotal = center.like();
   }
 
+  /**
+   * Construct a new clsuter with the given id as identifier
+   * 
+   * @param identifier
+   */
+  public Cluster(String clusterId) {
+
+    this.clusterId = Integer.parseInt((clusterId.substring(1)));
+    this.numPoints = 0;
+    this.converged = clusterId.startsWith("V");
+  }
+
   @Override
   public String toString() {
     return getIdentifier() + " - " + center.asFormatString();
@@ -215,25 +238,17 @@
   /**
    * Add the point to the cluster
    * 
-   * @param point
-   *            a point to add
+   * @param point a point to add
    */
   public void addPoint(Vector point) {
-    centroid = null;
-    numPoints++;
-    if (pointTotal == null)
-      pointTotal = point.copy();
-    else
-      pointTotal = point.plus(pointTotal);
+    addPoints(1, point);
   }
 
   /**
    * Add the point to the cluster
    * 
-   * @param count
-   *            the number of points in the delta
-   * @param delta
-   *            a point to add
+   * @param count the number of points in the delta
+   * @param delta a point to add
    */
   public void addPoints(int count, Vector delta) {
     centroid = null;
@@ -241,7 +256,7 @@
     if (pointTotal == null)
       pointTotal = delta.copy();
     else
-      pointTotal = delta.plus(pointTotal);
+      pointTotal = pointTotal.plus(delta);
   }
 
   public Vector getCenter() {

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java?rev=755548&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java Wed Mar 18 11:07:16 2009
@@ -0,0 +1,36 @@
+package org.apache.mahout.clustering.kmeans;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.matrix.AbstractVector;
+import org.apache.mahout.matrix.Vector;
+
+public class KMeansClusterMapper extends KMeansMapper {
+  public void map(WritableComparable<?> key, Text values,
+      OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+    Vector point = AbstractVector.decodeVector(values.toString());
+    Cluster.outputPointWithClusterInfo(values.toString(), point, clusters,
+        values, output);
+  }
+
+}

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -14,9 +14,11 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.Iterator;
+
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.MapReduceBase;
@@ -25,20 +27,19 @@
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.mahout.matrix.AbstractVector;
 
-import java.io.IOException;
-import java.util.Iterator;
-
 public class KMeansCombiner extends MapReduceBase implements
     Reducer<Text, Text, Text, Text> {
 
   @Override
   public void reduce(Text key, Iterator<Text> values,
       OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
-    Cluster cluster = Cluster.decodeCluster(key.toString());
+    Cluster cluster = new Cluster(key.toString());
     while (values.hasNext()) {
-      cluster.addPoint(AbstractVector.decodeVector(values.next().toString()));
+      String[] numPointnValue = values.next().toString().split("\t");
+      cluster.addPoints(Integer.parseInt(numPointnValue[0].trim()),
+          AbstractVector.decodeVector(numPointnValue[1].trim()));
     }
-    output.collect(key, new Text(cluster.getNumPoints() + ", "
+    output.collect(key, new Text(cluster.getNumPoints() + "\t"
         + cluster.getPointTotal().asFormatString()));
   }
 

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -14,23 +14,28 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.fs.PathFilter;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.FileInputFormat;
 import org.apache.hadoop.mapred.FileOutputFormat;
+import org.apache.hadoop.mapred.FileSplit;
 import org.apache.hadoop.mapred.JobClient;
 import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapred.KeyValueLineRecordReader;
+import org.apache.hadoop.mapred.TextInputFormat;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-
 public class KMeansDriver {
 
   private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
@@ -45,21 +50,23 @@
     String measureClass = args[3];
     double convergenceDelta = Double.parseDouble(args[4]);
     int maxIterations = Integer.parseInt(args[5]);
-    runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations);
+    runJob(input, clusters, output, measureClass, convergenceDelta,
+        maxIterations, 2);
   }
 
   /**
    * Run the job using supplied arguments
-   *
-   * @param input            the directory pathname for input points
-   * @param clustersIn       the directory pathname for initial & computed clusters
-   * @param output           the directory pathname for output points
-   * @param measureClass     the classname of the DistanceMeasure
+   * 
+   * @param input the directory pathname for input points
+   * @param clustersIn the directory pathname for initial & computed clusters
+   * @param output the directory pathname for output points
+   * @param measureClass the classname of the DistanceMeasure
    * @param convergenceDelta the convergence delta value
-   * @param maxIterations    the maximum number of iterations
+   * @param maxIterations the maximum number of iterations
    */
   public static void runJob(String input, String clustersIn, String output,
-                            String measureClass, double convergenceDelta, int maxIterations) {
+      String measureClass, double convergenceDelta, int maxIterations,
+      int numCentroids) {
     // iterate until the clusters converge
     boolean converged = false;
     int iteration = 0;
@@ -70,32 +77,32 @@
       // point the output to a new directory per iteration
       String clustersOut = output + "/clusters-" + iteration;
       converged = runIteration(input, clustersIn, clustersOut, measureClass,
-              delta);
+          delta, numCentroids);
       // now point the input to the old output directory
       clustersIn = output + "/clusters-" + iteration;
       iteration++;
     }
     // now actually cluster the points
     log.info("Clustering ");
-    runClustering(input, clustersIn, output + "/points", measureClass,
-            delta);
+    runClustering(input, clustersIn, output + "/points", measureClass, delta);
   }
 
   /**
    * Run the job using supplied arguments
-   *
-   * @param input            the directory pathname for input points
-   * @param clustersIn       the directory pathname for iniput clusters
-   * @param clustersOut      the directory pathname for output clusters
-   * @param measureClass     the classname of the DistanceMeasure
+   * 
+   * @param input the directory pathname for input points
+   * @param clustersIn the directory pathname for iniput clusters
+   * @param clustersOut the directory pathname for output clusters
+   * @param measureClass the classname of the DistanceMeasure
    * @param convergenceDelta the convergence delta value
    * @return true if the iteration successfully runs
    */
   private static boolean runIteration(String input, String clustersIn,
-                              String clustersOut, String measureClass, String convergenceDelta) {
+      String clustersOut, String measureClass, String convergenceDelta,
+      int numReduceTasks) {
     JobClient client = new JobClient();
     JobConf conf = new JobConf(KMeansDriver.class);
-
+    conf.setInputFormat(TextInputFormat.class);
     conf.setOutputKeyClass(Text.class);
     conf.setOutputValueClass(Text.class);
 
@@ -106,12 +113,16 @@
     conf.setMapperClass(KMeansMapper.class);
     conf.setCombinerClass(KMeansCombiner.class);
     conf.setReducerClass(KMeansReducer.class);
-    conf.setNumReduceTasks(1);
-    conf.setOutputFormat(SequenceFileOutputFormat.class);
+    // conf.setNumMapTasks(numMapTasks);
+    conf.setNumReduceTasks(numReduceTasks);
     conf.set(Cluster.CLUSTER_PATH_KEY, clustersIn);
     conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
     conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
 
+    conf.set("mapred.child.java.opts", "-Xmx1536m");
+    // uncomment it to run locally
+    conf.set("mapred.job.tracker", "local");
+
     client.setConf(conf);
     try {
       JobClient.runJob(conf);
@@ -125,15 +136,15 @@
 
   /**
    * Run the job using supplied arguments
-   *
-   * @param input            the directory pathname for input points
-   * @param clustersIn       the directory pathname for input clusters
-   * @param output           the directory pathname for output points
-   * @param measureClass     the classname of the DistanceMeasure
+   * 
+   * @param input the directory pathname for input points
+   * @param clustersIn the directory pathname for input clusters
+   * @param output the directory pathname for output points
+   * @param measureClass the classname of the DistanceMeasure
    * @param convergenceDelta the convergence delta value
    */
-  private static void runClustering(String input, String clustersIn, String output,
-                            String measureClass, String convergenceDelta) {
+  private static void runClustering(String input, String clustersIn,
+      String output, String measureClass, String convergenceDelta) {
     JobClient client = new JobClient();
     JobConf conf = new JobConf(KMeansDriver.class);
 
@@ -144,13 +155,16 @@
     Path outPath = new Path(output);
     FileOutputFormat.setOutputPath(conf, outPath);
 
-    conf.setMapperClass(KMeansMapper.class);
+    conf.setMapperClass(KMeansClusterMapper.class);
     conf.setNumReduceTasks(0);
     conf.set(Cluster.CLUSTER_PATH_KEY, clustersIn);
     conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
     conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
 
     client.setConf(conf);
+    // uncomment it to run locally
+    // conf.set("mapred.job.tracker", "local");
+    conf.set("mapred.child.java.opts", "-Xmx1536m");
     try {
       JobClient.runJob(conf);
     } catch (IOException e) {
@@ -160,23 +174,52 @@
 
   /**
    * Return if all of the Clusters in the filePath have converged or not
-   *
+   * 
    * @param filePath the file path to the single file containing the clusters
-   * @param conf     the JobConf
-   * @param fs       the FileSystem
+   * @param conf the JobConf
+   * @param fs the FileSystem
    * @return true if all Clusters are converged
    * @throws IOException if there was an IO error
    */
-  private static boolean isConverged(String filePath, JobConf conf, FileSystem fs)
-          throws IOException {
-    Path outPart = new Path(filePath);
-    SequenceFile.Reader reader = new SequenceFile.Reader(fs, outPart, conf);
-    Text key = new Text();
-    Text value = new Text();
+  private static boolean isConverged(String filePath, JobConf conf,
+      FileSystem fs) throws IOException {
+    Path clusterPath = new Path(filePath);
+    List<Path> result = new ArrayList<Path>();
+
+    PathFilter clusterFileFilter = new PathFilter() {
+      public boolean accept(Path path) {
+        return path.getName().startsWith("part");
+      }
+    };
+
+    FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(
+        clusterPath, clusterFileFilter)), clusterFileFilter);
+
+    for (FileStatus match : matches) {
+      result.add(fs.makeQualified(match.getPath()));
+    }
     boolean converged = true;
-    while (converged && reader.next(key, value)) {
-      converged = value.toString().charAt(0) == 'V';
+
+    for (Path p : result) {
+      KeyValueLineRecordReader reader = null;
+
+      try {
+        reader = new KeyValueLineRecordReader(conf, new FileSplit(p, 0, fs
+            .getFileStatus(p).getLen(), (String[]) null));
+        Text key = new Text();
+        Text value = new Text();
+
+        while (converged && reader.next(key, value)) {
+          converged = value.toString().startsWith("V");
+        }
+      } finally {
+        if (reader != null) {
+          reader.close();
+        }
+      }
+
     }
+
     return converged;
   }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,11 @@
 /**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -17,26 +18,37 @@
 
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.mapred.JobConf;
 
-import java.io.IOException;
-
 public class KMeansJob {
 
   private KMeansJob() {
   }
 
   public static void main(String[] args) throws IOException {
-    String input = args[0];
-    String clusters = args[1];
-    String output = args[2];
-    String measureClass = args[3];
-    double convergenceDelta = Double.parseDouble(args[4]);
-    int maxIterations = Integer.parseInt(args[5]);
+    int index = 0;
+
+    if (args.length != 7) {
+      System.out.println("Expected number of arguments 10 and received:"
+          + args.length);
+      System.out
+          .println("Usage:input clustersIn output measureClass convergenceDelta maxIterations numCentroids");
+      System.exit(1);
+    }
+    String input = args[index++];
+    String clusters = args[index++];
+    String output = args[index++];
+    String measureClass = args[index++];
+    double convergenceDelta = Double.parseDouble(args[index++]);
+    int maxIterations = Integer.parseInt(args[index++]);
+    int numCentroids = Integer.parseInt(args[index++]);
+
     runJob(input, clusters, output, measureClass, convergenceDelta,
-        maxIterations);
+        maxIterations, numCentroids);
   }
 
   /**
@@ -51,7 +63,8 @@
    * @param maxIterations the maximum number of iterations
    */
   public static void runJob(String input, String clustersIn, String output,
-      String measureClass, double convergenceDelta, int maxIterations) throws IOException {
+      String measureClass, double convergenceDelta, int maxIterations,
+      int numCentroids) throws IOException {
     // delete the output directory
     JobConf conf = new JobConf(KMeansJob.class);
     Path outPath = new Path(output);
@@ -60,7 +73,8 @@
       fs.delete(outPath, true);
     }
     fs.mkdirs(outPath);
+
     KMeansDriver.runJob(input, clustersIn, output, measureClass,
-        convergenceDelta, maxIterations);
+        convergenceDelta, maxIterations, numCentroids);
   }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java Wed Mar 18 11:07:16 2009
@@ -14,12 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.mahout.clustering.kmeans;
 
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapred.JobConf;
@@ -30,57 +30,43 @@
 import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.Vector;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
 public class KMeansMapper extends MapReduceBase implements
-        Mapper<WritableComparable<?>, Text, Text, Text> {
+    Mapper<WritableComparable<?>, Text, Text, Text> {
 
-  private List<Cluster> clusters;
+  protected List<Cluster> clusters;
 
   @Override
   public void map(WritableComparable<?> key, Text values,
-                  OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+      OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
     Vector point = AbstractVector.decodeVector(values.toString());
     Cluster.emitPointToNearestCluster(point, clusters, values, output);
   }
 
   /**
    * Configure the mapper by providing its clusters. Used by unit tests.
-   *
+   * 
    * @param clusters a List<Cluster>
    */
   void config(List<Cluster> clusters) {
     this.clusters = clusters;
   }
 
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.hadoop.mapred.MapReduceBase#configure(org.apache.hadoop.mapred.JobConf)
+   */
   @Override
   public void configure(JobConf job) {
     super.configure(job);
     Cluster.configure(job);
 
-    String clusterPath = job.get(Cluster.CLUSTER_PATH_KEY);
     clusters = new ArrayList<Cluster>();
 
-    try {
-      FileSystem fs = FileSystem.get(job);
-      Path path = new Path(clusterPath + "/part-00000");
-      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
-      try {
-        Text key = new Text();
-        Text value = new Text();
-        while (reader.next(key, value)) {
-          Cluster cluster = Cluster.decodeCluster(value.toString());
-          // add the center so the centroid will be correct on output formatting
-          cluster.addPoint(cluster.getCenter());
-          clusters.add(cluster);
-        }
-      } finally {
-        reader.close();
-      }
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
+    KMeansUtil.configureWithClusterInfo(job.get(Cluster.CLUSTER_PATH_KEY),
+        clusters);
+
+    if (clusters.size() == 0)
+      throw new NullPointerException("Cluster is empty!!!");
   }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -14,9 +14,15 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.MapReduceBase;
@@ -24,37 +30,62 @@
 import org.apache.hadoop.mapred.Reducer;
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.mahout.matrix.AbstractVector;
-import org.apache.mahout.matrix.Vector;
-
-import java.io.IOException;
-import java.util.Iterator;
 
 public class KMeansReducer extends MapReduceBase implements
-        Reducer<Text, Text, Text, Text> {
+    Reducer<Text, Text, Text, Text> {
 
-  //double delta = 0;
+  protected Map<String, Cluster> clusterMap;
 
   @Override
   public void reduce(Text key, Iterator<Text> values,
-                     OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
-    Cluster cluster = Cluster.decodeCluster(key.toString());
+      OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+    Cluster cluster = clusterMap.get(key.toString());
+
     while (values.hasNext()) {
       String value = values.next().toString();
-      int ix = value.indexOf(',');
-      int count = Integer.parseInt(value.substring(0, ix));
-      Vector total = AbstractVector.decodeVector(value.substring(ix + 2));
-      cluster.addPoints(count, total);
+      String[] numNValue = value.split("\t");
+      cluster.addPoints(Integer.parseInt(numNValue[0].trim()), AbstractVector
+          .decodeVector(numNValue[1].trim()));
     }
     // force convergence calculation
     cluster.computeConvergence();
     output.collect(new Text(cluster.getIdentifier()), new Text(Cluster
-            .formatCluster(cluster)));
+        .formatCluster(cluster)));
   }
 
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.hadoop.mapred.MapReduceBase#configure(org.apache.hadoop.mapred.JobConf)
+   */
   @Override
   public void configure(JobConf job) {
+
     super.configure(job);
     Cluster.configure(job);
+    clusterMap = new HashMap<String, Cluster>();
+
+    List<Cluster> clusters = new ArrayList<Cluster>();
+    KMeansUtil.configureWithClusterInfo(job.get(Cluster.CLUSTER_PATH_KEY),
+        clusters);
+    setClusterMap(clusters);
+
+    if (clusterMap.size() == 0)
+      throw new NullPointerException("Cluster is empty!!!");
+  }
+
+  private void setClusterMap(List<Cluster> clusters) {
+    clusterMap = new HashMap<String, Cluster>();
+    for (Cluster cluster : clusters) {
+      clusterMap.put(cluster.getIdentifier(), cluster);
+    }
+    clusters.clear();
+    clusters = null;
+  }
+
+  public void config(List<Cluster> clusters) {
+    setClusterMap(clusters);
+
   }
 
 }

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java?rev=755548&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java Wed Mar 18 11:07:16 2009
@@ -0,0 +1,98 @@
+package org.apache.mahout.clustering.kmeans;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.FileSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.KeyValueLineRecordReader;
+import org.apache.hadoop.mapred.RecordReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class KMeansUtil {
+  private static final Logger log = LoggerFactory.getLogger(KMeansUtil.class);
+
+  /**
+   * Configure the mapper with the cluster info
+   * 
+   * @param job
+   * @param clusters
+   */
+  public static void configureWithClusterInfo(String clusterPathStr,
+      List<Cluster> clusters) {
+    // Get the path location where the cluster Info is stored
+    JobConf job = new JobConf(KMeansUtil.class);
+    Path clusterPath = new Path(clusterPathStr);
+    List<Path> result = new ArrayList<Path>();
+
+    // filter out the files
+    PathFilter clusterFileFilter = new PathFilter() {
+      public boolean accept(Path path) {
+        return path.getName().startsWith("part");
+      }
+    };
+
+    try {
+      // get all filtered file names in result list
+      FileSystem fs = clusterPath.getFileSystem(job);
+      FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(
+          clusterPath, clusterFileFilter)), clusterFileFilter);
+
+      for (FileStatus match : matches) {
+        result.add(fs.makeQualified(match.getPath()));
+      }
+
+      // iterate thru the result path list
+      for (Path path : result) {
+        RecordReader<Text, Text> recordReader = null;
+        try {
+          recordReader = new KeyValueLineRecordReader(job, new FileSplit(path,
+              0, fs.getFileStatus(path).getLen(), (String[]) null));
+          Text key = new Text();
+          Text value = new Text();
+          int counter = 1;
+          while (recordReader.next(key, value)) {
+            // get the cluster info
+            Cluster cluster = Cluster.decodeCluster(value.toString());
+            clusters.add(cluster);
+          }
+        } finally {
+          if (recordReader != null) {
+            recordReader.close();
+          }
+
+        }
+      }
+
+    } catch (IOException e) {
+      log.info("Exception occurred in loading clusters:", e);
+      e.printStackTrace();
+      throw new RuntimeException(e);
+    }
+  }
+
+}

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Wed Mar 18 11:07:16 2009
@@ -17,10 +17,24 @@
 
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
 import junit.framework.TestCase;
+
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapred.JobConf;
@@ -34,42 +48,28 @@
 import org.apache.mahout.utils.EuclideanDistanceMeasure;
 import org.apache.mahout.utils.ManhattanDistanceMeasure;
 
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.IOException;
-import java.io.FileOutputStream;
-import java.io.OutputStreamWriter;
-import java.io.InputStreamReader;
-import java.io.FileInputStream;
-import java.util.ArrayList;
-import java.util.List;
-import java.nio.charset.Charset;
-
 public class TestKmeansClustering extends TestCase {
 
-  public static final double[][] reference = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 },
-      { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
+  public static final double[][] reference = { { 1, 1 }, { 2, 1 }, { 1, 2 },
+      { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
 
-  public static final int[][] expectedNumPoints = { { 9 }, { 4, 5 }, { 4, 5, 0 },
-      { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 }, { 1, 1, 1, 1, 1, 4 },
+  public static final int[][] expectedNumPoints = { { 9 }, { 4, 5 },
+      { 4, 5, 0 }, { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 }, { 1, 1, 1, 1, 1, 4 },
       { 1, 1, 1, 1, 1, 2, 2 }, { 1, 1, 1, 1, 1, 1, 2, 1 },
       { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
 
-  private static void rmr(String path) throws Exception {
+  private void rmr(String path) throws Exception {
     File f = new File(path);
     if (f.exists()) {
       if (f.isDirectory()) {
         String[] contents = f.list();
-        for (String content : contents) {
-          rmr(f.toString() + File.separator + content);
-        }
+        for (int i = 0; i < contents.length; i++)
+          rmr(f.toString() + File.separator + contents[i]);
       }
       f.delete();
     }
   }
 
-  @Override
   protected void setUp() throws Exception {
     super.setUp();
     rmr("output");
@@ -81,16 +81,12 @@
    * over the points and clusters until their centers converge or until the
    * maximum number of iterations is exceeded.
    * 
-   * @param points
-   *            the input List<Vector> of points
-   * @param clusters
-   *            the initial List<Cluster> of clusters
-   * @param measure
-   *            the DistanceMeasure to use
-   * @param maxIter
-   *            the maximum number of iterations
+   * @param points the input List<Vector> of points
+   * @param clusters the initial List<Cluster> of clusters
+   * @param measure the DistanceMeasure to use
+   * @param maxIter the maximum number of iterations
    */
-  private static void referenceKmeans(List<Vector> points, List<Cluster> clusters,
+  private void referenceKmeans(List<Vector> points, List<Cluster> clusters,
       DistanceMeasure measure, int maxIter) {
     boolean converged = false;
     int iteration = 0;
@@ -103,16 +99,15 @@
    * Perform a single iteration over the points and clusters, assigning points
    * to clusters and returning if the iterations are completed.
    * 
-   * @param points
-   *            the List<Vector> having the input points
-   * @param clusters
-   *            the List<Cluster> clusters
-   * @param measure
-   *            a DistanceMeasure to use
+   * @param points the List<Vector> having the input points
+   * @param clusters the List<Cluster> clusters
+   * @param measure a DistanceMeasure to use
    * @return
    */
-  private static boolean iterateReference(List<Vector> points, List<Cluster> clusters,
+  private boolean iterateReference(List<Vector> points, List<Cluster> clusters,
       DistanceMeasure measure) {
+    boolean converged;
+    converged = true;
     // iterate through all points, assigning each to the nearest cluster
     for (Vector point : points) {
       Cluster closestCluster = null;
@@ -127,7 +122,6 @@
       closestCluster.addPoint(point);
     }
     // test for convergence
-    boolean converged = true;
     for (Cluster cluster : clusters) {
       if (!cluster.computeConvergence())
         converged = false;
@@ -141,7 +135,8 @@
 
   public static List<Vector> getPoints(double[][] raw) {
     List<Vector> points = new ArrayList<Vector>();
-    for (double[] fr : raw) {
+    for (int i = 0; i < raw.length; i++) {
+      double[] fr = raw[i];
       Vector vec = new SparseVector(fr.length);
       vec.assign(fr);
       points.add(vec);
@@ -160,7 +155,7 @@
     Cluster.config(measure, 0.001);
     // try all possible values of k
     for (int k = 0; k < points.size(); k++) {
-      System.out.println("Test k=" + (k + 1) + ':');
+      System.out.println("Test k=" + (k + 1) + ":");
       // pick k initial cluster centers at random
       List<Cluster> clusters = new ArrayList<Cluster>();
       for (int i = 0; i < k + 1; i++) {
@@ -179,6 +174,15 @@
     }
   }
 
+  private Map<String, Cluster> loadClusterMap(List<Cluster> clusters) {
+    Map<String, Cluster> clusterMap = new HashMap<String, Cluster>();
+
+    for (Cluster cluster : clusters) {
+      clusterMap.put(cluster.getIdentifier(), cluster);
+    }
+    return clusterMap;
+  }
+
   /**
    * Story: test that the mapper will map input points to the nearest cluster
    * 
@@ -193,12 +197,15 @@
       // pick k initial cluster centers at random
       DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
       List<Cluster> clusters = new ArrayList<Cluster>();
+
       for (int i = 0; i < k + 1; i++) {
         Cluster cluster = new Cluster(points.get(i));
         // add the center so the centroid will be correct upon output
         cluster.addPoint(cluster.getCenter());
         clusters.add(cluster);
       }
+
+      Map<String, Cluster> clusterMap = loadClusterMap(clusters);
       mapper.config(clusters);
       // map the data
       for (Vector point : points) {
@@ -208,10 +215,12 @@
       assertEquals("Number of map results", k + 1, collector.getData().size());
       // now verify that all points are correctly allocated
       for (String key : collector.getKeys()) {
-        Cluster cluster = Cluster.decodeCluster(key);
+        Cluster cluster = clusterMap.get(key);
         List<Text> values = collector.getValue(key);
         for (Writable value : values) {
-          Vector point = AbstractVector.decodeVector(value.toString());
+          String[] pointInfo = value.toString().split("\t");
+
+          Vector point = AbstractVector.decodeVector(pointInfo[1]);
           double distance = euclideanDistanceMeasure.distance(cluster
               .getCenter(), point);
           for (Cluster c : clusters)
@@ -266,10 +275,10 @@
         List<Text> values = collector2.getValue(key);
         assertEquals("too many values", 1, values.size());
         String value = values.get(0).toString();
-        int ix = value.indexOf(',');
-        count += Integer.parseInt(value.substring(0, ix));
-        total = total
-            .plus(AbstractVector.decodeVector(value.substring(ix + 2)));
+
+        String[] pointInfo = value.split("\t");
+        count += Integer.parseInt(pointInfo[0]);
+        total = total.plus(AbstractVector.decodeVector(pointInfo[1]));
       }
       assertEquals("total points", 9, count);
       assertEquals("point total[0]", 27, (int) total.get(0));
@@ -297,7 +306,7 @@
         Vector vec = points.get(i);
         Cluster cluster = new Cluster(vec, i);
         // add the center so the centroid will be correct upon output
-        cluster.addPoint(cluster.getCenter());
+        // cluster.addPoint(cluster.getCenter());
         clusters.add(cluster);
       }
       mapper.config(clusters);
@@ -315,6 +324,7 @@
 
       // now reduce the data
       KMeansReducer reducer = new KMeansReducer();
+      reducer.config(clusters);
       DummyOutputCollector<Text, Text> collector3 = new DummyOutputCollector<Text, Text>();
       for (String key : collector2.getKeys())
         reducer.reduce(new Text(key), collector2.getValue(key).iterator(),
@@ -337,7 +347,8 @@
 
       // now verify that all clusters have correct centers
       converged = true;
-      for (Cluster ref : reference) {
+      for (int i = 0; i < reference.size(); i++) {
+        Cluster ref = reference.get(i);
         String key = ref.getIdentifier();
         List<Text> values = collector3.getValue(key);
         String value = values.get(0).toString();
@@ -373,52 +384,52 @@
 
     writePointsToFile(points, "testdata/points/file1");
     writePointsToFile(points, "testdata/points/file2");
-    for (int k = 0; k < points.size(); k++) {
+    for (int k = 1; k < points.size(); k++) {
       System.out.println("testKMeansMRJob k= " + k);
       // pick k initial cluster centers at random
       JobConf job = new JobConf(KMeansDriver.class);
       FileSystem fs = FileSystem.get(job);
       Path path = new Path("testdata/clusters/part-00000");
-      SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
-          Text.class, Text.class);
+      BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(fs
+          .create(path)));
+
       for (int i = 0; i < k + 1; i++) {
         Vector vec = points.get(i);
 
-        Cluster cluster = new Cluster(vec);
+        Cluster cluster = new Cluster(vec, i);
         // add the center so the centroid will be correct upon output
         cluster.addPoint(cluster.getCenter());
-        writer.append(new Text(cluster.getIdentifier()), new Text(Cluster
-            .formatCluster(cluster)));
+        writer.write(cluster.getIdentifier() + "\t"
+            + Cluster.formatCluster(cluster) + "\n");
       }
       writer.close();
-
       // now run the Job
       KMeansJob.runJob("testdata/points", "testdata/clusters", "output",
-          EuclideanDistanceMeasure.class.getName(), 0.001, 10);
-
+          EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1);
       // now compare the expected clusters with actual
       File outDir = new File("output/points");
       assertTrue("output dir exists?", outDir.exists());
       String[] outFiles = outDir.list();
-      assertEquals("output dir files?", 4, outFiles.length);
-      BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(
-          "output/points/part-00000"), Charset.forName("UTF-8")));
+      // assertEquals("output dir files?", 4, outFiles.length);
+      BufferedReader reader = new BufferedReader(new InputStreamReader(
+          new FileInputStream("output/points/part-00000"), Charset
+              .forName("UTF-8")));
       int[] expect = expectedNumPoints[k];
       DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
       while (reader.ready()) {
         String line = reader.readLine();
         String[] lineParts = line.split("\t");
         assertEquals("line parts", 2, lineParts.length);
-        String cl = line.substring(0, line.indexOf(':'));
-        collector.collect(new Text(cl), new Text(lineParts[1]));
+        // String cl = line.substring(0, line.indexOf(':'));
+        collector.collect(new Text(lineParts[1]), new Text(lineParts[0]));
       }
       reader.close();
       if (k == 2)
         // cluster 3 is empty so won't appear in output
-        assertEquals("clusters[" + k + ']', expect.length - 1, collector
+        assertEquals("clusters[" + k + "]", expect.length - 1, collector
             .getKeys().size());
       else
-        assertEquals("clusters[" + k + ']', expect.length, collector.getKeys()
+        assertEquals("clusters[" + k + "]", expect.length, collector.getKeys()
             .size());
     }
   }
@@ -429,7 +440,7 @@
    * 
    * @throws Exception
    */
-  public static void textKMeansWithCanopyClusterInput() throws Exception {
+  public void textKMeansWithCanopyClusterInput() throws Exception {
     List<Vector> points = getPoints(reference);
     File testData = new File("testdata");
     if (!testData.exists())
@@ -446,15 +457,16 @@
 
     // now run the KMeans job
     KMeansJob.runJob("testdata/points", "testdata/canopies", "output",
-        EuclideanDistanceMeasure.class.getName(), 0.001, 10);
+        EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1);
 
     // now compare the expected clusters with actual
     File outDir = new File("output/points");
     assertTrue("output dir exists?", outDir.exists());
     String[] outFiles = outDir.list();
     assertEquals("output dir files?", 4, outFiles.length);
-    BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(
-        "output/points/part-00000"), Charset.forName("UTF-8")));
+    BufferedReader reader = new BufferedReader(new InputStreamReader(
+        new FileInputStream("output/points/part-00000"), Charset
+            .forName("UTF-8")));
     DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
     while (reader.ready()) {
       String line = reader.readLine();
@@ -470,7 +482,8 @@
 
   public static void writePointsToFileWithPayload(List<Vector> points,
       String fileName, String payload) throws IOException {
-    BufferedWriter output = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileName), Charset.forName("UTF-8")));
+    BufferedWriter output = new BufferedWriter(new OutputStreamWriter(
+        new FileOutputStream(fileName), Charset.forName("UTF-8")));
     for (Vector point : points) {
       output.write(point.asFormatString());
       output.write(payload);

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Wed Mar 18 11:07:16 2009
@@ -80,7 +80,7 @@
     CanopyClusteringJob
         .runJob(output + "/data", output, measureClass, t1, t2);
     KMeansDriver.runJob(output + "/data", output + "/canopies", output,
-        measureClass, convergenceDelta, maxIterations);
+        measureClass, convergenceDelta, maxIterations,1);
     OutputDriver.runJob(output + "/points", output + "/clustered-points");
   }
 }



Mime
View raw message