mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1455095 - in /mahout/trunk/core/src: main/java/org/apache/mahout/math/hadoop/similarity/ test/java/org/apache/mahout/math/hadoop/similarity/
Date Mon, 11 Mar 2013 11:04:58 GMT
Author: ssc
Date: Mon Mar 11 11:04:58 2013
New Revision: 1455095

URL: http://svn.apache.org/r1455095
Log:
MAHOUT-1019 VectorDistanceSimilarityJob

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java?rev=1455095&r1=1455094&r2=1455095&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
Mon Mar 11 11:04:58 2013
@@ -36,6 +36,8 @@ public final class VectorDistanceMapper
 
   private DistanceMeasure measure;
   private List<NamedVector> seedVectors;
+  private boolean usesThreshold = false;
+  private double maxDistance;
 
   @Override
   protected void map(WritableComparable<?> key, VectorWritable value, Context context)
@@ -47,12 +49,15 @@ public final class VectorDistanceMapper
     } else {
       keyName = key.toString();
     }
+    
     for (NamedVector seedVector : seedVectors) {
       double distance = measure.distance(seedVector, valVec);
-      StringTuple outKey = new StringTuple();
-      outKey.add(seedVector.getName());
-      outKey.add(keyName);
-      context.write(outKey, new DoubleWritable(distance));
+      if (!usesThreshold || distance <= maxDistance) {
+          StringTuple outKey = new StringTuple();
+          outKey.add(seedVector.getName());
+          outKey.add(keyName);
+          context.write(outKey, new DoubleWritable(distance));          
+      }
     }
   }
 
@@ -60,8 +65,15 @@ public final class VectorDistanceMapper
   protected void setup(Context context) throws IOException, InterruptedException {
     super.setup(context);
     Configuration conf = context.getConfiguration();
-    measure =
-        ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY),
DistanceMeasure.class);
+
+    String maxDistanceParam = conf.get(VectorDistanceSimilarityJob.MAX_DISTANCE);
+    if (maxDistanceParam != null) {
+      usesThreshold = true;
+      maxDistance = Double.parseDouble(maxDistanceParam);
+    }
+    
+    measure = ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY),
+        DistanceMeasure.class);
     measure.configure(conf);
     seedVectors = SeedVectorUtil.loadSeedVectors(conf);
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java?rev=1455095&r1=1455094&r2=1455095&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
Mon Mar 11 11:04:58 2013
@@ -36,6 +36,8 @@ import org.apache.mahout.common.distance
 import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 
