hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1392918 - in /hama/trunk/ml/src: main/java/org/apache/hama/ml/ main/java/org/apache/hama/ml/kmeans/ test/java/org/apache/hama/ml/ test/java/org/apache/hama/ml/kmeans/
Date Tue, 02 Oct 2012 13:55:51 GMT
Author: tommaso
Date: Tue Oct  2 13:55:51 2012
New Revision: 1392918

URL: http://svn.apache.org/viewvc?rev=1392918&view=rev
Log:
[HAMA-650] - moved KMeansBSP and CenterMessage to a dedicated kmeans package

Added:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/
    hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java   (with props)
    hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java   (with props)
    hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/
    hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java   (with props)
Removed:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/CenterMessage.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/KMeansBSP.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/TestKMeansBSP.java

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java?rev=1392918&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java Tue Oct  2 13:55:51
2012
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.kmeans;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.writable.VectorWritable;
+
+public final class CenterMessage implements Writable {
+
+  private int centerIndex;
+  private DoubleVector newCenter;
+  private int incrementCounter;
+
+  public CenterMessage() {
+  }
+
+  public CenterMessage(int key, DoubleVector value) {
+    this.centerIndex = key;
+    this.newCenter = value;
+  }
+
+  public CenterMessage(int key, int increment, DoubleVector value) {
+    this.centerIndex = key;
+    this.incrementCounter = increment;
+    this.newCenter = value;
+  }
+
+  @Override
+  public final void readFields(DataInput in) throws IOException {
+    centerIndex = in.readInt();
+    incrementCounter = in.readInt();
+    newCenter = VectorWritable.readVector(in);
+  }
+
+  @Override
+  public final void write(DataOutput out) throws IOException {
+    out.writeInt(centerIndex);
+    out.writeInt(incrementCounter);
+    VectorWritable.writeVector(newCenter, out);
+  }
+
+  public int getCenterIndex() {
+    return centerIndex;
+  }
+
+  public int getIncrementCounter() {
+    return incrementCounter;
+  }
+
+  public final int getTag() {
+    return centerIndex;
+  }
+
+  public final DoubleVector getData() {
+    return newCenter;
+  }
+
+}

