mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pran...@apache.org
Subject svn commit: r1293874 - in /mahout/trunk/core/src: main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Date Sun, 26 Feb 2012 17:02:20 GMT
Author: pranjan
Date: Sun Feb 26 17:02:20 2012
New Revision: 1293874

URL: http://svn.apache.org/viewvc?rev=1293874&view=rev
Log:
MAHOUT-931, MAHOUT-929. Added emitMostLikely and threshold based outlier removal capability
in ClusterClassificationDriver.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1293874&r1=1293873&r2=1293874&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
Sun Feb 26 17:02:20 2012
@@ -46,6 +46,7 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
 import org.apache.mahout.math.VectorWritable;
 
 /**
@@ -63,8 +64,10 @@ public class ClusterClassificationDriver
     addInputOption();
     addOutputOption();
     addOption(DefaultOptionCreator.methodOption().create());
-    addOption(DefaultOptionCreator.clustersInOption()
-        .withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable,
Cluster/Canopy.")
+    addOption(DefaultOptionCreator
+        .clustersInOption()
+        .withDescription(
+            "The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.")
         .create());
     
     if (parseArguments(args) == null) {
@@ -77,16 +80,19 @@ public class ClusterClassificationDriver
     if (getConf() == null) {
       setConf(new Configuration());
     }
-    Path clustersIn = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
-    boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
-        DefaultOptionCreator.SEQUENTIAL_METHOD);
+    Path clustersIn = new Path(
+        getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+    boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION)
+        .equalsIgnoreCase(DefaultOptionCreator.SEQUENTIAL_METHOD);
     
     double clusterClassificationThreshold = 0.0;
     if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
-      clusterClassificationThreshold = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+      clusterClassificationThreshold = Double
+          .parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
     }
     
-    run(input, clustersIn, output, clusterClassificationThreshold, runSequential);
+    run(input, clustersIn, output, clusterClassificationThreshold, true,
+        runSequential);
     
     return 0;
   }
@@ -97,7 +103,8 @@ public class ClusterClassificationDriver
   private ClusterClassificationDriver() {}
   
   public static void main(String[] args) throws Exception {
-    ToolRunner.run(new Configuration(), new ClusterClassificationDriver(), args);
+    ToolRunner
+        .run(new Configuration(), new ClusterClassificationDriver(), args);
   }
   
   /**
@@ -117,27 +124,36 @@ public class ClusterClassificationDriver
    *          classified for the cluster.
    * @param runSequential
    *          Run the process sequentially or in a mapreduce way.
+   * @param runSequential
    * @throws IOException
    * @throws InterruptedException
    * @throws ClassNotFoundException
    */
-  public static void run(Path input, Path clusteringOutputPath, Path output, Double clusterClassificationThreshold,
-      boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException
{
+  public static void run(Path input, Path clusteringOutputPath, Path output,
+      Double clusterClassificationThreshold, boolean emitMostLikely,
+      boolean runSequential) throws IOException, InterruptedException,
+      ClassNotFoundException {
     if (runSequential) {
-      classifyClusterSeq(input, clusteringOutputPath, output, clusterClassificationThreshold);
+      classifyClusterSeq(input, clusteringOutputPath, output,
+          clusterClassificationThreshold, emitMostLikely);
     } else {
       Configuration conf = new Configuration();
-      classifyClusterMR(conf, input, clusteringOutputPath, output, clusterClassificationThreshold);
+      classifyClusterMR(conf, input, clusteringOutputPath, output,
+          clusterClassificationThreshold, emitMostLikely);
     }
     
   }
   
-  private static void classifyClusterSeq(Path input, Path clusters, Path output, Double clusterClassificationThreshold)
+  private static void classifyClusterSeq(Path input, Path clusters,
+      Path output, Double clusterClassificationThreshold, boolean emitMostLikely)
       throws IOException {
     List<Cluster> clusterModels = populateClusterModels(clusters);
-    ClusteringPolicy policy = ClusterClassifier.readPolicy(finalClustersPath(clusters));
-    ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels, policy);
-    selectCluster(input, clusterModels, clusterClassifier, output, clusterClassificationThreshold);
+    ClusteringPolicy policy = ClusterClassifier
+        .readPolicy(finalClustersPath(clusters));
+    ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels,
+        policy);
+    selectCluster(input, clusterModels, clusterClassifier, output,
+        clusterClassificationThreshold, emitMostLikely);
     
   }
   
