commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1132448 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java site/xdoc/changes.xml
Date Sun, 05 Jun 2011 16:27:53 GMT
Author: luc
Date: Sun Jun  5 16:27:53 2011
New Revision: 1132448

URL: http://svn.apache.org/viewvc?rev=1132448&view=rev
Log:
Improved k-means++ clustering performances and initial cluster center choice.

JIRA: MATH-584

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
    commons/proper/math/trunk/src/site/xdoc/changes.xml

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java?rev=1132448&r1=1132447&r2=1132448&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
Sun Jun  5 16:27:53 2011
@@ -19,6 +19,7 @@ package org.apache.commons.math.stat.clu
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Random;
 
@@ -193,41 +194,121 @@ public class KMeansPlusPlusClusterer<T e
     private static <T extends Clusterable<T>> List<Cluster<T>>
         chooseInitialCenters(final Collection<T> points, final int k, final Random
random) {
 
-        final List<T> pointSet = new ArrayList<T>(points);
+        // Convert to list for indexed access. Make it unmodifiable, since removal of items
+        // would screw up the logic of this method.
+        final List<T> pointList = Collections.unmodifiableList(new ArrayList<T>
(points));
+
+        // The number of points in the list.
+        final int numPoints = pointList.size();
+
+        // Set the corresponding element in this array to indicate when
+        // elements of pointList are no longer available.
+        final boolean[] taken = new boolean[numPoints];
+
+        // The resulting list of initial centers.
         final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
 
         // Choose one center uniformly at random from among the data points.
-        final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
+        final int firstPointIndex = random.nextInt(numPoints);
+
+        final T firstPoint = pointList.get(firstPointIndex);
+
         resultSet.add(new Cluster<T>(firstPoint));
 
-        final double[] dx2 = new double[pointSet.size()];
+        // Must mark it as taken
+        taken[firstPointIndex] = true;
+
+        // To keep track of the minimum distance squared of elements of
+        // pointList to elements of resultSet.
+        final double[] minDistSquared = new double[numPoints];
+
+        // Initialize the elements.  Since the only point in resultSet is firstPoint,
+        // this is very easy.
+        for (int i = 0; i < numPoints; i++) {
+            if (i != firstPointIndex) { // That point isn't considered
+                double d = firstPoint.distanceFrom(pointList.get(i));
+                minDistSquared[i] = d*d;
+            }
+        }
+
         while (resultSet.size() < k) {
-            // For each data point x, compute D(x), the distance between x and
-            // the nearest center that has already been chosen.
-            double sum = 0;
-            for (int i = 0; i < pointSet.size(); i++) {
-                final T p = pointSet.get(i);
-                int nearestClusterIndex = getNearestCluster(resultSet, p);
-                final Cluster<T> nearest = resultSet.get(nearestClusterIndex);
-                final double d = p.distanceFrom(nearest.getCenter());
-                sum += d * d;
-                dx2[i] = sum;
+
+            // Sum up the squared distances for the points in pointList not
+            // already taken.
+            double distSqSum = 0.0;
+
+            for (int i = 0; i < numPoints; i++) {
+                if (!taken[i]) {
+                    distSqSum += minDistSquared[i];
+                }
             }
 
             // Add one new data point as a center. Each point x is chosen with
             // probability proportional to D(x)2
-            final double r = random.nextDouble() * sum;
-            for (int i = 0 ; i < dx2.length; i++) {
-                if (dx2[i] >= r) {
-                    final T p = pointSet.remove(i);
-                    resultSet.add(new Cluster<T>(p));
-                    break;
+            final double r = random.nextDouble() * distSqSum;
+
+            // The index of the next point to be added to the resultSet.
+            int nextPointIndex = -1;
+
+            // Sum through the squared min distances again, stopping when
+            // sum >= r.
+            double sum = 0.0;
+            for (int i = 0; i < numPoints; i++) {
+                if (!taken[i]) {
+                    sum += minDistSquared[i];
+                    if (sum >= r) {
+                        nextPointIndex = i;
+                        break;
+                    }
                 }
             }
+
+            // If it's not set to >= 0, the point wasn't found in the previous
+            // for loop, probably because distances are extremely small.  Just pick
+            // the last available point.
+            if (nextPointIndex == -1) {
+                for (int i = numPoints - 1; i >= 0; i--) {
+                    if (!taken[i]) {
+                        nextPointIndex = i;
+                        break;
+                    }
+                }
+            }
+
+            // We found one.
+            if (nextPointIndex >= 0) {
+
+                final T p = pointList.get(nextPointIndex);
+
+                resultSet.add(new Cluster<T> (p));
+
+                // Mark it as taken.
+                taken[nextPointIndex] = true;
+
+                if (resultSet.size() < k) {
+                    // Now update elements of minDistSquared.  We only have to compute
+                    // the distance to the new center to do this.
+                    for (int j = 0; j < numPoints; j++) {
+                        // Only have to worry about the points still not taken.
+                        if (!taken[j]) {
+                            double d = p.distanceFrom(pointList.get(j));
+                            double d2 = d * d;
+                            if (d2 < minDistSquared[j]) {
+                                minDistSquared[j] = d2;
+                            }
+                        }
+                    }
+                }
+
+            } else {
+                // None found --
+                // Break from the while loop to prevent
+                // an infinite loop.
+                break;
+            }
         }
 
         return resultSet;
-
     }
 
     /**

Modified: commons/proper/math/trunk/src/site/xdoc/changes.xml
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/site/xdoc/changes.xml?rev=1132448&r1=1132447&r2=1132448&view=diff
==============================================================================
--- commons/proper/math/trunk/src/site/xdoc/changes.xml (original)
+++ commons/proper/math/trunk/src/site/xdoc/changes.xml Sun Jun  5 16:27:53 2011
@@ -52,6 +52,9 @@ The <action> type attribute can be add,u
     If the output is not quite correct, check for invisible trailing spaces!
      -->
     <release version="3.0" date="TBD" description="TBD">
+      <action dev="luc" type="fix" issue="MATH-584" due-to="Randall Scarberry">
+        Improved k-means++ clustering performances and initial cluster center choice.
+      </action>
       <action dev="luc" type="fix" issue="MATH-504" due-to="X. B.">
         Fixed tricube function implementation in Loess interpolator.
       </action>



Mime
View raw message