mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jeast...@apache.org
Subject svn commit: r1222524 - /mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Date Fri, 23 Dec 2011 01:01:58 GMT
Author: jeastman
Date: Fri Dec 23 01:01:58 2011
New Revision: 1222524

URL: http://svn.apache.org/viewvc?rev=1222524&view=rev
Log:
MAHOUT-846: Cache pdf zProd2piR term constant over life of cluster

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=1222524&r1=1222523&r2=1222524&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Fri Dec 23 01:01:58 2011
@@ -48,27 +48,47 @@ public class GaussianCluster extends Abs
     return new GaussianCluster(getCenter(), getRadius(), getId());
   }
   
+  // the value of the zProduct(S*2pi) term. Calculated below.
+  private Double zProd2piR = null;
+  
   @Override
   public double pdf(VectorWritable vw) {
-    Vector x = vw.get();
-    Vector m = getCenter();
-    Vector s = getRadius().plus(0.0000001); // add a small prior to avoid divide by zero
-    return Math.exp(-(divideSquareAndSum(x.minus(m), s) / 2)) / zProdSqt2Pi(s);
+    if (zProd2piR == null) {
+      computeProd2piR();
+    }
+    return Math.exp(-(sumXminusCdivRsquared(vw.get()) / 2)) / zProd2piR;
   }
   
-  private double zProdSqt2Pi(Vector s) {
-    double prod = 1;
-    for (int i = 0; i < s.size(); i++) {
-      prod *= s.getQuick(i) * UncommonDistributions.SQRT2PI;
+  /**
+   * Compute the product(r[i]*SQRT2PI) over all i. Note that the cluster Radius
+   * corresponds to the Stdev of a Gaussian and the Center to its Mean.
+   */
+  private void computeProd2piR() {
+    zProd2piR = 1.0;
+    for (Iterator<Element> it = getRadius().iterateNonZero(); it.hasNext();) {
+      Element radius = it.next();
+      zProd2piR *= radius.get() * UncommonDistributions.SQRT2PI;
     }
-    return prod;
   }
   
-  private double divideSquareAndSum(Vector numerator, Vector denominator) {
+  @Override
+  public void computeParameters() {
+    super.computeParameters();
+    zProd2piR = null;
+  }
+  
+  /**
+   * @param x
+   *          a Vector
+   * @return the zSum(((x[i]-c[i])/r[i])^2) over all i
+   */
+  private double sumXminusCdivRsquared(Vector x) {
     double result = 0;
-    for (Iterator<Element> it = denominator.iterateNonZero(); it.hasNext();) {
-      Element denom = it.next();
-      double quotient = numerator.getQuick(denom.index()) / denom.get();
+    for (Iterator<Element> it = getRadius().iterateNonZero(); it.hasNext();) {
+      Element radiusElem = it.next();
+      int index = radiusElem.index();
+      double quotient = (x.get(index) - getCenter().get(index))
+          / radiusElem.get();
       result += quotient * quotient;
     }
     return result;



Mime
View raw message