mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject svn commit: r1197510 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/bayes/ core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/ core/src/main/java/org/apache/mahout...
Date Fri, 04 Nov 2011 11:20:04 GMT
Author: srowen
Date: Fri Nov  4 11:20:03 2011
New Revision: 1197510

URL: http://svn.apache.org/viewvc?rev=1197510&view=rev
Log:
MAHOUT-838 Add confusion matrix dumper

Added:
    mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/
    mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
    mahout/trunk/src/conf/driver.classes.props

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Fri
Nov  4 11:20:03 2011
@@ -19,35 +19,40 @@ package org.apache.mahout.classifier;
 
 import java.util.Collection;
 import java.util.Collections;
-import java.util.LinkedHashMap;
 import java.util.Map;
 
-import com.google.common.collect.Maps;
 import org.apache.commons.lang.StringUtils;
-import org.apache.mahout.math.CardinalityException;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.Matrix;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
 
 /**
  * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
  * 
+ * The fact of whether there is a default is not stored. A row of zeros is the only indicator
that there is no default.
+ * 
  * See http://en.wikipedia.org/wiki/Confusion_matrix for background
  */
 public class ConfusionMatrix {
-
-  private final Map<String,Integer> labelMap = new LinkedHashMap<String,Integer>();
+  private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
   private final int[][] confusionMatrix;
   private String defaultLabel = "unknown";
   
   public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
     confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
     this.defaultLabel = defaultLabel;
+    int i = 0;
     for (String label : labels) {
-      labelMap.put(label, labelMap.size());
+      labelMap.put(label, i++);
     }
-    labelMap.put(defaultLabel, labelMap.size());
+    labelMap.put(defaultLabel, i);
+  }
+  
+  public ConfusionMatrix(Matrix m) {
+    confusionMatrix = new int[m.numRows()][m.numRows()];
+    setMatrix(m);
   }
   
   public int[][] getConfusionMatrix() {
@@ -76,7 +81,7 @@ public class ConfusionMatrix {
     return confusionMatrix[labelId][labelId];
   }
   
-  public double getTotal(String label) {
+  public int getTotal(String label) {
     int labelId = labelMap.get(label);
     int labelTotal = 0;
     for (int i = 0; i < labelMap.size(); i++) {
@@ -94,25 +99,25 @@ public class ConfusionMatrix {
   }
   
   public int getCount(String correctLabel, String classifiedLabel) {
-    Preconditions.checkArgument(labelMap.containsKey(correctLabel),
-                                "Label not found: " + correctLabel);
-    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
-                                "Label not found: " + classifiedLabel);
+    Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not found: " +
correctLabel);
+    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found:
" + classifiedLabel);
     int correctId = labelMap.get(correctLabel);
     int classifiedId = labelMap.get(classifiedLabel);
     return confusionMatrix[correctId][classifiedId];
   }
   
   public void putCount(String correctLabel, String classifiedLabel, int count) {
-    Preconditions.checkArgument(labelMap.containsKey(correctLabel),
-                                "Label not found: " + correctLabel);
-    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
-                                "Label not found: " + classifiedLabel);
+    Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not found: " +
correctLabel);
+    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found:
" + classifiedLabel);
     int correctId = labelMap.get(correctLabel);
     int classifiedId = labelMap.get(classifiedLabel);
     confusionMatrix[correctId][classifiedId] = count;
   }
   
+  public String getDefaultLabel() {
+    return defaultLabel;
+  }
+  
   public void incrementCount(String correctLabel, String classifiedLabel, int count) {
     putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
   }
@@ -132,45 +137,69 @@ public class ConfusionMatrix {
   }
   
   public Matrix getMatrix() {
-	  int length = confusionMatrix.length;
-	  Matrix m = new DenseMatrix(length, length);
-	  for (int r = 0; r < length; r++) {
-		  for (int c = 0; c < length; c++) {
-			  m.set(r, c, confusionMatrix[r][c]);
-		  }
-	  }
-	  Map<String,Integer> labels = Maps.newHashMap();
-	  for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
-		  labels.put(entry.getKey(), entry.getValue());
-	  }
-	  m.setRowLabelBindings(labels);
-	  m.setColumnLabelBindings(labels);
-	  return m;
+    int length = confusionMatrix.length;
+    Matrix m = new DenseMatrix(length, length);
+    for (int r = 0; r < length; r++) {
+      for (int c = 0; c < length; c++) {
+        m.set(r, c, confusionMatrix[r][c]);
+      }
+    }
+    Map<String,Integer> labels = Maps.newHashMap();
+    for(Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+      labels.put(entry.getKey(), entry.getValue());
+    }
+    m.setRowLabelBindings(labels);
+    m.setColumnLabelBindings(labels);
+    return m;
   }
