mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r764529 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ core/src/main/java/org/apache/mahout/clustering/kmeans/ core/src/main/java/org/apache/mahout/matrix/ examples/src/main/java/org/apache/maho...
Date Mon, 13 Apr 2009 17:01:19 GMT
Author: jeastman
Date: Mon Apr 13 17:01:19 2009
New Revision: 764529

URL: http://svn.apache.org/viewvc?rev=764529&view=rev
Log:
- Added Canopy and K-Means clustering examples that display clusters using same sample data
as Dirichlet examples for direct comparison of results
- Added SquareRootFunction to improve model parameter computations
- Fixed model's parameter computations
- Added standard deviation calculation to K-Means cluster to show effective radius of observed
clusters

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=764529&r1=764528&r2=764529&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
Mon Apr 13 17:01:19 2009
@@ -16,6 +16,7 @@
  */
 package org.apache.mahout.clustering.dirichlet.models;
 
+import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 
 public class AsymmetricSampledNormalModel implements Model<Vector> {
@@ -73,14 +74,12 @@
     if (s0 == 0)
       return;
     mean = s1.divide(s0);
-    // the average of the two component stds
-    Vector ss = s2.times(s0).minus(s1.times(s1));
+    // compute the two component stds
     if (s0 > 1) {
-      sd.set(0, Math.sqrt(ss.get(0)) / s0);
-      sd.set(1, Math.sqrt(ss.get(1)) / s0);
+      sd = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction())
+          .divide(s0);
     } else {
-      sd.set(0, Double.MIN_NORMAL);
-      sd.set(1, Double.MIN_NORMAL);
+      sd.assign(Double.MIN_NORMAL);
     }
   }
 

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=764529&r1=764528&r2=764529&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Mon Apr 13 17:01:19 2009
@@ -16,6 +16,7 @@
  */
 package org.apache.mahout.clustering.dirichlet.models;
 
