mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r987647 [1/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/ core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ core/src/main/java/org/apache/ma...
Date Fri, 20 Aug 2010 21:56:17 GMT
Author: jeastman
Date: Fri Aug 20 21:56:16 2010
New Revision: 987647

URL: http://svn.apache.org/viewvc?rev=987647&view=rev
Log:
MAHOUT-479: added unit tests to test VectorModelClassifier, ModelDistribution serialization and to ensure
GaussianClusterDistribution and DistanceMeasureClusterDistributions work in Dirichlet. Refactored model
distribution arguments to allow Java developers to provide fully-configured model distributions vs multiple
string parameters. Added distance measure parameter to Dirichlet for use with DMClusterDistributions.

All unit tests run.

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java
      - copied, changed from r987240, mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
Removed:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonDistanceMeasureAdapter.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+public class JsonDistanceMeasureAdapter implements JsonSerializer<DistanceMeasure>, JsonDeserializer<DistanceMeasure> {
+
+  private static final Logger log = LoggerFactory.getLogger(JsonDistanceMeasureAdapter.class);
+
+  @Override
+  public JsonElement serialize(DistanceMeasure src, Type typeOfSrc, JsonSerializationContext context) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = new JsonObject();
+    obj.add("class", new JsonPrimitive(src.getClass().getName()));
+    obj.add("model", new JsonPrimitive(gson.toJson(src)));
+    return obj;
+  }
+
+  @Override
+  public DistanceMeasure deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = json.getAsJsonObject();
+    String klass = obj.get("class").getAsString();
+    String model = obj.get("model").getAsString();
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    Class<?> cl = null;
+    try {
+      cl = ccl.loadClass(klass);
+    } catch (ClassNotFoundException e) {
+      log.warn("Error while loading class", e);
+    }
+    return (DistanceMeasure) gson.fromJson(model, cl);
+  }
+}

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java (from r987240, mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java&r1=987240&r2=987647&rev=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelAdapter.java Fri Aug 20 21:56:16 2010
@@ -14,11 +14,10 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.mahout.clustering.dirichlet;
+package org.apache.mahout.clustering;
 
 import java.lang.reflect.Type;
 
-import org.apache.mahout.clustering.Model;
 import org.apache.mahout.math.JsonVectorAdapter;
 import org.apache.mahout.math.Vector;
 import org.slf4j.Logger;

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/JsonModelDistributionAdapter.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+public class JsonModelDistributionAdapter implements JsonSerializer<ModelDistribution<?>>, JsonDeserializer<ModelDistribution<?>> {
+
+  private static final Logger log = LoggerFactory.getLogger(JsonModelDistributionAdapter.class);
+
+  @Override
+  public JsonElement serialize(ModelDistribution<?> src, Type typeOfSrc, JsonSerializationContext context) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = new JsonObject();
+    obj.add("class", new JsonPrimitive(src.getClass().getName()));
+    obj.add("model", new JsonPrimitive(gson.toJson(src)));
+    return obj;
+  }
+
+  @Override
+  public ModelDistribution<?> deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = json.getAsJsonObject();
+    String klass = obj.get("class").getAsString();
+    String model = obj.get("model").getAsString();
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    Class<?> cl = null;
+    try {
+      cl = ccl.loadClass(klass);
+    } catch (ClassNotFoundException e) {
+      log.warn("Error while loading class", e);
+    }
+    return (ModelDistribution<?>) gson.fromJson(model, cl);
+  }
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ModelDistribution.java Fri Aug 20 21:56:16 2010
@@ -18,6 +18,7 @@
 package org.apache.mahout.clustering;
 
 
+
 /** A model distribution allows us to sample a model from its prior distribution. */
 public interface ModelDistribution<O> {
   
@@ -39,4 +40,10 @@ public interface ModelDistribution<O> {
    */
   Model<O>[] sampleFromPosterior(Model<O>[] posterior);
   
+  /**
+   * Return a JSON string representing the receiver. Needed to pass persistent state.
+   * @return a String
+   */
+  String asJsonString();
+  
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java Fri Aug 20 21:56:16 2010
@@ -1,8 +1,11 @@
 package org.apache.mahout.clustering;
 
+import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
@@ -20,11 +23,22 @@ public class VectorModelClassifier exten
   @Override
   public Vector classify(Vector instance) {
     Vector pdfs = new DenseVector(models.size());
-    int i = 0;
-    for (Model<VectorWritable> model : models) {
-      pdfs.set(i++, model.pdf(new VectorWritable(instance)));
+    if (models.get(0) instanceof SoftCluster) {
+      List<SoftCluster> clusters = new ArrayList<SoftCluster>();
+      List<Double> distances = new ArrayList<Double>();
+      for (Model<VectorWritable> model : models) {
+        SoftCluster sc = (SoftCluster) model;
+        clusters.add(sc);
+        distances.add(sc.getMeasure().distance(instance, sc.getCenter()));
+      }
+      return new FuzzyKMeansClusterer().computePi(clusters, distances);
+    } else {
+      int i = 0;
+      for (Model<VectorWritable> model : models) {
+        pdfs.set(i++, model.pdf(new VectorWritable(instance)));
+      }
+      return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
     }
-    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
   }
 
   @Override

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Fri Aug 20 21:56:16 2010
@@ -36,13 +36,16 @@ import org.apache.hadoop.mapreduce.lib.i
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ModelDistribution;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
 import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
 import org.apache.mahout.clustering.kmeans.OutputLogFilter;
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
@@ -53,7 +56,7 @@ public class DirichletDriver extends Abs
 
   public static final String STATE_IN_KEY = "org.apache.mahout.clustering.dirichlet.stateIn";
 
-  public static final String MODEL_FACTORY_KEY = "org.apache.mahout.clustering.dirichlet.modelFactory";
+  public static final String MODEL_DISTRIBUTION_KEY = "org.apache.mahout.clustering.dirichlet.modelFactory";
 
   public static final String MODEL_PROTOTYPE_KEY = "org.apache.mahout.clustering.dirichlet.modelPrototype";
 
@@ -97,6 +100,7 @@ public class DirichletDriver extends Abs
               "mp",
               "The ModelDistribution prototype Vector class name. Defaults to RandomAccessSparseVector",
               RandomAccessSparseVector.class.getName());
+    addOption(DefaultOptionCreator.distanceMeasureOption().withRequired(false).create());
     addOption(DefaultOptionCreator.emitMostLikelyOption().create());
     addOption(DefaultOptionCreator.thresholdOption().create());
     addOption(DefaultOptionCreator.numReducersOption().create());
@@ -113,6 +117,7 @@ public class DirichletDriver extends Abs
     }
     String modelFactory = getOption(MODEL_DISTRIBUTION_CLASS_OPTION);
     String modelPrototype = getOption(MODEL_PROTOTYPE_CLASS_OPTION);
+    String distanceMeasure = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
     int numModels = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
     int numReducers = Integer.parseInt(getOption(DefaultOptionCreator.MAX_REDUCERS_OPTION));
     int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
@@ -121,11 +126,16 @@ public class DirichletDriver extends Abs
     double alpha0 = Double.parseDouble(getOption(ALPHA_OPTION));
     boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION);
     boolean runSequential = (getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(DefaultOptionCreator.SEQUENTIAL_METHOD));
