mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dfili...@apache.org
Subject svn commit: r1482907 - in /mahout/trunk: ./ core/ core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/ core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/ src/conf/
Date Wed, 15 May 2013 15:35:13 GMT
Author: dfilimon
Date: Wed May 15 15:35:12 2013
New Revision: 1482907

URL: http://svn.apache.org/r1482907
Log:
MAHOUT-1181: Adding StreamingKMeans MapReduce classes

These classes implement the MapReduce version of StreamingKMeans, add a driver
and a new command line tool.


Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
Modified:
    mahout/trunk/CHANGELOG
    mahout/trunk/core/pom.xml
    mahout/trunk/src/conf/driver.classes.default.props

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1482907&r1=1482906&r2=1482907&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Wed May 15 15:35:12 2013
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 0.8 - unreleased
 
+__MAHOUT-1181: Adding StreamingKMeans MapReduce classes (dfilimon)
+
   MAHOUT-1212: Incorrect classify-20newsgroups.sh file description (Julian Ortega via smarthi)
    
   MAHOUT-1209: DRY out maven-compiler-plugin configuration (Stevo Slavic via smarthi) 

Modified: mahout/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/mahout/trunk/core/pom.xml?rev=1482907&r1=1482906&r2=1482907&view=diff
==============================================================================
--- mahout/trunk/core/pom.xml (original)
+++ mahout/trunk/core/pom.xml Wed May 15 15:35:12 2013
@@ -187,6 +187,13 @@
       <artifactId>easymock</artifactId>
       <scope>test</scope>
     </dependency>
