mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gsing...@apache.org
Subject svn commit: r1147257 - in /mahout/trunk/core/src: main/java/org/apache/mahout/math/hadoop/similarity/ test/java/org/apache/mahout/math/hadoop/similarity/
Date Fri, 15 Jul 2011 17:31:25 GMT
Author: gsingers
Date: Fri Jul 15 17:31:25 2011
New Revision: 1147257

URL: http://svn.apache.org/viewvc?rev=1147257&view=rev
Log:
MAHOUT-763: add map-side distance calculation

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java?rev=1147257&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
(added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
Fri Jul 15 17:31:25 2011
@@ -0,0 +1,140 @@
+package org.apache.mahout.math.hadoop.similarity;
+
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+public class VectorDistanceMapper extends Mapper<WritableComparable<?>, VectorWritable,
StringTuple, DoubleWritable> {
+  private transient static Logger log = LoggerFactory.getLogger(VectorDistanceMapper.class);
+  protected DistanceMeasure measure;
+  protected List<NamedVector> seedVectors;
+
+  @Override
+  protected void map(WritableComparable<?> key, VectorWritable value, Context context)
throws IOException, InterruptedException {
+    String keyName;
+    Vector valVec = value.get();
+    if (valVec instanceof NamedVector) {
+      keyName = ((NamedVector) valVec).getName();
+    } else {
+      keyName = key.toString();
+    }
+    for (NamedVector seedVector : seedVectors) {
+      double distance = measure.distance(seedVector, valVec);
+      StringTuple outKey = new StringTuple();
+      outKey.add(seedVector.getName());
+      outKey.add(keyName);
+      context.write(outKey, new DoubleWritable(distance));
+    }
+  }
+
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    try {
+      ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+      measure = ccl.loadClass(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY))
+              .asSubclass(DistanceMeasure.class).newInstance();
+      measure.configure(conf);
+
+
+      String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
+      if (seedPathStr != null && seedPathStr.length() > 0) {
+
+        Path thePath = new Path(seedPathStr, "*");
+        Collection<Path> result = Lists.newArrayList();
+
+        // get all filtered file names in result list
+        FileSystem fs = thePath.getFileSystem(conf);
+        FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
+                PathFilters.partFilter());
+
+        for (FileStatus match : matches) {
+          result.add(fs.makeQualified(match.getPath()));
+        }
+        seedVectors = new ArrayList<NamedVector>(100);
+        long item = 0;
+        for (Path seedPath : result) {
+          for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf))
{
+            Class<? extends Writable> valueClass = value.getClass();
+            if (valueClass.equals(Cluster.class)) {
+              // get the cluster info
+              Cluster cluster = (Cluster) value;
+              Vector vector = cluster.getCenter();
+              if (vector instanceof NamedVector) {
+                seedVectors.add((NamedVector) vector);
+              } else {
+                seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
+              }
+            } else if (valueClass.equals(Canopy.class)) {
+              // get the cluster info
+              Canopy canopy = (Canopy) value;
+              Vector vector = canopy.getCenter();
+              if (vector instanceof NamedVector) {
+                seedVectors.add((NamedVector) vector);
+              } else {
+                seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
+              }
+            } else if (valueClass.equals(Vector.class)) {
+              Vector vector = (Vector) value;
+              if (vector instanceof NamedVector) {
+                seedVectors.add((NamedVector) vector);
+              } else {
+                seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
+              }
+            } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class))
{
+              VectorWritable vw = (VectorWritable) value;
+              Vector vector = vw.get();
+              if (vector instanceof NamedVector) {
+                seedVectors.add((NamedVector) vector);
+              } else {
+                seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
+              }
+            } else {
+              throw new IllegalStateException("Bad value class: " + valueClass);
+            }
+          }
+        }
+        if (seedVectors.isEmpty()) {
+          throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
+        } else {
+          log.info("Seed Vectors size: " + seedVectors.size());
+        }
+      }
+    } catch (ClassNotFoundException e) {
+      throw new IllegalStateException(e);
+    } catch (IllegalAccessException e) {
+      throw new IllegalStateException(e);
+    } catch (InstantiationException e) {
+      throw new IllegalStateException(e);
+    }
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java?rev=1147257&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
(added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
Fri Jul 15 17:31:25 2011
@@ -0,0 +1,112 @@
+package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * This class does a Map-side join between seed vectors (the map side can also be a Cluster)
and a list of other vectors
+ * and emits the a tuple of seed id, other id, distance.  It is a more generic version of
KMean's mapper
+ */
+public class VectorDistanceSimilarityJob extends AbstractJob {
+  private static final Logger log = LoggerFactory.getLogger(VectorDistanceSimilarityJob.class);
+  public static final String SEEDS = "seeds";
+  public static final String SEEDS_PATH_KEY = "seedsPath";
+  public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.distanceMeasureOption().create());
+    addOption(SEEDS, "s", "The set of vectors to compute distances against.  Must fit in
memory on the mapper");
+    addOption(DefaultOptionCreator.overwriteOption().create());
+
+    if (parseArguments(args) == null) {
+      return -1;
+    }
+
+    Path input = getInputPath();
+    Path output = getOutputPath();
+    Path seeds = new Path(getOption(SEEDS));
+    String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+    if (measureClass == null) {
+      measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+    }
+    if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+      HadoopUtil.delete(getConf(), output);
+    }
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+    DistanceMeasure measure = ccl.loadClass(measureClass).asSubclass(DistanceMeasure.class).newInstance();
+    if (getConf() == null) {
+      setConf(new Configuration());
+    }
+    run(getConf(), input, seeds, output, measure);
+    return 0;
+  }
+
+  public static void run(Configuration conf,
+                         Path input,
+                         Path seeds,
+                         Path output,
+                         DistanceMeasure measure) throws IOException, ClassNotFoundException,
InterruptedException {
+    conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
+    conf.set(SEEDS_PATH_KEY, seeds.toString());
+    Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + " input: " +
input);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setMapOutputKeyClass(StringTuple.class);
+    job.setOutputKeyClass(StringTuple.class);
+    job.setMapOutputValueClass(DoubleWritable.class);
+    job.setOutputValueClass(DoubleWritable.class);
+    job.setMapperClass(VectorDistanceMapper.class);
+
+    job.setNumReduceTasks(0);
+    FileInputFormat.addInputPath(job, input);
+    FileOutputFormat.setOutputPath(job, output);
+
+    job.setJarByClass(VectorDistanceSimilarityJob.class);
+    HadoopUtil.delete(conf, output);
+    if (!job.waitForCompletion(true)) {
+      throw new InterruptedException("VectorDistance Similarity failed processing " + seeds);
+    }
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java?rev=1147257&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Fri Jul 15 17:31:25 2011
@@ -0,0 +1,136 @@
+package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+public class TestVectorDistanceSimilarityJob extends MahoutTestCase {
+  private FileSystem fs;
+
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    Configuration conf = new Configuration();
+    fs = FileSystem.get(conf);
+  }
+
+  @Test
+  public void testVectorDistanceMapper() throws Exception {
+    Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable>.Context
context =
+            EasyMock.createMock(Mapper.Context.class);
+    StringTuple tuple;
+    tuple = new StringTuple();
+    tuple.add("foo");
+    tuple.add("123");
+    context.write(tuple, new DoubleWritable(Math.sqrt(2.0)));
+    tuple = new StringTuple();
+    tuple.add("foo2");
+    tuple.add("123");
+    context.write(tuple, new DoubleWritable(1));
+
+    EasyMock.replay(context);
+
+    Vector vector = new RandomAccessSparseVector(2);
+    vector.set(0, 2);
+    vector.set(1, 2);
+
+    VectorDistanceMapper mapper = new VectorDistanceMapper();
+    setField(mapper, "measure", new EuclideanDistanceMeasure());
+    List<NamedVector> seedVectors = new ArrayList<NamedVector>();
+    Vector seed1 = new RandomAccessSparseVector(2);
+    seed1.set(0, 1);
+    seed1.set(1, 1);
+    Vector seed2 = new RandomAccessSparseVector(2);
+    seed2.set(0, 2);
+    seed2.set(1, 1);
+
+    seedVectors.add(new NamedVector(seed1, "foo"));
+    seedVectors.add(new NamedVector(seed2, "foo2"));
+    setField(mapper, "seedVectors", seedVectors);
+
+    mapper.map(new IntWritable(123), new VectorWritable(vector), context);
+
+    EasyMock.verify(context);
+
+  }
+
+  public static final double[][] REFERENCE = {
+          {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5}
+  };
+
+  public static final double[][] SEEDS = {
+          {1, 1}, {10, 10}
+  };
+
+  @Test
+  public void testRun() throws Exception {
+    Path input = getTestTempDirPath("input");
+    Path output = getTestTempDirPath("output");
+    Path seedsPath = getTestTempDirPath("seeds");
+    List<VectorWritable> points = getPointsWritable(REFERENCE);
+    List<VectorWritable> seeds = getPointsWritable(SEEDS);
+    Configuration conf = new Configuration();
+    ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+    ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"),
fs, conf);
+    String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+            optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+            output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName()
+            };
+    ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+  }
+
+  public static List<VectorWritable> getPointsWritable(double[][] raw) {
+    List<VectorWritable> points = Lists.newArrayList();
+    for (double[] fr : raw) {
+      Vector vec = new RandomAccessSparseVector(fr.length);
+      vec.assign(fr);
+      points.add(new VectorWritable(vec));
+    }
+    return points;
+  }
+
+}



Mime
View raw message