commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From er...@apache.org
Subject [commons-math] 01/04: MATH-1524 Move "chooseInitialCenters" out of the KMeansPlusPlusClusterer
Date Wed, 11 Mar 2020 00:17:13 GMT
This is an automated email from the ASF dual-hosted git repository.

erans pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-math.git

commit 84102c0c4ca34d5d024cf6d02c733c8583fb1b2a
Author: CT <chentao@qq.com>
AuthorDate: Wed Mar 11 01:48:26 2020 +0800

    MATH-1524 Move "chooseInitialCenters" out of the KMeansPlusPlusClusterer
---
 .../ml/clustering/KMeansPlusPlusClusterer.java     | 134 +--------------
 .../initialization/CentroidInitializer.java        |  39 +++++
 .../KMeansPlusPlusCentroidInitializer.java         | 186 +++++++++++++++++++++
 .../initialization/RandomCentroidInitializer.java  |  65 +++++++
 .../initialization/CentroidInitializerTest.java    |  49 ++++++
 5 files changed, 347 insertions(+), 126 deletions(-)

diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
index e5dea41..e05918e 100644
--- a/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/KMeansPlusPlusClusterer.java
@@ -26,6 +26,8 @@ import org.apache.commons.math4.exception.ConvergenceException;
 import org.apache.commons.math4.exception.MathIllegalArgumentException;
 import org.apache.commons.math4.exception.NumberIsTooSmallException;
 import org.apache.commons.math4.exception.util.LocalizedFormats;
+import org.apache.commons.math4.ml.clustering.initialization.CentroidInitializer;
+import org.apache.commons.math4.ml.clustering.initialization.KMeansPlusPlusCentroidInitializer;
 import org.apache.commons.math4.ml.distance.DistanceMeasure;
 import org.apache.commons.math4.ml.distance.EuclideanDistance;
 import org.apache.commons.rng.simple.RandomSource;
@@ -70,6 +72,9 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends
Clusterer<T>
     /** Selected strategy for empty clusters. */
     private final EmptyClusterStrategy emptyStrategy;
 