-
+  
   public void setMatrix(Matrix m) {
-	  int length = confusionMatrix.length;
-	  if (m.numRows() != m.numCols()) {
-      throw new CardinalityException(m.numRows(), m.numCols());
-    }
-    if (m.numRows() != length) {
-      throw new CardinalityException(m.numRows(), length);
-    }
-	  for (int r = 0; r < length; r++) {
-		  for (int c = 0; c < length; c++) {
-			  confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
-		  }
-	  }
-	  Map<String,Integer> labels = m.getRowLabelBindings();
-	  if (labels == null) {
+    int length = confusionMatrix.length;
+    if (m.numRows() != m.numCols()) {
+      throw new IllegalArgumentException(
+          String.format("ConfusionMatrix: matrix({},{}) must be square", m.numRows(), m.numCols()));
+    }
+    for (int r = 0; r < length; r++) {
+      for (int c = 0; c < length; c++) {
+        confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+      }
+    }
+    Map<String,Integer> labels = m.getRowLabelBindings();
+    if (labels == null) {
       labels = m.getColumnLabelBindings();
     }
-    labelMap.clear();    
-	  if (labels != null) {
-      labelMap.putAll(labels);
-	  }
+    if (labels != null) {
+      String[] sorted = sortLabels(labels);
+      verifyLabels(length, sorted);
+      labelMap.clear();
+      for(int i = 0; i < length; i++) {
+        labelMap.put(sorted[i], i);
+      }
+    }
+  }
+  
+  private static String[] sortLabels(Map<String,Integer> labels) {
+    String[] sorted = new String[labels.keySet().size()];
+    for(String label: labels.keySet()) {
+      Integer index = labels.get(label);
+      sorted[index] = label;
+    }
+    return sorted;
+  }
+  
+  private void verifyLabels(int length, String[] sorted) {
+    Preconditions.checkArgument(sorted.length == length, "One label, one row");
+    for(int i = 0; i < length; i++) {
+      if (sorted[i] == null) {
+        Preconditions.checkArgument(false, "One label, one row");
+      }
+    }
   }
   
+  /**
+   * This is overloaded. toString() is not a formatted report you print for a manager :)
+   * Assume that if there are no default assignments, the default feature was not used
+   */
   @Override
   public String toString() {
     StringBuilder returnString = new StringBuilder(200);
@@ -178,26 +207,37 @@ public class ConfusionMatrix {
     returnString.append("Confusion Matrix\n");
     returnString.append("-------------------------------------------------------").append('\n');
     
+    int unclassified = getTotal(defaultLabel);
     for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+        continue;
+      }
+      
       returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
     }
     
     returnString.append("<--Classified as").append('\n');
-    
     for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+        continue;
+      }
       String correctLabel = entry.getKey();
       int labelTotal = 0;
       for (String classifiedLabel : this.labelMap.keySet()) {
+        if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
+          continue;
+        }
         returnString.append(
-          StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)),
5)).append('\t');
+            StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)),
5)).append('\t');
         labelTotal += getCount(correctLabel, classifiedLabel);
       }
       returnString.append(" |  ").append(StringUtils.rightPad(String.valueOf(labelTotal),
6)).append('\t')
-          .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
-          .append(" = ").append(correctLabel).append('\n');
+      .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
+      .append(" = ").append(correctLabel).append('\n');
+    }
+    if (unclassified > 0) {
+      returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
     }
-    returnString.append("Default Category: ").append(defaultLabel).append(": ").append(
-      labelMap.get(defaultLabel)).append('\n');
     returnString.append('\n');
     return returnString.toString();
   }
@@ -212,5 +252,5 @@ public class ConfusionMatrix {
     } while (val > 0);
     return returnString.toString();
   }
-  
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
Fri Nov  4 11:20:03 2011
@@ -33,6 +33,7 @@ import org.apache.commons.cli2.builder.D
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
 import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ConfusionMatrix;
 import org.apache.mahout.classifier.ResultAnalyzer;
 import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
 import org.apache.mahout.common.CommandLineUtil;