+import com.google.common.base.Preconditions;
+
 import java.io.IOException;
 
 /**
@@ -48,6 +50,7 @@ public class VectorDistanceSimilarityJob
   public static final String SEEDS_PATH_KEY = "seedsPath";
   public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
   public static final String OUT_TYPE_KEY = "outType";
+  public static final String MAX_DISTANCE = "maxDistance";
 
   public static void main(String[] args) throws Exception {
     ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
@@ -60,11 +63,13 @@ public class VectorDistanceSimilarityJob
     addOutputOption();
     addOption(DefaultOptionCreator.distanceMeasureOption().create());
     addOption(SEEDS, "s", "The set of vectors to compute distances against.  Must fit in
memory on the mapper");
+    addOption(MAX_DISTANCE, "mx", "set an upper-bound on distance (double) such that any
pair of vectors with a" +
+        " distance greater than this value is ignored in the output. Ignored for non pairwise
output!");
     addOption(DefaultOptionCreator.overwriteOption().create());
-    addOption(OUT_TYPE_KEY, "ot",
-              "[pw|v] -- Define the output style: pairwise, the default, (pw) or vector (v).
 Pairwise is a "
-                  + "tuple of <seed, other, distance>, vector is <other, <Vector
of size the number of seeds>>.",
-              "pw");
+    addOption(OUT_TYPE_KEY, "ot", "[pw|v] -- Define the output style: pairwise, the default,
(pw) or vector (v).  " +
+        "Pairwise is a tuple of <seed, other, distance>, vector is <other, <Vector
of size the number of seeds>>.",
+        "pw");
+
     if (parseArguments(args) == null) {
       return -1;
     }
@@ -83,12 +88,19 @@ public class VectorDistanceSimilarityJob
     if (getConf() == null) {
       setConf(new Configuration());
     }
-    String outType = getOption(OUT_TYPE_KEY);
-    if (outType == null) {
-      outType = "pw";
+    String outType = getOption(OUT_TYPE_KEY, "pw");
+    
+    Double maxDistance = null;
+
+    if ("pw".equals(outType)) {
+      String maxDistanceArg = getOption(MAX_DISTANCE);
+      if (maxDistanceArg != null) {
+        maxDistance = Double.parseDouble(maxDistanceArg);
+        Preconditions.checkArgument(maxDistance > 0d, "value for " + MAX_DISTANCE + "
must be greater than zero");
+      }
     }
 
-    run(getConf(), input, seeds, output, measure, outType);
+    run(getConf(), input, seeds, output, measure, outType, maxDistance);
     return 0;
   }
 
@@ -98,6 +110,18 @@ public class VectorDistanceSimilarityJob
                          Path output,
                          DistanceMeasure measure, String outType)
     throws IOException, ClassNotFoundException, InterruptedException {
+      run(conf, input, seeds, output, measure, outType, null);
+  }      
+  
+  public static void run(Configuration conf,
+          Path input,
+          Path seeds,
+          Path output,
+          DistanceMeasure measure, String outType, Double maxDistance)
+    throws IOException, ClassNotFoundException, InterruptedException {
+    if (maxDistance != null) {
+      conf.set(MAX_DISTANCE, String.valueOf(maxDistance));
+    }
     conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
     conf.set(SEEDS_PATH_KEY, seeds.toString());
     Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + " input: " +
input);
@@ -119,7 +143,6 @@ public class VectorDistanceSimilarityJob
       throw new IllegalArgumentException("Invalid outType specified: " + outType);
     }
 
-
     job.setNumReduceTasks(0);
     FileInputFormat.addInputPath(job, input);
     FileOutputFormat.setOutputPath(job, output);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java?rev=1455095&r1=1455094&r2=1455095&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Mon Mar 11 11:04:58 2013
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.math.hadoop.similarity;
 
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
@@ -50,14 +51,19 @@ import java.util.List;
 import java.util.Map;
 
 public class TestVectorDistanceSimilarityJob extends MahoutTestCase {
+
   private FileSystem fs;
 
+  private static final double[][] REFERENCE = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, {
3, 3 }, { 4, 4 }, { 5, 4 },
+      { 4, 5 }, { 5, 5 } };
+
+  private static final double[][] SEEDS = { { 1, 1 }, { 10, 10 } };
+
   @Override
   @Before
   public void setUp() throws Exception {
     super.setUp();
-    Configuration conf = new Configuration();
-    fs = FileSystem.get(conf);
+    fs = FileSystem.get(new Configuration());
   }
 
   @Test
@@ -96,7 +102,6 @@ public class TestVectorDistanceSimilarit
     mapper.map(new IntWritable(123), new VectorWritable(vector), context);
 
     EasyMock.verify(context);
-
   }
 
   @Test
@@ -130,39 +135,65 @@ public class TestVectorDistanceSimilarit
 
   }
 
-  private static final double[][] REFERENCE = {
-          {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5}
-  };
-
-  private static final double[][] SEEDS = {
-          {1, 1}, {10, 10}
-  };
-
   @Test
   public void testRun() throws Exception {
     Path input = getTestTempDirPath("input");
     Path output = getTestTempDirPath("output");
     Path seedsPath = getTestTempDirPath("seeds");
+
     List<VectorWritable> points = getPointsWritable(REFERENCE);
     List<VectorWritable> seeds = getPointsWritable(SEEDS);
+
     Configuration conf = new Configuration();
     ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
     ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"),
fs, conf);
-    String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
-            optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
-            output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName()
-    };
+
+    String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+        optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+        output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+        EuclideanDistanceMeasure.class.getName() };
+
     ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
-    int expect = SEEDS.length * REFERENCE.length;
-    DummyOutputCollector<StringTuple, DoubleWritable> collector =
-            new DummyOutputCollector<StringTuple, DoubleWritable>();
-    //
-    for (Pair<StringTuple, DoubleWritable> record :
-            new SequenceFileIterable<StringTuple, DoubleWritable>(
-                    new Path(output, "part-m-00000"), conf)) {
-      collector.collect(record.getFirst(), record.getSecond());
+
+    int expectedOutputSize = SEEDS.length * REFERENCE.length;
+    int outputSize = Iterables.size(new SequenceFileIterable<StringTuple, DoubleWritable>(new
Path(output,
+        "part-m-00000"), conf));
+    assertEquals(expectedOutputSize, outputSize);
+  }
+
+  @Test
+  public void testMaxDistance() throws Exception {
+
+    Path input = getTestTempDirPath("input");
+    Path output = getTestTempDirPath("output");
+    Path seedsPath = getTestTempDirPath("seeds");
+
+    List<VectorWritable> points = getPointsWritable(REFERENCE);
+    List<VectorWritable> seeds = getPointsWritable(SEEDS);
+
+    Configuration conf = new Configuration();
+    ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+    ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"),
fs, conf);
+
+    double maxDistance = 10;
+
+    String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+        optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+        output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+        EuclideanDistanceMeasure.class.getName(),
+        optKey(VectorDistanceSimilarityJob.MAX_DISTANCE), String.valueOf(maxDistance) };
+
+    ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+
+    int outputSize = 0;
+
+    for (Pair<StringTuple, DoubleWritable> record : new SequenceFileIterable<StringTuple,
DoubleWritable>(
+        new Path(output, "part-m-00000"), conf)) {
+      outputSize++;
+      assertTrue(record.getSecond().get() <= maxDistance);
     }
-    assertEquals(expect, collector.getData().size());
+
+    assertEquals(14, outputSize);
   }
 
   @Test
@@ -176,18 +207,17 @@ public class TestVectorDistanceSimilarit
     ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
     ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"),
fs, conf);
     String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
-            optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
-            output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(),
-            optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
+        optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+        output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+        EuclideanDistanceMeasure.class.getName(),
+        optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
     };
     ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
 
-    DummyOutputCollector<Text, VectorWritable> collector =
-            new DummyOutputCollector<Text, VectorWritable>();
-    //
-    for (Pair<Text, VectorWritable> record :
-            new SequenceFileIterable<Text, VectorWritable>(
-                    new Path(output, "part-m-00000"), conf)) {
+    DummyOutputCollector<Text, VectorWritable> collector = new DummyOutputCollector<Text,
VectorWritable>();
+
+    for (Pair<Text, VectorWritable> record :  new SequenceFileIterable<Text, VectorWritable>(
+        new Path(output, "part-m-00000"), conf)) {
       collector.collect(record.getFirst(), record.getSecond());
     }
     assertEquals(REFERENCE.length, collector.getData().size());
@@ -196,7 +226,7 @@ public class TestVectorDistanceSimilarit
     }
   }
 
-  public static List<VectorWritable> getPointsWritable(double[][] raw) {
+  private List<VectorWritable> getPointsWritable(double[][] raw) {
     List<VectorWritable> points = Lists.newArrayList();
     for (double[] fr : raw) {
       Vector vec = new RandomAccessSparseVector(fr.length);



Mime
View raw message