mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject svn commit: r1060450 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/ core/src/test/java/org/apache/mahout/clustering/ utils/src/main/java/org/apache/mahout/clustering/cdbw/
Date Tue, 18 Jan 2011 16:28:09 GMT
Author: srowen
Date: Tue Jan 18 16:28:08 2011
New Revision: 1060450

URL: http://svn.apache.org/viewvc?rev=1060450&view=rev
Log:
MAHOUT-533 make stdev calculation more accurate in clustering

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
    mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
Tue Jan 18 16:28:08 2011
@@ -50,8 +50,9 @@ public interface GaussianAccumulator {
    * Observe the vector 
    * 
    * @param x a Vector
+   * @param weight the double observation weight (usually 1.0)
    */
-  void observe(Vector x);
+  void observe(Vector x, double weight);
 
   /**
    * Compute the mean, variance and standard deviation

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
Tue Jan 18 16:28:08 2011
@@ -24,17 +24,15 @@ import org.apache.mahout.math.function.S
  * numerically-stable. See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
  */
 public class OnlineGaussianAccumulator implements GaussianAccumulator {
-  private double n = 0;
 
+  private double sumWeight = 0.0;
   private Vector mean;
-
-  private Vector m2;
-
+  private Vector s;
   private Vector variance;
 
   @Override
   public double getN() {
-    return n;
+    return sumWeight;
   }
 
   @Override
@@ -47,23 +45,44 @@ public class OnlineGaussianAccumulator i
     return variance.clone().assign(new SquareRootFunction());
   }
 
-  @Override
-  public void observe(Vector x) {
-    n++;
-    Vector delta;
-    if (mean != null) {
-      delta = x.minus(mean);
-    } else {
+  /* from Wikipedia: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+   * 
+   * Weighted incremental algorithm
+   * 
+   * def weighted_incremental_variance(dataWeightPairs):
+   * mean = 0
+   * S = 0
+   * sumweight = 0
+   * for x, weight in dataWeightPairs: # Alternately "for x in zip(data, weight):"
+   *     temp = weight + sumweight
+   *     Q = x - mean
+   *      R = Q * weight / temp
+   *      S = S + sumweight * Q * R
+   *      mean = mean + R
+   *      sumweight = temp
+   *  Variance = S / (sumweight-1)  # if sample is the population, omit -1
+   *  return Variance
+   */
+
+  @Override
+  public void observe(Vector x, double weight) {
+    double temp = weight + sumWeight;
+    Vector Q;
+    if (mean == null) {
       mean = x.like();
-      delta = x.clone();
+      Q = x.clone();
+    } else {
+      Q = x.minus(mean);
     }
-    mean = mean.plus(delta.divide(n));
-    if (m2 != null) {
-      m2 = m2.plus(delta.times(x.minus(mean)));
+    Vector R = Q.times(weight).divide(temp);
+    if (s == null) {
+      s = Q.times(sumWeight).times(R);
     } else {
-      m2 = delta.times(x.minus(mean));
+      s = s.plus(Q.times(sumWeight).times(R));
     }
-    variance = m2.divide(n - 1);
+    mean = mean.plus(R);
+    sumWeight = temp;
+    variance = s.divide(sumWeight - 1);//  # if sample is the population, omit -1
   }
 
   @Override
@@ -73,8 +92,8 @@ public class OnlineGaussianAccumulator i
 
   @Override
   public double getAverageStd() {
-    if (n == 0) {
-      return 0;
+    if (sumWeight == 0.0) {
+      return 0.0;
     } else {
       Vector std = getStd();
       return std.zSum() / std.size();

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
Tue Jan 18 16:28:08 2011
@@ -25,14 +25,11 @@ import org.apache.mahout.math.function.S
  * Suffers from overflow, underflow and roundoff error but has minimal observe-time overhead
  */
 public class RunningSumsGaussianAccumulator implements GaussianAccumulator {
-  private double s0 = 0;
 
+  private double s0 = 0.0;
   private Vector s1;
-
   private Vector s2;
-
   private Vector mean;
-
   private Vector std;
 
   @Override
@@ -52,8 +49,8 @@ public class RunningSumsGaussianAccumula
 
   @Override
   public double getAverageStd() {
-    if (s0 == 0) {
-      return 0;
+    if (s0 == 0.0) {
+      return 0.0;
     } else {
       return std.zSum() / std.size();
     }
@@ -65,14 +62,15 @@ public class RunningSumsGaussianAccumula
   }
 
   @Override
-  public void observe(Vector x) {
-    s0++;
+  public void observe(Vector x, double weight) {
+    s0 += weight;
+    Vector weightedX = x.times(weight);
     if (s1 == null) {
-      s1 = x.clone();
+      s1 = weightedX;
     } else {
-      x.addTo(s1);
+      weightedX.addTo(s1);
     }
-    Vector x2 = x.times(x);
+    Vector x2 = x.times(x).times(weight);
     if (s2 == null) {
       s2 = x2;
     } else {
@@ -82,11 +80,10 @@ public class RunningSumsGaussianAccumula
 
   @Override
   public void compute() {
-    if (s0 == 0) {
-      return;
+    if (s0 != 0.0) {
+      mean = s1.divide(s0);
+      std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
     }
-    mean = s1.divide(s0);
-    std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
   }
 
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
Tue Jan 18 16:28:08 2011
@@ -22,7 +22,9 @@ import java.util.Collection;
 import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.SquareRootFunction;
 import org.junit.Before;
 import org.junit.Test;
 import org.slf4j.Logger;
@@ -30,37 +32,38 @@ import org.slf4j.LoggerFactory;
 
 public final class TestGaussianAccumulators extends MahoutTestCase {
 
-  private Collection<VectorWritable> sampleData = new ArrayList<VectorWritable>();
-
   private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
 
+  private Collection<VectorWritable> sampleData = new ArrayList<VectorWritable>();
+  private int sampleN;
+  private Vector sampleMean;
+  private Vector sampleStd;
+  private Vector sampleVar;
+
   @Override
   @Before
   public void setUp() throws Exception {
     super.setUp();
     sampleData = new ArrayList<VectorWritable>();
     generateSamples();
-  }
+    sampleN = 0;
+    Vector sum = new DenseVector(2);
+    for (VectorWritable v : sampleData) {
+      v.get().addTo(sum);
+      sampleN++;
+    }
+    sampleMean = sum.divide(sampleN);
 
-  /**
-   * Generate random samples and add them to the sampleData
-   * 
-   * @param num
-   *          int number of samples to generate
-   * @param mx
-   *          double x-value of the sample mean
-   * @param my
-   *          double y-value of the sample mean
-   * @param sd
-   *          double standard deviation of the samples
-   * @throws Exception 
-   */
-  private void generateSamples(int num, double mx, double my, double sd) {
-    log.info("Generating {} samples m=[{}, {}] sd={}", new Object[] { num, mx, my, sd });
-    for (int i = 0; i < num; i++) {
-      sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx,
sd),
-          UncommonDistributions.rNorm(my, sd) })));
+    sampleVar = new DenseVector(2);
+    for (VectorWritable v : sampleData) {
+      Vector delta = v.get().minus(sampleMean);
+      delta.times(delta).addTo(sampleVar);
     }
+    sampleVar = sampleVar.divide(sampleN - 1);
+    sampleStd = sampleVar.clone();
+    sampleStd.assign(new SquareRootFunction());
+    log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
+             new Object[] { sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0),
sampleStd.get(1) });
   }
 
   /**
@@ -86,7 +89,7 @@ public final class TestGaussianAccumulat
   }
 
   private void generateSamples() {
-    generate2dSamples(500, 1, 2, 3, 4);
+    generate2dSamples(50000, 1, 2, 3, 4);
   }
 
   @Test
@@ -101,18 +104,76 @@ public final class TestGaussianAccumulat
   }
 
   @Test
-  public void testAccumulatorResults() {
+  public void testAccumulatorOneSample() {
+    GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+    GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+    Vector sample = new DenseVector(2);
+    accumulator0.observe(sample, 1.0);
+    accumulator1.observe(sample, 1.0);
+    accumulator0.compute();
+    accumulator1.compute();
+    assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+    assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+    assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(),
EPSILON);
+  }
+
+  @Test
+  public void testOLAccumulatorResults() {
+    GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator.observe(vw.get(), 1.0);
+    }
+    accumulator.compute();
+    log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[] { accumulator.getN(),
accumulator.getMean().get(0),
+        accumulator.getMean().get(1), accumulator.getStd().get(0), accumulator.getStd().get(1)
});
+    assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+    assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+    assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON);
+  }
+
+  @Test
+  public void testRSAccumulatorResults() {
+    GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator.observe(vw.get(), 1.0);
+    }
+    accumulator.compute();
+    log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[] { (int) accumulator.getN(),
+        accumulator.getMean().get(0), accumulator.getMean().get(1), accumulator.getStd().get(0),
accumulator.getStd().get(1) });
+    assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+    assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+    assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001);
+  }
+
+  @Test
+  public void testAccumulatorWeightedResults() {
+    GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+    GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator0.observe(vw.get(), 0.5);
+      accumulator1.observe(vw.get(), 0.5);
+    }
+    accumulator0.compute();
+    accumulator1.compute();
+    assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+    assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+    assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+    assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(),
0.01);
+  }
+
+  @Test
+  public void testAccumulatorWeightedResults2() {
     GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
     GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
     for (VectorWritable vw : sampleData) {
-      accumulator0.observe(vw.get());
-      accumulator1.observe(vw.get());
+      accumulator0.observe(vw.get(), 1.5);
+      accumulator1.observe(vw.get(), 1.5);
     }
     accumulator0.compute();
     accumulator1.compute();
     assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
     assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
-    assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.01);
-    assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(),
0.1);
+    assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+    assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(),
0.01);
   }
 }

Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
--- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
(original)
+++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
Tue Jan 18 16:28:08 2011
@@ -50,13 +50,9 @@ public class CDbwEvaluator {
   private static final Logger log = LoggerFactory.getLogger(CDbwEvaluator.class);
 
   private final Map<Integer, List<VectorWritable>> representativePoints;
-
   private final Map<Integer, Double> stDevs = new HashMap<Integer, Double>();
-
   private final List<Cluster> clusters;
-
   private final DistanceMeasure measure;
-
   private boolean pruned;
 
   /**
@@ -136,7 +132,7 @@ public class CDbwEvaluator {
     List<VectorWritable> repPts = representativePoints.get(cI);
     GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
     for (VectorWritable vw : repPts) {
-      accumulator.observe(vw.get());
+      accumulator.observe(vw.get(), 1.0);
     }
     accumulator.compute();
     double d = accumulator.getAverageStd();



Mime
View raw message