@@ -104,9 +105,14 @@ public final class TestClassifier {
       "Method of Classification: sequential|mapreduce. Default Value: sequential").withShortName("method")
         .create();
     
+    Option confusionMatrixOpt = obuilder.withLongName("confusionMatrix").withRequired(false).withArgument(
+        abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Export ConfusionMatrix as SequenceFile").withShortName("cm").create();
+      
     Group group = gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption(
       encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt)
-        .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt).create();
+        .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt)
+        .withOption(confusionMatrixOpt).create();
     
     try {
       Parser parser = new Parser();
@@ -163,6 +169,11 @@ public final class TestClassifier {
         classificationMethod = (String) cmdLine.getValue(methodOpt);
       }
       
+      String confusionMatrixFile = null;
+      if (cmdLine.hasOption(confusionMatrixOpt)) {
+        confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt);
+      }
+      
       params.setGramSize(gramSize);
       params.set("verbose", Boolean.toString(verbose));
       params.setBasePath(modelBasePath);
@@ -172,6 +183,7 @@ public final class TestClassifier {
       params.set("encoding", encoding);
       params.set("alpha_i", alphaI);
       params.set("testDirPath", testDirPath);
+      params.set("confusionMatrix", confusionMatrixFile);
       
       if ("sequential".equalsIgnoreCase(classificationMethod)) {
         classifySequential(params);
@@ -253,12 +265,12 @@ public final class TestClassifier {
           }
           lineNum++;
         }
-        /*
-         * log.info("{}\t{}\t{}/{}", new Object[] {correctLabel,
-         * resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel),
-         * resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel),
-         * resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});
-         */
+        ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix();
+        log.info("{}", matrix);
+        BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix);
+
+        log.info("ConfusionMatrix: {}", matrix.toString());
+           
         log.info("Classified instances from {}", file.getName());
         if (verbose) {
           log.info("Performance stats {}", operationStats.toString());

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
Fri Nov  4 11:20:03 2011
@@ -21,10 +21,15 @@ import java.io.IOException;
 import java.util.Map;
 
 import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+
 import org.apache.hadoop.conf.Configurable;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.FileInputFormat;
 import org.apache.hadoop.mapred.FileOutputFormat;
 import org.apache.hadoop.mapred.JobClient;
@@ -38,6 +43,7 @@ import org.apache.mahout.common.Paramete
 import org.apache.mahout.common.StringTuple;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.MatrixWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -83,6 +89,10 @@ public final class BayesClassifierDriver
     Path outputFiles = new Path(outPath, "part*");
     ConfusionMatrix matrix = readResult(outputFiles, conf, params);
     log.info("{}", matrix);
+    if (params.get("confusionMatrix") != null) {
+      confusionMatrixSeqFileExport(params, matrix);
+    }
+    
   }
   
   public static ConfusionMatrix readResult(Path pathPattern, Configuration conf, Parameters
params) {
@@ -117,6 +127,24 @@ public final class BayesClassifierDriver
       }
     }
     return matrix;
+  }
     
+  public static void confusionMatrixSeqFileExport(Parameters params, ConfusionMatrix matrix)
throws IOException {
+    if (params.get("confusionMatrix") != null) {
+      Configuration conf = new Configuration();
+      FileSystem fs = FileSystem.get(conf);
+      SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
+          new Path(params.get("confusionMatrix")), Text.class, MatrixWritable.class);
+      String name = params.get("confusionMatrix");
+      // embed file name as sequence key- useful for tuning classifiers
+      name = name.substring(name.lastIndexOf('/') + 1, name.length());
+      Text key = new Text(name);
+      MatrixWritable mw = new MatrixWritable(matrix.getMatrix());
+      try {
+        writer.append(key, mw);
+      } finally {
+        Closeables.closeQuietly(writer);
+      }
+    }
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java Fri Nov  4
11:20:03 2011
@@ -123,15 +123,15 @@ public class MatrixWritable implements W
     }
 
     if (hasLabels) {
-    	Map<String,Integer> columnLabelBindings = Maps.newHashMap();
-    	Map<String,Integer> rowLabelBindings = Maps.newHashMap();
-    	readLabels(in, columnLabelBindings, rowLabelBindings);
-    	if (!columnLabelBindings.isEmpty()) {
-    		r.setColumnLabelBindings(columnLabelBindings);
-    	}
-    	if (!rowLabelBindings.isEmpty()) {
-    		r.setRowLabelBindings(rowLabelBindings);
-    	}
+      Map<String,Integer> columnLabelBindings = Maps.newHashMap();
+      Map<String,Integer> rowLabelBindings = Maps.newHashMap();
+      readLabels(in, columnLabelBindings, rowLabelBindings);
+      if (!columnLabelBindings.isEmpty()) {
+        r.setColumnLabelBindings(columnLabelBindings);
+      }
+      if (!rowLabelBindings.isEmpty()) {
+        r.setRowLabelBindings(rowLabelBindings);
+      }
     }
 
     return r;