+import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 
 public class NormalModel implements Model<Vector> {
@@ -29,7 +30,9 @@
 
   // the observation statistics, initialized by the first observation
   int s0 = 0;
+
   Vector s1;
+
   Vector s2;
 
   public NormalModel() {
@@ -69,10 +72,12 @@
     if (s0 == 0)
       return;
     mean = s1.divide(s0);
-    // the average of the two component stds
-    if (s0 > 1)
-      sd = Math.sqrt(s2.times(s0).minus(s1.times(s1)).zSum() / 2) / s0;
-    else
+    // compute the average of the component stds
+    if (s0 > 1) {
+      Vector std = s2.times(s0).minus(s1.times(s1)).assign(
+          new SquareRootFunction()).divide(s0);
+      sd = std.zSum() / s1.cardinality();
+    } else
       sd = Double.MIN_VALUE;
   }
 

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=764529&r1=764528&r2=764529&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
(original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
Mon Apr 13 17:01:19 2009
@@ -16,17 +16,18 @@
  */
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.List;
+
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 import org.apache.mahout.utils.DistanceMeasure;
 
-import java.io.IOException;
-import java.util.List;
-
 public class Cluster {
 
   public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.kmeans.measure";
@@ -46,12 +47,18 @@
   // the current centroid is lazy evaluated and may be null
   private Vector centroid = null;
 
+  // the standard deviation of the covered points
+  private double std;
+
   // the number of points in the cluster
   private int numPoints = 0;
 
   // the total of all points added to the cluster
   private Vector pointTotal = null;
 
+  // the total of all the points squared, used for std computation
+  private Vector pointSquaredTotal = null;
+
   // has the centroid converged with the center?
   private boolean converged = false;
 
@@ -82,8 +89,9 @@
     String center = formattedString.substring(beginIndex);
     char firstChar = id.charAt(0);
     boolean startsWithV = firstChar == 'V';
-     if (firstChar == 'C' || startsWithV) {
-      int clusterId = Integer.parseInt(formattedString.substring(1, beginIndex - 2));   

+    if (firstChar == 'C' || startsWithV) {
+      int clusterId = Integer.parseInt(formattedString.substring(1,
+          beginIndex - 2));
       Vector clusterCenter = AbstractVector.decodeVector(center);
       Cluster cluster = new Cluster(clusterCenter, clusterId);
       cluster.converged = startsWithV;
@@ -181,6 +189,10 @@
     else if (centroid == null) {
       // lazy compute new centroid
       centroid = pointTotal.divide(numPoints);
+      Vector stds = pointSquaredTotal.times(numPoints).minus(
+          pointTotal.times(pointTotal)).assign(new SquareRootFunction())
+          .divide(numPoints);
+      std = stds.zSum() / 2;
     }
     return centroid;
   }
@@ -196,6 +208,7 @@
     this.center = center;
     this.numPoints = 0;
     this.pointTotal = center.like();
+    this.pointSquaredTotal = center.like();
   }
 
   /**
@@ -209,6 +222,7 @@
     this.center = center;
     this.numPoints = 0;
     this.pointTotal = center.like();
+    this.pointSquaredTotal = center.like();
   }
 
   /**
@@ -251,10 +265,13 @@
   public void addPoints(int count, Vector delta) {
     centroid = null;
     numPoints += count;
-    if (pointTotal == null)
+    if (pointTotal == null) {
       pointTotal = delta.copy();
-    else
+      pointSquaredTotal = delta.times(delta);
+    } else {
       pointTotal = pointTotal.plus(delta);
+      pointSquaredTotal = pointSquaredTotal.plus(delta.times(delta));
+    }
   }
 
   public Vector getCenter() {
@@ -293,4 +310,11 @@
     return converged;
   }
 
+  /**
+   * @return the std
+   */
+  public double getStd() {
+    return std;
+  }
+
 }

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java?rev=764529&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
(added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
Mon Apr 13 17:01:19 2009
@@ -0,0 +1,10 @@
+package org.apache.mahout.matrix;
+
+public class SquareRootFunction implements UnaryFunction {
+
+  @Override
+  public double apply(double arg1) {
+    return Math.sqrt(arg1);
+  }
+
+}

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java?rev=764529&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java
(added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java
Mon Apr 13 17:01:19 2009
@@ -0,0 +1,121 @@
+package org.apache.mahout.clustering.canopy;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.mahout.clustering.dirichlet.DisplayDirichlet;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.DistanceMeasure;
+import org.apache.mahout.utils.ManhattanDistanceMeasure;
+
+class DisplayCanopy extends DisplayDirichlet {
+  public DisplayCanopy() {
+    initialize();
+    this.setTitle("Canopy Clusters (> 5% of population)");
+  }
+
+  private static final long serialVersionUID = 1L;
+
+  static List<Canopy> canopies;
+
+  static double t1 = 3.0;
+
+  static double t2 = 1.5;
+
+  public void paint(Graphics g) {
+    super.plotSampleData(g);
+    Graphics2D g2 = (Graphics2D) g;
+    Vector dv = new DenseVector(2);
+    for (Canopy canopy : canopies)
+      if (canopy.getNumPoints() > sampleData.size() * 0.05) {
+        dv.assign(t1);
+        g2.setColor(colors[0]);
+        plotEllipse(g2, canopy.getCenter(), dv);
+        dv.assign(t2);
+        plotEllipse(g2, canopy.getCenter(), dv);
+      }
+  }
+
+  /**
+   * Iterate through the points, adding new canopies. Return the canopies.
+   * 
+   * @param measure
+   *            a DistanceMeasure to use
+   * @param points
+   *            a list<Vector> defining the points to be clustered
+   * @param t1
+   *            the T1 distance threshold
+   * @param t2
+   *            the T2 distance threshold
+   * @return the List<Canopy> created
+   */
+  static List<Canopy> populateCanopies(DistanceMeasure measure,
+      List<Vector> points, double t1, double t2) {
+    List<Canopy> canopies = new ArrayList<Canopy>();
+    Canopy.config(measure, t1, t2);
+    /**
+     * Reference Implementation: Given a distance metric, one can create
+     * canopies as follows: Start with a list of the data points in any order,
+     * and with two distance thresholds, T1 and T2, where T1 > T2. (These
+     * thresholds can be set by the user, or selected by cross-validation.) Pick
+     * a point on the list and measure its distance to all other points. Put all
+     * points that are within distance threshold T1 into a canopy. Remove from
+     * the list all points that are within distance threshold T2. Repeat until
+     * the list is empty.
+     */
+    while (!points.isEmpty()) {
+      Iterator<Vector> ptIter = points.iterator();
+      Vector p1 = ptIter.next();
+      ptIter.remove();
+      Canopy canopy = new VisibleCanopy(p1);
+      canopies.add(canopy);
+      while (ptIter.hasNext()) {
+        Vector p2 = ptIter.next();
+        double dist = measure.distance(p1, p2);
+        // Put all points that are within distance threshold T1 into the canopy
+        if (dist < t1)
+          canopy.addPoint(p2);
+        // Remove from the list all points that are within distance threshold T2
+        if (dist < t2)
+          ptIter.remove();
+      }
+    }
+    return canopies;
+  }
+
+  public static void main(String[] args) {
+    UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+    generateSamples();
+    List<Vector> points = new ArrayList<Vector>();
+    points.addAll(sampleData);
+    canopies = populateCanopies(new ManhattanDistanceMeasure(), points, t1, t2);
+    new DisplayCanopy();
+  }
+
+  static void generateResults() {
+    DisplayDirichlet.generateResults(new NormalModelDistribution());
+  }
+}

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java?rev=764529&r1=764528&r2=764529&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
(original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
Mon Apr 13 17:01:19 2009
@@ -19,7 +19,7 @@
 import org.apache.mahout.matrix.TimesFunction;
 import org.apache.mahout.matrix.Vector;
 
-class DisplayDirichlet extends Frame {
+public class DisplayDirichlet extends Frame {
   private static final long serialVersionUID = 1L;
 
   int res; //screen resolution
@@ -28,15 +28,15 @@
 
   int size = 8; // screen size in inches
 
-  static List<Vector> sampleData = new ArrayList<Vector>();
+  public static List<Vector> sampleData = new ArrayList<Vector>();
 
-  static List<Model<Vector>[]> result;
+  protected static List<Model<Vector>[]> result;
 
-  static double significance = 0.05;
+  protected static double significance = 0.05;
 
   static List<Vector> sampleParams = new ArrayList<Vector>();
 
-  static Color[] colors = { Color.red, Color.orange, Color.yellow, Color.green,
+  protected static Color[] colors = { Color.red, Color.orange, Color.yellow, Color.green,
       Color.blue, Color.magenta, Color.lightGray };
 
   /**
@@ -56,11 +56,11 @@
    * limitations under the License.
    */
 
-  DisplayDirichlet() {
+  public DisplayDirichlet() {
     initialize();
   }
 
-  void initialize() {
+  public void initialize() {
     //Get screen resolution
     res = Toolkit.getDefaultToolkit().getScreenResolution();
 
@@ -99,7 +99,7 @@
     }
   }
 
-  void plotSampleData(Graphics g) {
+  public void plotSampleData(Graphics g) {
     Graphics2D g2 = (Graphics2D) g;
     double sx = (double) res / ds;
     g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
@@ -123,7 +123,7 @@
    * @param v a Vector of rectangle centers
    * @param dv a Vector of rectangle sizes
    */
-  void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
+  public void plotRectangle(Graphics2D g2, Vector v, Vector dv) {
     int h = size / 2;
     double[] flip = { 1, -1 };
     Vector v2 = v.copy().assign(new DenseVector(flip), new TimesFunction());
@@ -140,7 +140,7 @@
    * @param v a Vector of rectangle centers
    * @param dv a Vector of rectangle sizes
    */
-  void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
+  public void plotEllipse(Graphics2D g2, Vector v, Vector dv) {
     int h = size / 2;
     double[] flip = { 1, -1 };
     Vector v2 = v.copy().assign(new DenseVector(flip), new TimesFunction());
@@ -167,13 +167,13 @@
     System.out.println();
   }
 
-  static void generateSamples() {
+  public static void generateSamples() {
     generateSamples(400, 1, 1, 3);
     generateSamples(300, 1, 0, 0.5);
     generateSamples(300, 0, 2, 0.1);
   }
 
-  static void generate2dSamples() {
+  public static void generate2dSamples() {
     generate2dSamples(400, 1, 1, 3, 1);
     generate2dSamples(300, 1, 0, 0.5, 1);
     generate2dSamples(300, 0, 2, 0.1, 0.5);
@@ -217,14 +217,14 @@
           UncommonDistributions.rNorm(my, sdy) }));
   }
 
-  static void generateResults(ModelDistribution<Vector> modelDist) {
+  public static void generateResults(ModelDistribution<Vector> modelDist) {
     DirichletClusterer<Vector> dc = new DirichletClusterer<Vector>(sampleData,
         modelDist, 1.0, 10, 2, 2);
     result = dc.cluster(20);
     printModels(result, 5);
   }
 
-  static boolean isSignificant(Model<Vector> model) {
+ public static boolean isSignificant(Model<Vector> model) {
     return (((double) model.count() / sampleData.size()) > significance);
   }
 

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java?rev=764529&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java
(added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java
Mon Apr 13 17:01:19 2009
@@ -0,0 +1,193 @@
+package org.apache.mahout.clustering.kmeans;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.awt.BasicStroke;
+import java.awt.Graphics;
+import java.awt.Graphics2D;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.canopy.VisibleCanopy;
+import org.apache.mahout.clustering.dirichlet.DisplayDirichlet;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.DistanceMeasure;
+import org.apache.mahout.utils.ManhattanDistanceMeasure;
+
+class DisplayKMeans extends DisplayDirichlet {
+  public DisplayKMeans() {
+    initialize();
+    this.setTitle("K-Means Clusters (> 5% of population)");
+  }
+
+  private static final long serialVersionUID = 1L;
+
+  static List<Canopy> canopies;
+
+  static List<List<Cluster>> clusters;
+
+  static double t1 = 3.0;
+
+  static double t2 = 1.5;
+
+  public void paint(Graphics g) {
+    super.plotSampleData(g);
+    Graphics2D g2 = (Graphics2D) g;
+    Vector dv = new DenseVector(2);
+    int i = clusters.size() - 1;
+    for (List<Cluster> cls : clusters) {
+      g2.setStroke(new BasicStroke(i == 0 ? 3 : 1));
+      g2.setColor(colors[Math.min(colors.length - 1, i--)]);
+      for (Cluster cluster : cls)
+        if (true || cluster.getNumPoints() > sampleData.size() * 0.05) {
+          dv.assign(cluster.getStd() * 3);
+          plotEllipse(g2, cluster.getCenter(), dv);
+        }
+    }
+  }
+
+  /**
+   * This is the reference k-means implementation. Given its inputs it iterates
+   * over the points and clusters until their centers converge or until the
+   * maximum number of iterations is exceeded.
+   * 
+   * @param points the input List<Vector> of points
+   * @param clusters the initial List<Cluster> of clusters
+   * @param measure the DistanceMeasure to use
+   * @param maxIter the maximum number of iterations
+   */
+  private static void referenceKmeans(List<Vector> points,
+      List<List<Cluster>> clusters, DistanceMeasure measure, int maxIter) {
+    boolean converged = false;
+    int iteration = 0;
+    while (!converged && iteration < maxIter) {
+      List<Cluster> next = new ArrayList<Cluster>();
+      List<Cluster> cs = clusters.get(iteration++);
+      for (Cluster c : cs)
+        next.add(new Cluster(c.getCenter()));
+      clusters.add(next);
+      converged = iterateReference(points, clusters.get(iteration), measure);
+    }
+  }
+
+  /**
+   * Perform a single iteration over the points and clusters, assigning points
+   * to clusters and returning if the iterations are completed.
+   * 
+   * @param points the List<Vector> having the input points
+   * @param clusters the List<Cluster> clusters
+   * @param measure a DistanceMeasure to use
+   * @return
+   */
+  private static boolean iterateReference(List<Vector> points,
+      List<Cluster> clusters, DistanceMeasure measure) {
+    boolean converged;
+    converged = true;
+    // iterate through all points, assigning each to the nearest cluster
+    for (Vector point : points) {
+      Cluster closestCluster = null;
+      double closestDistance = Double.MAX_VALUE;
+      for (Cluster cluster : clusters) {
+        double distance = measure.distance(cluster.getCenter(), point);
+        if (closestCluster == null || closestDistance > distance) {
+          closestCluster = cluster;
+          closestDistance = distance;
+        }
+      }
+      closestCluster.addPoint(point);
+    }
+    // test for convergence
+    for (Cluster cluster : clusters) {
+      if (!cluster.computeConvergence())
+        converged = false;
+    }
+    // update the cluster centers
+    if (!converged)
+      for (Cluster cluster : clusters)
+        cluster.recomputeCenter();
+    return converged;
+  }
+
+  /**
+   * Iterate through the points, adding new canopies. Return the canopies.
+   * 
+   * @param measure
+   *            a DistanceMeasure to use
+   * @param points
+   *            a list<Vector> defining the points to be clustered
+   * @param t1
+   *            the T1 distance threshold
+   * @param t2
+   *            the T2 distance threshold
+   * @return the List<Canopy> created
+   */
+  static List<Canopy> populateCanopies(DistanceMeasure measure,
+      List<Vector> points, double t1, double t2) {
+    List<Canopy> canopies = new ArrayList<Canopy>();
+    Canopy.config(measure, t1, t2);
+    /**
+     * Reference Implementation: Given a distance metric, one can create
+     * canopies as follows: Start with a list of the data points in any order,
+     * and with two distance thresholds, T1 and T2, where T1 > T2. (These
+     * thresholds can be set by the user, or selected by cross-validation.) Pick
+     * a point on the list and measure its distance to all other points. Put all
+     * points that are within distance threshold T1 into a canopy. Remove from
+     * the list all points that are within distance threshold T2. Repeat until
+     * the list is empty.
+     */
+    while (!points.isEmpty()) {
+      Iterator<Vector> ptIter = points.iterator();
+      Vector p1 = ptIter.next();
+      ptIter.remove();
+      Canopy canopy = new VisibleCanopy(p1);
+      canopies.add(canopy);
+      while (ptIter.hasNext()) {
+        Vector p2 = ptIter.next();
+        double dist = measure.distance(p1, p2);
+        // Put all points that are within distance threshold T1 into the canopy
+        if (dist < t1)
+          canopy.addPoint(p2);
+        // Remove from the list all points that are within distance threshold T2
+        if (dist < t2)
+          ptIter.remove();
+      }
+    }
+    return canopies;
+  }
+
+  public static void main(String[] args) {
+    UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+    generateSamples();
+    List<Vector> points = new ArrayList<Vector>();
+    points.addAll(sampleData);
+    canopies = populateCanopies(new ManhattanDistanceMeasure(), points, t1, t2);
+    DistanceMeasure measure = new ManhattanDistanceMeasure();
+    Cluster.config(measure, 0.001);
+    clusters = new ArrayList<List<Cluster>>();
+    clusters.add(new ArrayList<Cluster>());
+    for (Canopy canopy : canopies)
+      if (canopy.getNumPoints() > 0.05 * sampleData.size())
+        clusters.get(0).add(new Cluster(canopy.getCenter()));
+    referenceKmeans(sampleData, clusters, measure, 10);
+    new DisplayKMeans();
+  }
+}

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java?rev=764529&r1=764528&r2=764529&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
(original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
Mon Apr 13 17:01:19 2009
@@ -17,6 +17,7 @@
 package org.apache.mahout.clustering.syntheticcontrol.dirichlet;
 
 import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 
 public class NormalScModel implements Model<Vector> {
@@ -72,17 +73,18 @@
     if (s0 == 0)
       return;
     mean = s1.divide(s0);
-    //TODO: is this the average of the 60 component stds??
-    if (s0 > 1)
-      sd = Math.sqrt(s2.times(s0).minus(s1.times(s1)).zSum() / (60 * 60)) / s0;
-    else
+    //compute the average of the 60 component stds
+    if (s0 > 1) {
+      Vector std = s2.times(s0).minus(s1.times(s1)).assign(
+          new SquareRootFunction()).divide(s0);
+      sd = std.zSum() / s1.cardinality();
+    } else
       sd = Double.MIN_VALUE;
   }
 
   @Override
   // TODO: need to revisit this for reasonableness
   public double pdf(Vector x) {
-    assert x.size() == 60;
     double sd2 = sd * sd;
     double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
     double ex = Math.exp(exp);



Mime
View raw message