+    /** Clusters centroids initializer. */
+    private final CentroidInitializer centroidInitializer;
+
     /** Build a clusterer.
      * <p>
      * The default strategy for handling empty clusters that may appear during
@@ -148,6 +153,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends
Clusterer<T>
         this.maxIterations = maxIterations;
         this.random        = random;
         this.emptyStrategy = emptyStrategy;
+        // Use K-means++ to choose the initial centers.
+        this.centroidInitializer = new KMeansPlusPlusCentroidInitializer(measure, random);
     }
 
     /**
@@ -203,7 +210,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends
Clusterer<T>
         }
 
         // create the initial clusters
-        List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
+        List<CentroidCluster<T>> clusters = centroidInitializer.selectCentroids(points,
k);
 
         // create an array containing the latest assignment of a point to a cluster
         // no need to initialize the array, as it will be filled with the first assignment
@@ -277,131 +284,6 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> extends
Clusterer<T>
     }
 
     /**
-     * Use K-means++ to choose the initial centers.
-     *
-     * @param points the points to choose the initial centers from
-     * @return the initial centers
-     */
-    private List<CentroidCluster<T>> chooseInitialCenters(final Collection<T>
points) {
-
-        // Convert to list for indexed access. Make it unmodifiable, since removal of items
-        // would screw up the logic of this method.
-        final List<T> pointList = Collections.unmodifiableList(new ArrayList<>
(points));
-
-        // The number of points in the list.
-        final int numPoints = pointList.size();
-
-        // Set the corresponding element in this array to indicate when
-        // elements of pointList are no longer available.
-        final boolean[] taken = new boolean[numPoints];
-
-        // The resulting list of initial centers.
-        final List<CentroidCluster<T>> resultSet = new ArrayList<>();
-
-        // Choose one center uniformly at random from among the data points.
-        final int firstPointIndex = random.nextInt(numPoints);
-
-        final T firstPoint = pointList.get(firstPointIndex);
-
-        resultSet.add(new CentroidCluster<T>(firstPoint));
-
-        // Must mark it as taken
-        taken[firstPointIndex] = true;
-
-        // To keep track of the minimum distance squared of elements of
-        // pointList to elements of resultSet.
-        final double[] minDistSquared = new double[numPoints];
-
-        // Initialize the elements.  Since the only point in resultSet is firstPoint,
-        // this is very easy.
-        for (int i = 0; i < numPoints; i++) {
-            if (i != firstPointIndex) { // That point isn't considered
-                double d = distance(firstPoint, pointList.get(i));
-                minDistSquared[i] = d*d;
-            }
-        }
-
-        while (resultSet.size() < k) {
-
-            // Sum up the squared distances for the points in pointList not
-            // already taken.
-            double distSqSum = 0.0;
-
-            for (int i = 0; i < numPoints; i++) {
-                if (!taken[i]) {
-                    distSqSum += minDistSquared[i];
-                }
-            }
-
-            // Add one new data point as a center. Each point x is chosen with
-            // probability proportional to D(x)2
-            final double r = random.nextDouble() * distSqSum;
-
-            // The index of the next point to be added to the resultSet.
-            int nextPointIndex = -1;
-
-            // Sum through the squared min distances again, stopping when
-            // sum >= r.
-            double sum = 0.0;
-            for (int i = 0; i < numPoints; i++) {
-                if (!taken[i]) {
-                    sum += minDistSquared[i];
-                    if (sum >= r) {
-                        nextPointIndex = i;
-                        break;
-                    }
-                }
-            }
-
-            // If it's not set to >= 0, the point wasn't found in the previous
-            // for loop, probably because distances are extremely small.  Just pick
-            // the last available point.
-            if (nextPointIndex == -1) {
-                for (int i = numPoints - 1; i >= 0; i--) {
-                    if (!taken[i]) {
-                        nextPointIndex = i;
-                        break;
-                    }
-                }
-            }
-
-            // We found one.
-            if (nextPointIndex >= 0) {
-
-                final T p = pointList.get(nextPointIndex);
-
-                resultSet.add(new CentroidCluster<T> (p));
-
-                // Mark it as taken.
-                taken[nextPointIndex] = true;
-
-                if (resultSet.size() < k) {
-                    // Now update elements of minDistSquared.  We only have to compute
-                    // the distance to the new center to do this.
-                    for (int j = 0; j < numPoints; j++) {
-                        // Only have to worry about the points still not taken.
-                        if (!taken[j]) {
-                            double d = distance(p, pointList.get(j));
-                            double d2 = d * d;
-                            if (d2 < minDistSquared[j]) {
-                                minDistSquared[j] = d2;
-                            }
-                        }
-                    }
-                }
-
-            } else {
-                // None found --
-                // Break from the while loop to prevent
-                // an infinite loop.
-                break;
-            }
-        }
-
-        return resultSet;
-    }
-
-    /**
      * Get a random point from the {@link Cluster} with the largest distance variance.
      *
      * @param clusters the {@link Cluster}s to search
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
new file mode 100644
index 0000000..dcddc53
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializer.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.Clusterable;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Interface abstract the algorithm for clusterer to choose the initial centers.
+ */
+public interface CentroidInitializer {
+
+    /**
+     * Choose the initial centers.
+     *
+     * @param points the points to choose the initial centers from
+     * @param k      The number of clusters
+     * @return the initial centers
+     */
+    <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final
Collection<T> points, final int k);
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
new file mode 100644
index 0000000..f0ab288
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/KMeansPlusPlusCentroidInitializer.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.Clusterable;
+import org.apache.commons.math4.ml.distance.DistanceMeasure;
+import org.apache.commons.rng.UniformRandomProvider;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Use K-means++ to choose the initial centers.
+ *
+ * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
+ */
+public class KMeansPlusPlusCentroidInitializer implements CentroidInitializer {
+    private final DistanceMeasure measure;
+    private final UniformRandomProvider random;
+
+    /**
+     * Build a K-means++ CentroidInitializer
+     * @param measure the distance measure to use
+     * @param random the random to use.
+     */
+    public KMeansPlusPlusCentroidInitializer(final DistanceMeasure measure, final UniformRandomProvider
random) {
+        this.measure = measure;
+        this.random = random;
+    }
+
+    /**
+     * Use K-means++ to choose the initial centers.
+     *
+     * @param points the points to choose the initial centers from
+     * @param k      The number of clusters
+     * @return the initial centers
+     */
+    @Override
+    public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final
Collection<T> points, final int k) {
+        // Convert to list for indexed access. Make it unmodifiable, since removal of items
+        // would screw up the logic of this method.
+        final List<T> pointList = Collections.unmodifiableList(new ArrayList<>(points));
+
+        // The number of points in the list.
+        final int numPoints = pointList.size();
+
+        // Set the corresponding element in this array to indicate when
+        // elements of pointList are no longer available.
+        final boolean[] taken = new boolean[numPoints];
+
+        // The resulting list of initial centers.
+        final List<CentroidCluster<T>> resultSet = new ArrayList<>();
+
+        // Choose one center uniformly at random from among the data points.
+        final int firstPointIndex = random.nextInt(numPoints);
+
+        final T firstPoint = pointList.get(firstPointIndex);
+
+        resultSet.add(new CentroidCluster<>(firstPoint));
+
+        // Must mark it as taken
+        taken[firstPointIndex] = true;
+
+        // To keep track of the minimum distance squared of elements of
+        // pointList to elements of resultSet.
+        final double[] minDistSquared = new double[numPoints];
+
+        // Initialize the elements.  Since the only point in resultSet is firstPoint,
+        // this is very easy.
+        for (int i = 0; i < numPoints; i++) {
+            if (i != firstPointIndex) { // That point isn't considered
+                double d = distance(firstPoint, pointList.get(i));
+                minDistSquared[i] = d * d;
+            }
+        }
+
+        while (resultSet.size() < k) {
+
+            // Sum up the squared distances for the points in pointList not
+            // already taken.
+            double distSqSum = 0.0;
+
+            for (int i = 0; i < numPoints; i++) {
+                if (!taken[i]) {
+                    distSqSum += minDistSquared[i];
+                }
+            }
+
+            // Add one new data point as a center. Each point x is chosen with
+            // probability proportional to D(x)2
+            final double r = random.nextDouble() * distSqSum;
+
+            // The index of the next point to be added to the resultSet.
+            int nextPointIndex = -1;
+
+            // Sum through the squared min distances again, stopping when
+            // sum >= r.
+            double sum = 0.0;
+            for (int i = 0; i < numPoints; i++) {
+                if (!taken[i]) {
+                    sum += minDistSquared[i];
+                    if (sum >= r) {
+                        nextPointIndex = i;
+                        break;
+                    }
+                }
+            }
+
+            // If it's not set to >= 0, the point wasn't found in the previous
+            // for loop, probably because distances are extremely small.  Just pick
+            // the last available point.
+            if (nextPointIndex == -1) {
+                for (int i = numPoints - 1; i >= 0; i--) {
+                    if (!taken[i]) {
+                        nextPointIndex = i;
+                        break;
+                    }
+                }
+            }
+
+            // We found one.
+            if (nextPointIndex >= 0) {
+
+                final T p = pointList.get(nextPointIndex);
+
+                resultSet.add(new CentroidCluster<>(p));
+
+                // Mark it as taken.
+                taken[nextPointIndex] = true;
+
+                if (resultSet.size() < k) {
+                    // Now update elements of minDistSquared.  We only have to compute
+                    // the distance to the new center to do this.
+                    for (int j = 0; j < numPoints; j++) {
+                        // Only have to worry about the points still not taken.
+                        if (!taken[j]) {
+                            double d = distance(p, pointList.get(j));
+                            double d2 = d * d;
+                            if (d2 < minDistSquared[j]) {
+                                minDistSquared[j] = d2;
+                            }
+                        }
+                    }
+                }
+
+            } else {
+                // None found --
+                // Break from the while loop to prevent
+                // an infinite loop.
+                break;
+            }
+        }
+
+        return resultSet;
+    }
+
+    /**
+     * Calculates the distance between two {@link Clusterable} instances
+     * with the configured {@link DistanceMeasure}.
+     *
+     * @param p1 the first clusterable
+     * @param p2 the second clusterable
+     * @return the distance between the two clusterables
+     */
+    protected double distance(final Clusterable p1, final Clusterable p2) {
+        return measure.compute(p1.getPoint(), p2.getPoint());
+    }
+}
diff --git a/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
new file mode 100644
index 0000000..f3f561d
--- /dev/null
+++ b/src/main/java/org/apache/commons/math4/ml/clustering/initialization/RandomCentroidInitializer.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.Clusterable;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.ListSampler;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Random choose the initial centers.
+ */
+public class RandomCentroidInitializer implements CentroidInitializer {
+    private final UniformRandomProvider random;
+
+    /**
+     * Build a random RandomCentroidInitializer
+     *
+     * @param random the random to use.
+     */
+    public RandomCentroidInitializer(final UniformRandomProvider random) {
+        this.random = random;
+    }
+
+    /**
+     * Random choose the initial centers.
+     *
+     * @param points the points to choose the initial centers from
+     * @param k      The number of clusters
+     * @return the initial centers
+     */
+    @Override
+    public <T extends Clusterable> List<CentroidCluster<T>> selectCentroids(final
Collection<T> points, final int k) {
+        if (k < 1) {
+            return Collections.emptyList();
+        }
+        final ArrayList<T> list = new ArrayList<>(points);
+        ListSampler.shuffle(random, list);
+        final List<CentroidCluster<T>> result = new ArrayList<>(k);
+        for (int i = 0; i < k; i++) {
+            result.add(new CentroidCluster<>(list.get(i)));
+        }
+        return result;
+    }
+}
diff --git a/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java
b/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java
new file mode 100644
index 0000000..989fd14
--- /dev/null
+++ b/src/test/java/org/apache/commons/math4/ml/clustering/initialization/CentroidInitializerTest.java
@@ -0,0 +1,49 @@
+package org.apache.commons.math4.ml.clustering.initialization;
+
+import org.apache.commons.math4.ml.clustering.CentroidCluster;
+import org.apache.commons.math4.ml.clustering.DoublePoint;
+import org.apache.commons.math4.ml.distance.EuclideanDistance;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.simple.RandomSource;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class CentroidInitializerTest {
+    private void test_generate_appropriate_number_of_cluster(
+            final CentroidInitializer initializer) {
+        // Generate some data
+        final List<DoublePoint> points = new ArrayList<>();
+        final UniformRandomProvider rnd = RandomSource.create(RandomSource.MT_64);
+        for (int i = 0; i < 500; i++) {
+            double[] p = new double[2];
+            p[0] = rnd.nextDouble();
+            p[1] = rnd.nextDouble();
+            points.add(new DoublePoint(p));
+        }
+        // We can only assert that the centroid initializer
+        // implementation generate appropriate number of cluster
+        for (int k = 1; k < 50; k++) {
+            final List<CentroidCluster<DoublePoint>> centroidClusters =
+                    initializer.selectCentroids(points, k);
+            Assert.assertEquals(k, centroidClusters.size());
+        }
+    }
+
+    @Test
+    public void test_RandomCentroidInitializer() {
+        final CentroidInitializer initializer =
+                new RandomCentroidInitializer(RandomSource.create(RandomSource.MT_64));
+        test_generate_appropriate_number_of_cluster(initializer);
+    }
+
+    @Test
+    public void test_KMeanPlusPlusCentroidInitializer() {
+        final CentroidInitializer initializer =
+                new KMeansPlusPlusCentroidInitializer(new EuclideanDistance(),
+                        RandomSource.create(RandomSource.MT_64));
+        test_generate_appropriate_number_of_cluster(initializer);
+    }
+}


Mime
View raw message