+
+    <dependency>
+      <groupId>org.apache.mrunit</groupId>
+      <artifactId>mrunit</artifactId>
+      <version>1.0.0</version>
+      <classifier>hadoop1</classifier>
+    </dependency>
   </dependencies>
   
   <profiles>

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java Wed May 15 15:35:12 2013
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+public class CentroidWritable implements Writable {
+  private Centroid centroid = null;
+
+  public CentroidWritable() {}
+
+  public CentroidWritable(Centroid centroid) {
+    this.centroid = centroid;
+  }
+
+  public Centroid getCentroid() {
+    return centroid;
+  }
+
+  @Override
+  public void write(DataOutput dataOutput) throws IOException {
+    dataOutput.writeInt(centroid.getIndex());
+    dataOutput.writeDouble(centroid.getWeight());
+    VectorWritable.writeVector(dataOutput, centroid.getVector());
+  }
+
+  @Override
+  public void readFields(DataInput dataInput) throws IOException {
+    if (centroid == null) {
+      centroid = read(dataInput);
+      return;
+    }
+    centroid.setIndex(dataInput.readInt());
+    centroid.setWeight(dataInput.readDouble());
+    centroid.assign(VectorWritable.readVector(dataInput));
+  }
+
+  public static Centroid read(DataInput dataInput) throws IOException {
+    int index = dataInput.readInt();
+    double weight = dataInput.readDouble();
+    Vector v = VectorWritable.readVector(dataInput);
+    return new Centroid(index, v, weight);
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (!(o instanceof CentroidWritable)) {
+      return false;
+    }
+    CentroidWritable writable = (CentroidWritable) o;
+    return centroid.equals(writable.centroid);
+  }
+
+  @Override
+  public int hashCode() {
+    return centroid.hashCode();
+  }
+
+  @Override
+  public String toString() {
+    return centroid.toString();
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java Wed May 15 15:35:12 2013
@@ -0,0 +1,474 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Classifies the vectors into different clusters found by the clustering
+ * algorithm.
+ */
+public final class StreamingKMeansDriver extends AbstractJob {
+  /**
+   * Streaming KMeans options
+   */
+  /**
+   * The number of cluster that Mappers will use should be \(O(k log n)\) where k is the number of clusters
+   * to get at the end and n is the number of points to cluster. This doesn't need to be exact.
+   * It will be adjusted at runtime.
+   */
+  public static final String ESTIMATED_NUM_MAP_CLUSTERS = "estimatedNumMapClusters";
+  /**
+   * The initial estimated distance cutoff between two points for forming new clusters.
+   * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans
+   * Defaults to 10e-6.
+   */
+  public static final String ESTIMATED_DISTANCE_CUTOFF = "estimatedDistanceCutoff";
+
+  /**
+   * Ball KMeans options
+   */
+  /**
+   * After mapping finishes, we get an intermediate set of vectors that represent approximate
+   * clusterings of the data from each Mapper. These can be clustered by the Reducer using
+   * BallKMeans in memory. This variable is the maximum number of iterations in the final
+   * BallKMeans algorithm.
+   * Defaults to 10.
+   */
+  public static final String MAX_NUM_ITERATIONS = "maxNumIterations";
+  /**
+   * The "ball" aspect of ball k-means means that only the closest points to the centroid will actually be used
+   * for updating. The fraction of the points to be used is those points whose distance to the center is within
+   * trimFraction * distance to the closest other center.
+   * Defaults to 0.9.
+   */
+  public static final String TRIM_FRACTION = "trimFraction";
+  /**
+   * Whether to use k-means++ initialization or random initialization of the seed centroids.
+   * Essentially, k-means++ provides better clusters, but takes longer, whereas random initialization takes less
+   * time, but produces worse clusters, and tends to fail more often and needs multiple runs to compare to
+   * k-means++. If set, uses randomInit.
+   * @see org.apache.mahout.clustering.streaming.cluster.BallKMeans
+   */
+  public static final String RANDOM_INIT = "randomInit";
+  /**
+   * Whether to correct the weights of the centroids after the clustering is done. The weights end up being wrong
+   * because of the trimFraction and possible train/test splits. In some cases, especially in a pipeline, having
+   * an accurate count of the weights is useful. If set, ignores the final weights.
+   */
+  public static final String IGNORE_WEIGHTS = "ignoreWeights";
+  /**
+   * The percentage of points that go into the "test" set when evaluating BallKMeans runs in the reducer.
+   */
+  public static final String TEST_PROBABILITY = "testProbability";
+  /**
+   * The percentage of points that go into the "training" set when evaluating BallKMeans runs in the reducer.
+   */
+  public static final String NUM_BALLKMEANS_RUNS = "numBallKMeansRuns";
+
+  /**
+   Searcher options
+   */
+  /**
+   * The Searcher class when performing nearest neighbor search in StreamingKMeans.
+   * Defaults to ProjectionSearch.
+   */
+  public static final String SEARCHER_CLASS_OPTION = "searcherClass";
+  /**
+   * The number of projections to use when using a projection searcher like ProjectionSearch or
+   * FastProjectionSearch. Projection searches work by projection the all the vectors on to a set of
+   * basis vectors and searching for the projected query in that totally ordered set. This
+   * however can produce false positives (vectors that are closer when projected than they would
+   * actually be.
+   * So, there must be more than one projection vectors in the basis. This variable is the number
+   * of vectors in a basis.
+   * Defaults to 3
+   */
+  public static final String NUM_PROJECTIONS_OPTION = "numProjections";
+  /**
+   * When using approximate searches (anything that's not BruteSearch),
+   * more than just the seemingly closest element must be considered. This variable has different
+   * meanings depending on the actual Searcher class used but is a measure of how many candidates
+   * will be considered.
+   * See the ProjectionSearch, FastProjectionSearch, LocalitySensitiveHashSearch classes for more
+   * details.
+   * Defaults to 2.
+   */
+  public static final String SEARCH_SIZE_OPTION = "searchSize";
+
+  private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class);
+
+  private static final double INVALID_DISTANCE_CUTOFF = -1;
+
+  @Override
+  public int run(String[] args) throws Exception {
+    // Standard options for any Mahout job.
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.overwriteOption().create());
+
+    // The number of clusters to create for the data.
+    addOption(DefaultOptionCreator.numClustersOption().withDescription(
+        "The k in k-Means. Approximately this many clusters will be generated.").create());
+
+    // StreamingKMeans (mapper) options
+    // There will be k final clusters, but in the Map phase to get a good approximation of the data, O(k log n)
+    // clusters are needed. Since n is the number of data points and not knowable until reading all the vectors,
+    // provide a decent estimate.
+    addOption(ESTIMATED_NUM_MAP_CLUSTERS, "km", "The estimated number of clusters to use for the " +
+        "Map phase of the job when running StreamingKMeans. This should be around k * log(n), " +
+        "where k is the final number of clusters and n is the total number of data points to " +
+        "cluster.");
+
+    addOption(ESTIMATED_DISTANCE_CUTOFF, "e", "The initial estimated distance cutoff between two " +
+        "points for forming new clusters. If no value is given, it's estimated from the data set",
+        String.valueOf(INVALID_DISTANCE_CUTOFF));
+
+    // BallKMeans (reducer) options
+    addOption(MAX_NUM_ITERATIONS, "mi", "The maximum number of iterations to run for the " +
+        "BallKMeans algorithm used by the reducer. If no value is given, defaults to 10.", String.valueOf(10));
+
+    addOption(TRIM_FRACTION, "tf", "The 'ball' aspect of ball k-means means that only the closest points " +
+        "to the centroid will actually be used for updating. The fraction of the points to be used is those " +
+        "points whose distance to the center is within trimFraction * distance to the closest other center. " +
+        "If no value is given, defaults to 0.9.", String.valueOf(0.9));
+
+    addFlag(RANDOM_INIT, "ri", "Whether to use k-means++ initialization or random initialization " +
+        "of the seed centroids. Essentially, k-means++ provides better clusters, but takes longer, whereas random " +
+        "initialization takes less time, but produces worse clusters, and tends to fail more often and needs " +
+        "multiple runs to compare to k-means++. If set, uses the random initialization.");
+
+    addFlag(IGNORE_WEIGHTS, "iw", "Whether to correct the weights of the centroids after the clustering is done. " +
+        "The weights end up being wrong because of the trimFraction and possible train/test splits. In some cases, " +
+        "especially in a pipeline, having an accurate count of the weights is useful. If set, ignores the final " +
+        "weights");
+
+    addOption(TEST_PROBABILITY, "testp", "A double value between 0 and 1 that represents the percentage of " +
+        "points to be used for 'testing' different clustering runs in the final BallKMeans " +
+        "step. If no value is given, defaults to 0.1", String.valueOf(0.1));
+
+    addOption(NUM_BALLKMEANS_RUNS, "nbkm", "Number of BallKMeans runs to use at the end to try to cluster the " +
+        "points. If no value is given, defaults to 4", String.valueOf(4));
+
+    // Nearest neighbor search options
+    // The distance measure used for computing the distance between two points. Generally, the
+    // SquaredEuclideanDistance is used for clustering problems (it's equivalent to CosineDistance for normalized
+    // vectors).
+    // WARNING! You can use any metric but most of the literature is for the squared euclidean distance.
+    addOption(DefaultOptionCreator.distanceMeasureOption().create());
+
+    // The default searcher should be something more efficient that BruteSearch (ProjectionSearch, ...). See
+    // o.a.m.math.neighborhood.*
+    addOption(SEARCHER_CLASS_OPTION, "sc", "The type of searcher to be used when performing nearest " +
+        "neighbor searches. Defaults to ProjectionSearch.", ProjectionSearch.class.getCanonicalName());
+
+    // In the original paper, the authors used 1 projection vector.
+    addOption(NUM_PROJECTIONS_OPTION, "np", "The number of projections considered in estimating the " +
+        "distances between vectors. Only used when the distance measure requested is either " +
+        "ProjectionSearch or FastProjectionSearch. If no value is given, defaults to 3.", String.valueOf(3));
+
+    addOption(SEARCH_SIZE_OPTION, "s", "In more efficient searches (non BruteSearch), " +
+        "not all distances are calculated for determining the nearest neighbors. The number of " +
+        "elements whose distances from the query vector is actually computer is proportional to " +
+        "searchSize. If no value is given, defaults to 1.", String.valueOf(2));
+
+    addOption(DefaultOptionCreator.methodOption().create());
+
+    if (parseArguments(args) == null) {
+      return -1;
+    }
+    Path output = getOutputPath();
+    if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+      HadoopUtil.delete(getConf(), output);
+    }
+    configureOptionsForWorkers();
+    run(getConf(), getInputPath(), output);
+    return 0;
+  }
+
+  private void configureOptionsForWorkers() throws ClassNotFoundException {
+    log.info("Starting to configure options for workers");
+
+    String method = getOption(DefaultOptionCreator.METHOD_OPTION);
+
+    int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
+
+    // StreamingKMeans
+    int estimatedNumMapClusters = Integer.parseInt(getOption(ESTIMATED_NUM_MAP_CLUSTERS));
+    float estimatedDistanceCutoff = Float.parseFloat(getOption(ESTIMATED_DISTANCE_CUTOFF));
+
+    // BallKMeans
+    int maxNumIterations = Integer.parseInt(getOption(MAX_NUM_ITERATIONS));
+    float trimFraction = Float.parseFloat(getOption(TRIM_FRACTION));
+    boolean randomInit = hasOption(RANDOM_INIT);
+    boolean ignoreWeights = hasOption(IGNORE_WEIGHTS);
+    float testProbability = Float.parseFloat(getOption(TEST_PROBABILITY));
+    int numBallKMeansRuns = Integer.parseInt(getOption(NUM_BALLKMEANS_RUNS));
+
+    // Nearest neighbor search
+    String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+    String searcherClass = getOption(SEARCHER_CLASS_OPTION);
+
+    // Get more parameters depending on the kind of search class we're working with. BruteSearch
+    // doesn't need anything else.
+    // LocalitySensitiveHashSearch and ProjectionSearches need searchSize.
+    // ProjectionSearches also need the number of projections.
+    boolean getSearchSize = false;
+    boolean getNumProjections = false;
+    if (!searcherClass.equals(BruteSearch.class.getName())) {
+      getSearchSize = true;
+      getNumProjections = true;
+    }
+
+    // The search size to use. This is quite fuzzy and might end up not being configurable at all.
+    int searchSize = 0;
+    if (getSearchSize) {
+      searchSize = Integer.parseInt(getOption(SEARCH_SIZE_OPTION));
+    }
+
+    // The number of projections to use. This is only useful in projection searches which
+    // project the vectors on multiple basis vectors to get distance estimates that are faster to
+    // calculate.
+    int numProjections = 0;
+    if (getNumProjections) {
+      numProjections = Integer.parseInt(getOption(NUM_PROJECTIONS_OPTION));
+    }
+
+    configureOptionsForWorkers(getConf(), numClusters,
+        /* StreamingKMeans */
+        estimatedNumMapClusters,  estimatedDistanceCutoff,
+        /* BallKMeans */
+        maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns,
+        /* Searcher */
+        measureClass, searcherClass,  searchSize, numProjections,
+        method);
+  }
+
+  /**
+   * Checks the parameters for a StreamingKMeans job and prepares a Configuration with them.
+   *
+   * @param conf the Configuration to populate
+   * @param numClusters k, the number of clusters at the end
+   * @param estimatedNumMapClusters O(k log n), the number of clusters requested from each mapper
+   * @param estimatedDistanceCutoff an estimate of the minimum distance that separates two clusters (can be smaller and
+   *                                will be increased dynamically)
+   * @param maxNumIterations the maximum number of iterations of BallKMeans
+   * @param trimFraction the fraction of the points to be considered in updating a ball k-means
+   * @param randomInit whether to initialize the ball k-means seeds randomly
+   * @param ignoreWeights whether to ignore the invalid final ball k-means weights
+   * @param testProbability the percentage of vectors assigned to the test set for selecting the best final centers
+   * @param numBallKMeansRuns the number of BallKMeans runs in the reducer that determine the centroids to return
+   *                          (clusters are computed for the training set and the error is computed on the test set)
+   * @param measureClass string, name of the distance measure class; theory works for Euclidean-like distances
+   * @param searcherClass string, name of the searcher that will be used for nearest neighbor search
+   * @param searchSize the number of closest neighbors to look at for selecting the closest one in approximate nearest
+   *                   neighbor searches
+   * @param numProjections the number of projected vectors to use for faster searching (only useful for ProjectionSearch
+   *                       or FastProjectionSearch); @see org.apache.mahout.math.neighborhood.ProjectionSearch
+   */
+  public static void configureOptionsForWorkers(Configuration conf,
+                                                int numClusters,
+                                                /* StreamingKMeans */
+                                                int estimatedNumMapClusters, float estimatedDistanceCutoff,
+                                                /* BallKMeans */
+                                                int maxNumIterations, float trimFraction, boolean randomInit,
+                                                boolean ignoreWeights, float testProbability, int numBallKMeansRuns,
+                                                /* Searcher */
+                                                String measureClass, String searcherClass,
+                                                int searchSize, int numProjections,
+                                                String method) throws ClassNotFoundException {
+    // Checking preconditions for the parameters.
+    Preconditions.checkArgument(numClusters > 0, "Invalid number of clusters requested");
+
+    // StreamingKMeans
+    Preconditions.checkArgument(estimatedNumMapClusters > numClusters, "Invalid number of estimated map " +
+        "clusters; There must be more than the final number of clusters (k log n vs k)");
+    Preconditions.checkArgument(estimatedDistanceCutoff == INVALID_DISTANCE_CUTOFF || estimatedDistanceCutoff > 0,
+        "estimatedDistanceCutoff cannot be negative");
+
+    // BallKMeans
+    Preconditions.checkArgument(maxNumIterations > 0, "Must have at least one BallKMeans iteration");
+    Preconditions.checkArgument(trimFraction > 0, "trimFraction must be positive");
+    Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "test probability is not in the " +
+        "interval [0, 1)");
+    Preconditions.checkArgument(numBallKMeansRuns > 0, "numBallKMeans cannot be negative");
+
+    // Searcher
+    if (!searcherClass.contains("Brute")) {
+      // These tests only make sense when a relevant searcher is being used.
+      Preconditions.checkArgument(searchSize > 0, "Invalid searchSize. Must be positive.");
+      if (searcherClass.contains("Projection")) {
+        Preconditions.checkArgument(numProjections > 0, "Invalid numProjections. Must be positive");
+      }
+    }
+
+    // Setting the parameters in the Configuration.
+    conf.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, numClusters);
+    /* StreamingKMeans */
+    conf.setInt(ESTIMATED_NUM_MAP_CLUSTERS, estimatedNumMapClusters);
+    if (estimatedDistanceCutoff != INVALID_DISTANCE_CUTOFF) {
+      conf.setFloat(ESTIMATED_DISTANCE_CUTOFF, estimatedDistanceCutoff);
+    }
+    /* BallKMeans */
+    conf.setInt(MAX_NUM_ITERATIONS, maxNumIterations);
+    conf.setFloat(TRIM_FRACTION, trimFraction);
+    conf.setBoolean(RANDOM_INIT, randomInit);
+    conf.setBoolean(IGNORE_WEIGHTS, ignoreWeights);
+    conf.setFloat(TEST_PROBABILITY, testProbability);
+    conf.setInt(NUM_BALLKMEANS_RUNS, numBallKMeansRuns);
+    /* Searcher */
+    // Checks if the measureClass is available, throws exception otherwise.
+    Class.forName(measureClass);
+    conf.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, measureClass);
+    // Checks if the searcherClass is available, throws exception otherwise.
+    Class.forName(searcherClass);
+    conf.set(SEARCHER_CLASS_OPTION, searcherClass);
+    conf.setInt(SEARCH_SIZE_OPTION, searchSize);
+    conf.setInt(NUM_PROJECTIONS_OPTION, numProjections);
+    conf.set(DefaultOptionCreator.METHOD_OPTION, method);
+    log.info("Parameters are: [k] numClusters {}; " +
+        "[SKM] estimatedNumMapClusters {}; estimatedDistanceCutoff {} " +
+        "[BKM] maxNumIterations {}; trimFraction {}; randomInit {}; ignoreWeights {}; " +
+        "testProbability {}; numBallKMeansRuns {}; " +
+        "[S] measureClass {}; searcherClass {}; searcherSize {}; numProjections {}; " +
+        "method {}", numClusters, estimatedNumMapClusters, estimatedDistanceCutoff,
+        maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns,
+        measureClass, searcherClass, searchSize, numProjections, method);
+  }
+
+  /**
+   * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to
+   * cluster the input vectors.
+   *
+   * @param input the directory pathname for input points.
+   * @param output the directory pathname for output points.
+   * @return 0 on success, -1 on failure.
+   */
+  @SuppressWarnings("unchecked")
+  public static int run(Configuration conf, Path input, Path output)
+      throws IOException, InterruptedException, ClassNotFoundException, ExecutionException {
+    log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}",
+        input.toString(), output.toString());
+
+    if (conf.get(DefaultOptionCreator.METHOD_OPTION,
+        DefaultOptionCreator.MAPREDUCE_METHOD).equals(DefaultOptionCreator.SEQUENTIAL_METHOD)) {
+      return runSequentially(conf, input, output);
+    } else {
+      return runMapReduce(conf, input, output);
+    }
+  }
+
+  private static int runSequentially(Configuration conf, Path input, Path output)
+      throws IOException, ExecutionException, InterruptedException {
+    long start = System.currentTimeMillis();
+    // Run StreamingKMeans step in parallel by spawning 1 thread per input path to process.
+    ExecutorService pool = Executors.newCachedThreadPool();
+    List<Future<Iterable<Centroid>>> intermediateCentroidFutures = Lists.newArrayList();
+    for (FileStatus status : HadoopUtil.listStatus(FileSystem.get(conf), input)) {
+      intermediateCentroidFutures.add(pool.submit(new StreamingKMeansThread(status.getPath(), conf)));
+    }
+    log.info("Finished running Mappers");
+    // Merge the resulting "mapper" centroids.
+    List<Centroid> intermediateCentroids = Lists.newArrayList();
+    for (Future<Iterable<Centroid>> futureIterable : intermediateCentroidFutures) {
+      for (Centroid centroid : futureIterable.get()) {
+        intermediateCentroids.add(centroid);
+      }
+    }
+    pool.shutdown();
+    pool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
+    log.info("Finished StreamingKMeans");
+    SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, output, IntWritable.class,
+        CentroidWritable.class);
+    int numCentroids = 0;
+    // Run BallKMeans on the intermediate centroids.
+    for (Vector finalVector : StreamingKMeansReducer.getBestCentroids(intermediateCentroids, conf)) {
+      Centroid finalCentroid = (Centroid)finalVector;
+      writer.append(new IntWritable(numCentroids++), new CentroidWritable(finalCentroid));
+    }
+    writer.close();
+    long end = System.currentTimeMillis();
+    log.info("Finished BallKMeans. Took {}.", (end - start) / 1000.0);
+    return 0;
+  }
+
+  @SuppressWarnings("unchecked")
+  public static int runMapReduce(Configuration conf, Path input, Path output) throws IOException, ClassNotFoundException, InterruptedException {
+    // Prepare Job for submission.
+    Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class,
+        StreamingKMeansMapper.class, IntWritable.class, CentroidWritable.class,
+        StreamingKMeansReducer.class, IntWritable.class, CentroidWritable.class, SequenceFileOutputFormat.class,
+        conf);
+    job.setJobName(HadoopUtil.getCustomJobName(StreamingKMeansDriver.class.getSimpleName(), job,
+        StreamingKMeansMapper.class, StreamingKMeansReducer.class));
+
+    // There is only one reducer so that the intermediate centroids get collected on one
+    // machine and are clustered in memory to get the right number of clusters.
+    job.setNumReduceTasks(1);
+
+    // Set the JAR (so that the required libraries are available) and run.
+    job.setJarByClass(StreamingKMeansDriver.class);
+
+    // Run job!
+    long start = System.currentTimeMillis();
+    if (!job.waitForCompletion(true)) {
+      return -1;
+    }
+    long end = System.currentTimeMillis();
+
+    log.info("StreamingKMeans clustering complete. Results are in {}. Took {} ms", output.toString(), end - start);
+    return 0;
+  }
+
+  /**
+   * Constructor to be used by the ToolRunner.
+   */
+  private StreamingKMeansDriver() {}
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new StreamingKMeansDriver(), args);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java Wed May 15 15:35:12 2013
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+
+public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> {
+  /**
+   * The clusterer object used to cluster the points received by this mapper online.
+   */
+  private StreamingKMeans clusterer;
+
+  /**
+   * Number of points clustered so far.
+   */
+  private int numPoints = 0;
+
+  @Override
+  public void setup(Context context) {
+    // At this point the configuration received from the Driver is assumed to be valid.
+    // No other checks are made.
+    Configuration conf = context.getConfiguration();
+    UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
+    int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+    // There is no way of estimating the distance cutoff unless we have some data.
+    clusterer = new StreamingKMeans(searcher, numClusters,
+        conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, 1e-4f));
+  }
+
+  @Override
+  public void map(Writable key, VectorWritable point, Context context) {
+    clusterer.cluster(new Centroid(numPoints++, point.get().clone(), 1));
+  }
+
+  @Override
+  public void cleanup(Context context) throws IOException, InterruptedException {
+    // Reindex the centroids before passing them to the reducer.
+    clusterer.reindexCentroids();
+    // All outputs have the same key to go to the same final reducer.
+    for (Centroid centroid : clusterer) {
+      context.write(new IntWritable(0), new CentroidWritable(centroid));
+    }
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java Wed May 15 15:35:12 2013
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.streaming.cluster.BallKMeans;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+
+public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> {
+  /**
+   * Configuration for the MapReduce job.
+   */
+  private Configuration conf;
+
+  @Override
+  public void setup(Context context) {
+    // At this point the configuration received from the Driver is assumed to be valid.
+    // No other checks are made.
+    conf = context.getConfiguration();
+  }
+
+  @Override
+  public void reduce(IntWritable key, Iterable<CentroidWritable> centroids,
+                     Context context) throws IOException, InterruptedException {
+    int index = 0;
+    for (Vector centroid : getBestCentroids(centroidWritablesToList(centroids), conf)) {
+      context.write(new IntWritable(index), new CentroidWritable((Centroid)centroid));
+      ++index;
+    }
+  }
+
+  public List<Centroid> centroidWritablesToList(Iterable<CentroidWritable> centroids) {
+    // A new list must be created because Hadoop iterators mutate the contents of the Writable in
+    // place, without allocating new references when iterating through the centroids Iterable.
+    return Lists.newArrayList(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+      @Override
+      public Centroid apply(CentroidWritable input) {
+        Preconditions.checkNotNull(input);
+        return input.getCentroid().clone();
+      }
+    }));
+  }
+
+  public static Iterable<Vector> getBestCentroids(List<Centroid> centroids, Configuration conf) {
+    int numClusters = conf.getInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1);
+    int maxNumIterations = conf.getInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, 10);
+    float trimFraction = conf.getFloat(StreamingKMeansDriver.TRIM_FRACTION, 0.9f);
+    boolean kMeansPlusPlusInit = !conf.getBoolean(StreamingKMeansDriver.RANDOM_INIT, false);
+    boolean correctWeights = !conf.getBoolean(StreamingKMeansDriver.IGNORE_WEIGHTS, false);
+    float testProbability = conf.getFloat(StreamingKMeansDriver.TEST_PROBABILITY, 0.1f);
+    int numRuns = conf.getInt(StreamingKMeansDriver.NUM_BALLKMEANS_RUNS, 3);
+
+    BallKMeans clusterer = new BallKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(conf),
+        numClusters, maxNumIterations, trimFraction, kMeansPlusPlusInit, correctWeights, testProbability, numRuns);
+    return clusterer.cluster(centroids);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java Wed May 15 15:35:12 2013
@@ -0,0 +1,43 @@
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.util.concurrent.Callable;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+
+public class StreamingKMeansThread implements Callable<Iterable<Centroid>> {
+  private Configuration conf;
+  private Iterable<Centroid> datapoints;
+
+  public StreamingKMeansThread(Path input, Configuration conf) {
+    this.datapoints = StreamingKMeansUtilsMR.getCentroidsFromVectorWritable(new SequenceFileValueIterable<VectorWritable>(input, false, conf));
+    this.conf = conf;
+  }
+
+  public StreamingKMeansThread(Iterable<Centroid> datapoints, Configuration conf) {
+    this.datapoints = datapoints;
+    this.conf = conf;
+  }
+
+  @Override
+  public Iterable<Centroid> call() throws Exception {
+    UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
+    int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+
+    double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+        (float) ClusteringUtils.estimateDistanceCutoff(datapoints, searcher.getDistanceMeasure(), 100));
+
+    StreamingKMeans clusterer = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff);
+    clusterer.cluster(datapoints);
+    clusterer.reindexCentroids();
+
+    return clusterer;
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java Wed May 15 15:35:12 2013
@@ -0,0 +1,139 @@
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+
+public class StreamingKMeansUtilsMR {
+
+  /**
+   * Instantiates a searcher from a given configuration.
+   * @param conf the configuration
+   * @return the instantiated searcher
+   * @throws RuntimeException if the distance measure class cannot be instantiated
+   * @throws IllegalStateException if an unknown searcher class was requested
+   */
+  public static UpdatableSearcher searcherFromConfiguration(Configuration conf) {
+    DistanceMeasure distanceMeasure;
+    String distanceMeasureClass = conf.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+    try {
+      distanceMeasure = (DistanceMeasure)Class.forName(distanceMeasureClass).newInstance();
+    } catch (Exception e) {
+      throw new RuntimeException("Failed to instantiate distanceMeasure", e);
+    }
+
+    int numProjections =  conf.getInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, 20);
+    int searchSize =  conf.getInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, 10);
+
+    String searcherClass = conf.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION);
+
+    if (searcherClass.equals(BruteSearch.class.getName())) {
+      return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
+          new Class[]{DistanceMeasure.class}, new Object[]{distanceMeasure});
+    } else if (searcherClass.equals(FastProjectionSearch.class.getName()) ||
+        searcherClass.equals(ProjectionSearch.class.getName())) {
+      return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
+          new Class[]{DistanceMeasure.class, int.class, int.class},
+          new Object[]{distanceMeasure, numProjections, searchSize});
+    } else {
+      throw new IllegalStateException("Unknown class instantiation requested");
+    }
+  }
+
+  /**
+   * Returns an Iterable of centroids from an Iterable of VectorWritables by creating a new Centroid containing
+   * a RandomAccessSparseVector as a delegate for each VectorWritable.
+   * @param inputIterable VectorWritable Iterable to get Centroids from
+   * @return the new Centroids
+   */
+  public static Iterable<Centroid> getCentroidsFromVectorWritable(Iterable<VectorWritable> inputIterable) {
+    return Iterables.transform(inputIterable, new Function<VectorWritable, Centroid>() {
+      int numVectors = 0;
+
+      @Override
+      public Centroid apply(VectorWritable input) {
+        Preconditions.checkNotNull(input);
+        return new Centroid(numVectors++, new RandomAccessSparseVector(input.get()), 1);
+      }
+    });
+  }
+
+  /**
+   * Returns an Iterable of Centroid from an Iterable of Vector by either casting each Vector to Centroid (if the
+   * instance extends Centroid) or create a new Centroid based on that Vector.
+   * The implicit expectation is that the input will not have interleaving types of vectors. Otherwise, the numbering
+   * of new Centroids will become invalid.
+   * @param input Iterable of Vectors to cast
+   * @return the new Centroids
+   */
+  public static Iterable<Centroid> castVectorsToCentroids(final Iterable<Vector> input) {
+    return Iterables.transform(input, new Function<Vector, Centroid>() {
+      private int numVectors = 0;
+      @Override
+      public Centroid apply(Vector input) {
+        Preconditions.checkNotNull(input);
+        if (input instanceof Centroid) {
+          return (Centroid) input;
+        } else {
+          return new Centroid(numVectors++, input, 1);
+        }
+      }
+    });
+  }
+
+  /**
+   * Writes centroids to a sequence file.
+   * @param centroids the centroids to write.
+   * @param path the path of the output file.
+   * @param conf the configuration for the HDFS to write the file to.
+   * @throws java.io.IOException
+   */
+  public static void writeCentroidsToSequenceFile(Iterable<Centroid> centroids, Path path, Configuration conf)
+      throws IOException {
+    SequenceFile.Writer writer = null;
+    try {
+      writer = SequenceFile.createWriter(FileSystem.get(conf), conf,
+          path, IntWritable.class, CentroidWritable.class);
+      int i = 0;
+      for (Centroid centroid : centroids) {
+        writer.append(new IntWritable(i++), new CentroidWritable(centroid));
+      }
+    } finally {
+      Closeables.close(writer, true);
+    }
+  }
+
+  public static void writeVectorsToSequenceFile(Iterable<? extends Vector> datapoints, Path path, Configuration conf)
+      throws IOException {
+    SequenceFile.Writer writer = null;
+    try {
+      writer = SequenceFile.createWriter(FileSystem.get(conf), conf,
+          path, IntWritable.class, VectorWritable.class);
+      int i = 0;
+      for (Vector vector : datapoints) {
+        writer.append(new IntWritable(i++), new VectorWritable(vector));
+      }
+    } finally {
+      Closeables.close(writer, true);
+    }
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java?rev=1482907&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java Wed May 15 15:35:12 2013
@@ -0,0 +1,271 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mrunit.mapreduce.MapDriver;
+import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
+import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.DataUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.random.WeightedThing;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+@RunWith(value = Parameterized.class)
+public class StreamingKMeansTestMR {
+  private static final int NUM_DATA_POINTS = 1 << 15;
+  private static final int NUM_DIMENSIONS = 8;
+  private static final int NUM_PROJECTIONS = 3;
+  private static final int SEARCH_SIZE = 5;
+  private static final int MAX_NUM_ITERATIONS = 10;
+  private static final double DISTANCE_CUTOFF = 1e-6;
+
+  private static Pair<List<Centroid>, List<Centroid>> syntheticData =
+      DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, 1e-4);
+
+  private String searcherClassName;
+  private String distanceMeasureClassName;
+
+  public StreamingKMeansTestMR(String searcherClassName, String distanceMeasureClassName) {
+    this.searcherClassName = searcherClassName;
+    this.distanceMeasureClassName = distanceMeasureClassName;
+  }
+
+  private void configure(Configuration configuration) {
+    configuration.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, distanceMeasureClassName);
+    configuration.setInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, SEARCH_SIZE);
+    configuration.setInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, NUM_PROJECTIONS);
+    configuration.set(StreamingKMeansDriver.SEARCHER_CLASS_OPTION, searcherClassName);
+    configuration.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1 << NUM_DIMENSIONS);
+    configuration.setInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS,
+        (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS));
+    configuration.setFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, (float) DISTANCE_CUTOFF);
+    configuration.setInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, MAX_NUM_ITERATIONS);
+  }
+
+  @Parameterized.Parameters
+  public static List<Object[]> generateData() {
+    return Arrays.asList(new Object[][]{
+        {ProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()},
+        {FastProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()},
+    });
+  }
+
+  @Test
+  public void testHypercubeMapper() throws IOException {
+    MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver =
+        MapDriver.newMapDriver(new StreamingKMeansMapper());
+    configure(mapDriver.getConfiguration());
+    System.out.printf("%s mapper test\n",
+        mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+    for (Centroid datapoint : syntheticData.getFirst()) {
+      mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
+    }
+    List<org.apache.hadoop.mrunit.types.Pair<IntWritable,CentroidWritable>> results = mapDriver.run();
+    BruteSearch resultSearcher = new BruteSearch(new SquaredEuclideanDistanceMeasure());
+    for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) {
+      resultSearcher.add(result.getSecond().getCentroid());
+    }
+    System.out.printf("Clustered the data into %d clusters\n", results.size());
+    for (Vector mean : syntheticData.getSecond()) {
+      WeightedThing<Vector> closest = resultSearcher.search(mean, 1).get(0);
+      assertTrue("Weight " + closest.getWeight() + " not less than 0.5", closest.getWeight() < 0.5);
+    }
+  }
+
+  @Test
+  public void testMapperVsLocal() throws IOException {
+    // Clusters the data using the StreamingKMeansMapper.
+    MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver =
+        MapDriver.newMapDriver(new StreamingKMeansMapper());
+    Configuration configuration = mapDriver.getConfiguration();
+    configure(configuration);
+    System.out.printf("%s mapper vs local test\n",
+        mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+
+    for (Centroid datapoint : syntheticData.getFirst()) {
+      mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
+    }
+    List<Centroid> mapperCentroids = Lists.newArrayList();
+    for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> pair : mapDriver.run()) {
+      mapperCentroids.add(pair.getSecond().getCentroid());
+    }
+
+    // Clusters the data using local batch StreamingKMeans.
+    StreamingKMeans batchClusterer =
+        new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
+            mapDriver.getConfiguration().getInt("estimatedNumMapClusters", -1), DISTANCE_CUTOFF);
+    batchClusterer.cluster(syntheticData.getFirst());
+    List<Centroid> batchCentroids = Lists.newArrayList();
+    for (Vector v : batchClusterer) {
+      batchCentroids.add((Centroid) v);
+    }
+
+    // Clusters the data using point by point StreamingKMeans.
+    StreamingKMeans perPointClusterer =
+        new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
+            (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF);
+    for (Centroid datapoint : syntheticData.getFirst()) {
+      perPointClusterer.cluster(datapoint);
+    }
+    List<Centroid> perPointCentroids = Lists.newArrayList();
+    for (Vector v : perPointClusterer) {
+      perPointCentroids.add((Centroid) v);
+    }
+
+    // Computes the cost (total sum of distances) of these different clusterings.
+    double mapperCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), mapperCentroids);
+    double localCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), batchCentroids);
+    double perPointCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), perPointCentroids);
+    System.out.printf("[Total cost] Mapper %f [%d] Local %f [%d] Perpoint local %f [%d];" +
+        "[ratio m-vs-l %f] [ratio pp-vs-l %f]\n", mapperCost, mapperCentroids.size(),
+        localCost, batchCentroids.size(), perPointCost, perPointCentroids.size(),
+        mapperCost / localCost, perPointCost / localCost);
+
+    // These ratios should be close to 1.0 and have been observed to be go as low as 0.6 and as low as 1.5.
+    // A buffer of [0.2, 1.8] seems appropriate.
+    assertEquals("Mapper StreamingKMeans / Batch local StreamingKMeans total cost ratio too far from 1",
+        1.0, mapperCost / localCost, 0.8);
+    assertEquals("One by one local StreamingKMeans / Batch local StreamingKMeans total cost ratio too high",
+        1.0, perPointCost / localCost, 0.8);
+  }
+
+  @Test
+  public void testHypercubeReducer() throws IOException {
+    ReduceDriver<IntWritable, CentroidWritable, IntWritable, CentroidWritable> reduceDriver =
+        ReduceDriver.newReduceDriver(new StreamingKMeansReducer());
+    Configuration configuration = reduceDriver.getConfiguration();
+    configure(configuration);
+
+    System.out.printf("%s reducer test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+    StreamingKMeans clusterer =
+        new StreamingKMeans(StreamingKMeansUtilsMR .searcherFromConfiguration(configuration),
+            (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF);
+
+    long start = System.currentTimeMillis();
+    clusterer.cluster(syntheticData.getFirst());
+    long end = System.currentTimeMillis();
+
+    System.out.printf("%f [s]\n", (end - start) / 1000.0);
+    List<CentroidWritable> reducerInputs = Lists.newArrayList();
+    int postMapperTotalWeight = 0;
+    for (Centroid intermediateCentroid : clusterer) {
+      reducerInputs.add(new CentroidWritable(intermediateCentroid));
+      postMapperTotalWeight += intermediateCentroid.getWeight();
+    }
+
+    reduceDriver.addInput(new IntWritable(0), reducerInputs);
+    List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results =
+        reduceDriver.run();
+    testReducerResults(postMapperTotalWeight, results);
+  }
+
+  @Test
+  public void testHypercubeMapReduce() throws IOException {
+    MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable, IntWritable, CentroidWritable>
+        mapReduceDriver = new MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable,
+        IntWritable, CentroidWritable>(new StreamingKMeansMapper(), new StreamingKMeansReducer());
+    Configuration configuration = mapReduceDriver.getConfiguration();
+    configure(configuration);
+
+    System.out.printf("%s full test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+    for (Centroid datapoint : syntheticData.getFirst()) {
+      mapReduceDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
+    }
+    List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = mapReduceDriver.run();
+    testReducerResults(syntheticData.getFirst().size(), results);
+  }
+
+  @Test
+  public void testHypercubeMapReduceRunSequentially()
+      throws IOException, InterruptedException, ExecutionException, ClassNotFoundException {
+    Configuration configuration = new Configuration();
+    configure(configuration);
+    configuration.set(DefaultOptionCreator.METHOD_OPTION, DefaultOptionCreator.SEQUENTIAL_METHOD);
+
+    Path inputPath = new Path("testInput");
+    Path outputPath = new Path("testOutput");
+    StreamingKMeansUtilsMR.writeVectorsToSequenceFile(syntheticData.getFirst(), inputPath, configuration);
+
+    StreamingKMeansDriver.run(configuration, inputPath, outputPath);
+
+    testReducerResults(syntheticData.getFirst().size(),
+        Lists.newArrayList(Iterables.transform(
+            new SequenceFileIterable<IntWritable, CentroidWritable>(outputPath, configuration),
+            new Function<
+                Pair<IntWritable, CentroidWritable>,
+                org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>>() {
+              @Override
+              public org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> apply(
+                  org.apache.mahout.common.Pair<IntWritable, CentroidWritable> input) {
+                return new org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>(
+                    input.getFirst(), input.getSecond());
+              }
+            })));
+  }
+
+  private void testReducerResults(int totalWeight, List<org.apache.hadoop.mrunit.types.Pair<IntWritable,
+      CentroidWritable>> results) {
+    int expectedNumClusters = 1 << NUM_DIMENSIONS;
+    double expectedWeight = totalWeight / expectedNumClusters;
+    int numClusters = 0;
+    int numUnbalancedClusters = 0;
+    int totalReducerWeight = 0;
+    for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) {
+      if (result.getSecond().getCentroid().getWeight() != expectedWeight) {
+        System.out.printf("Unbalanced weight %f in centroid %d\n",  result.getSecond().getCentroid().getWeight(),
+            result.getSecond().getCentroid().getIndex());
+        ++numUnbalancedClusters;
+      }
+      assertEquals("Final centroid index is invalid", numClusters, result.getFirst().get());
+      totalReducerWeight += result.getSecond().getCentroid().getWeight();
+      ++numClusters;
+    }
+    System.out.printf("%d clusters are unbalanced\n", numUnbalancedClusters);
+    assertEquals("Invalid total weight", totalWeight, totalReducerWeight);
+    assertEquals("Invalid number of clusters", 1 << NUM_DIMENSIONS, numClusters);
+  }
+
+}

Modified: mahout/trunk/src/conf/driver.classes.default.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.default.props?rev=1482907&r1=1482906&r2=1482907&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.default.props (original)
+++ mahout/trunk/src/conf/driver.classes.default.props Wed May 15 15:35:12 2013
@@ -34,6 +34,8 @@ org.apache.mahout.clustering.canopy.Cano
 org.apache.mahout.clustering.spectral.eigencuts.EigencutsDriver = eigencuts : Eigencuts spectral clustering
 org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver = spectralkmeans : Spectral k-means clustering
 org.apache.mahout.clustering.topdown.postprocessor.ClusterOutputPostProcessorDriver = clusterpp : Groups Clustering Output In Clusters
+org.apache.mahout.clustering.streaming.mapreduce.StreamingKMeansDriver = streamingkmeans : Streaming k-means clustering
+
 #Freq. Itemset Mining
 org.apache.mahout.fpm.pfpgrowth.FPGrowthDriver = fpg : Frequent Pattern Growth
 #Classification



Mime
View raw message