@@ -159,7 +159,7 @@ public class MatrixWritable implements W
       VectorWritable.writeVector(out, matrix.viewRow(i), false);
     }
     if ((flags & FLAG_LABELS) != 0) {
-    	writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
+      writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
     }
   }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java Fri Nov
 4 11:20:03 2011
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.math;
 
+import com.google.common.collect.Maps;
 import com.google.common.io.Closeables;
 import org.apache.hadoop.io.Writable;
 import org.junit.Test;
@@ -26,7 +27,6 @@ import java.io.ByteArrayOutputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
-import java.util.HashMap;
 import java.util.Map;
 
 public final class MatrixWritableTest extends MahoutTestCase {
@@ -36,13 +36,14 @@ public final class MatrixWritableTest ex
 		Matrix m = new SparseMatrix(5, 5);
 		m.set(1, 2, 3.0);
 		m.set(3, 4, 5.0);
-		Map<String, Integer> bindings = new HashMap<String, Integer>();
+		Map<String, Integer> bindings = Maps.newHashMap();
 		bindings.put("A", 0);
 		bindings.put("B", 1);
 		bindings.put("C", 2);
 		bindings.put("D", 3);
 		bindings.put("default", 4);
 		m.setRowLabelBindings(bindings);
+    m.setColumnLabelBindings(bindings);
 		doTestMatrixWritableEquals(m);
 	}
 
@@ -51,12 +52,13 @@ public final class MatrixWritableTest ex
 		Matrix m = new DenseMatrix(5,5);
 		m.set(1, 2, 3.0);
 		m.set(3, 4, 5.0);
-		Map<String, Integer> bindings = new HashMap<String, Integer>();
+		Map<String, Integer> bindings = Maps.newHashMap();
 		bindings.put("A", 0);
 		bindings.put("B", 1);
 		bindings.put("C", 2);
 		bindings.put("D", 3);
 		bindings.put("default", 4);
+    m.setRowLabelBindings(bindings);
 		m.setColumnLabelBindings(bindings);
 		doTestMatrixWritableEquals(m);
 	}
@@ -66,7 +68,9 @@ public final class MatrixWritableTest ex
 		MatrixWritable matrixWritable2 = new MatrixWritable();
 		writeAndRead(matrixWritable, matrixWritable2);
 		Matrix m2 = matrixWritable2.get();
-		compareMatrices(m, m2);  // not sure this works?
+		compareMatrices(m, m2); 
+    doCheckBindings(m2.getRowLabelBindings());
+    doCheckBindings(m2.getColumnLabelBindings());    
 	}
 
 	private static void compareMatrices(Matrix m, Matrix m2) {
@@ -98,6 +102,14 @@ public final class MatrixWritableTest ex
 		}
 	}
 
+  private static void doCheckBindings(Map<String,Integer> labels) {
+    assertTrue("Missing label", labels.keySet().contains("A"));
+    assertTrue("Missing label", labels.keySet().contains("B"));
+    assertTrue("Missing label", labels.keySet().contains("C"));
+    assertTrue("Missing label", labels.keySet().contains("D"));
+    assertTrue("Missing label", labels.keySet().contains("default"));
+  }
+
 	private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
 		ByteArrayOutputStream baos = new ByteArrayOutputStream();
 		DataOutputStream dos = new DataOutputStream(baos);

