mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r908235 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/clustering/ main/java/org/apache/mahout/clustering/dirichlet/ main/java/org/apache/mahout/clustering/dirichlet/models/ main/java/org/apache/mahout/clustering/kmeans/ te...
Date Tue, 09 Feb 2010 21:30:18 GMT
Author: jeastman
Date: Tue Feb  9 21:30:17 2010
New Revision: 908235

URL: http://svn.apache.org/viewvc?rev=908235&view=rev
Log:
MAHOUT-270:

This patch adds the Printable interface and unifies ClusterBase, Model and DirichletCluster
as implementors. I did not extend this to Vector or Matrix, though that is possible, preferring
to write a simple vector formatter in ClusterBase which is used by all implementations.

Added back the JsonModelAdapter and JsonClusterAdapter which I removed recently because they
were unused. Adding back Json formatting to the Dirichlet clustering required their reinstatement.

Added TestPrintableInterface which tests the new implementations. All other tests continue
to run.

TODO: need to modify ClusterDumper to complete this issue but, as it has users who may want
some time to review the new code, I will hold off for a few days.

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonClusterAdapter.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java Tue
Feb  9 21:30:17 2010
@@ -17,15 +17,64 @@
 
 package org.apache.mahout.clustering;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.Iterator;
+
 import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.JsonVectorAdapter;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
 
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
+public abstract class ClusterBase implements Writable, Printable {
 
-public abstract class ClusterBase implements Writable {
+  /**
+   * Return a human-readable formatted string representation of the vector, not intended
+   * to be complete nor usable as an input/output representation such as Json
+   * 
+   * @param v a Vector
+   * @return a String
+   */
+  public static String formatVector(Vector v, String[] bindings) {
+    StringBuilder buf = new StringBuilder();
+    int nzero = 0;
+    Iterator<Element> iterateNonZero = v.iterateNonZero();
+    while (iterateNonZero.hasNext()) {
+      iterateNonZero.next();
+      nzero++;
+    }
+    // if vector is sparse or if we have bindings, use sparse notation
+    if (nzero < v.size() || bindings != null) {
+      buf.append('[');
+      for (int i = 0; i < v.size(); i++) {
+        double elem = v.get(i);
+        if (elem == 0.0)
+          continue;
+        String label = null;
+        if (bindings != null && (label = bindings[i]) != null)
+          buf.append(label).append(":");
+        else
+          buf.append(i).append(":");
+        buf.append(String.format("%.3f", elem)).append(", ");
+      }
+    } else {
+      buf.append('[');
+      for (int i = 0; i < v.size(); i++) {
+        double elem = v.get(i);
+        buf.append(String.format("%.3f", elem)).append(", ");
+      }
+    }
+    buf.setLength(buf.length() - 2);
+    buf.append(']');
+    return buf.toString();
+  }
 
   // this cluster's clusterId
   private int id;
@@ -71,8 +120,37 @@
     this.pointTotal = pointTotal;
   }
 
+  /**
+   * @deprecated
+   * @return
+   */
   public abstract String asFormatString();
 
+  /* (non-Javadoc)
+   * @see org.apache.mahout.clustering.Printable#asFormatString(java.lang.String[])
+   */
+  public String asFormatString(String[] bindings) {
+    StringBuilder buf = new StringBuilder();
+    buf.append(getIdentifier()).append(": ").append(formatVector(computeCentroid(), bindings));
+    return buf.toString();
+  }
+
+  public abstract Vector computeCentroid();
+
+  public abstract Object getIdentifier();
+
+  /* (non-Javadoc)
+   * @see org.apache.mahout.clustering.Printable#asJsonString(java.util.Map)
+   */
+  public String asJsonString() {
+    Type vectorType = new TypeToken<Vector>() {
+    }.getType();
+    GsonBuilder gBuilder = new GsonBuilder();
+    gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
+    Gson gson = gBuilder.create();
+    return gson.toJson(this, this.getClass());
+  }
+
   /**
    * Simply writes out the id, and that's it!
    *

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java?rev=908235&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java Tue
Feb  9 21:30:17 2010
@@ -0,0 +1,43 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+
+/**
+ * Implementations of this interface have a printable representation. This representation

+ * may be enhanced by an optional Vector label bindings dictionary.
+ *
+ */
+public interface Printable {
+
+  /**
+   * Produce a custom, printable representation of the receiver.
+   * 
+   * @param bindings an optional String[] containing labels used to format the primary 
+   *    Vector/s of this implementation.
+   * @return a String
+   */
+  public String asFormatString(String[] bindings);
+
+  /**
+   * Produce a printable representation of the receiver using Json. (Label bindings
+   * are transient and not part of the Json representation)
+   * 
+   * @return a Json String
+   */
+  public String asJsonString();
+
+}

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
Tue Feb  9 21:30:17 2010
@@ -16,17 +16,21 @@
  */
 package org.apache.mahout.clustering.dirichlet;
 
-import com.google.gson.reflect.TypeToken;
-import org.apache.hadoop.io.Writable;
-import org.apache.mahout.clustering.dirichlet.models.Model;
-import org.apache.mahout.math.Vector;
-
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
 import java.lang.reflect.Type;
 
-public class DirichletCluster<O> implements Writable {
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.Printable;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.math.Vector;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
+public class DirichletCluster<O> implements Writable, Printable {
 
   @Override
   public void readFields(DataInput in) throws IOException {
@@ -50,6 +54,12 @@
     this.totalCount = totalCount;
   }
 
+  public DirichletCluster(Model<O> model) {
+    super();
+    this.model = model;
+    this.totalCount = 0.0;
+  }
+
   public DirichletCluster() {
     super();
   }
@@ -67,16 +77,16 @@
     return totalCount;
   }
 
-  private static final Type typeOfModel = new TypeToken<DirichletCluster<Vector>>()
{
+  private static final Type clusterType = new TypeToken<DirichletCluster<Vector>>()
{
   }.getType();
 
   /** Reads a typed Model instance from the input stream */
+  @SuppressWarnings("unchecked")
   public static <O> Model<O> readModel(DataInput in) throws IOException {
     String modelClassName = in.readUTF();
     Model<O> model;
     try {
-      model = (Model<O>) Class.forName(modelClassName).asSubclass(Model.class)
-          .newInstance();
+      model = (Model<O>) Class.forName(modelClassName).asSubclass(Model.class).newInstance();
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
     } catch (IllegalAccessException e) {
@@ -94,4 +104,17 @@
     model.write(out);
   }
 
+  @Override
+  public String asFormatString(String[] bindings) {
+    return model.toString();
+  }
+
+  @Override
+  public String asJsonString() {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    return gson.toJson(this, clusterType);
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
Tue Feb  9 21:30:17 2010
@@ -45,7 +45,7 @@
     // sample initial prior models
     clusters = new ArrayList<DirichletCluster<O>>();
     for (Model<O> m : modelFactory.sampleFromPrior(numClusters)) {
-      clusters.add(new DirichletCluster<O>(m, 0.0));
+      clusters.add(new DirichletCluster<O>(m));
     }
     // sample the mixture parameters from a Dirichlet distribution on the totalCounts 
     mixture = UncommonDistributions.rDirichlet(totalCounts(), alpha_0);

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonClusterAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonClusterAdapter.java?rev=908235&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonClusterAdapter.java
(added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonClusterAdapter.java
Tue Feb  9 21:30:17 2010
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParseException;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+public class JsonClusterAdapter implements JsonSerializer<DirichletCluster<?>>,
JsonDeserializer<DirichletCluster<?>> {
+
+  private static final Logger log = LoggerFactory.getLogger(JsonClusterAdapter.class);
+
+  @Override
+  public JsonElement serialize(DirichletCluster<?> src, Type typeOfSrc, JsonSerializationContext
context) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = new JsonObject();
+    obj.add("total", new JsonPrimitive(src.getTotalCount()));
+    obj.add("modelClass", new JsonPrimitive(src.getModel().getClass().getName()));
+    obj.add("modelJson", new JsonPrimitive(gson.toJson(src)));
+    return obj;
+  }
+
+  @SuppressWarnings("unchecked")
+  @Override
+  public DirichletCluster<?> deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext
context)
+      throws JsonParseException {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = json.getAsJsonObject();
+    double total = obj.get("total").getAsDouble();
+    String klass = obj.get("modelClass").getAsString();
+    String modelJson = obj.get("modelJson").getAsString();
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    Class<?> cl = null;
+    try {
+      cl = ccl.loadClass(klass);
+    } catch (ClassNotFoundException e) {
+      log.warn("Error while loading class", e);
+    }
+    Model<Vector> model = (Model<Vector>) gson.fromJson(modelJson, cl);
+    return new DirichletCluster<Vector>(model, total);
+  }
+}

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java?rev=908235&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
(added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
Tue Feb  9 21:30:17 2010
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet;
+
+import java.lang.reflect.Type;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParseException;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+public class JsonModelAdapter implements JsonSerializer<Model<?>>,
+    JsonDeserializer<Model<?>> {
+
+  private static final Logger log = LoggerFactory.getLogger(JsonModelAdapter.class);
+
+  @Override
+  public JsonElement serialize(Model<?> src, Type typeOfSrc,
+                               JsonSerializationContext context) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = new JsonObject();
+    obj.add("class", new JsonPrimitive(src.getClass().getName()));
+    obj.add("model", new JsonPrimitive(gson.toJson(src)));
+    return obj;
+  }
+
+  @Override
+  public Model<?> deserialize(JsonElement json, Type typeOfT,
+                              JsonDeserializationContext context) throws JsonParseException
{
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    JsonObject obj = json.getAsJsonObject();
+    String klass = obj.get("class").getAsString();
+    String model = obj.get("model").getAsString();
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    Class<?> cl = null;
+    try {
+      cl = ccl.loadClass(klass);
+    } catch (ClassNotFoundException e) {
+      log.warn("Error while loading class", e);
+    }
+    return (Model<?>) gson.fromJson(model, cl);
+  }
+}

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
Tue Feb  9 21:30:17 2010
@@ -17,13 +17,20 @@
 
 package org.apache.mahout.clustering.dirichlet.models;
 
-import org.apache.mahout.math.function.SquareRootFunction;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.lang.reflect.Type;
+
+import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.SquareRootFunction;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
 
 public class AsymmetricSampledNormalModel implements Model<VectorWritable> {
 
@@ -41,6 +48,10 @@
 
   private Vector s2;
 
+  
+  private static final Type modelType = new TypeToken<Model<Vector>>() {
+  }.getType();
+
   public AsymmetricSampledNormalModel() {
     super();
   }
@@ -95,8 +106,7 @@
     mean = s1.divide(s0);
     // compute the component stds
     if (s0 > 1) {
-      stdDev = s2.times(s0).minus(s1.times(s1))
-          .assign(new SquareRootFunction()).divide(s0);
+      stdDev = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
     } else {
       stdDev.assign(Double.MIN_NORMAL);
     }
@@ -134,20 +144,20 @@
 
   @Override
   public String toString() {
+    return asFormatString(null);
+  }
+
+  public String asFormatString(String[] bindings) {
     StringBuilder buf = new StringBuilder(50);
-    buf.append("asnm{n=").append(s0).append(" m=[");
+    buf.append("asnm{n=").append(s0).append(" m=");
     if (mean != null) {
-      for (int i = 0; i < mean.size(); i++) {
-        buf.append(String.format("%.2f", mean.get(i))).append(", ");
-      }
+      buf.append(ClusterBase.formatVector(mean, bindings));
     }
-    buf.append("] sd=[");
+    buf.append(" sd=");
     if (stdDev != null) {
-      for (int i = 0; i < stdDev.size(); i++) {
-        buf.append(String.format("%.2f", stdDev.get(i))).append(", ");
-      }
+      buf.append(ClusterBase.formatVector(stdDev, bindings));
     }
-    buf.append("]}");
+    buf.append("}");
     return buf.toString();
   }
 
@@ -168,4 +178,12 @@
     VectorWritable.writeVector(out, s1);
     VectorWritable.writeVector(out, s2);
   }
+
+  @Override
+  public String asJsonString() {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    return gson.toJson(this, modelType);
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
Tue Feb  9 21:30:17 2010
@@ -19,12 +19,19 @@
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.lang.reflect.Type;
 
+import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
 public class L1Model implements Model<VectorWritable> {
 
   private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -44,6 +51,9 @@
 
   private Vector observed;
 
+  private static final Type modelType = new TypeToken<Model<Vector>>() {
+  }.getType();
+
   @Override
   public void computeParameters() {
     coefficients = observed.divide(count);
@@ -81,23 +91,28 @@
 
   @Override
   public String toString() {
+    return asFormatString(null);
+  }
+
+  public String asFormatString(String[] bindings) {
     StringBuilder buf = new StringBuilder();
-    buf.append("l1m{n=").append(count).append(" c=[");
+    buf.append("l1m{n=").append(count).append(" c=");
     if (coefficients != null) {
-      // handle sparse Vectors gracefully, suppressing zero values
-      int nextIx = 0;
-      for (int i = 0; i < coefficients.size(); i++) {
-        double elem = coefficients.get(i);
-        if (elem == 0.0)
-          continue;
-        if (i > nextIx)
-          buf.append("..{").append(i).append("}=");
-        buf.append(String.format("%.2f", elem)).append(", ");
-        nextIx = i + 1;
-      }
+      buf.append(ClusterBase.formatVector(coefficients, bindings));
     }
-    buf.append("]}");
+    buf.append("}");
     return buf.toString();
   }
 
+  /* (non-Javadoc)
+   * @see org.apache.mahout.clustering.Printable#asJsonString()
+   */
+  @Override
+  public String asJsonString() {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    return gson.toJson(this, modelType);
+  }
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
Tue Feb  9 21:30:17 2010
@@ -18,12 +18,13 @@
 package org.apache.mahout.clustering.dirichlet.models;
 
 import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.Printable;
 
 /**
  * A model is a probability distribution over observed data points and allows the probability
of any data point to be
  * computed.
  */
-public interface Model<O> extends Writable {
+public interface Model<O> extends Writable, Printable {
 
   /**
    * Observe the given observation, retaining information about it

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Tue Feb  9 21:30:17 2010
@@ -17,13 +17,20 @@
 
 package org.apache.mahout.clustering.dirichlet.models;
 
+import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
 import org.apache.mahout.math.function.SquareRootFunction;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.lang.reflect.Type;
 
 public class NormalModel implements Model<VectorWritable> {
 
@@ -41,6 +48,9 @@
 
   private Vector s2;
 
+  private static final Type modelType = new TypeToken<Model<Vector>>() {
+  }.getType();
+
   public NormalModel() {
   }
 
@@ -120,22 +130,16 @@
 
   @Override
   public String toString() {
+    return asFormatString(null);
+  }
+
+  public String asFormatString(String[] bindings) {
     StringBuilder buf = new StringBuilder();
-    buf.append("nm{n=").append(s0).append(" m=[");
+    buf.append("nm{n=").append(s0).append(" m=");
     if (mean != null) {
-      // handle sparse Vectors gracefully, suppressing zero values
-      int nextIx = 0;
-      for (int i = 0; i < mean.size(); i++) {
-        double elem = mean.get(i);
-        if (elem == 0.0)
-          continue;
-        if (i > nextIx)
-          buf.append("..{").append(i).append("}=");
-        buf.append(String.format("%.2f", elem)).append(", ");
-        nextIx = i + 1;
-      }
+      buf.append(ClusterBase.formatVector(mean, bindings));
     }
-    buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
+    buf.append(" sd=").append(String.format("%.2f", stdDev)).append('}');
     return buf.toString();
   }
 
@@ -156,4 +160,12 @@
     VectorWritable.writeVector(out, s1);
     VectorWritable.writeVector(out, s2);
   }
+
+  @Override
+  public String asJsonString() {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    return gson.toJson(this, modelType);
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
Tue Feb  9 21:30:17 2010
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.clustering.dirichlet.models;
 
+import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.math.Vector;
 
 public class SampledNormalModel extends NormalModel {
@@ -31,15 +32,7 @@
 
   @Override
   public String toString() {
-    StringBuilder buf = new StringBuilder();
-    buf.append("snm{n=").append(getS0()).append(" m=[");
-    if (getMean() != null) {
-      for (int i = 0; i < getMean().size(); i++) {
-        buf.append(String.format("%.2f", getMean().get(i))).append(", ");
-      }
-    }
-    buf.append("] sd=").append(String.format("%.2f", getStdDev())).append('}');
-    return buf.toString();
+    return asFormatString(null);
   }
 
   /**
@@ -51,4 +44,14 @@
   public NormalModel sample() {
     return new SampledNormalModel(getMean(), getStdDev());
   }
+
+  public String asFormatString(String[] bindings) {
+    StringBuilder buf = new StringBuilder();
+    buf.append("nm{n=").append(getS0()).append(" m=");
+    if (getMean() != null) {
+      buf.append(ClusterBase.formatVector(getMean(), bindings));
+    }
+    buf.append(" sd=").append(String.format("%.2f", getStdDev())).append('}');
+    return buf.toString();
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=908235&r1=908234&r2=908235&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
Tue Feb  9 21:30:17 2010
@@ -113,7 +113,7 @@
    * 
    * @return the new centroid
    */
-  private Vector computeCentroid() {
+  public Vector computeCentroid() {
     if (getNumPoints() == 0) {
       return getCenter();
     } else if (centroid == null) {

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java?rev=908235&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java
(added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestPrintableInterface.java
Tue Feb  9 21:30:17 2010
@@ -0,0 +1,296 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.lang.reflect.Type;
+
+import junit.framework.TestCase;
+
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.dirichlet.DirichletCluster;
+import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
+import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
+import org.apache.mahout.clustering.dirichlet.models.L1Model;
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
+import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
+public class TestPrintableInterface extends TestCase {
+
+  private static final Type modelType = new TypeToken<Model<Vector>>() {
+  }.getType();
+
+  private static final Type clusterType = new TypeToken<DirichletCluster<Vector>>()
{
+  }.getType();
+
+  public void testDirichletNormalModel() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable model = new NormalModel(m, 0.75);
+    String format = model.asFormatString(null);
+    assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
+    String json = model.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    NormalModel model2 = gson.fromJson(json, modelType);
+    assertEquals("Json", format, model2.asFormatString(null));
+  }
+
+  public void testDirichletSampledNormalModel() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable model = new SampledNormalModel(m, 0.75);
+    String format = model.asFormatString(null);
+    assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
+    String json = model.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    SampledNormalModel model2 = gson.fromJson(json, modelType);
+    assertEquals("Json", format, model2.asFormatString(null));
+  }
+
+  public void testDirichletASNormalModel() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable model = new AsymmetricSampledNormalModel(m, m);
+    String format = model.asFormatString(null);
+    assertEquals("format", "asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}",
format);
+    String json = model.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    AsymmetricSampledNormalModel model2 = gson.fromJson(json, modelType);
+    assertEquals("Json", format, model2.asFormatString(null));
+  }
+
+  public void testDirichletL1Model() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable model = new L1Model(m);
+    String format = model.asFormatString(null);
+    assertEquals("format", "l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
+    String json = model.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    L1Model model2 = gson.fromJson(json, modelType);
+    assertEquals("Json", format, model2.asFormatString(null));
+  }
+
+  public void testDirichletNormalModelClusterAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    NormalModel model = new NormalModel(m, 0.75);
+    Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+    String format = cluster.asFormatString(null);
+    assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
+  }
+
+  public void testDirichletNormalModelClusterAsJsonString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    NormalModel model = new NormalModel(m, 0.75);
+    Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+    String json = cluster.asJsonString();
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    DirichletCluster<VectorWritable> result = gson.fromJson(json, clusterType);
+    assertNotNull("result null", result);
+    assertEquals("model", cluster.asFormatString(null), result.asFormatString(null));
+  }
+
+  public void testDirichletAsymmetricSampledNormalModelClusterAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(m, m);
+    Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+    String format = cluster.asFormatString(null);
+    assertEquals("format", "asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}",
format);
+  }
+
+  public void testDirichletAsymmetricSampledNormalModelClusterAsJsonString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(m, m);
+    Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+    String json = cluster.asJsonString();
+
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    DirichletCluster<VectorWritable> result = gson.fromJson(json, clusterType);
+    assertNotNull("result null", result);
+    assertEquals("model", cluster.asFormatString(null), result.asFormatString(null));
+  }
+
+  public void testDirichletL1ModelClusterAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    L1Model model = new L1Model(m);
+    Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+    String format = cluster.asFormatString(null);
+    assertEquals("format", "l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
+  }
+
+  public void testDirichletL1ModelClusterAsJsonString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    L1Model model = new L1Model(m);
+    Printable cluster = new DirichletCluster<VectorWritable>(model, 35.0);
+    String json = cluster.asJsonString();
+
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
+    Gson gson = builder.create();
+    DirichletCluster<VectorWritable> result = gson.fromJson(json, clusterType);
+    assertNotNull("result null", result);
+    assertEquals("model", cluster.asFormatString(null), result.asFormatString(null));
+  }
+
+  public void testCanopyAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable cluster = new Canopy(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [1.100, 2.200, 3.300]", formatString);
+  }
+
+  public void testCanopyAsFormatStringSparse() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Printable cluster = new Canopy(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
+  }
+
+  public void testCanopyAsFormatStringWithBindings() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable cluster = new Canopy(m, 123);
+    String[] bindings = { "fee", null, null };
+    String formatString = cluster.asFormatString(bindings);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [fee:1.100, 1:2.200, 2:3.300]", formatString);
+  }
+
+  public void testCanopyAsFormatStringSparseWithBindings() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Printable cluster = new Canopy(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
+  }
+
+  public void testClusterAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable cluster = new Cluster(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [1.100, 2.200, 3.300]", formatString);
+  }
+
+  public void testClusterAsFormatStringSparse() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Printable cluster = new Cluster(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
+  }
+
+  public void testClusterAsFormatStringWithBindings() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable cluster = new Cluster(m, 123);
+    String[] bindings = { "fee", null, "foo" };
+    String formatString = cluster.asFormatString(bindings);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [fee:1.100, 1:2.200, foo:3.300]", formatString);
+  }
+
+  public void testClusterAsFormatStringSparseWithBindings() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Printable cluster = new Cluster(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
+  }
+
+  public void testMSCanopyAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable cluster = new MeanShiftCanopy(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [1.100, 2.200, 3.300]", formatString);
+  }
+
+  public void testMSCanopyAsFormatStringSparse() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Printable cluster = new MeanShiftCanopy(m, 123);
+    String formatString = cluster.asFormatString(null);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [0:1.100, 2:3.300]", formatString);
+  }
+
+  public void testMSCanopyAsFormatStringWithBindings() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Printable cluster = new MeanShiftCanopy(m, 123);
+    String[] bindings = { "fee", null, "foo" };
+    String formatString = cluster.asFormatString(bindings);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [fee:1.100, 1:2.200, foo:3.300]", formatString);
+  }
+
+  public void testMSCanopyAsFormatStringSparseWithBindings() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Printable cluster = new MeanShiftCanopy(m, 123);
+    String[] bindings = { "fee", null, "foo" };
+    String formatString = cluster.asFormatString(bindings);
+    System.out.println(formatString);
+    assertEquals("format", "C123: [fee:1.100, foo:3.300]", formatString);
+  }
+
+}



Mime
View raw message