mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r1000599 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ utils/src/main/java/org/apache/mahout/clustering/cdbw/
Date Thu, 23 Sep 2010 20:02:40 GMT
Author: jeastman
Date: Thu Sep 23 20:02:40 2010
New Revision: 1000599

URL: http://svn.apache.org/viewvc?rev=1000599&view=rev
Log:
Added small prior to variance in pdf() computations to avoid numeric instability when it is
0. All tests run

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
    mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
Thu Sep 23 20:02:40 2010
@@ -357,7 +357,8 @@ public class DirichletClusterer {
     throws IOException {
     Vector pi = new DenseVector(clusters.size());
     for (int i = 0; i < clusters.size(); i++) {
-      pi.set(i, clusters.get(i).getModel().pdf(vector));
+      double pdf = clusters.get(i).getModel().pdf(vector);
+      pi.set(i, pdf);
     }
     pi = pi.divide(pi.zSum());
     if (emitMostLikely) {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
Thu Sep 23 20:02:40 2010
@@ -120,7 +120,8 @@ public class AsymmetricSampledNormalMode
     // TODO: is this reasonable? correct? It seems to work in some cases.
     double pdf = 1;
     for (int i = 0; i < x.size(); i++) {
-      pdf *= UncommonDistributions.dNorm(x.getQuick(i), getCenter().getQuick(i), getRadius().getQuick(i));
+      // small prior on stdDev to avoid numeric instability when stdDev==0
+      pdf *= UncommonDistributions.dNorm(x.getQuick(i), mean.getQuick(i), stdDev.getQuick(i)
+ 0.000001);
     }
     return pdf;
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Thu Sep 23 20:02:40 2010
@@ -38,7 +38,8 @@ public class GaussianCluster extends Abs
     for (int i = 0; i < x.size(); i++) {
       double x2 = x.get(i);
       double m = getCenter().get(i);
-      double s = getRadius().get(i);
+      // small prior on s to avoid numeric instability when s==0
+      double s = getRadius().get(i) + 0.000001;
       double dNorm = UncommonDistributions.dNorm(x2, m, s);
       pdf += dNorm;
     }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Thu Sep 23 20:02:40 2010
@@ -36,27 +36,29 @@ import com.google.gson.GsonBuilder;
 import com.google.gson.reflect.TypeToken;
 
 public class NormalModel implements Cluster {
-  
+
   private static final double SQRT2PI = Math.sqrt(2.0 * Math.PI);
 
-  private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {}.getType();
+  private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {
+  }.getType();
 
   private int id;
-  
+
   // the parameters
   private Vector mean;
-  
+
   private double stdDev;
-  
+
   // the observation statistics, initialized by the first observation
   private int s0;
-  
+
   private Vector s1;
-  
+
   private Vector s2;
 
-  public NormalModel() { }
-  
+  public NormalModel() {
+  }
+
   public NormalModel(int id, Vector mean, double stdDev) {
     this.id = id;
     this.mean = mean;
@@ -65,19 +67,19 @@ public class NormalModel implements Clus
     this.s1 = mean.like();
     this.s2 = mean.like();
   }
-  
+
   int getS0() {
     return s0;
   }
-  
+
   public Vector getMean() {
     return mean;
   }
-  
+
   public double getStdDev() {
     return stdDev;
   }
-  
+
   /**
    * TODO: Return a proper sample from the posterior. For now, return an instance with the
same parameters
    * 
@@ -86,7 +88,7 @@ public class NormalModel implements Clus
   public NormalModel sampleFromPosterior() {
     return new NormalModel(id, mean, stdDev);
   }
-  
+
   @Override
   public void observe(VectorWritable x) {
     s0++;
@@ -102,7 +104,7 @@ public class NormalModel implements Clus
       s2 = s2.plus(v.times(v));
     }
   }
-  
+
   @Override
   public void computeParameters() {
     if (s0 == 0) {
@@ -117,26 +119,28 @@ public class NormalModel implements Clus
       stdDev = Double.MIN_VALUE;
     }
   }
-  
+
   @Override
   public double pdf(VectorWritable v) {
     Vector x = v.get();
-    double sd2 = stdDev * stdDev;
+    // small prior on std to avoid numeric instability when std==0
+    double std = stdDev + 0.000001;
+    double sd2 = std * std;
     double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
     double ex = Math.exp(exp);
-    return ex / (stdDev * SQRT2PI);
+    return ex / (std * SQRT2PI);
   }
-  
+
   @Override
   public int count() {
     return s0;
   }
-  
+
   @Override
   public String toString() {
     return asFormatString(null);
   }
-  
+
   @Override
   public String asFormatString(String[] bindings) {
     StringBuilder buf = new StringBuilder();
@@ -147,7 +151,7 @@ public class NormalModel implements Clus
     buf.append(" sd=").append(String.format(Locale.ENGLISH, "%.2f", stdDev)).append('}');
     return buf.toString();
   }
-  
+
   @Override
   public void readFields(DataInput in) throws IOException {
     this.id = in.readInt();
@@ -161,7 +165,7 @@ public class NormalModel implements Clus
     temp.readFields(in);
     this.s2 = temp.get();
   }
-  
+
   @Override
   public void write(DataOutput out) throws IOException {
     out.writeInt(id);
@@ -171,7 +175,7 @@ public class NormalModel implements Clus
     VectorWritable.writeVector(out, s1);
     VectorWritable.writeVector(out, s2);
   }
-  
+
   @Override
   public String asJsonString() {
     GsonBuilder builder = new GsonBuilder();

Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
--- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java (original)
+++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java Thu
Sep 23 20:02:40 2010
@@ -59,6 +59,9 @@ public class CDbwMapper extends Mapper<I
     WeightedVectorWritable currentMDP = mostDistantPoints.get(key);
 
     List<VectorWritable> refPoints = representativePoints.get(key);
+    if (refPoints == null){
+      System.out.println();
+    }
     double totalDistance = 0.0;
     for (VectorWritable refPoint : refPoints) {
       totalDistance += measure.distance(refPoint.get(), point.getVector());



Mime
View raw message