mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r1000021 - in /mahout/trunk/utils/src: main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
Date Wed, 22 Sep 2010 15:05:09 GMT
Author: jeastman
Date: Wed Sep 22 15:05:09 2010
New Revision: 1000021

URL: http://svn.apache.org/viewvc?rev=1000021&view=rev
Log:
Added unit tests to surface corner-case math problems and fixed same.

Modified:
    mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
    mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java

Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java?rev=1000021&r1=1000020&r2=1000021&view=diff
==============================================================================
--- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
(original)
+++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
Wed Sep 22 15:05:09 2010
@@ -33,9 +33,13 @@ import org.apache.mahout.common.distance
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.function.SquareRootFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class CDbwEvaluator {
 
+  private static final Logger log = LoggerFactory.getLogger(CDbwEvaluator.class);
+
   private final Map<Integer, List<VectorWritable>> representativePoints;
 
   private final Map<Integer, Double> stDevs = new HashMap<Integer, Double>();
@@ -73,11 +77,10 @@ public class CDbwEvaluator {
    * @param clustersIn
    *            a String path to the input clusters directory
    */
-  public CDbwEvaluator(Configuration conf, Path clustersIn)
-      throws ClassNotFoundException, InstantiationException, IllegalAccessException, IOException
{
+  public CDbwEvaluator(Configuration conf, Path clustersIn) throws ClassNotFoundException,
InstantiationException,
+      IllegalAccessException, IOException {
     ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-    measure = ccl.loadClass(conf.get(CDbwDriver.DISTANCE_MEASURE_KEY))
-        .asSubclass(DistanceMeasure.class).newInstance();
+    measure = ccl.loadClass(conf.get(CDbwDriver.DISTANCE_MEASURE_KEY)).asSubclass(DistanceMeasure.class).newInstance();
     representativePoints = CDbwMapper.getRepresentativePoints(conf);
     clusters = loadClusters(conf, clustersIn);
     for (Integer cId : representativePoints.keySet()) {
@@ -92,10 +95,10 @@ public class CDbwEvaluator {
   public double intraClusterDensity() {
     double avgStd = 0.0;
     for (Integer cId : representativePoints.keySet()) {
-      avgStd += stDevs.get(cId);
+      avgStd += getStdev(cId);
     }
     avgStd /= representativePoints.size();
-  
+
     double sum = 0.0;
     for (Map.Entry<Integer, List<VectorWritable>> entry : representativePoints.entrySet())
{
       Integer cId = entry.getKey();
@@ -103,12 +106,14 @@ public class CDbwEvaluator {
       double cSum = 0.0;
       for (VectorWritable aRepI : repI) {
         double inDensity = intraDensity(clusters.get(cId).getCenter(), aRepI.get(), avgStd);
-        double std = stDevs.get(cId);
+        double std = getStdev(cId);
         if (std > 0.0) {
           cSum += inDensity / std;
         }
       }
-      sum += cSum / repI.size();
+      if (repI.size() > 0) {
+        sum += cSum / repI.size();
+      }
     }
     return sum / representativePoints.size();
   }
@@ -118,7 +123,7 @@ public class CDbwEvaluator {
     for (Map.Entry<Integer, List<VectorWritable>> entry1 : representativePoints.entrySet())
{
       Integer cI = entry1.getKey();
       List<VectorWritable> repI = entry1.getValue();
-      double stDevI = stDevs.get(cI);      
+      double stDevI = getStdev(cI);
       for (Map.Entry<Integer, List<VectorWritable>> entry2 : representativePoints.entrySet())
{
         Integer cJ = entry2.getKey();
         if (cI.equals(cJ)) {
@@ -138,23 +143,20 @@ public class CDbwEvaluator {
             }
           }
         }
-        double stDevJ = stDevs.get(cJ);
-        double interDensity = interDensity(uIJ, cI, cJ);
+        double stDevJ = getStdev(cJ);
+        double interDensity = uIJ == null ? 0 : interDensity(uIJ, cI, cJ);
         double stdSum = stDevI + stDevJ;
         double density = 0.0;
         if (stdSum > 0.0) {
           density = minDistance * interDensity / stdSum;
         }
-  
-        // Use a logger
-        //if (false) {
-        //  System.out.println("minDistance[" + cI + "," + cJ + "]=" + minDistance);
-        //  System.out.println("stDev[" + cI + "]=" + stDevI);
-        //  System.out.println("stDev[" + cJ + "]=" + stDevJ);
-        //  System.out.println("interDensity[" + cI + "," + cJ + "]=" + interDensity);
-        //  System.out.println("density[" + cI + "," + cJ + "]=" + density);
-        //  System.out.println();
-        //}
+
+        log.debug("minDistance[" + cI + "," + cJ + "]=" + minDistance);
+        log.debug("stDev[" + cI + "]=" + stDevI);
+        log.debug("stDev[" + cJ + "]=" + stDevJ);
+        log.debug("interDensity[" + cI + "," + cJ + "]=" + interDensity);
+        log.debug("density[" + cI + "," + cJ + "]=" + density);
+
         sum += density;
       }
     }
@@ -162,6 +164,20 @@ public class CDbwEvaluator {
     return sum;
   }
 
+  /**
+   * Handle missing stDevs when clusters are empty by returning 0
+   * @param cI
+   * @return
+   */
+  private Double getStdev(Integer cI) {
+    Double result = stDevs.get(cI);
+    if (result == null) {
+      return new Double(0);
+    } else {
+      return result;
+    }
+  }
+
   public double separation() {
     double minDistance = Double.MAX_VALUE;
     for (Map.Entry<Integer, List<VectorWritable>> entry1 : representativePoints.entrySet())
{
@@ -192,8 +208,8 @@ public class CDbwEvaluator {
    *            a String pathname to the directory containing input cluster files
    * @return a List<Cluster> of the clusters
    */
-  private static Map<Integer, Cluster> loadClusters(Configuration conf, Path clustersIn)
-      throws InstantiationException, IllegalAccessException, IOException {
+  private static Map<Integer, Cluster> loadClusters(Configuration conf, Path clustersIn)
throws InstantiationException,
+      IllegalAccessException, IOException {
     Map<Integer, Cluster> clusters = new HashMap<Integer, Cluster>();
     FileSystem fs = clustersIn.getFileSystem(conf);
     for (FileStatus part : fs.listStatus(clustersIn)) {
@@ -217,7 +233,7 @@ public class CDbwEvaluator {
     List<VectorWritable> repI = representativePoints.get(cI);
     List<VectorWritable> repJ = representativePoints.get(cJ);
     double density = 0.0;
-    double std = (stDevs.get(cI) + stDevs.get(cJ)) / 2.0;
+    double std = (getStdev(cI) + getStdev(cJ)) / 2.0;
     for (VectorWritable vwI : repI) {
       if (measure.distance(uIJ, vwI.get()) <= std) {
         density++;
@@ -245,10 +261,12 @@ public class CDbwEvaluator {
       s1 = s1 == null ? v.clone() : s1.plus(v);
       s2 = s2 == null ? v.times(v) : s2.plus(v.times(v));
     }
-    Vector std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
-    double d = std.zSum() / std.size();
-    //System.out.println("stDev[" + cI + "]=" + d);
-    stDevs.put(cI, d);
+    if (s0 > 1) {
+      Vector std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
+      double d = std.zSum() / std.size();
+      //System.out.println("stDev[" + cI + "]=" + d);
+      stDevs.put(cI, d);
+    }
   }
 
   /*

Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java?rev=1000021&r1=1000020&r2=1000021&view=diff
==============================================================================
--- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java (original)
+++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java Wed
Sep 22 15:05:09 2010
@@ -29,11 +29,11 @@ import org.apache.mahout.math.VectorWrit
 
 public class CDbwReducer extends Reducer<IntWritable, WeightedVectorWritable, IntWritable,
VectorWritable> {
 
-  private Map<Integer, List<VectorWritable>> referencePoints;
+  private Map<Integer, List<VectorWritable>> representativePoints;
 
   @Override
   protected void cleanup(Context context) throws IOException, InterruptedException {
-    for (Map.Entry<Integer, List<VectorWritable>> entry : referencePoints.entrySet())
{
+    for (Map.Entry<Integer, List<VectorWritable>> entry : representativePoints.entrySet())
{
       IntWritable iw = new IntWritable(entry.getKey());
       for (VectorWritable vw : entry.getValue()) {
         context.write(iw, vw);
@@ -60,7 +60,7 @@ public class CDbwReducer extends Reducer
     super.setup(context);
     Configuration conf = context.getConfiguration();
     try {
-      referencePoints = CDbwMapper.getRepresentativePoints(conf);
+      representativePoints = CDbwMapper.getRepresentativePoints(conf);
     } catch (NumberFormatException e) {
       throw new IllegalStateException(e);
     } catch (SecurityException e) {
@@ -70,8 +70,8 @@ public class CDbwReducer extends Reducer
     }
   }
 
-  public void configure(Map<Integer, List<VectorWritable>> referencePoints) {
-    this.referencePoints = referencePoints;
+  public void configure(Map<Integer, List<VectorWritable>> representativePoints)
{
+    this.representativePoints = representativePoints;
   }
 
 }

Modified: mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java?rev=1000021&r1=1000020&r2=1000021&view=diff
==============================================================================
--- mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
(original)
+++ mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
Wed Sep 22 15:05:09 2010
@@ -95,7 +95,7 @@ public final class TestCDbwEvaluator ext
    * Initialize synthetic data using 4 clusters dC units from origin having 4 representative
points dP from each center
    * @param dC a double cluster center offset
    * @param dP a double representative point offset
-   * @param measure TODO
+   * @param measure the DistanceMeasure
    */
   private void initData(double dC, double dP, DistanceMeasure measure) {
     clusters = new HashMap<Integer, Cluster>();
@@ -239,5 +239,36 @@ public final class TestCDbwEvaluator ext
                       1);
     checkRefPoints(numIterations);
   }
+  
+  @Test
+  public void testEmptyCluster() {
+    DistanceMeasure measure = new EuclideanDistanceMeasure();
+    initData(1, 0.25, measure);
+    Canopy cluster = new Canopy(new DenseVector(new double[] { 10, 10 }), 19, measure);
+    clusters.put(cluster.getId(), cluster);
+    List<VectorWritable> points = new ArrayList<VectorWritable>();
+    representativePoints.put(cluster.getId(), points);
+    CDbwEvaluator evaluator = new CDbwEvaluator(representativePoints, clusters, measure);
+    assertEquals("inter cluster density", 0.0, evaluator.interClusterDensity(), EPSILON);
+    assertEquals("separation", 1.5, evaluator.separation(), EPSILON);
+    assertEquals("intra cluster density", 0.7155417527999326, evaluator.intraClusterDensity(),
EPSILON);
+    assertEquals("CDbw", 1.073312629199899, evaluator.getCDbw(), EPSILON);
+  }
+
+  @Test
+  public void testSingleValueCluster() {
+    DistanceMeasure measure = new EuclideanDistanceMeasure();
+    initData(1, 0.25, measure);
+    Canopy cluster = new Canopy(new DenseVector(new double[] { 0, 0 }), 19, measure);
+    clusters.put(cluster.getId(), cluster);
+    List<VectorWritable> points = new ArrayList<VectorWritable>();
+    points.add(new VectorWritable(cluster.getCenter().plus(new DenseVector(new double[] {
1, 1 }))));
+    representativePoints.put(cluster.getId(), points);
+    CDbwEvaluator evaluator = new CDbwEvaluator(representativePoints, clusters, measure);
+    assertEquals("inter cluster density", 0.0, evaluator.interClusterDensity(), EPSILON);
+    assertEquals("separation", 0.0, evaluator.separation(), EPSILON);
+    assertEquals("intra cluster density", 0.7155417527999326, evaluator.intraClusterDensity(),
EPSILON);
+    assertEquals("CDbw", 0.0, evaluator.getCDbw(), EPSILON);
+  }
 
 }



Mime
View raw message