hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tjungb...@apache.org
Subject svn commit: r1375724 - in /hama/trunk: ./ examples/ examples/src/main/java/org/apache/hama/examples/ ml/src/main/java/org/apache/hama/ml/ ml/src/main/java/org/apache/hama/ml/distance/ ml/src/main/java/org/apache/hama/ml/writable/ ml/src/test/java/org/a...
Date Tue, 21 Aug 2012 19:20:06 GMT
Author: tjungblut
Date: Tue Aug 21 19:20:06 2012
New Revision: 1375724

URL: http://svn.apache.org/viewvc?rev=1375724&view=rev
Log:
[HAMA-547]: Add K-Means to ML package

Added:
    hama/trunk/examples/src/main/java/org/apache/hama/examples/Kmeans.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/CenterMessage.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/KMeansBSP.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/
    hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/TestKMeansBSP.java
Modified:
    hama/trunk/CHANGES.txt
    hama/trunk/examples/pom.xml
    hama/trunk/examples/src/main/java/org/apache/hama/examples/ExampleDriver.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java

Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1375724&r1=1375723&r2=1375724&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Tue Aug 21 19:20:06 2012
@@ -15,6 +15,7 @@ Release 0.5 - April 10, 2012 
 
   NEW FEATURES
 
+   HAMA-547: Add K-Means to ML package (tjungblut)
    HAMA-591: Improve Pagerank (tjungblut) 
    HAMA-550: Implementation of Bipartite Matching (Apurv Verma via tjungblut)
    HAMA-588: Add voteToHalt() mechanism in Graph API (edwardyoon)

Modified: hama/trunk/examples/pom.xml
URL: http://svn.apache.org/viewvc/hama/trunk/examples/pom.xml?rev=1375724&r1=1375723&r2=1375724&view=diff
==============================================================================
--- hama/trunk/examples/pom.xml (original)
+++ hama/trunk/examples/pom.xml Tue Aug 21 19:20:06 2012
@@ -41,6 +41,11 @@
       <artifactId>hama-graph</artifactId>
       <version>${project.version}</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.hama</groupId>
+      <artifactId>hama-ml</artifactId>
+      <version>${project.version}</version>
+    </dependency>
   </dependencies>
   <build>
     <finalName>hama-examples-${project.version}</finalName>

