mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r964507 [1/3] - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/canopy/ core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/ core/src/main/java/org/apache/...
Date Thu, 15 Jul 2010 17:53:52 GMT
Author: jeastman
Date: Thu Jul 15 17:53:51 2010
New Revision: 964507

URL: http://svn.apache.org/viewvc?rev=964507&view=rev
Log:
MAHOUT-294: 
- Refactored clustering jobs to subclass AbstractJob. 
- Added fuzzy k-means example to synthetic control examples
- Added cluster dump to synthetic control examples
- Fixed _log file access bug in ClusterDumper when run on Hadoop
- all synthetic control examples run on Hadoop cluster
- Fuzzy k-Means produces numerically odd-looking clusters
- added unit tests of run() command line option for each clustering algorithm
- all unit tests run

Added:
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/fuzzykmeans/Job.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/canopy/Job.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/meanshift/Job.java
    mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java?rev=964507&r1=964506&r2=964507&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java Thu Jul 15 17:53:51 2010
@@ -18,13 +18,8 @@
 package org.apache.mahout.clustering.canopy;
 
 import java.io.IOException;
+import java.util.Map;
 
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
@@ -36,68 +31,24 @@ import org.apache.hadoop.mapreduce.lib.o
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.WeightedVectorWritable;
-import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public final class CanopyDriver {
-
-  private static final Logger log = LoggerFactory.getLogger(CanopyDriver.class);
+public class CanopyDriver extends AbstractJob {
 
   public static final String DEFAULT_CLUSTERED_POINTS_DIRECTORY = "clusteredPoints";
 
-  private CanopyDriver() {
+  private static final Logger log = LoggerFactory.getLogger(CanopyDriver.class);
+
+  public CanopyDriver() {
   }
 
-  public static void main(String[] args) throws IOException {
-    Option helpOpt = DefaultOptionCreator.helpOption();
-    Option inputOpt = DefaultOptionCreator.inputOption().create();
-    Option outputOpt = DefaultOptionCreator.outputOption().create();
-    Option measureClassOpt = DefaultOptionCreator.distanceMeasureOption().create();
-    Option t1Opt = DefaultOptionCreator.t1Option().create();
-    Option t2Opt = DefaultOptionCreator.t2Option().create();
-
-    Option overwriteOutput = DefaultOptionCreator.overwriteOption().create();
-    Option clusteringOpt = DefaultOptionCreator.clusteringOption().create();
-
-    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(outputOpt)
-        .withOption(overwriteOutput).withOption(measureClassOpt).withOption(t1Opt).withOption(t2Opt)
-        .withOption(clusteringOpt).withOption(helpOpt).create();
-
-    try {
-      Parser parser = new Parser();
-      parser.setGroup(group);
-      parser.setHelpOption(helpOpt);
-      CommandLine cmdLine = parser.parse(args);
-
-      if (cmdLine.hasOption(helpOpt)) {
-        CommandLineUtil.printHelp(group);
-        return;
-      }
-
-      Path input = new Path(cmdLine.getValue(inputOpt).toString());
-      Path output = new Path(cmdLine.getValue(outputOpt).toString());
-      if (cmdLine.hasOption(overwriteOutput)) {
-        HadoopUtil.overwriteOutput(output);
-      }
-      String measureClass = cmdLine.getValue(measureClassOpt).toString();
-      double t1 = Double.parseDouble(cmdLine.getValue(t1Opt).toString());
-      double t2 = Double.parseDouble(cmdLine.getValue(t2Opt).toString());
-
-      runJob(input, output, measureClass, t1, t2, cmdLine.hasOption(clusteringOpt));
-    } catch (OptionException e) {
-      log.error("OptionException", e);
-      CommandLineUtil.printHelp(group);
-    } catch (InterruptedException e) {
-      log.error("InterruptedException", e);
-      CommandLineUtil.printHelp(group);
-    } catch (ClassNotFoundException e) {
-      log.error("ClassNotFoundException", e);
-      CommandLineUtil.printHelp(group);
-    }
+  public static void main(String[] args) throws Exception {
+    new CanopyDriver().run(args);
   }
 
   /**
@@ -120,6 +71,41 @@ public final class CanopyDriver {
    */
   public static void runJob(Path input, Path output, String measureClassName, double t1, double t2, boolean runClustering)
       throws IOException, InterruptedException, ClassNotFoundException {
+    new CanopyDriver().job(input, output, measureClassName, t1, t2, runClustering);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.distanceMeasureOption().create());
+    addOption(DefaultOptionCreator.t1Option().create());
+    addOption(DefaultOptionCreator.t2Option().create());
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    addOption(DefaultOptionCreator.clusteringOption().create());
+
+    Map<String, String> argMap = parseArguments(args);
+    if (argMap == null) {
+      return -1;
+    }
+
+    Path input = getInputPath();
+    Path output = getOutputPath();
+    if (argMap.containsKey(DefaultOptionCreator.OVERWRITE_OPTION_KEY)) {
+      HadoopUtil.overwriteOutput(output);
+    }
+    String measureClass = argMap.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION_KEY);
+    double t1 = Double.parseDouble(argMap.get(DefaultOptionCreator.T1_OPTION_KEY));
+    double t2 = Double.parseDouble(argMap.get(DefaultOptionCreator.T2_OPTION_KEY));
+    boolean runClustering = argMap.containsKey(DefaultOptionCreator.CLUSTERING_OPTION_KEY);
+
+    job(input, output, measureClass, t1, t2, runClustering);
+    return 0;
+  }
+
+  private void job(Path input, Path output, String measureClassName, double t1, double t2, boolean runClustering)
+      throws IOException, InterruptedException, ClassNotFoundException {
     log.info("Input: {} Out: {} " + "Measure: {} t1: {} t2: {}", new Object[] { input, output, measureClassName, t1, t2 });
     Configuration conf = new Configuration();
     conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, measureClassName);
@@ -166,9 +152,9 @@ public final class CanopyDriver {
    * @throws ClassNotFoundException 
    * @throws InterruptedException 
    */
-  public static void runClustering(Path points, Path canopies, Path output, String measureClassName, double t1, double t2)
+  private void runClustering(Path points, Path canopies, Path output, String measureClassName, double t1, double t2)
       throws IOException, InterruptedException, ClassNotFoundException {
-    
+
     Configuration conf = new Configuration();
     conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, measureClassName);
     conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(t1));
@@ -188,7 +174,7 @@ public final class CanopyDriver {
     Path outPath = new Path(output, DEFAULT_CLUSTERED_POINTS_DIRECTORY);
     FileOutputFormat.setOutputPath(job, outPath);
     HadoopUtil.overwriteOutput(outPath);
-    
+
     job.waitForCompletion(true);
   }
 

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=964507&r1=964506&r2=964507&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Thu Jul 15 17:53:51 2010
@@ -20,13 +20,10 @@ package org.apache.mahout.clustering.dir
 import java.io.IOException;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
+import java.util.Map;
 
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
@@ -43,16 +40,18 @@ import org.apache.hadoop.mapreduce.lib.o
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
 import org.apache.mahout.clustering.kmeans.OutputLogFilter;
-import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public final class DirichletDriver {
+public class DirichletDriver extends AbstractJob {
 
   public static final String STATE_IN_KEY = "org.apache.mahout.clustering.dirichlet.stateIn";
 
@@ -70,62 +69,85 @@ public final class DirichletDriver {
 
   public static final String THRESHOLD_KEY = "org.apache.mahout.clustering.dirichlet.threshold";
 
+  protected static final String MODEL_PROTOTYPE_CLASS_OPTION = "modelPrototypeClass";
+
+  public static final String MODEL_PROTOTYPE_CLASS_OPTION_KEY = "--" + MODEL_PROTOTYPE_CLASS_OPTION;
+
+  protected static final String MODEL_DISTRIBUTION_CLASS_OPTION = "modelDistClass";
+
+  public static final String MODEL_DISTRIBUTION_CLASS_OPTION_KEY = "--" + MODEL_DISTRIBUTION_CLASS_OPTION;
+
+  protected static final String ALPHA_OPTION = "alpha";
+
+  public static final String ALPHA_OPTION_KEY = "--" + ALPHA_OPTION;
+
   private static final Logger log = LoggerFactory.getLogger(DirichletDriver.class);
 
-  private DirichletDriver() {
+  protected DirichletDriver() {
   }
 
   public static void main(String[] args) throws Exception {
-    Option helpOpt = DefaultOptionCreator.helpOption();
-    Option inputOpt = DefaultOptionCreator.inputOption().create();
-    Option outputOpt = DefaultOptionCreator.outputOption().create();
-    Option maxIterOpt = DefaultOptionCreator.maxIterationsOption().create();
-    Option kOpt = DefaultOptionCreator.kOption().withRequired(true).create();
-    Option overwriteOutput = DefaultOptionCreator.overwriteOption().create();
-    Option clusteringOpt = DefaultOptionCreator.clusteringOption().create();
-    Option alphaOpt = DefaultOptionCreator.alphaOption().create();
-    Option modelDistOpt = DefaultOptionCreator.modelDistributionOption().create();
-    Option prototypeOpt = DefaultOptionCreator.modelPrototypeOption().create();
-    Option numRedOpt = DefaultOptionCreator.numReducersOption().create();
-    Option emitMostLikelyOpt = DefaultOptionCreator.emitMostLikelyOption().create();
-    Option thresholdOpt = DefaultOptionCreator.thresholdOption().create();
-
-    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(outputOpt)
-        .withOption(overwriteOutput).withOption(modelDistOpt).withOption(prototypeOpt)
-        .withOption(maxIterOpt).withOption(alphaOpt).withOption(kOpt).withOption(helpOpt)
-        .withOption(numRedOpt).withOption(clusteringOpt).withOption(emitMostLikelyOpt)
-        .withOption(thresholdOpt).create();
-
-    try {
-      Parser parser = new Parser();
-      parser.setGroup(group);
-      parser.setHelpOption(helpOpt);
-      CommandLine cmdLine = parser.parse(args);
-      if (cmdLine.hasOption(helpOpt)) {
-        CommandLineUtil.printHelp(group);
-        return;
-      }
+    new DirichletDriver().run(args);
+  }
 
-      Path input = new Path(cmdLine.getValue(inputOpt).toString());
-      Path output = new Path(cmdLine.getValue(outputOpt).toString());
-      if (cmdLine.hasOption(overwriteOutput)) {
-        HadoopUtil.overwriteOutput(output);
-      }
-      String modelFactory = cmdLine.getValue(modelDistOpt).toString();
-      String modelPrototype = cmdLine.getValue(prototypeOpt).toString();
-      int numModels = Integer.parseInt(cmdLine.getValue(kOpt).toString());
-      int numReducers = Integer.parseInt(cmdLine.getValue(numRedOpt).toString());
-      int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
-      boolean emitMostLikely = Boolean.parseBoolean(cmdLine.getValue(emitMostLikelyOpt).toString());
-      double threshold = Double.parseDouble(cmdLine.getValue(thresholdOpt).toString());
-      double alpha0 = Double.parseDouble(cmdLine.getValue(alphaOpt).toString());
-
-      runJob(input, output, modelFactory, modelPrototype, numModels, maxIterations, alpha0, numReducers, cmdLine
-          .hasOption(clusteringOpt), emitMostLikely, threshold);
-    } catch (OptionException e) {
-      log.error("Exception parsing command line: ", e);
-      CommandLineUtil.printHelp(group);
+  /* (non-Javadoc)
+   * @see org.apache.hadoop.util.Tool#run(java.lang.String[])
+   */
+  public int run(String[] args) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException,
+      NoSuchMethodException, InvocationTargetException, InterruptedException {
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.maxIterationsOption().create());
+    addOption(DefaultOptionCreator.numClustersOption().withRequired(true).create());
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    addOption(DefaultOptionCreator.clusteringOption().create());
+    addOption(new DefaultOptionBuilder().withLongName(ALPHA_OPTION).withRequired(false).withShortName("m")
+        .withArgument(new ArgumentBuilder().withName(ALPHA_OPTION).withDefault("1.0").withMinimum(1).withMaximum(1).create())
+        .withDescription("The alpha0 value for the DirichletDistribution. Defaults to 1.0").create());
+    addOption(new DefaultOptionBuilder().withLongName(MODEL_DISTRIBUTION_CLASS_OPTION).withRequired(false).withShortName("md")
+        .withArgument(new ArgumentBuilder().withName(MODEL_DISTRIBUTION_CLASS_OPTION).withDefault(NormalModelDistribution.class
+            .getName()).withMinimum(1).withMaximum(1).create()).withDescription("The ModelDistribution class name. "
+            + "Defaults to NormalModelDistribution").create());
+    addOption(new DefaultOptionBuilder().withLongName(MODEL_PROTOTYPE_CLASS_OPTION).withRequired(false).withShortName("mp")
+        .withArgument(new ArgumentBuilder().withName("prototypeClass").withDefault(RandomAccessSparseVector.class.getName())
+            .withMinimum(1).withMaximum(1).create())
+        .withDescription("The ModelDistribution prototype Vector class name. Defaults to RandomAccessSparseVector").create());
+    addOption(DefaultOptionCreator.emitMostLikelyOption().create());
+    addOption(DefaultOptionCreator.thresholdOption().create());
+    addOption(DefaultOptionCreator.numReducersOption().create());
+
+    Map<String, String> argMap = parseArguments(args);
+    if (argMap == null) {
+      return -1;
+    }
+
+    Path input = getInputPath();
+    Path output = getOutputPath();
+    if (argMap.containsKey(DefaultOptionCreator.OVERWRITE_OPTION_KEY)) {
+      HadoopUtil.overwriteOutput(output);
     }
+    String modelFactory = argMap.get(MODEL_DISTRIBUTION_CLASS_OPTION_KEY);
+    String modelPrototype = argMap.get(MODEL_PROTOTYPE_CLASS_OPTION_KEY);
+    int numModels = Integer.parseInt(argMap.get(DefaultOptionCreator.NUM_CLUSTERS_OPTION_KEY));
+    int numReducers = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_REDUCERS_OPTION_KEY));
+    int maxIterations = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_ITERATIONS_OPTION_KEY));
+    boolean emitMostLikely = Boolean.parseBoolean(argMap.get(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION_KEY));
+    double threshold = Double.parseDouble(argMap.get(DefaultOptionCreator.THRESHOLD_OPTION_KEY));
+    double alpha0 = Double.parseDouble(argMap.get(ALPHA_OPTION_KEY));
+    boolean runClustering = argMap.containsKey(DefaultOptionCreator.CLUSTERING_OPTION_KEY);
+
+    job(input,
+        output,
+        modelFactory,
+        modelPrototype,
+        numModels,
+        maxIterations,
+        alpha0,
+        numReducers,
+        runClustering,
+        emitMostLikely,
+        threshold);
+    return 0;
   }
 
   /**
@@ -165,39 +187,65 @@ public final class DirichletDriver {
                             int numReducers,
                             boolean runClustering,
                             boolean emitMostLikely,
-                            double threshold)
-      throws ClassNotFoundException, InstantiationException, IllegalAccessException, IOException,
-             SecurityException, NoSuchMethodException, InvocationTargetException, InterruptedException {
+                            double threshold) throws ClassNotFoundException, InstantiationException, IllegalAccessException,
+      IOException, SecurityException, NoSuchMethodException, InvocationTargetException, InterruptedException {
 
-    Path clustersIn = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+    new DirichletDriver().job(input,
+                              output,
+                              modelFactory,
+                              modelPrototype,
+                              numClusters,
+                              maxIterations,
+                              alpha0,
+                              numReducers,
+                              runClustering,
+                              emitMostLikely,
+                              threshold);
+  }
 
-    int protoSize = readPrototypeSize(input);
+  /**
+   * Creates a DirichletState object from the given arguments. Note that the modelFactory is presumed to be a
+   * subclass of VectorModelDistribution that can be initialized with a concrete Vector prototype.
+   * 
+   * @param modelFactory
+   *          a String which is the class name of the model factory
+   * @param modelPrototype
+   *          a String which is the class name of the Vector used to initialize the factory
+   * @param prototypeSize
+   *          an int number of dimensions of the model prototype vector
+   * @param numModels
+   *          an int number of models to be created
+   * @param alpha0
+   *          the double alpha_0 argument to the algorithm
+   * @return an initialized DirichletState
+   */
+  static DirichletState<VectorWritable> createState(String modelFactory,
+                                                    String modelPrototype,
+                                                    int prototypeSize,
+                                                    int numModels,
+                                                    double alpha0) throws ClassNotFoundException, InstantiationException,
+      IllegalAccessException, SecurityException, NoSuchMethodException, IllegalArgumentException, InvocationTargetException {
 
-    writeInitialState(output, clustersIn, modelFactory, modelPrototype, protoSize, numClusters, alpha0);
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    Class<? extends AbstractVectorModelDistribution> cl = ccl.loadClass(modelFactory)
+        .asSubclass(AbstractVectorModelDistribution.class);
+    AbstractVectorModelDistribution factory = cl.newInstance();
 
-    for (int iteration = 1; iteration <= maxIterations; iteration++) {
-      log.info("Iteration {}", iteration);
-      // point the output to a new directory per iteration
-      Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
-      runIteration(input,
-                   clustersIn,
-                   clustersOut,
-                   modelFactory,
-                   modelPrototype,
-                   protoSize,
-                   numClusters,
-                   alpha0,
-                   numReducers);
-      // now point the input to the old output directory
-      clustersIn = clustersOut;
-    }
-    if (runClustering) {
-      // now cluster the most likely points
-      runClustering(input, clustersIn, new Path(output, Cluster.CLUSTERED_POINTS_DIR), emitMostLikely, threshold);
-    }
+    Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
+    Constructor<? extends Vector> v = vcl.getConstructor(int.class);
+    factory.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
+    return new DirichletState<VectorWritable>(factory, numModels, alpha0);
   }
 
-  private static int readPrototypeSize(Path input) throws IOException, InstantiationException, IllegalAccessException {
+  /**
+   * Read the first input vector to determine the prototype size for the modelPrototype
+   * @param input
+   * @return
+   * @throws IOException
+   * @throws InstantiationException
+   * @throws IllegalAccessException
+   */
+  private int readPrototypeSize(Path input) throws IOException, InstantiationException, IllegalAccessException {
     Configuration conf = new Configuration();
     FileSystem fs = FileSystem.get(input.toUri(), conf);
     FileStatus[] status = fs.listStatus(input, new OutputLogFilter());
@@ -215,15 +263,31 @@ public final class DirichletDriver {
     return protoSize;
   }
 
-  private static void writeInitialState(Path output,
-                                        Path stateIn,
-                                        String modelFactory,
-                                        String modelPrototype,
-                                        int prototypeSize,
-                                        int numModels,
-                                        double alpha0)
-      throws ClassNotFoundException, InstantiationException, IllegalAccessException, IOException,
-             SecurityException, NoSuchMethodException, InvocationTargetException {
+  /**
+   * Write initial state (prior distribution) to the output path directory
+   * @param output the output Path
+   * @param stateIn the state input Path
+   * @param modelFactory the String class name of the modelFactory
+   * @param modelPrototype the String class name of the modelPrototype
+   * @param prototypeSize the int size of the modelPrototype vectors
+   * @param numModels the int number of models to generate
+   * @param alpha0 the double alpha_0 argument to the DirichletDistribution
+   * @throws ClassNotFoundException
+   * @throws InstantiationException
+   * @throws IllegalAccessException
+   * @throws IOException
+   * @throws SecurityException
+   * @throws NoSuchMethodException
+   * @throws InvocationTargetException
+   */
+  private void writeInitialState(Path output,
+                                 Path stateIn,
+                                 String modelFactory,
+                                 String modelPrototype,
+                                 int prototypeSize,
+                                 int numModels,
+                                 double alpha0) throws ClassNotFoundException, InstantiationException, IllegalAccessException,
+      IOException, SecurityException, NoSuchMethodException, InvocationTargetException {
 
     DirichletState<VectorWritable> state = createState(modelFactory, modelPrototype, prototypeSize, numModels, alpha0);
     Configuration conf = new Configuration();
@@ -237,41 +301,7 @@ public final class DirichletDriver {
   }
 
   /**
-   * Creates a DirichletState object from the given arguments. Note that the modelFactory is presumed to be a
-   * subclass of VectorModelDistribution that can be initialized with a concrete Vector prototype.
-   * 
-   * @param modelFactory
-   *          a String which is the class name of the model factory
-   * @param modelPrototype
-   *          a String which is the class name of the Vector used to initialize the factory
-   * @param prototypeSize
-   *          an int number of dimensions of the model prototype vector
-   * @param numModels
-   *          an int number of models to be created
-   * @param alpha0
-   *          the double alpha_0 argument to the algorithm
-   * @return an initialized DirichletState
-   */
-  public static DirichletState<VectorWritable> createState(String modelFactory,
-                                                           String modelPrototype,
-                                                           int prototypeSize,
-                                                           int numModels,
-                                                           double alpha0)
-      throws ClassNotFoundException, InstantiationException, IllegalAccessException,
-             SecurityException, NoSuchMethodException, IllegalArgumentException, InvocationTargetException {
-
-    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-    Class<? extends AbstractVectorModelDistribution> cl = ccl.loadClass(modelFactory).asSubclass(AbstractVectorModelDistribution.class);
-    AbstractVectorModelDistribution factory = cl.newInstance();
-
-    Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
-    Constructor<? extends Vector> v = vcl.getConstructor(int.class);
-    factory.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
-    return new DirichletState<VectorWritable>(factory, numModels, alpha0);
-  }
-
-  /**
-   * Run the job using supplied arguments
+   * Run an iteration using supplied arguments
    * 
    * @param input
    *          the directory pathname for input points
@@ -295,15 +325,15 @@ public final class DirichletDriver {
    * @throws ClassNotFoundException 
    * @throws InterruptedException 
    */
-  public static void runIteration(Path input,
-                                  Path stateIn,
-                                  Path stateOut,
-                                  String modelFactory,
-                                  String modelPrototype,
-                                  int prototypeSize,
-                                  int numClusters,
-                                  double alpha0,
-                                  int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+  private void runIteration(Path input,
+                            Path stateIn,
+                            Path stateOut,
+                            String modelFactory,
+                            String modelPrototype,
+                            int prototypeSize,
+                            int numClusters,
+                            double alpha0,
+                            int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     conf.set(STATE_IN_KEY, stateIn.toString());
     conf.set(MODEL_FACTORY_KEY, modelFactory);
@@ -313,7 +343,7 @@ public final class DirichletDriver {
     conf.set(ALPHA_0_KEY, Double.toString(alpha0));
 
     Job job = new Job(conf);
-    
+
     job.setInputFormatClass(SequenceFileInputFormat.class);
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
     job.setOutputKeyClass(Text.class);
@@ -348,13 +378,12 @@ public final class DirichletDriver {
    * @throws InterruptedException 
    * @throws IOException 
    */
-  public static void runClustering(Path input, Path stateIn, Path output, boolean emitMostLikely, double threshold)
-    throws IOException, InterruptedException, ClassNotFoundException {
+  private void runClustering(Path input, Path stateIn, Path output, boolean emitMostLikely, double threshold) throws IOException,
+      InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     conf.set(STATE_IN_KEY, stateIn.toString());
     conf.set(EMIT_MOST_LIKELY_KEY, Boolean.toString(emitMostLikely));
     conf.set(THRESHOLD_KEY, Double.toString(threshold));
-
     Job job = new Job(conf);
     job.setOutputKeyClass(IntWritable.class);
     job.setOutputValueClass(WeightedVectorWritable.class);
@@ -363,10 +392,74 @@ public final class DirichletDriver {
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
     job.setNumReduceTasks(0);
     job.setJarByClass(DirichletDriver.class);
-    
+
     FileInputFormat.addInputPath(job, input);
     FileOutputFormat.setOutputPath(job, output);
 
     job.waitForCompletion(true);
   }
+
+  /**
+   * Run the job
+   * @param input
+   *          the directory pathname for input points
+   * @param output
+   *          the directory pathname for output points
+   * @param modelFactory
+   *          the String ModelDistribution class name to use
+   * @param modelPrototype
+   *          the String class name of the model prototype
+   * @param numClusters
+   *          the number of models
+   * @param maxIterations
+   *          the maximum number of iterations
+   * @param alpha0
+   *          the alpha_0 value for the DirichletDistribution
+   * @param numReducers
+   *          the number of Reducers desired
+   * @param runClustering 
+   *          true if clustering of points to be done after iterations
+   * @param emitMostLikely
+   *          a boolean if true emit only most likely cluster for each point
+   * @param threshold 
+   *          a double threshold value emits all clusters having greater pdf (emitMostLikely = false)
+   * @throws IOException
+   * @throws InstantiationException
+   * @throws IllegalAccessException
+   * @throws ClassNotFoundException
+   * @throws NoSuchMethodException
+   * @throws InvocationTargetException
+   * @throws InterruptedException
+   */
+  private void job(Path input,
+                   Path output,
+                   String modelFactory,
+                   String modelPrototype,
+                   int numClusters,
+                   int maxIterations,
+                   double alpha0,
+                   int numReducers,
+                   boolean runClustering,
+                   boolean emitMostLikely,
+                   double threshold) throws IOException, InstantiationException, IllegalAccessException, ClassNotFoundException,
+      NoSuchMethodException, InvocationTargetException, InterruptedException {
+    Path clustersIn = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+
+    int protoSize = readPrototypeSize(input);
+
+    writeInitialState(output, clustersIn, modelFactory, modelPrototype, protoSize, numClusters, alpha0);
+
+    for (int iteration = 1; iteration <= maxIterations; iteration++) {
+      log.info("Iteration {}", iteration);
+      // point the output to a new directory per iteration
+      Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
+      runIteration(input, clustersIn, clustersOut, modelFactory, modelPrototype, protoSize, numClusters, alpha0, numReducers);
+      // now point the input to the old output directory
+      clustersIn = clustersOut;
+    }
+    if (runClustering) {
+      // now cluster the most likely points
+      runClustering(input, clustersIn, new Path(output, Cluster.CLUSTERED_POINTS_DIR), emitMostLikely, threshold);
+    }
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java?rev=964507&r1=964506&r2=964507&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java Thu Jul 15 17:53:51 2010
@@ -20,13 +20,10 @@ package org.apache.mahout.clustering.fuz
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
@@ -44,98 +41,26 @@ import org.apache.hadoop.mapreduce.lib.o
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
-import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public final class FuzzyKMeansDriver {
+public class FuzzyKMeansDriver extends AbstractJob {
+
+  protected static final String M_OPTION = "m";
+
+  public static final String M_OPTION_KEY = "--" + M_OPTION;
 
   private static final Logger log = LoggerFactory.getLogger(FuzzyKMeansDriver.class);
 
-  private FuzzyKMeansDriver() {
+  public FuzzyKMeansDriver() {
   }
 
   public static void main(String[] args) throws Exception {
-    Option inputOpt = DefaultOptionCreator.inputOption().create();
-    Option outputOpt = DefaultOptionCreator.outputOption().create();
-    Option measureClassOpt = DefaultOptionCreator.distanceMeasureOption().create();
-    Option clustersOpt = DefaultOptionCreator.clustersInOption().withDescription(
-        "The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
-            + "If k is also specified, then a random set of vectors will be selected"
-            + " and written out to this path first")
-        .create();
-    Option kOpt = DefaultOptionCreator.kOption().withDescription(
-        "The k in k-Means.  If specified, then a random selection of k Vectors will be chosen"
-            + " as the Centroid and written to the clusters input path.").create();
-    Option convergenceDeltaOpt = DefaultOptionCreator.convergenceOption().create();
-    Option maxIterationsOpt = DefaultOptionCreator.maxIterationsOption().create();
-    Option helpOpt = DefaultOptionCreator.helpOption();
-    Option overwriteOutput = DefaultOptionCreator.overwriteOption().create();
-    Option mOpt = DefaultOptionCreator.mOption().create();
-    Option numReduceTasksOpt = DefaultOptionCreator.numReducersOption().create();
-    Option numMapTasksOpt = DefaultOptionCreator.numMappersOption().create();
-    Option clusteringOpt = DefaultOptionCreator.clusteringOption().create();
-    Option emitMostLikelyOpt = DefaultOptionCreator.emitMostLikelyOption().create();
-    Option thresholdOpt = DefaultOptionCreator.thresholdOption().create();
-
-    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(clustersOpt)
-        .withOption(outputOpt).withOption(measureClassOpt).withOption(convergenceDeltaOpt)
-        .withOption(maxIterationsOpt).withOption(kOpt).withOption(mOpt)
-        .withOption(overwriteOutput).withOption(helpOpt).withOption(numMapTasksOpt)
-        .withOption(numReduceTasksOpt).withOption(emitMostLikelyOpt).withOption(thresholdOpt).create();
-
-    try {
-      Parser parser = new Parser();
-      parser.setGroup(group);
-      parser.setHelpOption(helpOpt);
-      CommandLine cmdLine = parser.parse(args);
-      if (cmdLine.hasOption(helpOpt)) {
-        CommandLineUtil.printHelp(group);
-        return;
-      }
-      Path input = new Path(cmdLine.getValue(inputOpt).toString());
-      Path clusters = new Path(cmdLine.getValue(clustersOpt).toString());
-      Path output = new Path(cmdLine.getValue(outputOpt).toString());
-      String measureClass = SquaredEuclideanDistanceMeasure.class.getName();
-      if (cmdLine.hasOption(measureClassOpt)) {
-        measureClass = cmdLine.getValue(measureClassOpt).toString();
-      }
-      double convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
-      float m = Float.parseFloat(cmdLine.getValue(mOpt).toString());
-
-      int numReduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
-      int numMapTasks = Integer.parseInt(cmdLine.getValue(numMapTasksOpt).toString());
-      int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
-      if (cmdLine.hasOption(overwriteOutput)) {
-        HadoopUtil.overwriteOutput(output);
-      }
-      boolean emitMostLikely = Boolean.parseBoolean(cmdLine.getValue(emitMostLikelyOpt).toString());
-      double threshold = Double.parseDouble(cmdLine.getValue(thresholdOpt).toString());
-      if (cmdLine.hasOption(kOpt)) {
-        clusters = RandomSeedGenerator.buildRandom(input, clusters,
-                                                   Integer.parseInt(cmdLine.getValue(kOpt).toString()));
-      }
-      runJob(input,
-             clusters,
-             output,
-             measureClass,
-             convergenceDelta,
-             maxIterations,
-             numMapTasks,
-             numReduceTasks,
-             m,
-             cmdLine.hasOption(clusteringOpt),
-             emitMostLikely,
-             threshold);
-
-    } catch (OptionException e) {
-      log.error("Exception", e);
-      CommandLineUtil.printHelp(group);
-    }
-
+    new FuzzyKMeansDriver().run(args);
   }
 
   /**
@@ -153,8 +78,6 @@ public final class FuzzyKMeansDriver {
    *          the convergence delta value
    * @param maxIterations
    *          the maximum number of iterations
-   * @param numMapTasks
-   *          the number of mapper tasks
    * @param numReduceTasks
    *          the number of reduce tasks
    * @param m
@@ -174,13 +97,119 @@ public final class FuzzyKMeansDriver {
                             String measureClass,
                             double convergenceDelta,
                             int maxIterations,
-                            int numMapTasks,
                             int numReduceTasks,
                             float m,
                             boolean runClustering,
                             boolean emitMostLikely,
                             double threshold) throws IOException, ClassNotFoundException, InterruptedException {
 
+    new FuzzyKMeansDriver().job(input,
+                                clustersIn,
+                                output,
+                                measureClass,
+                                convergenceDelta,
+                                maxIterations,
+                                numReduceTasks,
+                                m,
+                                runClustering,
+                                emitMostLikely,
+                                threshold);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.distanceMeasureOption().create());
+    addOption(DefaultOptionCreator.clustersInOption()
+        .withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
+            + "If k is also specified, then a random set of vectors will be selected" + " and written out to this path first")
+        .create());
+    addOption(DefaultOptionCreator.numClustersOption()
+        .withDescription("The k in k-Means.  If specified, then a random selection of k Vectors will be chosen"
+            + " as the Centroid and written to the clusters input path.").create());
+    addOption(DefaultOptionCreator.convergenceOption().create());
+    addOption(DefaultOptionCreator.maxIterationsOption().create());
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    addOption(new DefaultOptionBuilder().withLongName(M_OPTION).withRequired(true).withArgument(new ArgumentBuilder()
+        .withName(M_OPTION).withMinimum(1).withMaximum(1).create())
+        .withDescription("coefficient normalization factor, must be greater than 1").withShortName(M_OPTION).create());
+    addOption(DefaultOptionCreator.numReducersOption().create());
+    //TODO: addOption(DefaultOptionCreator.numMappersOption().create()); but how to set in new Job?
+    addOption(DefaultOptionCreator.clusteringOption().create());
+    addOption(DefaultOptionCreator.emitMostLikelyOption().create());
+    addOption(DefaultOptionCreator.thresholdOption().create());
+
+    Map<String, String> argMap = parseArguments(args);
+    if (argMap == null) {
+      return -1;
+    }
+
+    Path input = getInputPath();
+    Path clusters = new Path(argMap.get(DefaultOptionCreator.CLUSTERS_IN_OPTION_KEY));
+    Path output = getOutputPath();
+    String measureClass = argMap.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION_KEY);
+    if (measureClass == null) {
+      measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+    }
+    double convergenceDelta = Double.parseDouble(argMap.get(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION_KEY));
+    float fuzziness = Float.parseFloat(argMap.get(M_OPTION_KEY));
+
+    int numReduceTasks = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_REDUCERS_OPTION_KEY));
+    int maxIterations = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_ITERATIONS_OPTION_KEY));
+    if (argMap.containsKey(DefaultOptionCreator.OVERWRITE_OPTION_KEY)) {
+      HadoopUtil.overwriteOutput(output);
+    }
+    boolean emitMostLikely = Boolean.parseBoolean(argMap.get(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION_KEY));
+    double threshold = Double.parseDouble(argMap.get(DefaultOptionCreator.THRESHOLD_OPTION_KEY));
+    if (argMap.containsKey(DefaultOptionCreator.NUM_CLUSTERS_OPTION_KEY)) {
+      clusters = RandomSeedGenerator.buildRandom(input, clusters, Integer.parseInt(argMap
+          .get(DefaultOptionCreator.NUM_CLUSTERS_OPTION_KEY)));
+    }
+    boolean runClustering = argMap.containsKey(DefaultOptionCreator.CLUSTERING_OPTION_KEY);
+    job(input,
+        clusters,
+        output,
+        measureClass,
+        convergenceDelta,
+        maxIterations,
+        numReduceTasks,
+        fuzziness,
+        runClustering,
+        emitMostLikely,
+        threshold);
+    return 0;
+  }
+
+  /**
+   * Run the full clustering job
+   * @param input
+   * @param clustersIn
+   * @param output
+   * @param measureClass
+   * @param convergenceDelta
+   * @param maxIterations
+   * @param numReduceTasks
+   * @param m
+   * @param runClustering
+   * @param emitMostLikely 
+   * @param threshold
+   * @throws IOException
+   * @throws ClassNotFoundException
+   * @throws InterruptedException
+   */
+  private void job(Path input,
+                   Path clustersIn,
+                   Path output,
+                   String measureClass,
+                   double convergenceDelta,
+                   int maxIterations,
+                   int numReduceTasks,
+                   float m,
+                   boolean runClustering,
+                   boolean emitMostLikely,
+                   double threshold) throws IOException, ClassNotFoundException, InterruptedException {
     boolean converged = false;
     int iteration = 1;
 
@@ -190,36 +219,29 @@ public final class FuzzyKMeansDriver {
 
       // point the output to a new directory per iteration
       Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
-      converged = runIteration(input,
-                               clustersIn,
-                               clustersOut,
-                               measureClass,
-                               convergenceDelta,
-                               numMapTasks,
-                               numReduceTasks,
-                               iteration,
-                               m);
+      converged = runIteration(input, clustersIn, clustersOut, measureClass, convergenceDelta, numReduceTasks, iteration, m);
 
       // now point the input to the old output directory
       clustersIn = clustersOut;
       iteration++;
     }
 
-    // now actually cluster the points
-    log.info("Clustering ");
-    runClustering(input,
-                  clustersIn,
-                  new Path(output, Cluster.CLUSTERED_POINTS_DIR),
-                  measureClass,
-                  convergenceDelta,
-                  numMapTasks,
-                  m,
-                  emitMostLikely,
-                  threshold);
+    // now actually cluster the points if requested
+    if (runClustering) {
+      log.info("Clustering ");
+      runClustering(input,
+                    clustersIn,
+                    new Path(output, Cluster.CLUSTERED_POINTS_DIR),
+                    measureClass,
+                    convergenceDelta,
+                    m,
+                    emitMostLikely,
+                    threshold);
+    }
   }
 
   /**
-   * Run the job using supplied arguments
+   * Run the iteration using supplied arguments
    * 
    * @param input
    *          the directory pathname for input points
@@ -231,8 +253,6 @@ public final class FuzzyKMeansDriver {
    *          the classname of the DistanceMeasure
    * @param convergenceDelta
    *          the convergence delta value
-   * @param numMapTasks
-   *          the number of map tasks
    * @param iterationNumber
    *          the iteration number that is going to run
    * @param m
@@ -241,15 +261,14 @@ public final class FuzzyKMeansDriver {
    * @return true if the iteration successfully runs
    * @throws IOException 
    */
-  private static boolean runIteration(Path input,
-                                      Path clustersIn,
-                                      Path clustersOut,
-                                      String measureClass,
-                                      double convergenceDelta,
-                                      int numMapTasks,
-                                      int numReduceTasks,
-                                      int iterationNumber,
-                                      float m) throws IOException {
+  private boolean runIteration(Path input,
+                               Path clustersIn,
+                               Path clustersOut,
+                               String measureClass,
+                               double convergenceDelta,
+                               int numReduceTasks,
+                               int iterationNumber,
+                               float m) throws IOException {
 
     Configuration conf = new Configuration();
     conf.set(FuzzyKMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn.toString());
@@ -259,7 +278,7 @@ public final class FuzzyKMeansDriver {
     // these values don't matter during iterations as only used for clustering if requested
     conf.set(FuzzyKMeansConfigKeys.EMIT_MOST_LIKELY_KEY, Boolean.toString(true));
     conf.set(FuzzyKMeansConfigKeys.THRESHOLD_KEY, Double.toString(0));
-    
+
     Job job = new Job(conf);
     job.setMapOutputKeyClass(Text.class);
     job.setMapOutputValueClass(FuzzyKMeansInfo.class);
@@ -271,16 +290,12 @@ public final class FuzzyKMeansDriver {
     job.setMapperClass(FuzzyKMeansMapper.class);
     job.setCombinerClass(FuzzyKMeansCombiner.class);
     job.setReducerClass(FuzzyKMeansReducer.class);
-    //TODO: job.setNumMapTasks(numMapTasks);
     job.setNumReduceTasks(numReduceTasks);
     job.setJarByClass(FuzzyKMeansDriver.class);
 
     FileInputFormat.addInputPath(job, input);
     FileOutputFormat.setOutputPath(job, clustersOut);
 
-    // uncomment it to run locally
-    // conf.set("mapred.job.tracker", "local");
-
     try {
       job.waitForCompletion(true);
       FileSystem fs = FileSystem.get(clustersOut.toUri(), conf);
@@ -310,23 +325,20 @@ public final class FuzzyKMeansDriver {
    *          the classname of the DistanceMeasure
    * @param convergenceDelta
    *          the convergence delta value
-   * @param numMapTasks
-   *          the number of map tasks
    * @param emitMostLikely
    *          a boolean if true emit only most likely cluster for each point
    * @param threshold 
    *          a double threshold value emits all clusters having greater pdf (emitMostLikely = false)
    * @throws IOException 
    */
-  private static void runClustering(Path input,
-                                    Path clustersIn,
-                                    Path output,
-                                    String measureClass,
-                                    double convergenceDelta,
-                                    int numMapTasks,
-                                    float m,
-                                    boolean emitMostLikely,
-                                    double threshold) throws IOException, ClassNotFoundException, InterruptedException {
+  private void runClustering(Path input,
+                             Path clustersIn,
+                             Path output,
+                             String measureClass,
+                             double convergenceDelta,
+                             float m,
+                             boolean emitMostLikely,
+                             double threshold) throws IOException, ClassNotFoundException, InterruptedException {
 
     Configuration conf = new Configuration();
     conf.set(FuzzyKMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn.toString());
@@ -338,7 +350,7 @@ public final class FuzzyKMeansDriver {
 
     // Clear output
     output.getFileSystem(conf).delete(output, true);
-    
+
     Job job = new Job(conf);
     job.setOutputKeyClass(IntWritable.class);
     job.setOutputValueClass(WeightedVectorWritable.class);
@@ -350,7 +362,6 @@ public final class FuzzyKMeansDriver {
 
     job.setInputFormatClass(SequenceFileInputFormat.class);
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
-    //TODO: job.setNumMapTasks(numMapTasks);
     job.setNumReduceTasks(0);
     job.setJarByClass(FuzzyKMeansDriver.class);
 
@@ -370,7 +381,7 @@ public final class FuzzyKMeansDriver {
    * @throws IOException
    *           if there was an IO error
    */
-  private static boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException {
+  private boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException {
 
     Path clusterPath = new Path(filePath, "*");
     List<Path> result = new ArrayList<Path>();
@@ -382,8 +393,7 @@ public final class FuzzyKMeansDriver {
       }
     };
 
-    FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(
-        fs.globStatus(clusterPath, clusterFileFilter)), clusterFileFilter);
+    FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(clusterPath, clusterFileFilter)), clusterFileFilter);
 
     for (FileStatus match : matches) {
       result.add(fs.makeQualified(match.getPath()));

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java?rev=964507&r1=964506&r2=964507&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java Thu Jul 15 17:53:51 2010
@@ -43,12 +43,11 @@ public class FuzzyKMeansReducer extends 
       } else {
         cluster.addPoints(value.getVector(), value.getProbability());
       }
-
     }
     // force convergence calculation
     boolean converged = clusterer.computeConvergence(cluster);
     if (converged) {
-      // TODO: reporter.incrCounter("Clustering", "Converged Clusters", 1);
+      context.getCounter("Clustering", "Converged Clusters").increment(1);
     }
     context.write(new Text(cluster.getIdentifier()), cluster);
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=964507&r1=964506&r2=964507&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Thu Jul 15 17:53:51 2010
@@ -17,13 +17,8 @@
 package org.apache.mahout.clustering.kmeans;
 
 import java.io.IOException;
+import java.util.Map;
 
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
@@ -38,73 +33,23 @@ import org.apache.hadoop.mapreduce.lib.i
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.WeightedVectorWritable;
-import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public final class KMeansDriver {
+public class KMeansDriver extends AbstractJob {
 
   private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
 
-  private KMeansDriver() {
+  protected KMeansDriver() {
   }
 
   public static void main(String[] args) throws Exception {
-    Option inputOpt = DefaultOptionCreator.inputOption().create();
-    Option clustersOpt = DefaultOptionCreator.clustersInOption().withDescription(
-        "The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
-            + "If k is also specified, then a random set of vectors will be selected"
-            + " and written out to this path first")
-        .create();
-    Option kOpt = DefaultOptionCreator.kOption().withDescription(
-        "The k in k-Means.  If specified, then a random selection of k Vectors will be chosen"
-            + " as the Centroid and written to the clusters input path.").create();
-    Option outputOpt = DefaultOptionCreator.outputOption().create();
-    Option overwriteOutput = DefaultOptionCreator.overwriteOption().create();
-    Option measureClassOpt = DefaultOptionCreator.distanceMeasureOption().create();
-    Option convergenceDeltaOpt = DefaultOptionCreator.convergenceOption().create();
-    Option maxIterationsOpt = DefaultOptionCreator.maxIterationsOption().create();
-    Option numReduceTasksOpt = DefaultOptionCreator.numReducersOption().create();
-    Option clusteringOpt = DefaultOptionCreator.clusteringOption().create();
-    Option helpOpt = DefaultOptionCreator.helpOption();
-
-    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(clustersOpt)
-        .withOption(outputOpt).withOption(measureClassOpt).withOption(convergenceDeltaOpt)
-        .withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt).withOption(overwriteOutput)
-        .withOption(helpOpt).withOption(clusteringOpt).create();
-    try {
-      Parser parser = new Parser();
-      parser.setGroup(group);
-      parser.setHelpOption(helpOpt);
-      CommandLine cmdLine = parser.parse(args);
-
-      if (cmdLine.hasOption(helpOpt)) {
-        CommandLineUtil.printHelp(group);
-        return;
-      }
-      Path input = new Path(cmdLine.getValue(inputOpt).toString());
-      Path clusters = new Path(cmdLine.getValue(clustersOpt).toString());
-      Path output = new Path(cmdLine.getValue(outputOpt).toString());
-      String measureClass = cmdLine.getValue(measureClassOpt).toString();
-      double convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
-      int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
-      int numReduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
-      if (cmdLine.hasOption(overwriteOutput)) {
-        HadoopUtil.overwriteOutput(output);
-      }
-      if (cmdLine.hasOption(kOpt)) {
-        clusters = RandomSeedGenerator.buildRandom(input, clusters,
-                                                   Integer.parseInt(cmdLine.getValue(kOpt).toString()));
-      }
-      runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations, numReduceTasks, cmdLine
-          .hasOption(clusteringOpt));
-    } catch (OptionException e) {
-      log.error("Exception", e);
-      CommandLineUtil.printHelp(group);
-    }
+    new KMeansDriver().run(args);
   }
 
   /**
@@ -129,16 +74,77 @@ public final class KMeansDriver {
    * @throws ClassNotFoundException 
    * @throws InterruptedException 
    */
-  public static void runJob(Path input, Path clustersIn, Path output, String measureClass, double convergenceDelta,
-      int maxIterations, int numReduceTasks, boolean runClustering) throws IOException, InterruptedException,
-      ClassNotFoundException {
+  public static void runJob(Path input,
+                            Path clustersIn,
+                            Path output,
+                            String measureClass,
+                            double convergenceDelta,
+                            int maxIterations,
+                            int numReduceTasks,
+                            boolean runClustering) throws IOException, InterruptedException, ClassNotFoundException {
+    new KMeansDriver().job(input, clustersIn, output, measureClass, convergenceDelta, maxIterations, numReduceTasks, runClustering);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.distanceMeasureOption().create());
+    addOption(DefaultOptionCreator.clustersInOption()
+        .withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
+            + "If k is also specified, then a random set of vectors will be selected" + " and written out to this path first")
+        .create());
+    addOption(DefaultOptionCreator.numClustersOption()
+        .withDescription("The k in k-Means.  If specified, then a random selection of k Vectors will be chosen"
+            + " as the Centroid and written to the clusters input path.").create());
+    addOption(DefaultOptionCreator.convergenceOption().create());
+    addOption(DefaultOptionCreator.maxIterationsOption().create());
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    addOption(DefaultOptionCreator.numReducersOption().create());
+    addOption(DefaultOptionCreator.clusteringOption().create());
+
+    Map<String, String> argMap = parseArguments(args);
+    if (argMap == null) {
+      return -1;
+    }
+
+    Path input = getInputPath();
+    Path clusters = new Path(argMap.get(DefaultOptionCreator.CLUSTERS_IN_OPTION_KEY));
+    Path output = getOutputPath();
+    String measureClass = argMap.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION_KEY);
+    if (measureClass == null) {
+      measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+    }
+    double convergenceDelta = Double.parseDouble(argMap.get(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION_KEY));
+    int numReduceTasks = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_REDUCERS_OPTION_KEY));
+    int maxIterations = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_ITERATIONS_OPTION_KEY));
+    if (argMap.containsKey(DefaultOptionCreator.OVERWRITE_OPTION_KEY)) {
+      HadoopUtil.overwriteOutput(output);
+    }
+    if (argMap.containsKey(DefaultOptionCreator.NUM_CLUSTERS_OPTION_KEY)) {
+      clusters = RandomSeedGenerator.buildRandom(input, clusters, Integer.parseInt(argMap
+          .get(DefaultOptionCreator.NUM_CLUSTERS_OPTION_KEY)));
+    }
+    boolean runClustering = argMap.containsKey(DefaultOptionCreator.CLUSTERING_OPTION_KEY);
+    job(input, clusters, output, measureClass, convergenceDelta, maxIterations, numReduceTasks, runClustering);
+    return 0;
+  }
+
+  private void job(Path input,
+                   Path clustersIn,
+                   Path output,
+                   String measureClass,
+                   double convergenceDelta,
+                   int maxIterations,
+                   int numReduceTasks,
+                   boolean runClustering) throws IOException, InterruptedException, ClassNotFoundException {
     // iterate until the clusters converge
     String delta = Double.toString(convergenceDelta);
     if (log.isInfoEnabled()) {
-      log.info("Input: {} Clusters In: {} Out: {} Distance: {}",
-               new Object[] { input, clustersIn, output, measureClass });
-      log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}",
-               new Object[] { convergenceDelta, maxIterations, numReduceTasks, VectorWritable.class.getName() });
+      log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] { input, clustersIn, output, measureClass });
+      log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}", new Object[] { convergenceDelta,
+          maxIterations, numReduceTasks, VectorWritable.class.getName() });
     }
     boolean converged = false;
     int iteration = 1;
@@ -177,13 +183,12 @@ public final class KMeansDriver {
    * @throws ClassNotFoundException 
    * @throws InterruptedException 
    */
-  private static boolean runIteration(Path input,
-                                      Path clustersIn,
-                                      Path clustersOut,
-                                      String measureClass,
-                                      String convergenceDelta,
-                                      int numReduceTasks)
-    throws IOException, InterruptedException, ClassNotFoundException {
+  private boolean runIteration(Path input,
+                               Path clustersIn,
+                               Path clustersOut,
+                               String measureClass,
+                               String convergenceDelta,
+                               int numReduceTasks) throws IOException, InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn.toString());
     conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measureClass);
@@ -230,16 +235,11 @@ public final class KMeansDriver {
    * @throws ClassNotFoundException 
    * @throws InterruptedException 
    */
-  private static void runClustering(Path input,
-                                    Path clustersIn,
-                                    Path output,
-                                    String measureClass,
-                                    String convergenceDelta)
-    throws IOException, InterruptedException, ClassNotFoundException {
+  private void runClustering(Path input, Path clustersIn, Path output, String measureClass, String convergenceDelta)
+      throws IOException, InterruptedException, ClassNotFoundException {
     if (log.isInfoEnabled()) {
       log.info("Running Clustering");
-      log.info("Input: {} Clusters In: {} Out: {} Distance: {}",
-               new Object[] { input, clustersIn, output, measureClass });
+      log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] { input, clustersIn, output, measureClass });
       log.info("convergence: {} Input Vectors: {}", convergenceDelta, VectorWritable.class.getName());
     }
     Configuration conf = new Configuration();
@@ -277,7 +277,7 @@ public final class KMeansDriver {
    * @throws IOException
    *           if there was an IO error
    */
-  private static boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException {
+  private boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException {
     FileStatus[] parts = fs.listStatus(filePath);
     for (FileStatus part : parts) {
       String name = part.getPath().getName();

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=964507&r1=964506&r2=964507&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Thu Jul 15 17:53:51 2010
@@ -18,14 +18,11 @@
 package org.apache.mahout.clustering.lda;
 
 import java.io.IOException;
+import java.util.Map;
 import java.util.Random;
 
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
@@ -37,7 +34,7 @@ import org.apache.hadoop.mapreduce.lib.i
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.IntPairWritable;
 import org.apache.mahout.common.RandomUtils;
@@ -50,7 +47,19 @@ import org.slf4j.LoggerFactory;
  * Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
  * it outputs a matrix of log probabilities of each topic.
  */
-public final class LDADriver {
+public final class LDADriver extends AbstractJob {
+
+  private static final String TOPIC_SMOOTHING_OPTION = "topicSmoothing";
+
+  private static final String TOPIC_SMOOTHING_OPTION_KEY = "--" + TOPIC_SMOOTHING_OPTION;
+
+  private static final String NUM_WORDS_OPTION = "numWords";
+
+  private static final String NUM_WORDS_OPTION_KEY = "--" + NUM_WORDS_OPTION;
+
+  private static final String NUM_TOPICS_OPTION = "numTopics";
+
+  private static final String NUM_TOPICS_OPTION_KEY = "--" + NUM_TOPICS_OPTION;
 
   static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
 
@@ -71,50 +80,8 @@ public final class LDADriver {
   private LDADriver() {
   }
 
-  public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
-    Option inputOpt = DefaultOptionCreator.inputOption().create();
-    Option outputOpt = DefaultOptionCreator.outputOption().create();
-    Option overwriteOutput = DefaultOptionCreator.overwriteOption().create();
-    Option topicsOpt = DefaultOptionCreator.numTopicsOption().create();
-    Option wordsOpt = DefaultOptionCreator.numWordsOption().create();
-    Option topicSmOpt = DefaultOptionCreator.topicSmoothingOption().create();
-    Option maxIterOpt = DefaultOptionCreator.maxIterationsOption().withRequired(false).create();
-    Option numReducOpt = DefaultOptionCreator.numReducersOption().create();
-    Option helpOpt = DefaultOptionCreator.helpOption();
-
-    Group group = new GroupBuilder().withName("Options").withOption(inputOpt).withOption(outputOpt)
-        .withOption(topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt)
-        .withOption(numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
-    try {
-      Parser parser = new Parser();
-      parser.setGroup(group);
-      parser.setHelpOption(helpOpt);
-      CommandLine cmdLine = parser.parse(args);
-
-      if (cmdLine.hasOption(helpOpt)) {
-        CommandLineUtil.printHelp(group);
-        return;
-      }
-      Path input = new Path(cmdLine.getValue(inputOpt).toString());
-      Path output = new Path(cmdLine.getValue(outputOpt).toString());
-      if (cmdLine.hasOption(overwriteOutput)) {
-        HadoopUtil.overwriteOutput(output);
-      }
-      int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
-      int numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString());
-      int numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
-      int numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString());
-      double topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
-      if (topicSmoothing < 1) {
-        topicSmoothing = 50.0 / numTopics;
-      }
-
-      runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReduceTasks);
-
-    } catch (OptionException e) {
-      log.error("Exception", e);
-      CommandLineUtil.printHelp(group);
-    }
+  public static void main(String[] args) throws Exception {
+    new LDADriver().run(args);
   }
 
   /**
@@ -144,6 +111,113 @@ public final class LDADriver {
                             int maxIterations,
                             int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
 
+    new LDADriver().job(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReducers);
+  }
+
+  static LDAState createState(Configuration job) throws IOException {
+    String statePath = job.get(STATE_IN_KEY);
+    int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
+    int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
+    double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));
+
+    Path dir = new Path(statePath);
+    FileSystem fs = dir.getFileSystem(job);
+
+    DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
+    double[] logTotals = new double[numTopics];
+    double ll = 0.0;
+
+    IntPairWritable key = new IntPairWritable();
+    DoubleWritable value = new DoubleWritable();
+    for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
+      Path path = status.getPath();
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+      while (reader.next(key, value)) {
+        int topic = key.getFirst();
+        int word = key.getSecond();
+        if (word == TOPIC_SUM_KEY) {
+          logTotals[topic] = value.get();
+          if (Double.isInfinite(value.get())) {
+            throw new IllegalArgumentException();
+          }
+        } else if (topic == LOG_LIKELIHOOD_KEY) {
+          ll = value.get();
+        } else {
+          if (!((topic >= 0) && (word >= 0))) {
+            throw new IllegalArgumentException(topic + " " + word);
+          }
+          if (pWgT.getQuick(topic, word) != 0.0) {
+            throw new IllegalArgumentException();
+          }
+          pWgT.setQuick(topic, word, value.get());
+          if (Double.isInfinite(pWgT.getQuick(topic, word))) {
+            throw new IllegalArgumentException();
+          }
+        }
+      }
+      reader.close();
+    }
+
+    return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    addOption(new DefaultOptionBuilder().withLongName(NUM_TOPICS_OPTION).withRequired(true).withArgument(new ArgumentBuilder()
+        .withName(NUM_TOPICS_OPTION).withMinimum(1).withMaximum(1).create())
+        .withDescription("The total number of topics in the corpus").withShortName("k").create());
+    addOption(new DefaultOptionBuilder().withLongName(NUM_WORDS_OPTION).withRequired(true).withArgument(new ArgumentBuilder()
+        .withName(NUM_WORDS_OPTION).withMinimum(1).withMaximum(1).create())
+        .withDescription("The total number of words in the corpus (can be approximate, needs to exceed the actual value)")
+        .withShortName("v").create());
+    addOption(new DefaultOptionBuilder().withLongName(TOPIC_SMOOTHING_OPTION).withRequired(false)
+        .withArgument(new ArgumentBuilder().withName(TOPIC_SMOOTHING_OPTION).withDefault(-1.0).withMinimum(0).withMaximum(1)
+            .create()).withDescription("Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create());
+    addOption(DefaultOptionCreator.maxIterationsOption().withRequired(false).create());
+    addOption(DefaultOptionCreator.numReducersOption().create());
+
+    Map<String, String> argMap = parseArguments(args);
+    if (argMap == null) {
+      return -1;
+    }
+
+    Path input = getInputPath();
+    Path output = getOutputPath();
+    if (argMap.containsKey(DefaultOptionCreator.OVERWRITE_OPTION_KEY)) {
+      HadoopUtil.overwriteOutput(output);
+    }
+    int maxIterations = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_ITERATIONS_OPTION_KEY));
+    int numReduceTasks = Integer.parseInt(argMap.get(DefaultOptionCreator.MAX_REDUCERS_OPTION_KEY));
+    int numTopics = Integer.parseInt(argMap.get(NUM_TOPICS_OPTION_KEY));
+    int numWords = Integer.parseInt(argMap.get(NUM_WORDS_OPTION_KEY));
+    double topicSmoothing = Double.parseDouble(argMap.get(TOPIC_SMOOTHING_OPTION_KEY));
+    if (topicSmoothing < 1) {
+      topicSmoothing = 50.0 / numTopics;
+    }
+
+    job(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReduceTasks);
+
+    return 0;
+  }
+
+  /**
+   * @param input
+   * @param output
+   * @param numTopics
+   * @param numWords
+   * @param topicSmoothing
+   * @param maxIterations
+   * @param numReducers
+   * @throws IOException
+   * @throws InterruptedException
+   * @throws ClassNotFoundException
+   */
+  private void job(Path input, Path output, int numTopics, int numWords, double topicSmoothing, int maxIterations, int numReducers)
+      throws IOException, InterruptedException, ClassNotFoundException {
     Path stateIn = new Path(output, "state-0");
     writeInitialState(stateIn, numTopics, numWords);
     double oldLL = Double.NEGATIVE_INFINITY;
@@ -167,7 +241,7 @@ public final class LDADriver {
     }
   }
 
-  private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
+  private void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
     Configuration job = new Configuration();
     FileSystem fs = statePath.getFileSystem(job);
 
@@ -196,7 +270,7 @@ public final class LDADriver {
     }
   }
 
-  private static double findLL(Path statePath, Configuration job) throws IOException {
+  private double findLL(Path statePath, Configuration job) throws IOException {
     FileSystem fs = statePath.getFileSystem(job);
 
     double ll = 0.0;
@@ -232,13 +306,13 @@ public final class LDADriver {
    * @param numReducers
    *          the number of Reducers desired
    */
-  public static double runIteration(Path input,
-                                    Path stateIn,
-                                    Path stateOut,
-                                    int numTopics,
-                                    int numWords,
-                                    double topicSmoothing,
-                                    int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+  private double runIteration(Path input,
+                              Path stateIn,
+                              Path stateOut,
+                              int numTopics,
+                              int numWords,
+                              double topicSmoothing,
+                              int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     conf.set(STATE_IN_KEY, stateIn.toString());
     conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
@@ -263,51 +337,4 @@ public final class LDADriver {
     job.waitForCompletion(true);
     return findLL(stateOut, conf);
   }
-
-  static LDAState createState(Configuration job) throws IOException {
-    String statePath = job.get(STATE_IN_KEY);
-    int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
-    int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
-    double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));
-
-    Path dir = new Path(statePath);
-    FileSystem fs = dir.getFileSystem(job);
-
-    DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
-    double[] logTotals = new double[numTopics];
-    double ll = 0.0;
-
-    IntPairWritable key = new IntPairWritable();
-    DoubleWritable value = new DoubleWritable();
-    for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
-      Path path = status.getPath();
-      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
-      while (reader.next(key, value)) {
-        int topic = key.getFirst();
-        int word = key.getSecond();
-        if (word == TOPIC_SUM_KEY) {
-          logTotals[topic] = value.get();
-          if (Double.isInfinite(value.get())) {
-            throw new IllegalArgumentException();
-          }
-        } else if (topic == LOG_LIKELIHOOD_KEY) {
-          ll = value.get();
-        } else {
-          if (!((topic >= 0) && (word >= 0))) {
-            throw new IllegalArgumentException(topic + " " + word);
-          }
-          if (pWgT.getQuick(topic, word) != 0.0) {
-            throw new IllegalArgumentException();
-          }
-          pWgT.setQuick(topic, word, value.get());
-          if (Double.isInfinite(pWgT.getQuick(topic, word))) {
-            throw new IllegalArgumentException();
-          }
-        }
-      }
-      reader.close();
-    }
-
-    return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
-  }
 }



Mime
View raw message