+    int prototypeSize = readPrototypeSize(input);
+
+    AbstractVectorModelDistribution modelDistribution = createModelDistribution(modelFactory,
+                                                                                modelPrototype,
+                                                                                distanceMeasure,
+                                                                                prototypeSize);
 
     job(input,
         output,
-        modelFactory,
-        modelPrototype,
+        modelDistribution,
         numModels,
         maxIterations,
         alpha0,
@@ -138,15 +148,48 @@ public class DirichletDriver extends Abs
   }
 
   /**
+   * Create an instance of AbstractVectorModelDistribution from the given command line arguments
+   * @param modelFactory
+   * @param modelPrototype
+   * @param distanceMeasure
+   * @param prototypeSize
+   * @return
+   * @throws ClassNotFoundException
+   * @throws InstantiationException
+   * @throws IllegalAccessException
+   * @throws NoSuchMethodException
+   * @throws InvocationTargetException
+   */
+  public static AbstractVectorModelDistribution createModelDistribution(String modelFactory,
+                                                                 String modelPrototype,
+                                                                 String distanceMeasure,
+                                                                 int prototypeSize) throws ClassNotFoundException,
+      InstantiationException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    Class<? extends AbstractVectorModelDistribution> cl = ccl.loadClass(modelFactory)
+        .asSubclass(AbstractVectorModelDistribution.class);
+    AbstractVectorModelDistribution modelDistribution = cl.newInstance();
+
+    Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
+    Constructor<? extends Vector> v = vcl.getConstructor(int.class);
+    modelDistribution.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
+
+    if (modelDistribution instanceof DistanceMeasureClusterDistribution) {
+      Class<? extends DistanceMeasure> measureCl = ccl.loadClass(distanceMeasure).asSubclass(DistanceMeasure.class);
+      DistanceMeasure measure = measureCl.newInstance();
+      ((DistanceMeasureClusterDistribution) modelDistribution).setMeasure(measure);
+    }
+    return modelDistribution;
+  }
+
+  /**
    * Run the job using supplied arguments on a new driver instance (convenience)
    * 
    * @param input
    *          the directory pathname for input points
    * @param output
    *          the directory pathname for output points
-   * @param modelFactory
-   *          the String ModelDistribution class name to use
-   * @param modelPrototype
+   * @param modelDistribution
    *          the String class name of the model prototype
    * @param numClusters
    *          the number of models
@@ -166,8 +209,7 @@ public class DirichletDriver extends Abs
    */
   public static void runJob(Path input,
                             Path output,
-                            String modelFactory,
-                            String modelPrototype,
+                            ModelDistribution<VectorWritable> modelDistribution,
                             int numClusters,
                             int maxIterations,
                             double alpha0,
@@ -180,8 +222,7 @@ public class DirichletDriver extends Abs
 
     new DirichletDriver().job(input,
                               output,
-                              modelFactory,
-                              modelPrototype,
+                              modelDistribution,
                               numClusters,
                               maxIterations,
                               alpha0,
@@ -196,37 +237,21 @@ public class DirichletDriver extends Abs
    * Creates a DirichletState object from the given arguments. Note that the modelFactory is presumed to be a
    * subclass of VectorModelDistribution that can be initialized with a concrete Vector prototype.
    * 
-   * @param modelFactory
-   *          a String which is the class name of the model factory
-   * @param modelPrototype
-   *          a String which is the class name of the Vector used to initialize the factory
-   * @param prototypeSize
-   *          an int number of dimensions of the model prototype vector
-   * @param numModels
-   *          an int number of models to be created
-   * @param alpha0
-   *          the double alpha_0 argument to the algorithm
+   * @param modelDistribution the ModelDistribution
+   * @param numModels an int number of models to be created
+   * @param alpha0 the double alpha_0 argument to the algorithm
    * @return an initialized DirichletState
    */
-  static DirichletState createState(String modelFactory, String modelPrototype, int prototypeSize, int numModels, double alpha0)
+  static DirichletState createState(ModelDistribution<VectorWritable> modelDistribution, int numModels, double alpha0)
       throws ClassNotFoundException, InstantiationException, IllegalAccessException, SecurityException, NoSuchMethodException,
       IllegalArgumentException, InvocationTargetException {
-
-    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-    Class<? extends AbstractVectorModelDistribution> cl = ccl.loadClass(modelFactory)
-        .asSubclass(AbstractVectorModelDistribution.class);
-    AbstractVectorModelDistribution factory = cl.newInstance();
-
-    Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
-    Constructor<? extends Vector> v = vcl.getConstructor(int.class);
-    factory.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
-    return new DirichletState(factory, numModels, alpha0);
+    return new DirichletState(modelDistribution, numModels, alpha0);
   }
 
   /**
    * Read the first input vector to determine the prototype size for the modelPrototype
    */
-  private int readPrototypeSize(Path input) throws IOException, InstantiationException, IllegalAccessException {
+  public static int readPrototypeSize(Path input) throws IOException, InstantiationException, IllegalAccessException {
     Configuration conf = new Configuration();
     FileSystem fs = FileSystem.get(input.toUri(), conf);
     FileStatus[] status = fs.listStatus(input, new OutputLogFilter());
@@ -248,22 +273,18 @@ public class DirichletDriver extends Abs
    * Write initial state (prior distribution) to the output path directory
    * @param output the output Path
    * @param stateOut the state output Path
-   * @param modelFactory the String class name of the modelFactory
-   * @param modelPrototype the String class name of the modelPrototype
-   * @param prototypeSize the int size of the modelPrototype vectors
+   * @param modelDistribution the ModelDistribution
    * @param numModels the int number of models to generate
    * @param alpha0 the double alpha_0 argument to the DirichletDistribution
    */
   private void writeInitialState(Path output,
                                  Path stateOut,
-                                 String modelFactory,
-                                 String modelPrototype,
-                                 int prototypeSize,
+                                 ModelDistribution<VectorWritable> modelDistribution,
                                  int numModels,
                                  double alpha0) throws ClassNotFoundException, InstantiationException, IllegalAccessException,
       IOException, SecurityException, NoSuchMethodException, InvocationTargetException {
 
-    DirichletState state = createState(modelFactory, modelPrototype, prototypeSize, numModels, alpha0);
+    DirichletState state = createState(modelDistribution, numModels, alpha0);
     writeState(output, stateOut, numModels, state);
   }
 
@@ -281,39 +302,24 @@ public class DirichletDriver extends Abs
   /**
    * Run an iteration using supplied arguments
    * 
-   * @param input
-   *          the directory pathname for input points
-   * @param stateIn
-   *          the directory pathname for input state
-   * @param stateOut
-   *          the directory pathname for output state
-   * @param modelFactory
-   *          the class name of the model factory class
-   * @param modelPrototype
-   *          the class name of the model prototype (a Vector implementation)
-   * @param prototypeSize
-   *          the size of the model prototype vector
-   * @param numClusters
-   *          the number of clusters
-   * @param alpha0
-   *          alpha_0
-   * @param numReducers
-   *          the number of Reducers desired
+   * @param input the directory pathname for input points
+   * @param stateIn the directory pathname for input state
+   * @param stateOut the directory pathname for output state
+   * @param modelDistribution the ModelDistribution
+   * @param numClusters the number of clusters
+   * @param alpha0 alpha_0
+   * @param numReducers the number of Reducers desired
    */
   private void runIteration(Path input,
                             Path stateIn,
                             Path stateOut,
-                            String modelFactory,
-                            String modelPrototype,
-                            int prototypeSize,
+                            ModelDistribution<VectorWritable> modelDistribution,
                             int numClusters,
                             double alpha0,
                             int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     conf.set(STATE_IN_KEY, stateIn.toString());
-    conf.set(MODEL_FACTORY_KEY, modelFactory);
-    conf.set(MODEL_PROTOTYPE_KEY, modelPrototype);
-    conf.set(PROTOTYPE_SIZE_KEY, Integer.toString(prototypeSize));
+    conf.set(MODEL_DISTRIBUTION_KEY, modelDistribution.asJsonString());
     conf.set(NUM_CLUSTERS_KEY, Integer.toString(numClusters));
     conf.set(ALPHA_0_KEY, Double.toString(alpha0));
 
@@ -344,9 +350,7 @@ public class DirichletDriver extends Abs
    *          the directory Path for input points
    * @param output
    *          the directory Path for output points
-   * @param modelFactory
-   *          the String ModelDistribution class name to use
-   * @param modelPrototype
+   * @param modelDistribution
    *          the String class name of the model's prototype vector
    * @param numClusters
    *          the number of models to iterate over
@@ -366,8 +370,7 @@ public class DirichletDriver extends Abs
    */
   public void job(Path input,
                   Path output,
-                  String modelFactory,
-                  String modelPrototype,
+                  ModelDistribution<VectorWritable> modelDistribution,
                   int numClusters,
                   int maxIterations,
                   double alpha0,
@@ -379,8 +382,7 @@ public class DirichletDriver extends Abs
       ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InterruptedException {
     Path clustersOut = buildClusters(input,
                                      output,
-                                     modelFactory,
-                                     modelPrototype,
+                                     modelDistribution,
                                      numClusters,
                                      maxIterations,
                                      alpha0,
@@ -398,9 +400,7 @@ public class DirichletDriver extends Abs
    *          the directory Path for input points
    * @param output
    *          the directory Path for output points
-   * @param modelFactory
-   *          the String ModelDistribution class name to use
-   * @param modelPrototype
+   * @param modelDistribution
    *          the String class name of the model's prototype vector
    * @param numClusters
    *          the number of models to iterate over
@@ -415,8 +415,7 @@ public class DirichletDriver extends Abs
    */
   public Path buildClusters(Path input,
                             Path output,
-                            String modelFactory,
-                            String modelPrototype,
+                            ModelDistribution<VectorWritable> modelDistribution,
                             int numClusters,
                             int maxIterations,
                             double alpha0,
@@ -424,47 +423,24 @@ public class DirichletDriver extends Abs
                             boolean runSequential) throws IOException, InstantiationException, IllegalAccessException,
       ClassNotFoundException, NoSuchMethodException, InvocationTargetException, InterruptedException {
     Path clustersIn = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
-
-    int protoSize = readPrototypeSize(input);
-
-    writeInitialState(output, clustersIn, modelFactory, modelPrototype, protoSize, numClusters, alpha0);
+    writeInitialState(output, clustersIn, modelDistribution, numClusters, alpha0);
 
     if (runSequential) {
-      clustersIn = buildClustersSeq(input,
-                                    output,
-                                    modelFactory,
-                                    modelPrototype,
-                                    numClusters,
-                                    maxIterations,
-                                    alpha0,
-                                    numReducers,
-                                    clustersIn,
-                                    protoSize);
+      clustersIn = buildClustersSeq(input, output, modelDistribution, numClusters, maxIterations, alpha0, numReducers, clustersIn);
     } else {
-      clustersIn = buildClustersMR(input,
-                                   output,
-                                   modelFactory,
-                                   modelPrototype,
-                                   numClusters,
-                                   maxIterations,
-                                   alpha0,
-                                   numReducers,
-                                   clustersIn,
-                                   protoSize);
+      clustersIn = buildClustersMR(input, output, modelDistribution, numClusters, maxIterations, alpha0, numReducers, clustersIn);
     }
     return clustersIn;
   }
 
   private Path buildClustersSeq(Path input,
                                 Path output,
-                                String modelFactory,
-                                String modelPrototype,
+                                ModelDistribution<VectorWritable> modelDistribution,
                                 int numClusters,
                                 int maxIterations,
                                 double alpha0,
                                 int numReducers,
-                                Path clustersIn,
-                                int protoSize) throws IOException, ClassNotFoundException, InstantiationException,
+                                Path clustersIn) throws IOException, ClassNotFoundException, InstantiationException,
       IllegalAccessException, NoSuchMethodException, InvocationTargetException {
     for (int iteration = 1; iteration <= maxIterations; iteration++) {
       log.info("Iteration {}", iteration);
@@ -472,10 +448,8 @@ public class DirichletDriver extends Abs
       Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
       DirichletState state = DirichletMapper.loadState(new Configuration(),
                                                        clustersIn.toString(),
-                                                       modelFactory,
-                                                       modelPrototype,
+                                                       modelDistribution,
                                                        alpha0,
-                                                       protoSize,
                                                        numClusters);
       Cluster[] newModels = (Cluster[]) state.getModelFactory().sampleFromPosterior(state.getModels());
       DirichletClusterer clusterer = new DirichletClusterer(state);
@@ -506,19 +480,17 @@ public class DirichletDriver extends Abs
 
   private Path buildClustersMR(Path input,
                                Path output,
-                               String modelFactory,
-                               String modelPrototype,
+                               ModelDistribution<VectorWritable> modelDistribution,
                                int numClusters,
                                int maxIterations,
                                double alpha0,
                                int numReducers,
-                               Path clustersIn,
-                               int protoSize) throws IOException, InterruptedException, ClassNotFoundException {
+                               Path clustersIn) throws IOException, InterruptedException, ClassNotFoundException {
     for (int iteration = 1; iteration <= maxIterations; iteration++) {
       log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
       Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
-      runIteration(input, clustersIn, clustersOut, modelFactory, modelPrototype, protoSize, numClusters, alpha0, numReducers);
+      runIteration(input, clustersIn, clustersOut, modelDistribution, numClusters, alpha0, numReducers);
       // now point the input to the old output directory
       clustersIn = clustersOut;
     }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java Fri Aug 20 21:56:16 2010
@@ -29,17 +29,22 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.JsonModelDistributionAdapter;
+import org.apache.mahout.clustering.ModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
 import org.apache.mahout.clustering.kmeans.OutputLogFilter;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.VectorWritable;
 
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+
 public class DirichletMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
 
   private DirichletClusterer clusterer;
 
   @Override
-  protected void map(WritableComparable<?> key, VectorWritable v, Context context)
-      throws IOException, InterruptedException {
+  protected void map(WritableComparable<?> key, VectorWritable v, Context context) throws IOException, InterruptedException {
     int k = clusterer.assignToModel(v);
     context.write(new Text(String.valueOf(k)), v);
   }
@@ -72,23 +77,19 @@ public class DirichletMapper extends Map
     this.clusterer = new DirichletClusterer(state);
   }
 
-  public static DirichletState getDirichletState(Configuration conf) throws NoSuchMethodException,
-      InvocationTargetException {
+  public static DirichletState getDirichletState(Configuration conf) throws NoSuchMethodException, InvocationTargetException {
     String statePath = conf.get(DirichletDriver.STATE_IN_KEY);
-    String modelFactory = conf.get(DirichletDriver.MODEL_FACTORY_KEY);
-    String modelPrototype = conf.get(DirichletDriver.MODEL_PROTOTYPE_KEY);
-    String prototypeSize = conf.get(DirichletDriver.PROTOTYPE_SIZE_KEY);
+    String json = conf.get(DirichletDriver.MODEL_DISTRIBUTION_KEY);
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+    Gson gson = builder.create();
+    ModelDistribution<VectorWritable> modelDistribution = gson.fromJson(json,
+                                                                        AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
     String numClusters = conf.get(DirichletDriver.NUM_CLUSTERS_KEY);
     String alpha0 = conf.get(DirichletDriver.ALPHA_0_KEY);
 
     try {
-      return loadState(conf,
-                       statePath,
-                       modelFactory,
-                       modelPrototype,
-                       Double.parseDouble(alpha0),
-                       Integer.parseInt(prototypeSize),
-                       Integer.parseInt(numClusters));
+      return loadState(conf, statePath, modelDistribution, Double.parseDouble(alpha0), Integer.parseInt(numClusters));
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
     } catch (InstantiationException e) {
@@ -101,15 +102,12 @@ public class DirichletMapper extends Map
   }
 
   protected static DirichletState loadState(Configuration conf,
-                                                            String statePath,
-                                                            String modelFactory,
-                                                            String modelPrototype,
-                                                            double alpha,
-                                                            int pSize,
-                                                            int k)
-      throws ClassNotFoundException, InstantiationException, IllegalAccessException,
+                                            String statePath,
+                                            ModelDistribution<VectorWritable> modelDistribution,
+                                            double alpha,
+                                            int k) throws ClassNotFoundException, InstantiationException, IllegalAccessException,
       NoSuchMethodException, InvocationTargetException, IOException {
-    DirichletState state = DirichletDriver.createState(modelFactory, modelPrototype, pSize, k, alpha);
+    DirichletState state = DirichletDriver.createState(modelDistribution, k, alpha);
     Path path = new Path(statePath);
     FileSystem fs = FileSystem.get(path.toUri(), conf);
     FileStatus[] status = fs.listStatus(path, new OutputLogFilter());

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AbstractVectorModelDistribution.java Fri Aug 20 21:56:16 2010
@@ -16,29 +16,53 @@
 
 package org.apache.mahout.clustering.dirichlet.models;
 
+import java.lang.reflect.Type;
+
+import org.apache.mahout.clustering.JsonDistanceMeasureAdapter;
+import org.apache.mahout.clustering.JsonModelDistributionAdapter;
 import org.apache.mahout.clustering.ModelDistribution;
+import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
 public abstract class AbstractVectorModelDistribution implements ModelDistribution<VectorWritable> {
 
+  public static final Type MODEL_DISTRIBUTION_TYPE = new TypeToken<ModelDistribution<VectorWritable>>() {
+  }.getType();
+
   // a prototype instance used for creating prior model distributions using like(). It
   // should be of the class and cardinality desired for the particular application.
   private VectorWritable modelPrototype;
 
   protected AbstractVectorModelDistribution() {
   }
-  
+
   protected AbstractVectorModelDistribution(VectorWritable modelPrototype) {
     this.modelPrototype = modelPrototype;
   }
 
+  /* (non-Javadoc)
+   * @see org.apache.mahout.clustering.ModelDistribution#asJsonString()
+   */
+  @Override
+  public String asJsonString() {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+    builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+    Gson gson = builder.create();
+    return gson.toJson(this, MODEL_DISTRIBUTION_TYPE);
+  }
+
   /**
    * @return the modelPrototype
    */
   public VectorWritable getModelPrototype() {
     return modelPrototype;
   }
-  
+
   /**
    * @param modelPrototype
    *          the modelPrototype to set
@@ -46,5 +70,5 @@ public abstract class AbstractVectorMode
   public void setModelPrototype(VectorWritable modelPrototype) {
     this.modelPrototype = modelPrototype;
   }
-  
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Fri Aug 20 21:56:16 2010
@@ -24,8 +24,8 @@ import java.lang.reflect.Type;
 
 import org.apache.mahout.clustering.AbstractCluster;
 import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.JsonModelAdapter;
 import org.apache.mahout.clustering.Model;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.function.SquareRootFunction;

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/DistanceMeasureClusterDistribution.java Fri Aug 20 21:56:16 2010
@@ -19,24 +19,30 @@ package org.apache.mahout.clustering.dir
 
 import org.apache.mahout.clustering.DistanceMeasureCluster;
 import org.apache.mahout.clustering.Model;
+import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
 /**
  * An implementation of the ModelDistribution interface suitable for testing the DirichletCluster algorithm.
- * Uses a Normal Distribution to sample the prior model values. Model values have a vector standard deviation,
- * allowing assymetrical regions to be covered by a model.
+ * Models use a DistanceMeasure to calculate pdf values.
  */
 public class DistanceMeasureClusterDistribution extends AbstractVectorModelDistribution {
 
-  ManhattanDistanceMeasure measure = new ManhattanDistanceMeasure();
+  DistanceMeasure measure;
 
   public DistanceMeasureClusterDistribution() {
   }
 
   public DistanceMeasureClusterDistribution(VectorWritable modelPrototype) {
     super(modelPrototype);
+    this.measure = new ManhattanDistanceMeasure();
+  }
+
+  public DistanceMeasureClusterDistribution(VectorWritable modelPrototype, DistanceMeasure measure) {
+    super(modelPrototype);
+    this.measure = measure;
   }
 
   @Override
@@ -59,4 +65,12 @@ public class DistanceMeasureClusterDistr
     return result;
   }
 
+  public void setMeasure(DistanceMeasure measure) {
+    this.measure = measure;
+  }
+
+  public DistanceMeasure getMeasure() {
+    return measure;
+  }
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Fri Aug 20 21:56:16 2010
@@ -23,8 +23,8 @@ import java.lang.reflect.Type;
 
 import org.apache.mahout.clustering.AbstractCluster;
 import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.JsonModelAdapter;
 import org.apache.mahout.clustering.Model;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.Vector;

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Fri Aug 20 21:56:16 2010
@@ -25,8 +25,8 @@ import java.util.Locale;
 
 import org.apache.mahout.clustering.AbstractCluster;
 import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.JsonModelAdapter;
 import org.apache.mahout.clustering.Model;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.function.SquareRootFunction;

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Fri Aug 20 21:56:16 2010
@@ -62,6 +62,10 @@ public class FuzzyKMeansClusterer {
     this.configure(conf);
   }
 
+  public FuzzyKMeansClusterer() {
+    // TODO Auto-generated constructor stub
+  }
+
   /**
    * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
    * until their centers converge or until the maximum number of iterations is exceeded.
@@ -222,11 +226,7 @@ public class FuzzyKMeansClusterer {
       clusterDistanceList.add(getMeasure().distance(cluster.getCenter(), point.get()));
     }
     // calculate point pdf for all clusters
-    Vector pi = new DenseVector(clusters.size());
-    for (int i = 0; i < clusters.size(); i++) {
-      double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
-      pi.set(i, probWeight);
-    }
+    Vector pi = computePi(clusters, clusterDistanceList);
     if (emitMostLikely) {
       emitMostLikelyCluster(point.get(), clusters, pi, context);
     } else {
@@ -234,6 +234,15 @@ public class FuzzyKMeansClusterer {
     }
   }
 
+  public Vector computePi(List<SoftCluster> clusters, List<Double> clusterDistanceList) {
+    Vector pi = new DenseVector(clusters.size());
+    for (int i = 0; i < clusters.size(); i++) {
+      double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+      pi.set(i, probWeight);
+    }
+    return pi;
+  }
+
   /**
    * Emit the point to the cluster with the highest pdf
    */
@@ -302,12 +311,7 @@ public class FuzzyKMeansClusterer {
     for (SoftCluster cluster : clusters) {
       clusterDistanceList.add(getMeasure().distance(cluster.getCenter(), point.get()));
     }
-    // calculate point pdf for all clusters
-    Vector pi = new DenseVector(clusters.size());
-    for (int i = 0; i < clusters.size(); i++) {
-      double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
-      pi.set(i, probWeight);
-    }
+    Vector pi = computePi(clusters, clusterDistanceList);
     if (emitMostLikely) {
       emitMostLikelyCluster(point.get(), clusters, pi, writer);
     } else {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Fri Aug 20 21:56:16 2010
@@ -17,9 +17,11 @@
 
 package org.apache.mahout.clustering.fuzzykmeans;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.mahout.clustering.kmeans.Cluster;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 
 public class SoftCluster extends Cluster {
 
@@ -46,4 +48,13 @@ public class SoftCluster extends Cluster
   public String getIdentifier() {
     return (isConverged() ? "SV-" : "SC-") + getId();
   }
+
+  /* (non-Javadoc)
+   * @see org.apache.mahout.clustering.DistanceMeasureCluster#pdf(org.apache.mahout.math.VectorWritable)
+   */
+  @Override
+  public double pdf(VectorWritable vw) {
+    // SoftCluster pdf cannot be calculated out of context. See FuzzyKMeansClusterer
+    throw new NotImplementedException();
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Fri Aug 20 21:56:16 2010
@@ -22,10 +22,12 @@ import java.io.DataOutput;
 import java.io.IOException;
 import java.lang.reflect.Type;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.mahout.clustering.kmeans.Cluster;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.JsonVectorAdapter;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.list.IntArrayList;
 
 import com.google.gson.Gson;
@@ -156,4 +158,13 @@ public class MeanShiftCanopy extends Clu
     return (isConverged() ? "MSV-" : "MSC-") + getId();
   }
 
+  /* (non-Javadoc)
+   * @see org.apache.mahout.clustering.DistanceMeasureCluster#pdf(org.apache.mahout.math.VectorWritable)
+   */
+  @Override
+  public double pdf(VectorWritable vw) {
+    // MSCanopy membership is explicit via membership in boundPoints. Can't compute pdf for Arbitrary point
+    throw new NotImplementedException();
+  }
+
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java Fri Aug 20 21:56:16 2010
@@ -22,7 +22,6 @@ import java.lang.reflect.Type;
 import org.apache.mahout.clustering.canopy.Canopy;
 import org.apache.mahout.clustering.dirichlet.DirichletCluster;
 import org.apache.mahout.clustering.dirichlet.JsonClusterModelAdapter;
-import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
 import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
 import org.apache.mahout.clustering.dirichlet.models.L1Model;
 import org.apache.mahout.clustering.dirichlet.models.NormalModel;

Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+
+public class TestModelDistributionSerialization extends MahoutTestCase {
+
+  public void testGaussianClusterDistribution() {
+    GaussianClusterDistribution dist = new GaussianClusterDistribution(new VectorWritable(new DenseVector(2)));
+    String json = dist.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+    builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+    Gson gson = builder.create();
+    GaussianClusterDistribution dist1 = (GaussianClusterDistribution) gson
+        .fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
+    assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+  }
+
+  public void testDMClusterDistribution() {
+    DistanceMeasureClusterDistribution dist = new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)));
+    String json = dist.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+    builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+    Gson gson = builder.create();
+    DistanceMeasureClusterDistribution dist1 = (DistanceMeasureClusterDistribution) gson
+        .fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
+    assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+    assertEquals("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
+  }
+
+  public void testDMClusterDistribution2() {
+    DistanceMeasureClusterDistribution dist = new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)),
+                                                                                     new EuclideanDistanceMeasure());
+    String json = dist.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
+    builder.registerTypeAdapter(DistanceMeasure.class, new JsonDistanceMeasureAdapter());
+    Gson gson = builder.create();
+    DistanceMeasureClusterDistribution dist1 = (DistanceMeasureClusterDistribution) gson
+        .fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
+    assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+    assertEquals("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java?rev=987647&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java Fri Aug 20 21:56:16 2010
@@ -0,0 +1,101 @@
+package org.apache.mahout.clustering;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class TestVectorModelClassifier extends MahoutTestCase {
+
+  public void testDMClusterClassification() {
+    List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0, measure));
+    models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
+    models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2, measure));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
+  }
+
+  public void testCanopyClassification() {
+    List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    models.add(new Canopy(new DenseVector(2).assign(1), 0, measure));
+    models.add(new Canopy(new DenseVector(2), 1, measure));
+    models.add(new Canopy(new DenseVector(2).assign(-1), 2, measure));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
+  }
+
+  public void testClusterClassification() {
+    List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    models.add(new Cluster(new DenseVector(2).assign(1), 0, measure));
+    models.add(new Cluster(new DenseVector(2), 1, measure));
+    models.add(new Cluster(new DenseVector(2).assign(-1), 2, measure));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
+  }
+
+  public void testMSCanopyClassification() {
+    List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    models.add(new MeanShiftCanopy(new DenseVector(2).assign(1), 0, measure));
+    models.add(new MeanShiftCanopy(new DenseVector(2), 1, measure));
+    models.add(new MeanShiftCanopy(new DenseVector(2).assign(-1), 2, measure));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    try {
+      classifier.classify(new DenseVector(2));
+      fail("Expected NotImplementedException");
+    } catch (NotImplementedException e) {
+      assertTrue(true);
+    }
+  }
+
+  public void testSoftClusterClassification() {
+    List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    models.add(new SoftCluster(new DenseVector(2).assign(1), 0, measure));
+    models.add(new SoftCluster(new DenseVector(2), 1, measure));
+    models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.000, 1.000, 0.000]", AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.735, 0.184, 0.082]", AbstractCluster.formatVector(pdf, null));
+  }
+
+  public void testGaussianClusterClassification() {
+    List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
+    models.add(new GaussianCluster(new DenseVector(2).assign(1), new DenseVector(2).assign(1), 0));
+    models.add(new GaussianCluster(new DenseVector(2), new DenseVector(2).assign(1), 1));
+    models.add(new GaussianCluster(new DenseVector(2).assign(-1), new DenseVector(2).assign(1), 2));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.107, 0.787, 0.107]", AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.998, 0.002, 0.000]", AbstractCluster.formatVector(pdf, null));
+  }
+
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java?rev=987647&r1=987646&r2=987647&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java Fri Aug 20 21:56:16 2010
@@ -21,8 +21,9 @@ import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.Model;
 import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
 import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
 import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution;
 import org.apache.mahout.common.MahoutTestCase;
@@ -48,10 +49,8 @@ public class TestDirichletClustering ext
    * @param sd  double standard deviation of the samples
    * @param card int cardinality of the generated sample vectors
    */
-  private void generateSamples(int num, double mx, double my, double sd,
-      int card) {
-    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
-        + "] sd=" + sd);
+  private void generateSamples(int num, double mx, double my, double sd, int card) {
+    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my + "] sd=" + sd);
     for (int i = 0; i < num; i++) {
       DenseVector v = new DenseVector(card);
       for (int j = 0; j < card; j++) {
@@ -72,14 +71,13 @@ public class TestDirichletClustering ext
     generateSamples(num, mx, my, sd, 2);
   }
 
-  private static void printResults(List<Cluster[]> result,
-      int significant) {
+  private static void printResults(List<Cluster[]> result, int significant) {
     int row = 0;
-    for (Model<VectorWritable>[] r : result) {
+    for (Cluster[] r : result) {
       System.out.print("sample[" + row++ + "]= ");
-      for (Model<VectorWritable> model : r) {
+      for (Cluster model : r) {
         if (model.count() > significant) {
-          System.out.print(model.toString() + ", ");
+          System.out.print(model.asFormatString(null) + ", ");
         }
       }
       System.out.println();
@@ -93,9 +91,12 @@ public class TestDirichletClustering ext
     generateSamples(30, 1, 0, 0.1);
     generateSamples(30, 0, 1, 0.1);
 
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new NormalModelDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new NormalModelDistribution(new VectorWritable(new DenseVector(2))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
     List<Cluster[]> result = dc.cluster(30);
     printResults(result, 2);
     assertNotNull(result);
@@ -107,9 +108,12 @@ public class TestDirichletClustering ext
     generateSamples(30, 1, 0, 0.1);
     generateSamples(30, 0, 1, 0.1);
 
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new SampledNormalDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new SampledNormalDistribution(new VectorWritable(new DenseVector(2))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
     List<Cluster[]> result = dc.cluster(30);
     printResults(result, 2);
     assertNotNull(result);
@@ -121,107 +125,29 @@ public class TestDirichletClustering ext
     generateSamples(30, 1, 0, 0.1);
     generateSamples(30, 0, 1, 0.1);
 
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new AsymmetricSampledNormalDistribution(new VectorWritable(new DenseVector(2))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
     List<Cluster[]> result = dc.cluster(30);
     printResults(result, 2);
     assertNotNull(result);
   }
 
-  public void testDirichletCluster1000() {
-    System.out.println("testDirichletCluster1000");
-    generateSamples(400, 1, 1, 3);
-    generateSamples(300, 1, 0, 0.1);
-    generateSamples(300, 0, 1, 0.1);
-
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new NormalModelDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
-    List<Cluster[]> result = dc.cluster(30);
-    printResults(result, 20);
-    assertNotNull(result);
-  }
-
-  public void testDirichletCluster1000s() {
-    System.out.println("testDirichletCluster1000s");
-    generateSamples(400, 1, 1, 3);
-    generateSamples(300, 1, 0, 0.1);
-    generateSamples(300, 0, 1, 0.1);
-
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new SampledNormalDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
-    List<Cluster[]> result = dc.cluster(30);
-    printResults(result, 20);
-    assertNotNull(result);
-  }
-
-  public void testDirichletCluster1000as() {
-    System.out.println("testDirichletCluster1000as");
-    generateSamples(400, 1, 1, 3);
-    generateSamples(300, 1, 0, 0.1);
-    generateSamples(300, 0, 1, 0.1);
-
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
-    List<Cluster[]> result = dc.cluster(30);
-    printResults(result, 20);
-    assertNotNull(result);
-  }
-
-  public void testDirichletCluster10000() {
-    System.out.println("testDirichletCluster10000");
-    generateSamples(4000, 1, 1, 3);
-    generateSamples(3000, 1, 0, 0.1);
-    generateSamples(3000, 0, 1, 0.1);
-
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new NormalModelDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
-    List<Cluster[]> result = dc.cluster(30);
-    printResults(result, 200);
-    assertNotNull(result);
-  }
-
-  public void testDirichletCluster10000as() {
-    System.out.println("testDirichletCluster10000as");
-    generateSamples(4000, 1, 1, 3);
-    generateSamples(3000, 1, 0, 0.1);
-    generateSamples(3000, 0, 1, 0.1);
-
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
-    List<Cluster[]> result = dc.cluster(30);
-    printResults(result, 200);
-    assertNotNull(result);
-  }
-
-  public void testDirichletCluster10000s() {
-    System.out.println("testDirichletCluster10000s");
-    generateSamples(4000, 1, 1, 3);
-    generateSamples(3000, 1, 0, 0.1);
-    generateSamples(3000, 0, 1, 0.1);
-
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new SampledNormalDistribution(new VectorWritable(
-            new DenseVector(2))), 1.0, 10, 1, 0);
-    List<Cluster[]> result = dc.cluster(30);
-    printResults(result, 200);
-    assertNotNull(result);
-  }
-
   public void testDirichletCluster100C3() {
     System.out.println("testDirichletCluster100");
     generateSamples(40, 1, 1, 3, 3);
     generateSamples(30, 1, 0, 0.1, 3);
     generateSamples(30, 0, 1, 0.1, 3);
 
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new NormalModelDistribution(new VectorWritable(
-            new DenseVector(3))), 1.0, 10, 1, 0);
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new NormalModelDistribution(new VectorWritable(new DenseVector(3))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
     List<Cluster[]> result = dc.cluster(30);
     printResults(result, 2);
     assertNotNull(result);
@@ -233,9 +159,12 @@ public class TestDirichletClustering ext
     generateSamples(30, 1, 0, 0.1, 3);
     generateSamples(30, 0, 1, 0.1, 3);
 
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new SampledNormalDistribution(new VectorWritable(
-            new DenseVector(3))), 1.0, 10, 1, 0);
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new SampledNormalDistribution(new VectorWritable(new DenseVector(3))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
     List<Cluster[]> result = dc.cluster(30);
     printResults(result, 2);
     assertNotNull(result);
@@ -247,9 +176,46 @@ public class TestDirichletClustering ext
     generateSamples(30, 1, 0, 0.1, 3);
     generateSamples(30, 0, 1, 0.1, 3);
 
-    DirichletClusterer dc = new DirichletClusterer(
-        sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable(
-            new DenseVector(3))), 1.0, 10, 1, 0);
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new AsymmetricSampledNormalDistribution(new VectorWritable(new DenseVector(3))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
+    List<Cluster[]> result = dc.cluster(30);
+    printResults(result, 2);
+    assertNotNull(result);
+  }
+
+  public void testDirichletGaussianCluster100() {
+    System.out.println("testDirichletGaussianCluster100");
+    generateSamples(40, 1, 1, 3);
+    generateSamples(30, 1, 0, 0.1);
+    generateSamples(30, 0, 1, 0.1);
+
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new GaussianClusterDistribution(new VectorWritable(new DenseVector(2))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
+    List<Cluster[]> result = dc.cluster(30);
+    printResults(result, 2);
+    assertNotNull(result);
+  }
+
+  public void testDirichletDMCluster100() {
+    System.out.println("testDirichletDMCluster100");
+    generateSamples(40, 1, 1, 3);
+    generateSamples(30, 1, 0, 0.1);
+    generateSamples(30, 0, 1, 0.1);
+
+    DirichletClusterer dc = new DirichletClusterer(sampleData,
+                                                   new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2))),
+                                                   1.0,
+                                                   10,
+                                                   1,
+                                                   0);
     List<Cluster[]> result = dc.cluster(30);
     printResults(result, 2);
     assertNotNull(result);



Mime
View raw message