Propchange: hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/CenterMessage.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java?rev=1392918&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java Tue Oct  2 13:55:51
2012
@@ -0,0 +1,487 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.kmeans;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSP;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+import org.apache.hama.ml.distance.DistanceMeasurer;
+import org.apache.hama.ml.distance.EuclidianDistance;
+import org.apache.hama.ml.math.DenseDoubleVector;
+import org.apache.hama.ml.math.DoubleVector;
+import org.apache.hama.ml.writable.VectorWritable;
+import org.apache.hama.util.ReflectionUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * K-Means in BSP that reads a bunch of vectors from input system and a given
+ * centroid path that contains initial centers.
+ * 
+ */
+public final class KMeansBSP
+    extends
+    BSP<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage> {
+
+  public static final String CENTER_OUT_PATH = "center.out.path";
+  public static final String MAX_ITERATIONS_KEY = "k.means.max.iterations";
+  public static final String CACHING_ENABLED_KEY = "k.means.caching.enabled";
+  public static final String DISTANCE_MEASURE_CLASS = "distance.measure.class";
+  public static final String CENTER_IN_PATH = "center.in.path";
+
+  private static final Log LOG = LogFactory.getLog(KMeansBSP.class);
+  // a task local copy of our cluster centers
+  private DoubleVector[] centers;
+  // simple cache to speed up computation, because the algorithm is disk based
+  private List<DoubleVector> cache;
+  // numbers of maximum iterations to do
+  private int maxIterations;
+  // our distance measurement
+  private DistanceMeasurer distanceMeasurer;
+  private Configuration conf;
+
+  @Override
+  public final void setup(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException, InterruptedException {
+
+    conf = peer.getConfiguration();
+
+    Path centroids = new Path(peer.getConfiguration().get(CENTER_IN_PATH));
+    FileSystem fs = FileSystem.get(peer.getConfiguration());
+    final ArrayList<DoubleVector> centers = new ArrayList<DoubleVector>();
+    SequenceFile.Reader reader = null;
+    try {
+      reader = new SequenceFile.Reader(fs, centroids, peer.getConfiguration());
+      VectorWritable key = new VectorWritable();
+      NullWritable value = NullWritable.get();
+      while (reader.next(key, value)) {
+        DoubleVector center = key.getVector();
+        centers.add(center);
+      }
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+
+    Preconditions.checkArgument(centers.size() > 0,
+        "Centers file must contain at least a single center!");
+    this.centers = centers.toArray(new DoubleVector[centers.size()]);
+
+
+    String distanceClass = peer.getConfiguration().get(DISTANCE_MEASURE_CLASS);
+    if (distanceClass != null) {
+      try {
+        distanceMeasurer = ReflectionUtils.newInstance(distanceClass);
+      } catch (ClassNotFoundException e) {
+        throw new RuntimeException(new StringBuilder("Wrong DistanceMeasurer implementation
").
+            append(distanceClass).append(" provided").toString());
+      }
+    }
+    else {
+      distanceMeasurer = new EuclidianDistance();
+    }
+
+    maxIterations = peer.getConfiguration().getInt(MAX_ITERATIONS_KEY, -1);
+    // normally we want to rely on OS caching, but if not, we can cache in heap
+    if (peer.getConfiguration().getBoolean(CACHING_ENABLED_KEY, false)) {
+      cache = new ArrayList<DoubleVector>();
+    }
+  }
+
+  @Override
+  public final void bsp(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException, InterruptedException, SyncException {
+    long converged;
+    while (true) {
+      assignCenters(peer);
+      peer.sync();
+      converged = updateCenters(peer);
+      peer.reopenInput();
+      if (converged == 0)
+        break;
+      if (maxIterations > 0 && maxIterations < peer.getSuperstepCount())
+        break;
+    }
+    LOG.info("Finished! Writing the assignments...");
+    recalculateAssignmentsAndWrite(peer);
+    LOG.info("Done.");
+  }
+
+  private long updateCenters(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException {
+    // this is the update step
+    DoubleVector[] msgCenters = new DoubleVector[centers.length];
+    int[] incrementSum = new int[centers.length];
+    CenterMessage msg;
+    // basically just summing incoming vectors
+    while ((msg = peer.getCurrentMessage()) != null) {
+      DoubleVector oldCenter = msgCenters[msg.getCenterIndex()];
+      DoubleVector newCenter = msg.getData();
+      incrementSum[msg.getCenterIndex()] += msg.getIncrementCounter();
+      if (oldCenter == null) {
+        msgCenters[msg.getCenterIndex()] = newCenter;
+      } else {
+        msgCenters[msg.getCenterIndex()] = oldCenter.add(newCenter);
+      }
+    }
+    // divide by how often we globally summed vectors
+    for (int i = 0; i < msgCenters.length; i++) {
+      // and only if we really have an update for c
+      if (msgCenters[i] != null) {
+        msgCenters[i] = msgCenters[i].divide(incrementSum[i]);
+      }
+    }
+    // finally check for convergence by the absolute difference
+    long convergedCounter = 0L;
+    for (int i = 0; i < msgCenters.length; i++) {
+      final DoubleVector oldCenter = centers[i];
+      if (msgCenters[i] != null) {
+        double calculateError = oldCenter.subtract(msgCenters[i]).abs().sum();
+        if (calculateError > 0.0d) {
+          centers[i] = msgCenters[i];
+          convergedCounter++;
+        }
+      }
+    }
+    return convergedCounter;
+  }
+
+  private void assignCenters(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException {
+    // each task has all the centers, if a center has been updated it
+    // needs to be broadcasted.
+    final DoubleVector[] newCenterArray = new DoubleVector[centers.length];
+    final int[] summationCount = new int[centers.length];
+    // if our cache is not enabled, iterate over the disk items
+    if (cache == null) {
+      // we have an assignment step
+      final NullWritable value = NullWritable.get();
+      final VectorWritable key = new VectorWritable();
+      while (peer.readNext(key, value)) {
+        assignCentersInternal(newCenterArray, summationCount, key.getVector()
+            .deepCopy());
+      }
+    } else {
+      // if our cache is enabled but empty, we have to read it from disk first
+      if (cache.isEmpty()) {
+        final NullWritable value = NullWritable.get();
+        final VectorWritable key = new VectorWritable();
+        while (peer.readNext(key, value)) {
+          DoubleVector deepCopy = key.getVector().deepCopy();
+          cache.add(deepCopy);
+          // but do the assignment directly
+          assignCentersInternal(newCenterArray, summationCount, deepCopy);
+        }
+      } else {
+        // now we can iterate in memory and check against the centers
+        for (DoubleVector v : cache) {
+          assignCentersInternal(newCenterArray, summationCount, v);
+        }
+      }
+    }
+    // now send messages about the local updates to each other peer
+    for (int i = 0; i < newCenterArray.length; i++) {
+      if (newCenterArray[i] != null) {
+        for (String peerName : peer.getAllPeerNames()) {
+          peer.send(peerName, new CenterMessage(i, summationCount[i],
+              newCenterArray[i]));
+        }
+      }
+    }
+  }
+
+  private void assignCentersInternal(final DoubleVector[] newCenterArray,
+      final int[] summationCount, final DoubleVector key) {
+    final int lowestDistantCenter = getNearestCenter(key);
+    final DoubleVector clusterCenter = newCenterArray[lowestDistantCenter];
+    if (clusterCenter == null) {
+      newCenterArray[lowestDistantCenter] = key;
+    } else {
+      // add the vector to the center
+      newCenterArray[lowestDistantCenter] = newCenterArray[lowestDistantCenter]
+          .add(key);
+      summationCount[lowestDistantCenter]++;
+    }
+  }
+
+  private int getNearestCenter(DoubleVector key) {
+    int lowestDistantCenter = 0;
+    double lowestDistance = Double.MAX_VALUE;
+    for (int i = 0; i < centers.length; i++) {
+      final double estimatedDistance = distanceMeasurer.measureDistance(
+          centers[i], key);
+      // check if we have a can assign a new center, because we
+      // got a lower distance
+      if (estimatedDistance < lowestDistance) {
+        lowestDistance = estimatedDistance;
+        lowestDistantCenter = i;
+      }
+    }
+    return lowestDistantCenter;
+  }
+
+  private void recalculateAssignmentsAndWrite(
+      BSPPeer<VectorWritable, NullWritable, IntWritable, VectorWritable, CenterMessage>
peer)
+      throws IOException {
+    final NullWritable value = NullWritable.get();
+    // also use our cache to speed up the final writes if exists
+    if (cache == null) {
+      final VectorWritable key = new VectorWritable();
+      IntWritable keyWrite = new IntWritable();
+      while (peer.readNext(key, value)) {
+        final int lowestDistantCenter = getNearestCenter(key.getVector());
+        keyWrite.set(lowestDistantCenter);
+        peer.write(keyWrite, key);
+      }
+    } else {
+      IntWritable keyWrite = new IntWritable();
+      for (DoubleVector v : cache) {
+        final int lowestDistantCenter = getNearestCenter(v);
+        keyWrite.set(lowestDistantCenter);
+        peer.write(keyWrite, new VectorWritable(v));
+      }
+    }
+    // just on the first task write the centers to filesystem to prevent
+    // collisions
+    if (peer.getPeerName().equals(peer.getPeerName(0))) {
+      String pathString = conf.get(CENTER_OUT_PATH);
+      if (pathString != null) {
+        final SequenceFile.Writer dataWriter = SequenceFile.createWriter(
+            FileSystem.get(conf), conf, new Path(pathString),
+            VectorWritable.class, NullWritable.class, CompressionType.NONE);
+        for (DoubleVector center : centers) {
+          dataWriter.append(new VectorWritable(center), value);
+        }
+        dataWriter.close();
+      }
+    }
+  }
+
+  /**
+   * Creates a basic job with sequencefiles as in and output.
+   */
+  public static BSPJob createJob(Configuration cnf, Path in, Path out,
+      boolean textOut) throws IOException {
+    HamaConfiguration conf = new HamaConfiguration(cnf);
+    BSPJob job = new BSPJob(conf, KMeansBSP.class);
+    job.setJobName("KMeans Clustering");
+    job.setJarByClass(KMeansBSP.class);
+    job.setBspClass(KMeansBSP.class);
+    job.setInputPath(in);
+    job.setOutputPath(out);
+    job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
+    if (textOut)
+      job.setOutputFormat(org.apache.hama.bsp.TextOutputFormat.class);
+    else
+      job.setOutputFormat(org.apache.hama.bsp.SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(VectorWritable.class);
+    return job;
+  }
+
+  public static void main(String[] args) throws IOException,
+      ClassNotFoundException, InterruptedException {
+
+    if (args.length < 6) {
+      LOG.info("USAGE: <INPUT_PATH> <OUTPUT_PATH> <COUNT> <K> <DIMENSION
OF VECTORS> <MAXITERATIONS> <optional: num of tasks>");
+      return;
+    }
+
+    Configuration conf = new Configuration();
+    int count = Integer.parseInt(args[2]);
+    int k = Integer.parseInt(args[3]);
+    int dimension = Integer.parseInt(args[4]);
+    int iterations = Integer.parseInt(args[5]);
+    conf.setInt(MAX_ITERATIONS_KEY, iterations);
+
+    Path in = new Path(args[0]);
+    Path out = new Path(args[1]);
+    Path center = new Path(in, "center/cen.seq");
+    Path centerOut = new Path(out, "center/center_output.seq");
+
+    conf.set(CENTER_IN_PATH, center.toString());
+    conf.set(CENTER_OUT_PATH, centerOut.toString());
+    // if you're in local mode, you can increase this to match your core sizes
+    conf.set("bsp.local.tasks.maximum", ""
+        + Runtime.getRuntime().availableProcessors());
+    // deactivate (set to false) if you want to iterate over disk, else it will
+    // cache the input vectors in memory
+    conf.setBoolean(CACHING_ENABLED_KEY, true);
+    BSPJob job = createJob(conf, in, out, false);
+
+    LOG.info("N: " + count + " k: " + k + " Dimension: " + dimension
+        + " Iterations: " + iterations);
+
+    FileSystem fs = FileSystem.get(conf);
+    // prepare the input, like deleting old versions and creating centers
+    prepareInput(count, k, dimension, conf, in, center, out, fs);
+    if (args.length == 7) {
+      job.setNumBspTask(Integer.parseInt(args[6]));
+    }
+
+    // just submit the job
+    job.waitForCompletion(true);
+  }
+
+  /**
+   * Reads the centers outputted from the clustering job.
+   * 
+   * @return an index on the key dimension, and a cluster center on the value.
+   */
+  public static HashMap<Integer, DoubleVector> readOutput(Configuration conf,
+      Path out, Path centerPath, FileSystem fs) throws IOException {
+    HashMap<Integer, DoubleVector> centerMap = new HashMap<Integer, DoubleVector>();
+    SequenceFile.Reader centerReader = new SequenceFile.Reader(fs, centerPath,
+        conf);
+    int index = 0;
+    VectorWritable center = new VectorWritable();
+    while (centerReader.next(center, NullWritable.get())) {
+      centerMap.put(index++, center.getVector());
+    }
+    centerReader.close();
+    return centerMap;
+  }
+
+  /**
+   * Reads input text files and writes it to a sequencefile.
+   */
+  public static Path prepareInputText(int k, Configuration conf, Path txtIn,
+      Path center, Path out, FileSystem fs) throws IOException {
+
+    Path in;
+    if (fs.isFile(txtIn)) {
+      in = new Path(txtIn.getParent(), "textinput/in.seq");
+    } else {
+      in = new Path(txtIn, "textinput/in.seq");
+    }
+
+    if (fs.exists(out))
+      fs.delete(out, true);
+
+    if (fs.exists(center))
+      fs.delete(center, true);
+
+    if (fs.exists(in))
+      fs.delete(in, true);
+
+    final NullWritable value = NullWritable.get();
+
+    Writer centerWriter = new SequenceFile.Writer(fs, conf, center,
+        VectorWritable.class, NullWritable.class);
+
+    final SequenceFile.Writer dataWriter = SequenceFile.createWriter(fs, conf,
+        in, VectorWritable.class, NullWritable.class, CompressionType.NONE);
+
+    int i = 0;
+
+    BufferedReader br = new BufferedReader(
+        new InputStreamReader(fs.open(txtIn)));
+    String line;
+    while ((line = br.readLine()) != null) {
+      String[] split = line.split("\t");
+      DenseDoubleVector vec = new DenseDoubleVector(split.length);
+      for (int j = 0; j < split.length; j++) {
+        vec.set(j, Double.parseDouble(split[j]));
+      }
+      VectorWritable vector = new VectorWritable(vec);
+      dataWriter.append(vector, value);
+      if (k > i) {
+          assert centerWriter != null;
+          centerWriter.append(vector, value);
+      } else {
+        if (centerWriter != null) {
+          centerWriter.close();
+          centerWriter = null;
+        }
+      }
+      i++;
+    }
+    br.close();
+    dataWriter.close();
+    return in;
+  }
+
+  /**
+   * Create some random vectors as input and assign the first k vectors as
+   * intial centers.
+   */
+  public static void prepareInput(int count, int k, int dimension,
+      Configuration conf, Path in, Path center, Path out, FileSystem fs)
+      throws IOException {
+    if (fs.exists(out))
+      fs.delete(out, true);
+
+    if (fs.exists(center))
+      fs.delete(out, true);
+
+    if (fs.exists(in))
+      fs.delete(in, true);
+
+    final SequenceFile.Writer centerWriter = SequenceFile.createWriter(fs,
+        conf, center, VectorWritable.class, NullWritable.class,
+        CompressionType.NONE);
+    final NullWritable value = NullWritable.get();
+
+    final SequenceFile.Writer dataWriter = SequenceFile.createWriter(fs, conf,
+        in, VectorWritable.class, NullWritable.class, CompressionType.NONE);
+
+    Random r = new Random();
+    for (int i = 0; i < count; i++) {
+
+      double[] arr = new double[dimension];
+      for (int d = 0; d < dimension; d++) {
+        arr[d] = r.nextInt(count);
+      }
+      VectorWritable vector = new VectorWritable(new DenseDoubleVector(arr));
+      dataWriter.append(vector, value);
+      if (k > i) {
+        centerWriter.append(vector, value);
+      } else if (k == i) {
+        centerWriter.close();
+      }
+    }
+    dataWriter.close();
+  }
+}

Propchange: hama/trunk/ml/src/main/java/org/apache/hama/ml/kmeans/KMeansBSP.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java?rev=1392918&view=auto
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java (added)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java Tue Oct  2 13:55:51
2012
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.kmeans;
+
+import java.io.BufferedWriter;
+import java.io.OutputStreamWriter;
+import java.util.HashMap;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.ml.kmeans.KMeansBSP;
+import org.apache.hama.ml.math.DoubleVector;
+
+public class TestKMeansBSP extends TestCase {
+
+  public void testRunJob() throws Exception {
+    Configuration conf = new Configuration();
+    Path in = new Path("/tmp/clustering/in/in.txt");
+    Path out = new Path("/tmp/clustering/out/");
+    FileSystem fs = FileSystem.get(conf);
+    Path center = null;
+
+    try {
+      center = new Path(in.getParent(), "center/cen.seq");
+
+      Path centerOut = new Path(out, "center/center_output.seq");
+      conf.set(KMeansBSP.CENTER_IN_PATH, center.toString());
+      conf.set(KMeansBSP.CENTER_OUT_PATH, centerOut.toString());
+      int iterations = 10;
+      conf.setInt(KMeansBSP.MAX_ITERATIONS_KEY, iterations);
+      int k = 1;
+
+      FSDataOutputStream create = fs.create(in);
+      BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(create));
+      StringBuilder sb = new StringBuilder();
+
+      for (int i = 0; i < 100; i++) {
+        sb.append(i);
+        sb.append('\t');
+        sb.append(i);
+        sb.append('\n');
+      }
+
+      bw.write(sb.toString());
+      bw.close();
+
+      in = KMeansBSP.prepareInputText(k, conf, in, center, out, fs);
+
+      BSPJob job = KMeansBSP.createJob(conf, in, out, true);
+
+      // just submit the job
+      boolean result = job.waitForCompletion(true);
+
+      assertEquals(true, result);
+
+      HashMap<Integer, DoubleVector> centerMap = KMeansBSP.readOutput(conf,
+          out, centerOut, fs);
+      System.out.println(centerMap);
+      assertEquals(1, centerMap.size());
+      DoubleVector doubleVector = centerMap.get(0);
+      assertTrue(doubleVector.get(0) > 50 && doubleVector.get(0) < 51);
+      assertTrue(doubleVector.get(1) > 50 && doubleVector.get(1) < 51);
+    } finally {
+      fs.delete(new Path("/tmp/clustering"), true);
+    }
+  }
+
+}

Propchange: hama/trunk/ml/src/test/java/org/apache/hama/ml/kmeans/TestKMeansBSP.java
------------------------------------------------------------------------------
    svn:eol-style = native



Mime
View raw message