@@ -149,12 +165,14 @@ public class ClusterClassificationDriver
    * @return The list of clusters found by the clustering.
    * @throws IOException
    */
-  private static List<Cluster> populateClusterModels(Path clusterOutputPath) throws
IOException {
+  private static List<Cluster> populateClusterModels(Path clusterOutputPath)
+      throws IOException {
     List<Cluster> clusterModels = new ArrayList<Cluster>();
     Cluster cluster = null;
     Path finalClustersPath = finalClustersPath(clusterOutputPath);
-    Iterator<?> it = new SequenceFileDirValueIterator<Writable>(finalClustersPath,
PathType.LIST,
-        PathFilters.partFilter(), null, false, new Configuration());
+    Iterator<?> it = new SequenceFileDirValueIterator<Writable>(
+        finalClustersPath, PathType.LIST, PathFilters.partFilter(), null,
+        false, new Configuration());
     while (it.hasNext()) {
       cluster = (Cluster) it.next();
       clusterModels.add(cluster);
@@ -162,9 +180,12 @@ public class ClusterClassificationDriver
     return clusterModels;
   }
   
-  private static Path finalClustersPath(Path clusterOutputPath) throws IOException {
-    FileSystem fileSystem = clusterOutputPath.getFileSystem(new Configuration());
-    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+  private static Path finalClustersPath(Path clusterOutputPath)
+      throws IOException {
+    FileSystem fileSystem = clusterOutputPath
+        .getFileSystem(new Configuration());
+    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
+        PathFilters.finalPartFilter());
     Path finalClustersPath = clusterFiles[0].getPath();
     return finalClustersPath;
   }
@@ -181,45 +202,84 @@ public class ClusterClassificationDriver
    * @param output
    *          the path to store classified data
    * @param clusterClassificationThreshold
+   * @param emitMostLikely
+   *          TODO
    * @throws IOException
    */
-  private static void selectCluster(Path input, List<Cluster> clusterModels, ClusterClassifier
clusterClassifier,
-      Path output, Double clusterClassificationThreshold) throws IOException {
+  private static void selectCluster(Path input, List<Cluster> clusterModels,
+      ClusterClassifier clusterClassifier, Path output,
+      Double clusterClassificationThreshold, boolean emitMostLikely)
+      throws IOException {
     Configuration conf = new Configuration();
-    SequenceFile.Writer writer = new SequenceFile.Writer(input.getFileSystem(conf), conf,
new Path(output,
-        "part-m-" + 0), IntWritable.class, VectorWritable.class);
-    for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(input,
PathType.LIST,
-        PathFilters.logsCRCFilter(), conf)) {
+    SequenceFile.Writer writer = new SequenceFile.Writer(
+        input.getFileSystem(conf), conf, new Path(output, "part-m-" + 0),
+        IntWritable.class, VectorWritable.class);
+    for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
+        input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
       Vector pdfPerCluster = clusterClassifier.classify(vw.get());
       if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
-        int maxValueIndex = pdfPerCluster.maxValueIndex();
-        Cluster cluster = clusterModels.get(maxValueIndex);
-        writer.append(new IntWritable(cluster.getId()), vw);
+        classifyAndWrite(clusterModels, clusterClassificationThreshold,
+            emitMostLikely, writer, vw, pdfPerCluster);
       }
     }
     writer.close();
   }
   
+  private static void classifyAndWrite(List<Cluster> clusterModels,
+      Double clusterClassificationThreshold, boolean emitMostLikely,
+      SequenceFile.Writer writer, VectorWritable vw, Vector pdfPerCluster)
+      throws IOException {
+    if (emitMostLikely) {
+      int maxValueIndex = pdfPerCluster.maxValueIndex();
+      write(clusterModels, writer, vw, maxValueIndex);
+    } else {
+      writeAllAboveThreshold(clusterModels, clusterClassificationThreshold,
+          writer, vw, pdfPerCluster);
+    }
+  }
+  
+  private static void writeAllAboveThreshold(List<Cluster> clusterModels,
+      Double clusterClassificationThreshold, SequenceFile.Writer writer,
+      VectorWritable vw, Vector pdfPerCluster) throws IOException {
+    Iterator<Element> iterateNonZero = pdfPerCluster.iterateNonZero();
+    while (iterateNonZero.hasNext()) {
+      Element pdf = iterateNonZero.next();
+      if (pdf.get() >= clusterClassificationThreshold) {
+        int clusterIndex = pdf.index();
+        write(clusterModels, writer, vw, clusterIndex);
+      }
+    }
+  }
+  
+  private static void write(List<Cluster> clusterModels,
+      SequenceFile.Writer writer, VectorWritable vw, int maxValueIndex)
+      throws IOException {
+    Cluster cluster = clusterModels.get(maxValueIndex);
+    writer.append(new IntWritable(cluster.getId()), vw);
+  }
+  
   /**
    * Decides whether the vector should be classified or not based on the max pdf
    * value of the clusters and threshold value.
    * 
-   * @param pdfPerCluster
-   *          pdf of vector belonging to different clusters.
-   * @param clusterClassificationThreshold
-   *          threshold below which the vectors won't be classified.
    * @return whether the vector should be classified or not.
    */
-  private static boolean shouldClassify(Vector pdfPerCluster, Double clusterClassificationThreshold)
{
-    return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+  private static boolean shouldClassify(Vector pdfPerCluster,
+      Double clusterClassificationThreshold) {
+    boolean isMaxPDFGreatherThanThreshold = pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+    return isMaxPDFGreatherThanThreshold;
   }
   
-  private static void classifyClusterMR(Configuration conf, Path input, Path clustersIn,
Path output,
-      Double clusterClassificationThreshold) throws IOException, InterruptedException, ClassNotFoundException
{
-    Job job = new Job(conf, "Cluster Classification Driver running over input: " + input);
+  private static void classifyClusterMR(Configuration conf, Path input,
+      Path clustersIn, Path output, Double clusterClassificationThreshold,
+      boolean emitMostLikely) throws IOException, InterruptedException,
+      ClassNotFoundException {
+    Job job = new Job(conf,
+        "Cluster Classification Driver running over input: " + input);
     job.setJarByClass(ClusterClassificationDriver.class);
     
-    conf.setFloat(OUTLIER_REMOVAL_THRESHOLD, clusterClassificationThreshold.floatValue());
+    conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
+        clusterClassificationThreshold.floatValue());
     
     conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
     
@@ -235,7 +295,8 @@ public class ClusterClassificationDriver
     FileInputFormat.addInputPath(job, input);
     FileOutputFormat.setOutputPath(job, output);
     if (!job.waitForCompletion(true)) {
-      throw new InterruptedException("Cluster Classification Driver Job failed processing
" + input);
+      throw new InterruptedException(
+          "Cluster Classification Driver Job failed processing " + input);
     }
   }
   

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1293874&r1=1293873&r2=1293874&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Sun Feb 26 17:02:20 2012
@@ -47,7 +47,8 @@ import com.google.common.collect.Lists;
 
 public class ClusterClassificationDriverTest extends MahoutTestCase {
   
-  private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4,
5}, {5, 5}, {9, 9}, {8, 8}};
+  private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4},
+      {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
   
   private FileSystem fs;
   
@@ -97,7 +98,8 @@ public class ClusterClassificationDriver
     
     conf = new Configuration();
     
-    ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+    ClusteringTestUtils.writePointsToFile(points,
+        new Path(pointsPath, "file1"), fs, conf);
     runClustering(pointsPath, conf);
     runClassificationWithoutOutlierRemoval(conf);
     collectVectorsForAssertion();
@@ -114,35 +116,42 @@ public class ClusterClassificationDriver
     
     conf = new Configuration();
     
-    ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+    ClusteringTestUtils.writePointsToFile(points,
+        new Path(pointsPath, "file1"), fs, conf);
     runClustering(pointsPath, conf);
     runClassificationWithOutlierRemoval(conf);
     collectVectorsForAssertion();
     assertVectorsWithOutlierRemoval();
   }
   
-  private void runClustering(Path pointsPath, Configuration conf) throws IOException, InterruptedException,
-      ClassNotFoundException {
-    CanopyDriver.run(conf, pointsPath, clusteringOutputPath, new ManhattanDistanceMeasure(),
3.1, 2.1, false, true);
+  private void runClustering(Path pointsPath, Configuration conf)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
+        new ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
     Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final");
-    ClusterClassifier.writePolicy(new CanopyClusteringPolicy(), finalClustersPath);
+    ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
+        finalClustersPath);
   }
   