Added: mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1197510&view=auto
==============================================================================
--- mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
(added)
+++ mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
Fri Nov  4 11:20:03 2011
@@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Export a ConfusionMatrix in various text formats: 
+ *   ToString version
+ *   Grayscale HTML table
+ *   Summary HTML table 
+ *   Table of counts
+ *   all with optional HTML wrappers
+ * 
+ * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair
+ * 
+ * Intended to consume ConfusionMatrix SequenceFile output by Bayes
+ * TestClassifier class
+ */
+public final class ConfusionMatrixDumper extends AbstractJob {
+
+  // HTML wrapper - default CSS
+  private static final String HEADER = "<html>" +
+      "<head>\n" +
+      "<title>TITLE</title>\n" +
+      "</head>" +
+      "<body>\n" +
+      "<style type='text/css'> \n" +
+      "table\n" +
+      "{\n" +
+      "border:3px solid black; text-align:left;\n" +
+      "}\n" +
+      "th.normalHeader\n" +
+      "{\n" +
+      "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n"
+
+      "}\n" +
+      "th.tallHeader\n" +
+      "{\n" +
+      "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white;
height:6em\n" +
+      "}\n" +
+      "tr.label\n" +
+      "{\n" +
+      "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n"
+
+      "}\n" +
+      "tr.row\n" +
+      "{\n" +
+      "border:1px solid gray;text-align:center;background-color:snow\n" +
+      "}\n" +
+      "td\n" +
+      "{\n" +
+      "min-width:2em\n" +
+      "}\n" +
+      "td.cell\n" +
+      "{\n" +
+      "border:1px solid black;text-align:right;background-color:snow\n" +
+      "}\n" +
+      "td.empty\n" +
+      "{\n" +
+      "border:0px;text-align:right;background-color:snow\n" +
+      "}\n" +
+      "td.white\n" +
+      "{\n" +
+      "border:0px solid black;text-align:right;background-color:white\n" +
+      "}\n" +
+      "td.black\n" +
+      "{\n" +
+      "border:0px solid red;text-align:right;background-color:black\n" +
+      "}\n" +
+      "td.gray1\n" +
+      "{\n" +
+      "border:0px solid green;text-align:right; background-color:LightGray\n" +
+      "}\n" +
+      "td.gray2\n" +
+      "{\n" +
+      "border:0px solid blue;text-align:right;background-color:gray\n" +
+      "}\n" +
+      "td.gray3\n" +
+      "{\n" +
+      "border:0px solid red;text-align:right;background-color:DarkGray\n" +
+      "}\n" +
+      "th" +
+      "{\n" +
+      "        text-align: center;\n" +
+      "        vertical-align: bottom;\n" +
+      "        padding-bottom: 3px;\n" +
+      "        padding-left: 5px;\n" +
+      "        padding-right: 5px;\n" +
+      "}\n" +
+      "     .verticalText\n" +
+      "      {\n" +
+      "        text-align: center;\n" +
+      "        vertical-align: middle;\n" +
+      "        width: 20px;\n" +
+      "        margin: 0px;\n" +
+      "        padding: 0px;\n" +
+      "        padding-left: 3px;\n" +
+      "        padding-right: 3px;\n" +
+      "        padding-top: 10px;\n" +
+      "        white-space: nowrap;\n" +
+      "        -webkit-transform: rotate(-90deg); \n" +
+      "        -moz-transform: rotate(-90deg);         \n" +
+      "      };\n" +
+      "</style>\n";
+  private static final String FOOTER = "</html></body>";
+  
+  // CSS style names. 
+  private static final String CSS_TABLE = "table";
+  private static final String CSS_LABEL = "label";
+  private static final String CSS_TALL_HEADER = "tall";
+  private static final String CSS_VERTICAL = "verticalText";
+  private static final String CSS_CELL = "cell";
+  private static final String CSS_EMPTY = "empty";
+  private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"};
+  
+  private ConfusionMatrixDumper() {}
+  
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new ConfusionMatrixDumper(), args);
+  }
+  
+  @Override
+  public int run(String[] args) throws IOException {
+    addInputOption();
+    addOption("output", "o", "Output path", null); // AbstractJob output feature requires
param
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    addFlag("html", null, "Create complete HTML page");
+    addFlag("text", null, "Dump simple text");
+    Map<String, String> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
+    }
+    
+    Path inputPath = getInputPath();
+    String outputFile = parsedArgs.containsKey("--output") ? parsedArgs.get("--output") :
null;
+    boolean text = parsedArgs.containsKey("--text");
+    boolean wrapHtml = parsedArgs.containsKey("--html");
+    PrintStream out = getPrintStream(outputFile);
+    if (text) {
+      exportText(inputPath, out);
+    } else {
+      exportTable(inputPath, out, wrapHtml);
+    }
+    out.flush();
+    if (out != System.out) {
+      out.close();
+    }
+    return 0;
+  }
+  
+  private static void exportText(Path inputPath, PrintStream out) throws IOException {
+    MatrixWritable mw = new MatrixWritable();
+    Text key = new Text();
+    readSeqFile(inputPath, key, mw);
+    Matrix m = mw.get();
+    ConfusionMatrix cm = new ConfusionMatrix(m);
+    out.println(cm.toString());
+  }
+  
+  private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws
IOException {
+    MatrixWritable mw = new MatrixWritable();
+    Text key = new Text();
+    readSeqFile(inputPath, key, mw);
+    String fileName = inputPath.getName();
+    fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length());
+    Matrix m = mw.get();
+    ConfusionMatrix cm = new ConfusionMatrix(m);
+    if (wrapHtml) {
+      printHeader(out, fileName);
+    }
+    out.println("<p/>");
+    printSummaryTable(cm, out);
+    out.println("<p/>");
+    printGrayTable(cm, out);
+    out.println("<p/>");
+    printCountsTable(cm, out);
+    out.println("<p/>");
+    printTextInBox(cm, out);
+    out.println("<p/>");
+    if (wrapHtml) {
+      printFooter(out);
+    }
+  }
+  
+  private static List<String> stripDefault(ConfusionMatrix cm) {
+    List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); 
+    String defaultLabel = cm.getDefaultLabel();
+    int unclassified = cm.getTotal(defaultLabel);
+    if (unclassified > 0) {
+      return stripped;
+    }
+    stripped.remove(defaultLabel);
+    return stripped;
+  }
+  
+  // TODO: test - this should work with HDFS files
+  private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException
{
+    Configuration conf = new Configuration();
+    FileSystem fs = FileSystem.get(conf);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+    reader.next(key, m);
+  }
+  
+  // TODO: test - this might not work with HDFS files?
+  //     after all, it does no seeks
+  private static PrintStream getPrintStream(String outputFilename) throws IOException {
+    if (outputFilename != null) {
+      File outputFile = new File(outputFilename);
+      if (outputFile.exists()) { 
+        outputFile.delete();
+      }
+      outputFile.createNewFile();
+      OutputStream os = new FileOutputStream(outputFile);
+      return new PrintStream(os);
+    } else {
+      return System.out;
+    }
+  }
+  
+  private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
+    Iterator<String> iter = cm.getLabels().iterator();
+    int count = 0;
+    while(iter.hasNext()) {
+      count += cm.getCount(rowLabel, iter.next());
+    }
+    return count;
+  }
+  
+  // HTML generator code
+  
+  private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
+    out.println("<div style='width:90%;overflow:scroll;'>");
+    out.println("<pre>");
+    out.println(cm.toString());
+    out.println("</pre>");
+    out.println("</div>");
+  }
+  
+  public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) {
+    format("<table class='%s'>\n", out, CSS_TABLE);
+    format("<tr class='%s'>", out, CSS_LABEL);
+    out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>");
+    out.println("</tr>");
+    List<String> labels = stripDefault(cm);
+    for(String label: labels) {
+      printSummaryRow(cm, out, label);
+    }
+    out.println("</table>");
+  }
+  
+  private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label)
{
+    format("<tr class='%s'>", out, CSS_CELL);
+    int correct = cm.getCorrect(label);
+    double accuracy = cm.getAccuracy(label);
+    int count = getCount(cm, label);
+    format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>",
+           out, CSS_CELL, label, count, correct, (int) Math.round(accuracy));
+    out.println("</tr>");
+  }
+  
+  private static int getCount(ConfusionMatrix cm, String label) {
+    int count = 0;
+    for (String s : cm.getLabels()) {
+      count += cm.getCount(label, s);
+    }
+    return count;
+  }
+  
+  public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
+    format("<table class='%s'>\n", out, CSS_TABLE);
+    printCountsHeader(cm, out, true);
+    printGrayRows(cm, out);
+    out.println("</table>");
+  }
+  
+  /**
+   * Print each value in a four-value grayscale based on count/max. 
+   * Gives a mostly white matrix with grays in misclassified, and black in diagonal.
+   * TODO: Using the sqrt(count/max) as the rating is more stringent 
+   */
+  private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
+    List<String> labels = stripDefault(cm);
+    for (String label: labels) {
+      printGrayRow(cm, out, labels, label);
+    }
+  }
+  
+  private static void printGrayRow(ConfusionMatrix cm, PrintStream out, Iterable<String>
labels, String rowLabel) {
+    format("<tr class='%s'>", out, CSS_LABEL);
+    format("<td>%s</td>", out, rowLabel);
+    int total = getLabelTotal(cm, rowLabel);
+    for (String columnLabel: labels) {
+      printGrayCell(cm, out, total, rowLabel, columnLabel);
+    }
+    out.println("</tr>");
+  }
+  
+  // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs
+  // assign black to count = total, meaning complete success
+  // alternative rating is to use sqrt(total) instead of total - this is more drastic 
+  private static void printGrayCell(ConfusionMatrix cm,
+                                    PrintStream out,
+                                    int total,
+                                    String rowLabel,
+                                    String columnLabel) {
+    
+    int count = cm.getCount(rowLabel, columnLabel);
+    if (count == 0) {
+      out.format("<td class='%s'/>", CSS_EMPTY);
+    } else {
+      // 0 is white, full is black, everything else gray
+      int rating = (int) ((count/ (double) total) * 4);
+      String css = CSS_GRAY_CELLS[rating];
+      format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count);
+    }
+  }
+  
+  public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
+    int length = cm.getLabels().size();
+    format("<table class='%s'>\n", out, CSS_TABLE);
+    printCountsHeader(cm, out, false);
+    printCountsRows(cm, out);
+    out.println("</table>");
+  }
+  
+  private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
+    List<String> labels = stripDefault(cm);
+    for(String label: labels) {
+      printCountsRow(cm, out, labels, label);
+    }
+  }
+  
+  private static void printCountsRow(ConfusionMatrix cm, PrintStream out, Iterable<String>
labels, String rowLabel) {
+    out.println("<tr>");
+    format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel);
+    for(String columnLabel: labels) {
+      printCountsCell(cm, out, rowLabel, columnLabel);
+    }
+    out.println("</tr>");
+  }
+  
+  private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel,
String columnLabel) {
+    int count = cm.getCount(rowLabel, columnLabel);
+    String s = count == 0 ? "" : Integer.toString(count);
+    format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s);
+  }
+  
+  private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical)
{
+    List<String> labels = stripDefault(cm);
+    int longest = getLongestHeader(labels);
+    if (vertical) {
+      // do vertical - rotation is a bitch
+      out.format("<tr class='%s' style='height:%dem'><th>&nbsp;</th>\n",
CSS_TALL_HEADER, longest/2);
+      for(String label: labels) {
+        out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL,
label);
+      }
+      out.println("</tr>");
+    } else {
+      // header - empty cell in upper left
+      out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE, CSS_LABEL);
+      for(String label: labels) {
+        out.format("<td>%s</td>", label);
+      }
+      out.format("</tr>");
+    }
+  }
+  
+  private static int getLongestHeader(Iterable<String> labels) {
+    int max = 0;
+    for (String label: labels) {
+      max = Math.max(label.length(), max);
+    }
+    return max;
+  }
+  
+  private static void format(String format, PrintStream out, Object ... args) {
+    String format2 = String.format(format, args);
+    out.println(format2);
+  }
+  
+  public static void printHeader(PrintStream out, CharSequence title) {
+    out.println(HEADER.replace("TITLE", title));
+  }
+  
+  public static void printFooter(PrintStream out) {
+    out.println(FOOTER);
+  }
+  
+}

Modified: mahout/trunk/src/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Fri Nov  4 11:20:03 2011
@@ -46,4 +46,6 @@ org.apache.mahout.classifier.sequencelea
 org.apache.mahout.classifier.sequencelearning.hmm.RandomSequenceGenerator = hmmpredict :
Generate random sequence of observations by given HMM
 org.apache.mahout.utils.SplitInput = split : Split Input data into test and train sets
 org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob = trainnb : Train the
Vector-based Bayes classifier
-org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb : Test the Vector-based
Bayes classifier
\ No newline at end of file
+org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb : Test the Vector-based
Bayes classifier
+org.apache.mahout.classifier.ConfusionMatrixDumper = cmdump : Dump confusion matrix in HTML
or text formats
+org.apache.mahout.utils.MatrixDumper = matrixdump : Dump matrix in CSV format
\ No newline at end of file



Mime
View raw message