Modified: hama/trunk/examples/src/main/java/org/apache/hama/examples/ExampleDriver.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/main/java/org/apache/hama/examples/ExampleDriver.java?rev=1375724&r1=1375723&r2=1375724&view=diff
==============================================================================
--- hama/trunk/examples/src/main/java/org/apache/hama/examples/ExampleDriver.java (original)
+++ hama/trunk/examples/src/main/java/org/apache/hama/examples/ExampleDriver.java Tue Aug
21 19:20:06 2012
@@ -36,6 +36,8 @@ public class ExampleDriver {
       pgd.addClass("inlnkcount", InlinkCount.class, "InlinkCount");
       pgd.addClass("bipartite", BipartiteMatching.class, 
           "Bipartite Matching");
+      pgd.addClass("kmeans", Kmeans.class, 
+          "K-Means Clustering");
       pgd.driver(args);
     } catch (Throwable e) {
       e.printStackTrace();

Added: hama/trunk/examples/src/main/java/org/apache/hama/examples/Kmeans.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/main/java/org/apache/hama/examples/Kmeans.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/examples/src/main/java/org/apache/hama/examples/Kmeans.java (added)
+++ hama/trunk/examples/src/main/java/org/apache/hama/examples/Kmeans.java Tue Aug 21 19:20:06
2012
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.examples;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.ml.KMeansBSP;
+import org.apache.hama.ml.writable.VectorWritable;
+
+/**
+ * Uses the {@link KMeansBSP} class to run a Kmeans Clustering with BSP. You can
+ * provide your own input, or generate some random input for benchmarking.
+ * 
+ * For your own input, you can supply a text file that contains a tab separated
+ * sequence of doubles on each line. The first k-vectors are used as the seed
+ * centers.
+ * 
+ * For random input, just supply the "-g" command the number of vectors to
+ * generate and the dimension of the vectors.
+ * 
+ * You must pass always an input directory and an output path, as well as how
+ * many iterations the algorithm should run (it will also stop if the centers
+ * won't move anymore).
+ * 
+ * The centers are stored in the given input path under
+ * center/center_output.seq. This is a center sequencefile with
+ * {@link VectorWritable} as key and {@link NullWritable} as value. You can read
+ * it with the normal FS cat utility, but you have to add the hama-ml jar to the
+ * lib directory of Hadoop, so it can find the vector classes.
+ * 
+ * The assignments from an index (the order of the center in the above sequence
+ * file matters!, also starting from 0!) to a vector can be found in the output
+ * path as text file.
+ * 
+ */
+public class Kmeans {
+
+  public static void main(String[] args) throws Exception {
+    if (args.length < 4 || args.length != 7) {
+      System.out
+          .println("USAGE: <INPUT_PATH> <OUTPUT_PATH> <MAXITERATIONS> <K
(how many centers)> -g [<COUNT> <DIMENSION OF VECTORS>]");
+      return;
+    }
+    Configuration conf = new Configuration();
+    Path in = new Path(args[0]);
+    Path out = new Path(args[1]);
+    FileSystem fs = FileSystem.get(conf);
+    Path center = null;
+    if (fs.isFile(in))
+      center = new Path(in.getParent(), "center/cen.seq");
+    else
+      center = new Path(in, "center/cen.seq");
+    Path centerOut = new Path(out, "center/center_output.seq");
+    conf.set(KMeansBSP.CENTER_IN_PATH, center.toString());
+    conf.set(KMeansBSP.CENTER_OUT_PATH, centerOut.toString());
+    int iterations = Integer.parseInt(args[2]);
+    conf.setInt(KMeansBSP.MAX_ITERATIONS_KEY, iterations);
+    int k = Integer.parseInt(args[3]);
+    if (args.length == 7 && args[4].equals("-g")) {
+      int count = Integer.parseInt(args[5]);
+      if(k > count)
+        throw new IllegalArgumentException("K can't be greater than n!");
+      int dimension = Integer.parseInt(args[6]);
+      System.out.println("N: " + count + " Dimension: " + dimension
+          + " Iterations: " + iterations);
+      // prepare the input, like deleting old versions and creating centers
+      KMeansBSP.prepareInput(count, k, dimension, conf, in, center, out, fs);
+    } else {
+      KMeansBSP.prepareInputText(k, conf, in, center, out, fs);
+      in = new Path(args[0], "textinput/in.seq");
+    }
+
+    BSPJob job = KMeansBSP.createJob(conf, in, out, true);
+
+    // just submit the job
+    job.waitForCompletion(true);
+  }
+}

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/CenterMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/CenterMessage.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/CenterMessage.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/CenterMessage.java Tue Aug 21 19:20:06
2012
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.writable.VectorWritable;
+
+public final class CenterMessage implements Writable {
+
+  private int centerIndex;
+  private DoubleVector newCenter;
+  private int incrementCounter;
+
+  public CenterMessage() {
+  }
+
+  public CenterMessage(int key, DoubleVector value) {
+    this.centerIndex = key;
+    this.newCenter = value;
+  }
+
+  public CenterMessage(int key, int increment, DoubleVector value) {
+    this.centerIndex = key;
+    this.incrementCounter = increment;
+    this.newCenter = value;
+  }
+
+  @Override
+  public final void readFields(DataInput in) throws IOException {
+    centerIndex = in.readInt();
+    incrementCounter = in.readInt();
+    newCenter = VectorWritable.readVector(in);
+  }
+
+  @Override
+  public final void write(DataOutput out) throws IOException {
+    out.writeInt(centerIndex);
+    out.writeInt(incrementCounter);
+    VectorWritable.writeVector(newCenter, out);
+  }
+
+  public int getCenterIndex() {
+    return centerIndex;
+  }
+
+  public int getIncrementCounter() {
+    return incrementCounter;
+  }
+
+  public final int getTag() {
+    return centerIndex;
+  }
+
+  public final DoubleVector getData() {
+    return newCenter;
+  }
+
+}

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/KMeansBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/KMeansBSP.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/KMeansBSP.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/KMeansBSP.java Tue Aug 21 19:20:06 2012
@@ -0,0 +1,482 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSP;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+import org.apache.hama.ml.distance.DistanceMeasurer;
+import org.apache.hama.ml.distance.EuclidianDistance;
+import org.apache.hama.ml.math.DenseDoubleVector;
+import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.writable.VectorWritable;
+import org.apache.hama.util.ReflectionUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * K-Means in BSP that reads a bunch of vectors from input system and a given
+ * centroid path that contains initial centers.
+ * 
+ */
+public final class KMeansBSP
+    extends
+    BSP<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> {
+
+  public static final String CENTER_OUT_PATH = "center.out.path";
+  public static final String MAX_ITERATIONS_KEY = "k.means.max.iterations";
+  public static final String CACHING_ENABLED_KEY = "k.means.caching.enabled";
+  public static final String DISTANCE_MEASURE_CLASS = "distance.measure.class";
+  public static final String CENTER_IN_PATH = "center.in.path";
+
+  private static final Log LOG = LogFactory.getLog(KMeansBSP.class);
+  // a task local copy of our cluster centers
+  private DoubleVector[] centers;
+  // simple cache to speed up computation, because the algorithm is disk based
+  private List<DoubleVector> cache;
+  // numbers of maximum iterations to do
+  private int maxIterations;
+  // our distance measurement
+  private DistanceMeasurer distanceMeasurer;
+  private Configuration conf;
+
+  @Override
+  public final void setup(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException, InterruptedException {
+
+    conf = peer.getConfiguration();
+
+    Path centroids = new Path(peer.getConfiguration().get(CENTER_IN_PATH));
+    FileSystem fs = FileSystem.get(peer.getConfiguration());
+    final ArrayList<DoubleVector> centers = new ArrayList<DoubleVector>();
+    SequenceFile.Reader reader = null;
+    try {
+      reader = new SequenceFile.Reader(fs, centroids, peer.getConfiguration());
+      VectorWritable key = new VectorWritable();
+      NullWritable value = NullWritable.get();
+      while (reader.next(key, value)) {
+        DoubleVector center = key.getVector();
+        centers.add(center);
+      }
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+
+    Preconditions.checkArgument(centers.size() > 0,
+        "Centers file must contain at least a single center!");
+    this.centers = centers.toArray(new DoubleVector[centers.size()]);
+
+    distanceMeasurer = new EuclidianDistance();
+    String distanceClass = peer.getConfiguration().get(DISTANCE_MEASURE_CLASS);
+    if (distanceClass != null) {
+      try {
+        distanceMeasurer = ReflectionUtils.newInstance(distanceClass);
+      } catch (ClassNotFoundException e) {
+        e.printStackTrace();
+      }
+    }
+
+    maxIterations = peer.getConfiguration().getInt(MAX_ITERATIONS_KEY, -1);
+    // normally we want to rely on OS caching, but if not, we can cache in heap
+    if (peer.getConfiguration().getBoolean(CACHING_ENABLED_KEY, false)) {
+      cache = new ArrayList<DoubleVector>();
+    }
+  }
+
+  @Override
+  public final void bsp(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException, InterruptedException, SyncException {
+    long converged;
+    while (true) {
+      assignCenters(peer);
+      peer.sync();
+      converged = updateCenters(peer);
+      peer.reopenInput();
+      if (converged == 0)
+        break;
+      if (maxIterations > 0 && maxIterations < peer.getSuperstepCount())
+        break;
+    }
+    LOG.info("Finished! Writing the assignments...");
+    recalculateAssignmentsAndWrite(peer);
+    LOG.info("Done.");
+  }
+
+  private long updateCenters(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException {
+    // this is the update step
+    DoubleVector[] msgCenters = new DoubleVector[centers.length];
+    int[] incrementSum = new int[centers.length];
+    CenterMessage msg;
+    // basically just summing incoming vectors
+    while ((msg = peer.getCurrentMessage()) != null) {
+      DoubleVector oldCenter = msgCenters[msg.getCenterIndex()];
+      DoubleVector newCenter = msg.getData();
+      incrementSum[msg.getCenterIndex()] += msg.getIncrementCounter();
+      if (oldCenter == null) {
+        msgCenters[msg.getCenterIndex()] = newCenter;
+      } else {
+        msgCenters[msg.getCenterIndex()] = oldCenter.add(newCenter);
+      }
+    }
+    // divide by how often we globally summed vectors
+    for (int i = 0; i < msgCenters.length; i++) {
+      // and only if we really have an update for c
+      if (msgCenters[i] != null) {
+        msgCenters[i] = msgCenters[i].divide(incrementSum[i]);
+      }
+    }
+    // finally check for convergence by the absolute difference
+    long convergedCounter = 0L;
+    for (int i = 0; i < msgCenters.length; i++) {
+      final DoubleVector oldCenter = centers[i];
+      if (msgCenters[i] != null) {
+        double calculateError = oldCenter.subtract(msgCenters[i]).abs().sum();
+        if (calculateError > 0.0d) {
+          centers[i] = msgCenters[i];
+          convergedCounter++;
+        }
+      }
+    }
+    return convergedCounter;
+  }
+
+  private void assignCenters(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException {
+    // each task has all the centers, if a center has been updated it
+    // needs to be broadcasted.
+    final DoubleVector[] newCenterArray = new DoubleVector[centers.length];
+    final int[] summationCount = new int[centers.length];
+    // if our cache is not enabled, iterate over the disk items
+    if (cache == null) {
+      // we have an assignment step
+      final NullWritable value = NullWritable.get();
+      final VectorWritable key = new VectorWritable();
+      while (peer.readNext(key, value)) {
+        assignCentersInternal(newCenterArray, summationCount, key.getVector()
+            .deepCopy());
+      }
+    } else {
+      // if our cache is enabled but empty, we have to read it from disk first
+      if (cache.isEmpty()) {
+        final NullWritable value = NullWritable.get();
+        final VectorWritable key = new VectorWritable();
+        while (peer.readNext(key, value)) {
+          DoubleVector deepCopy = key.getVector().deepCopy();
+          cache.add(deepCopy);
+          // but do the assignment directly
+          assignCentersInternal(newCenterArray, summationCount, deepCopy);
+        }
+      } else {
+        // now we can iterate in memory and check against the centers
+        for (DoubleVector v : cache) {
+          assignCentersInternal(newCenterArray, summationCount, v);
+        }
+      }
+    }
+    // now send messages about the local updates to each other peer
+    for (int i = 0; i < newCenterArray.length; i++) {
+      if (newCenterArray[i] != null) {
+        for (String peerName : peer.getAllPeerNames()) {
+          peer.send(peerName, new CenterMessage(i, summationCount[i],
+              newCenterArray[i]));
+        }
+      }
+    }
+  }
+
+  private void assignCentersInternal(final DoubleVector[] newCenterArray,
+      final int[] summationCount, final DoubleVector key) {
+    final int lowestDistantCenter = getNearestCenter(key);
+    final DoubleVector clusterCenter = newCenterArray[lowestDistantCenter];
+    if (clusterCenter == null) {
+      newCenterArray[lowestDistantCenter] = key;
+    } else {
+      // add the vector to the center
+      newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter]
+          .add(key);
+      summationCount[lowestDistantCenter]++;
+    }
+  }
+
+  private int getNearestCenter(DoubleVector key) {
+    int lowestDistantCenter = 0;
+    double lowestDistance = Double.MAX_VALUE;
+    for (int i = 0; i < centers.length; i++) {
+      final double estimatedDistance = distanceMeasurer.measureDistance(
+          centers[i], key);
+      // check if we have a can assign a new center, because we
+      // got a lower distance
+      if (estimatedDistance < lowestDistance) {
+        lowestDistance = estimatedDistance;
+        lowestDistantCenter = i;
+      }
+    }
+    return lowestDistantCenter;
+  }
+
+  private void recalculateAssignmentsAndWrite(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException {
+    final NullWritable value = NullWritable.get();
+    // also use our cache to speed up the final writes if exists
+    if (cache == null) {
+      final VectorWritable key = new VectorWritable();
+      IntWritable keyWrite = new IntWritable();
+      while (peer.readNext(key, value)) {
+        final int lowestDistantCenter = getNearestCenter(key.getVector());
+        keyWrite.set(lowestDistantCenter);
+        peer.write(keyWrite, key);
+      }
+    } else {
+      IntWritable keyWrite = new IntWritable();
+      for (DoubleVector v : cache) {
+        final int lowestDistantCenter = getNearestCenter(v);
+        keyWrite.set(lowestDistantCenter);
+        peer.write(keyWrite, new VectorWritable(v));
+      }
+    }
+    // just on the first task write the centers to filesystem to prevent
+    // collisions
+    if (peer.getPeerName().equals(peer.getPeerName(0))) {
+      String pathString = conf.get(CENTER_OUT_PATH);
+      if (pathString != null) {
+        final SequenceFile.Writer dataWriter = SequenceFile.createWriter(
+            FileSystem.get(conf), conf, new Path(pathString),
+            VectorWritable.class, NullWritable.class, CompressionType.NONE);
+        for (DoubleVector center : centers) {
+          dataWriter.append(new VectorWritable(center), value);
+        }
+        dataWriter.close();
+      }
+    }
+  }
+
+  /**
+   * Creates a basic job with sequencefiles as in and output.
+   */
+  public static BSPJob createJob(Configuration cnf, Path in, Path out,
+      boolean textOut) throws IOException {
+    HamaConfiguration conf = new HamaConfiguration(cnf);
+    BSPJob job = new BSPJob(conf, KMeansBSP.class);
+    job.setJobName("KMeans Clustering");
+    job.setJarByClass(KMeansBSP.class);
+    job.setBspClass(KMeansBSP.class);
+    job.setInputPath(in);
+    job.setOutputPath(out);
+    job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
+    if (textOut)
+      job.setOutputFormat(org.apache.hama.bsp.TextOutputFormat.class);
+    else
+      job.setOutputFormat(org.apache.hama.bsp.SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(VectorWritable.class);
+    return job;
+  }
+
+  public static void main(String[] args) throws IOException,
+      ClassNotFoundException, InterruptedException {
+
+    if (args.length < 6) {
+      LOG.info("USAGE: <INPUT_PATH> <OUTPUT_PATH> <COUNT> <K> <DIMENSION
OF VECTORS> <MAXITERATIONS> <optional: num of tasks>");
+      return;
+    }
+
+    Configuration conf = new Configuration();
+    int count = Integer.parseInt(args[2]);
+    int k = Integer.parseInt(args[3]);
+    int dimension = Integer.parseInt(args[4]);
+    int iterations = Integer.parseInt(args[5]);
+    conf.setInt(MAX_ITERATIONS_KEY, iterations);
+
+    Path in = new Path(args[0]);
+    Path out = new Path(args[1]);
+    Path center = new Path(in, "center/cen.seq");
+    Path centerOut = new Path(out, "center/center_output.seq");
+
+    conf.set(CENTER_IN_PATH, center.toString());
+    conf.set(CENTER_OUT_PATH, centerOut.toString());
+    // if you're in local mode, you can increase this to match your core sizes
+    conf.set("bsp.local.tasks.maximum", ""
+        + Runtime.getRuntime().availableProcessors());
+    // deactivate (set to false) if you want to iterate over disk, else it will
+    // cache the input vectors in memory
+    conf.setBoolean(CACHING_ENABLED_KEY, true);
+    BSPJob job = createJob(conf, in, out, false);
+
+    LOG.info("N: " + count + " k: " + k + " Dimension: " + dimension
+        + " Iterations: " + iterations);
+
+    FileSystem fs = FileSystem.get(conf);
+    // prepare the input, like deleting old versions and creating centers
+    prepareInput(count, k, dimension, conf, in, center, out, fs);
+    if (args.length == 7) {
+      job.setNumBspTask(Integer.parseInt(args[6]));
+    }
+
+    // just submit the job
+    job.waitForCompletion(true);
+  }
+
+  /**
+   * Reads the centers outputted from the clustering job.
+   * 
+   * @return an index on the key dimension, and a cluster center on the value.
+   */
+  public static HashMap<Integer, DoubleVector> readOutput(Configuration conf,
+      Path out, Path centerPath, FileSystem fs) throws IOException {
+    HashMap<Integer, DoubleVector> centerMap = new HashMap<Integer, DoubleVector>();
+    SequenceFile.Reader centerReader = new SequenceFile.Reader(fs, centerPath,
+        conf);
+    int index = 0;
+    VectorWritable center = new VectorWritable();
+    while (centerReader.next(center, NullWritable.get())) {
+      centerMap.put(index++, center.getVector());
+    }
+    centerReader.close();
+    return centerMap;
+  }
+
+  /**
+   * Reads input text files and writes it to a sequencefile.
+   */
+  public static Path prepareInputText(int k, Configuration conf, Path txtIn,
+      Path center, Path out, FileSystem fs) throws IOException {
+
+    Path in = null;
+    if (fs.isFile(txtIn)) {
+      in = new Path(txtIn.getParent(), "textinput/in.seq");
+    } else {
+      in = new Path(txtIn, "textinput/in.seq");
+    }
+
+    if (fs.exists(out))
+      fs.delete(out, true);
+
+    if (fs.exists(center))
+      fs.delete(center, true);
+
+    if (fs.exists(in))
+      fs.delete(in, true);
+
+    final NullWritable value = NullWritable.get();
+
+    Writer centerWriter = new SequenceFile.Writer(fs, conf, center,
+        VectorWritable.class, NullWritable.class);
+
+    final SequenceFile.Writer dataWriter = SequenceFile.createWriter(fs, conf,
+        in, VectorWritable.class, NullWritable.class, CompressionType.NONE);
+
+    int i = 0;
+
+    BufferedReader br = new BufferedReader(
+        new InputStreamReader(fs.open(txtIn)));
+    String line;
+    while ((line = br.readLine()) != null) {
+      String[] split = line.split("\t");
+      DenseDoubleVector vec = new DenseDoubleVector(split.length);
+      for (int j = 0; j < split.length; j++) {
+        vec.set(j, Double.parseDouble(split[j]));
+      }
+      VectorWritable vector = new VectorWritable(vec);
+      dataWriter.append(vector, value);
+      if (k > i) {
+        centerWriter.append(vector, value);
+      } else {
+        if (centerWriter != null) {
+          centerWriter.close();
+          centerWriter = null;
+        }
+      }
+      i++;
+    }
+    br.close();
+    dataWriter.close();
+    return in;
+  }
+
+  /**
+   * Create some random vectors as input and assign the first k vectors as
+   * intial centers.
+   */
+  public static void prepareInput(int count, int k, int dimension,
+      Configuration conf, Path in, Path center, Path out, FileSystem fs)
+      throws IOException {
+    if (fs.exists(out))
+      fs.delete(out, true);
+
+    if (fs.exists(center))
+      fs.delete(out, true);
+
+    if (fs.exists(in))
+      fs.delete(in, true);
+
+    final SequenceFile.Writer centerWriter = SequenceFile.createWriter(fs,
+        conf, center, VectorWritable.class, NullWritable.class,
+        CompressionType.NONE);
+    final NullWritable value = NullWritable.get();
+
+    final SequenceFile.Writer dataWriter = SequenceFile.createWriter(fs, conf,
+        in, VectorWritable.class, NullWritable.class, CompressionType.NONE);
+
+    Random r = new Random();
+    for (int i = 0; i < count; i++) {
+
+      double[] arr = new double[dimension];
+      for (int d = 0; d < dimension; d++) {
+        arr[d] = r.nextInt(count);
+      }
+      VectorWritable vector = new VectorWritable(new DenseDoubleVector(arr));
+      dataWriter.append(vector, value);
+      if (k > i) {
+        centerWriter.append(vector, value);
+      } else if (k == i) {
+        centerWriter.close();
+      }
+    }
+    dataWriter.close();
+  }
+}

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/CosineDistance.java Tue Aug 21
19:20:06 2012
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.distance;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+public final class CosineDistance implements DistanceMeasurer {
+
+  @Override
+  public double measureDistance(double[] set1, double[] set2) {
+    double dotProduct = 0.0;
+    double lengthSquaredp1 = 0.0;
+    double lengthSquaredp2 = 0.0;
+    for (int i = 0; i < set1.length; i++) {
+      lengthSquaredp1 += set1[i] * set1[i];
+      lengthSquaredp2 += set2[i] * set2[i];
+      dotProduct += set1[i] * set2[i];
+    }
+    double denominator = Math.sqrt(lengthSquaredp1)
+        * Math.sqrt(lengthSquaredp2);
+
+    // correct for floating-point rounding errors
+    if (denominator < dotProduct) {
+      denominator = dotProduct;
+    }
+    // prevent NaNs
+    if (denominator == 0.0d)
+      return 1.0;
+
+    return 1.0 - dotProduct / denominator;
+  }
+
+  @Override
+  public double measureDistance(DoubleVector vec1, DoubleVector vec2) {
+    double lengthSquaredv1 = vec1.pow(2).sum();
+    double lengthSquaredv2 = vec2.pow(2).sum();
+
+    double dotProduct = vec2.dot(vec1);
+    double denominator = Math.sqrt(lengthSquaredv1)
+        * Math.sqrt(lengthSquaredv2);
+
+    // correct for floating-point rounding errors
+    if (denominator < dotProduct) {
+      denominator = dotProduct;
+    }
+
+    return 1.0 - dotProduct / denominator;
+  }
+
+}

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/DistanceMeasurer.java Tue Aug
21 19:20:06 2012
@@ -0,0 +1,28 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.distance;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+public interface DistanceMeasurer {
+
+  public double measureDistance(double[] set1, double[] set2);
+
+  public double measureDistance(DoubleVector vec1, DoubleVector vec2);
+
+}

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/distance/EuclidianDistance.java Tue Aug
21 19:20:06 2012
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.distance;
+
+import org.apache.hama.ml.math.DoubleVector;
+
+public final class EuclidianDistance implements DistanceMeasurer {
+
+  @Override
+  public double measureDistance(double[] set1, double[] set2) {
+    double sum = 0;
+    int length = set1.length;
+    for (int i = 0; i < length; i++) {
+      double diff = set2[i] - set1[i];
+      // multiplication is faster than Math.pow() for ^2.
+      sum += (diff * diff);
+    }
+
+    return Math.sqrt(sum);
+  }
+
+  @Override
+  public double measureDistance(DoubleVector vec1, DoubleVector vec2) {
+    return Math.sqrt(vec2.subtract(vec1).pow(2).sum());
+  }
+
+}

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java?rev=1375724&r1=1375723&r2=1375724&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/writable/VectorWritable.java Tue Aug 21
19:20:06 2012
@@ -99,27 +99,9 @@ public final class VectorWritable implem
 
   public static void writeVector(DoubleVector vector, DataOutput out)
       throws IOException {
-    out.writeBoolean(vector.isSparse());
     out.writeInt(vector.getLength());
-    if (vector.isSparse()) {
-      out.writeInt(vector.getDimension());
-      Iterator<DoubleVector.DoubleVectorElement> iterateNonZero = vector
-          .iterateNonZero();
-      while (iterateNonZero.hasNext()) {
-        DoubleVector.DoubleVectorElement next = iterateNonZero.next();
-        out.writeInt(next.getIndex());
-        out.writeDouble(next.getValue());
-      }
-    } else {
-      for (int i = 0; i < vector.getDimension(); i++) {
-        out.writeDouble(vector.get(i));
-      }
-    }
-    if (vector.isNamed() && vector.getName() != null) {
-      out.writeBoolean(true);
-      out.writeUTF(vector.getName());
-    } else {
-      out.writeBoolean(false);
+    for (int i = 0; i < vector.getDimension(); i++) {
+      out.writeDouble(vector.get(i));
     }
   }
 

Added: hama/trunk/ml/src/test/java/org/apache/hama/ml/TestKMeansBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/TestKMeansBSP.java?rev=1375724&view=auto
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/TestKMeansBSP.java (added)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/TestKMeansBSP.java Tue Aug 21 19:20:06
2012
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml;
+
+import java.io.BufferedWriter;
+import java.io.OutputStreamWriter;
+import java.util.HashMap;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.ml.math.DoubleVector;
+
+public class TestKMeansBSP extends TestCase {
+
+  public void testRunJob() throws Exception {
+
+    Configuration conf = new Configuration();
+    Path in = new Path("/tmp/clustering/in/in.txt");
+    Path out = new Path("/tmp/clustering/out/");
+    FileSystem fs = FileSystem.get(conf);
+    Path center = null;
+    if (fs.isFile(in))
+      center = new Path(in.getParent(), "center/cen.seq");
+    else
+      center = new Path(in, "center/cen.seq");
+    Path centerOut = new Path(out, "center/center_output.seq");
+    conf.set(KMeansBSP.CENTER_IN_PATH, center.toString());
+    conf.set(KMeansBSP.CENTER_OUT_PATH, centerOut.toString());
+    int iterations = 10;
+    conf.setInt(KMeansBSP.MAX_ITERATIONS_KEY, iterations);
+    int k = 1;
+
+    FSDataOutputStream create = fs.create(in);
+    BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(create));
+    StringBuilder sb = new StringBuilder();
+
+    for (int i = 0; i < 100; i++) {
+      sb.append(i);
+      sb.append('\t');
+      sb.append(i);
+      sb.append('\n');
+    }
+
+    bw.write(sb.toString());
+    bw.close();
+
+    in = KMeansBSP.prepareInputText(k, conf, in, center, out, fs);
+
+    BSPJob job = KMeansBSP.createJob(conf, in, out, true);
+
+    // just submit the job
+    boolean result = job.waitForCompletion(true);
+
+    assertEquals(true, result);
+
+    HashMap<Integer, DoubleVector> centerMap = KMeansBSP.readOutput(conf, out,
+        centerOut, fs);
+    System.out.println(centerMap);
+    assertEquals(1, centerMap.size());
+    DoubleVector doubleVector = centerMap.get(0);
+    assertTrue(doubleVector.get(0) > 50 && doubleVector.get(0) < 51);
+    assertTrue(doubleVector.get(1) > 50 && doubleVector.get(1) < 51);
+
+  }
+
+}



Mime
View raw message