-  private void runClassificationWithoutOutlierRemoval(Configuration conf) throws IOException,
InterruptedException,
-      ClassNotFoundException {
-    ClusterClassificationDriver.run(pointsPath, clusteringOutputPath, classifiedOutputPath,
0.0, true);
+  private void runClassificationWithoutOutlierRemoval(Configuration conf)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
+        classifiedOutputPath, 0.0, true, true);
   }
   
-  private void runClassificationWithOutlierRemoval(Configuration conf2) throws IOException,
InterruptedException,
-      ClassNotFoundException {
-    ClusterClassificationDriver.run(pointsPath, clusteringOutputPath, classifiedOutputPath,
0.73, true);
+  private void runClassificationWithOutlierRemoval(Configuration conf2)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
+        classifiedOutputPath, 0.73, true, true);
   }
   
   private void collectVectorsForAssertion() throws IOException {
-    Path[] partFilePaths = FileUtil.stat2Paths(fs.globStatus(classifiedOutputPath));
+    Path[] partFilePaths = FileUtil.stat2Paths(fs
+        .globStatus(classifiedOutputPath));
     FileStatus[] listStatus = fs.listStatus(partFilePaths);
     for (FileStatus partFile : listStatus) {
-      SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs, partFile.getPath(),
conf);
+      SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
+          partFile.getPath(), conf);
       Writable clusterIdAsKey = new IntWritable();
       VectorWritable point = new VectorWritable();
       while (classifiedVectors.next(clusterIdAsKey, point)) {
@@ -176,30 +185,33 @@ public class ClusterClassificationDriver
   private void assertThirdClusterWithoutOutlierRemoval() {
     Assert.assertEquals(2, thirdCluster.size());
     for (Vector vector : thirdCluster) {
-      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}", "{1:8.0,0:8.0}"},
vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}",
+          "{1:8.0,0:8.0}"}, vector.asFormatString()));
     }
   }
   
   private void assertSecondClusterWithoutOutlierRemoval() {
     Assert.assertEquals(4, secondCluster.size());
     for (Vector vector : secondCluster) {
-      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}", "{1:4.0,0:5.0}",
"{1:5.0,0:4.0}",
-          "{1:5.0,0:5.0}"}, vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}",
+          "{1:4.0,0:5.0}", "{1:5.0,0:4.0}", "{1:5.0,0:5.0}"},
+          vector.asFormatString()));
     }
   }
   
   private void assertFirstClusterWithoutOutlierRemoval() {
     Assert.assertEquals(3, firstCluster.size());
     for (Vector vector : firstCluster) {
-      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}", "{1:1.0,0:2.0}",
"{1:2.0,0:1.0}"},
-          vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}",
+          "{1:1.0,0:2.0}", "{1:2.0,0:1.0}"}, vector.asFormatString()));
     }
   }
   
   private void assertThirdClusterWithOutlierRemoval() {
     Assert.assertEquals(1, thirdCluster.size());
     for (Vector vector : thirdCluster) {
-      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"}, vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"},
+          vector.asFormatString()));
     }
   }
   
@@ -210,7 +222,8 @@ public class ClusterClassificationDriver
   private void assertFirstClusterWithOutlierRemoval() {
     Assert.assertEquals(1, firstCluster.size());
     for (Vector vector : firstCluster) {
-      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"}, vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"},
+          vector.asFormatString()));
     }
   }
   



Mime
View raw message