mahout-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Grant Ingersoll <gsing...@apache.org>
Subject Re: svn commit: r1197510 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/bayes/ core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/ core/src/main/java/org/apache/mahout...
Date Fri, 04 Nov 2011 21:52:34 GMT
Is that supposed be ConfusionMatrixDumper in the driver.classes.props?

On Nov 4, 2011, at 7:20 AM, srowen@apache.org wrote:

> Author: srowen
> Date: Fri Nov  4 11:20:03 2011
> New Revision: 1197510
> 
> URL: http://svn.apache.org/viewvc?rev=1197510&view=rev
> Log:
> MAHOUT-838 Add confusion matrix dumper
> 
> Added:
>    mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/
>    mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
> Modified:
>    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
>    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
>    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
>    mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
>    mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
>    mahout/trunk/src/conf/driver.classes.props
> 
> Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
> URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1197510&r1=1197509&r2=1197510&view=diff
> ==============================================================================
> --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
(original)
> +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
Fri Nov  4 11:20:03 2011
> @@ -19,35 +19,40 @@ package org.apache.mahout.classifier;
> 
> import java.util.Collection;
> import java.util.Collections;
> -import java.util.LinkedHashMap;
> import java.util.Map;
> 
> -import com.google.common.collect.Maps;
> import org.apache.commons.lang.StringUtils;
> -import org.apache.mahout.math.CardinalityException;
> import org.apache.mahout.math.DenseMatrix;
> import org.apache.mahout.math.Matrix;
> 
> import com.google.common.base.Preconditions;
> +import com.google.common.collect.Maps;
> 
> /**
>  * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
>  * 
> + * The fact of whether there is a default is not stored. A row of zeros is the only
indicator that there is no default.
> + * 
>  * See http://en.wikipedia.org/wiki/Confusion_matrix for background
>  */
> public class ConfusionMatrix {
> -
> -  private final Map<String,Integer> labelMap = new LinkedHashMap<String,Integer>();
> +  private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
>   private final int[][] confusionMatrix;
>   private String defaultLabel = "unknown";
> 
>   public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
>     confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
>     this.defaultLabel = defaultLabel;
> +    int i = 0;
>     for (String label : labels) {
> -      labelMap.put(label, labelMap.size());
> +      labelMap.put(label, i++);
>     }
> -    labelMap.put(defaultLabel, labelMap.size());
> +    labelMap.put(defaultLabel, i);
> +  }
> +  
> +  public ConfusionMatrix(Matrix m) {
> +    confusionMatrix = new int[m.numRows()][m.numRows()];
> +    setMatrix(m);
>   }
> 
>   public int[][] getConfusionMatrix() {
> @@ -76,7 +81,7 @@ public class ConfusionMatrix {
>     return confusionMatrix[labelId][labelId];
>   }
> 
> -  public double getTotal(String label) {
> +  public int getTotal(String label) {
>     int labelId = labelMap.get(label);
>     int labelTotal = 0;
>     for (int i = 0; i < labelMap.size(); i++) {
> @@ -94,25 +99,25 @@ public class ConfusionMatrix {
>   }
> 
>   public int getCount(String correctLabel, String classifiedLabel) {
> -    Preconditions.checkArgument(labelMap.containsKey(correctLabel),
> -                                "Label not found: " + correctLabel);
> -    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
> -                                "Label not found: " + classifiedLabel);
> +    Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not found:
" + correctLabel);
> +    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found:
" + classifiedLabel);
>     int correctId = labelMap.get(correctLabel);
>     int classifiedId = labelMap.get(classifiedLabel);
>     return confusionMatrix[correctId][classifiedId];
>   }
> 
>   public void putCount(String correctLabel, String classifiedLabel, int count) {
> -    Preconditions.checkArgument(labelMap.containsKey(correctLabel),
> -                                "Label not found: " + correctLabel);
> -    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
> -                                "Label not found: " + classifiedLabel);
> +    Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not found:
" + correctLabel);
> +    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found:
" + classifiedLabel);
>     int correctId = labelMap.get(correctLabel);
>     int classifiedId = labelMap.get(classifiedLabel);
>     confusionMatrix[correctId][classifiedId] = count;
>   }
> 
> +  public String getDefaultLabel() {
> +    return defaultLabel;
> +  }
> +  
>   public void incrementCount(String correctLabel, String classifiedLabel, int count)
{
>     putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
>   }
> @@ -132,45 +137,69 @@ public class ConfusionMatrix {
>   }
> 
>   public Matrix getMatrix() {
> -	  int length = confusionMatrix.length;
> -	  Matrix m = new DenseMatrix(length, length);
> -	  for (int r = 0; r < length; r++) {
> -		  for (int c = 0; c < length; c++) {
> -			  m.set(r, c, confusionMatrix[r][c]);
> -		  }
> -	  }
> -	  Map<String,Integer> labels = Maps.newHashMap();
> -	  for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
> -		  labels.put(entry.getKey(), entry.getValue());
> -	  }
> -	  m.setRowLabelBindings(labels);
> -	  m.setColumnLabelBindings(labels);
> -	  return m;
> +    int length = confusionMatrix.length;
> +    Matrix m = new DenseMatrix(length, length);
> +    for (int r = 0; r < length; r++) {
> +      for (int c = 0; c < length; c++) {
> +        m.set(r, c, confusionMatrix[r][c]);
> +      }
> +    }
> +    Map<String,Integer> labels = Maps.newHashMap();
> +    for(Map.Entry<String, Integer> entry : labelMap.entrySet()) {
> +      labels.put(entry.getKey(), entry.getValue());
> +    }
> +    m.setRowLabelBindings(labels);
> +    m.setColumnLabelBindings(labels);
> +    return m;
>   }
> -
> +  
>   public void setMatrix(Matrix m) {
> -	  int length = confusionMatrix.length;
> -	  if (m.numRows() != m.numCols()) {
> -      throw new CardinalityException(m.numRows(), m.numCols());
> -    }
> -    if (m.numRows() != length) {
> -      throw new CardinalityException(m.numRows(), length);
> -    }
> -	  for (int r = 0; r < length; r++) {
> -		  for (int c = 0; c < length; c++) {
> -			  confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
> -		  }
> -	  }
> -	  Map<String,Integer> labels = m.getRowLabelBindings();
> -	  if (labels == null) {
> +    int length = confusionMatrix.length;
> +    if (m.numRows() != m.numCols()) {
> +      throw new IllegalArgumentException(
> +          String.format("ConfusionMatrix: matrix({},{}) must be square", m.numRows(),
m.numCols()));
> +    }
> +    for (int r = 0; r < length; r++) {
> +      for (int c = 0; c < length; c++) {
> +        confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
> +      }
> +    }
> +    Map<String,Integer> labels = m.getRowLabelBindings();
> +    if (labels == null) {
>       labels = m.getColumnLabelBindings();
>     }
> -    labelMap.clear();    
> -	  if (labels != null) {
> -      labelMap.putAll(labels);
> -	  }
> +    if (labels != null) {
> +      String[] sorted = sortLabels(labels);
> +      verifyLabels(length, sorted);
> +      labelMap.clear();
> +      for(int i = 0; i < length; i++) {
> +        labelMap.put(sorted[i], i);
> +      }
> +    }
> +  }
> +  
> +  private static String[] sortLabels(Map<String,Integer> labels) {
> +    String[] sorted = new String[labels.keySet().size()];
> +    for(String label: labels.keySet()) {
> +      Integer index = labels.get(label);
> +      sorted[index] = label;
> +    }
> +    return sorted;
> +  }
> +  
> +  private void verifyLabels(int length, String[] sorted) {
> +    Preconditions.checkArgument(sorted.length == length, "One label, one row");
> +    for(int i = 0; i < length; i++) {
> +      if (sorted[i] == null) {
> +        Preconditions.checkArgument(false, "One label, one row");
> +      }
> +    }
>   }
> 
> +  /**
> +   * This is overloaded. toString() is not a formatted report you print for a manager
:)
> +   * Assume that if there are no default assignments, the default feature was not used
> +   */
>   @Override
>   public String toString() {
>     StringBuilder returnString = new StringBuilder(200);
> @@ -178,26 +207,37 @@ public class ConfusionMatrix {
>     returnString.append("Confusion Matrix\n");
>     returnString.append("-------------------------------------------------------").append('\n');
> 
> +    int unclassified = getTotal(defaultLabel);
>     for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
> +      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
> +        continue;
> +      }
> +      
>       returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
>     }
> 
>     returnString.append("<--Classified as").append('\n');
> -    
>     for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
> +      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
> +        continue;
> +      }
>       String correctLabel = entry.getKey();
>       int labelTotal = 0;
>       for (String classifiedLabel : this.labelMap.keySet()) {
> +        if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
> +          continue;
> +        }
>         returnString.append(
> -          StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)),
5)).append('\t');
> +            StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)),
5)).append('\t');
>         labelTotal += getCount(correctLabel, classifiedLabel);
>       }
>       returnString.append(" |  ").append(StringUtils.rightPad(String.valueOf(labelTotal),
6)).append('\t')
> -          .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
> -          .append(" = ").append(correctLabel).append('\n');
> +      .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
> +      .append(" = ").append(correctLabel).append('\n');
> +    }
> +    if (unclassified > 0) {
> +      returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
>     }
> -    returnString.append("Default Category: ").append(defaultLabel).append(": ").append(
> -      labelMap.get(defaultLabel)).append('\n');
>     returnString.append('\n');
>     return returnString.toString();
>   }
> @@ -212,5 +252,5 @@ public class ConfusionMatrix {
>     } while (val > 0);
>     return returnString.toString();
>   }
> -  
> +
> }
> 
> Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
> URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1197510&r1=1197509&r2=1197510&view=diff
> ==============================================================================
> --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
(original)
> +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
Fri Nov  4 11:20:03 2011
> @@ -33,6 +33,7 @@ import org.apache.commons.cli2.builder.D
> import org.apache.commons.cli2.builder.GroupBuilder;
> import org.apache.commons.cli2.commandline.Parser;
> import org.apache.mahout.classifier.ClassifierResult;
> +import org.apache.mahout.classifier.ConfusionMatrix;
> import org.apache.mahout.classifier.ResultAnalyzer;
> import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
> import org.apache.mahout.common.CommandLineUtil;
> @@ -104,9 +105,14 @@ public final class TestClassifier {
>       "Method of Classification: sequential|mapreduce. Default Value: sequential").withShortName("method")
>         .create();
> 
> +    Option confusionMatrixOpt = obuilder.withLongName("confusionMatrix").withRequired(false).withArgument(
> +        abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "Export ConfusionMatrix as SequenceFile").withShortName("cm").create();
> +      
>     Group group = gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption(
>       encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt)
> -        .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt).create();
> +        .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt)
> +        .withOption(confusionMatrixOpt).create();
> 
>     try {
>       Parser parser = new Parser();
> @@ -163,6 +169,11 @@ public final class TestClassifier {
>         classificationMethod = (String) cmdLine.getValue(methodOpt);
>       }
> 
> +      String confusionMatrixFile = null;
> +      if (cmdLine.hasOption(confusionMatrixOpt)) {
> +        confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt);
> +      }
> +      
>       params.setGramSize(gramSize);
>       params.set("verbose", Boolean.toString(verbose));
>       params.setBasePath(modelBasePath);
> @@ -172,6 +183,7 @@ public final class TestClassifier {
>       params.set("encoding", encoding);
>       params.set("alpha_i", alphaI);
>       params.set("testDirPath", testDirPath);
> +      params.set("confusionMatrix", confusionMatrixFile);
> 
>       if ("sequential".equalsIgnoreCase(classificationMethod)) {
>         classifySequential(params);
> @@ -253,12 +265,12 @@ public final class TestClassifier {
>           }
>           lineNum++;
>         }
> -        /*
> -         * log.info("{}\t{}\t{}/{}", new Object[] {correctLabel,
> -         * resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel),
> -         * resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel),
> -         * resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});
> -         */
> +        ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix();
> +        log.info("{}", matrix);
> +        BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix);
> +
> +        log.info("ConfusionMatrix: {}", matrix.toString());
> +           
>         log.info("Classified instances from {}", file.getName());
>         if (verbose) {
>           log.info("Performance stats {}", operationStats.toString());
> 
> Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
> URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java?rev=1197510&r1=1197509&r2=1197510&view=diff
> ==============================================================================
> --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
(original)
> +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
Fri Nov  4 11:20:03 2011
> @@ -21,10 +21,15 @@ import java.io.IOException;
> import java.util.Map;
> 
> import com.google.common.collect.Maps;
> +import com.google.common.io.Closeables;
> +
> import org.apache.hadoop.conf.Configurable;
> import org.apache.hadoop.conf.Configuration;
> +import org.apache.hadoop.fs.FileSystem;
> import org.apache.hadoop.fs.Path;
> import org.apache.hadoop.io.DoubleWritable;
> +import org.apache.hadoop.io.SequenceFile;
> +import org.apache.hadoop.io.Text;
> import org.apache.hadoop.mapred.FileInputFormat;
> import org.apache.hadoop.mapred.FileOutputFormat;
> import org.apache.hadoop.mapred.JobClient;
> @@ -38,6 +43,7 @@ import org.apache.mahout.common.Paramete
> import org.apache.mahout.common.StringTuple;
> import org.apache.mahout.common.iterator.sequencefile.PathType;
> import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
> +import org.apache.mahout.math.MatrixWritable;
> import org.slf4j.Logger;
> import org.slf4j.LoggerFactory;
> 
> @@ -83,6 +89,10 @@ public final class BayesClassifierDriver
>     Path outputFiles = new Path(outPath, "part*");
>     ConfusionMatrix matrix = readResult(outputFiles, conf, params);
>     log.info("{}", matrix);
> +    if (params.get("confusionMatrix") != null) {
> +      confusionMatrixSeqFileExport(params, matrix);
> +    }
> +    
>   }
> 
>   public static ConfusionMatrix readResult(Path pathPattern, Configuration conf, Parameters
params) {
> @@ -117,6 +127,24 @@ public final class BayesClassifierDriver
>       }
>     }
>     return matrix;
> +  }
> 
> +  public static void confusionMatrixSeqFileExport(Parameters params, ConfusionMatrix
matrix) throws IOException {
> +    if (params.get("confusionMatrix") != null) {
> +      Configuration conf = new Configuration();
> +      FileSystem fs = FileSystem.get(conf);
> +      SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
> +          new Path(params.get("confusionMatrix")), Text.class, MatrixWritable.class);
> +      String name = params.get("confusionMatrix");
> +      // embed file name as sequence key- useful for tuning classifiers
> +      name = name.substring(name.lastIndexOf('/') + 1, name.length());
> +      Text key = new Text(name);
> +      MatrixWritable mw = new MatrixWritable(matrix.getMatrix());
> +      try {
> +        writer.append(key, mw);
> +      } finally {
> +        Closeables.closeQuietly(writer);
> +      }
> +    }
>   }
> }
> 
> Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
> URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1197510&r1=1197509&r2=1197510&view=diff
> ==============================================================================
> --- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java (original)
> +++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java Fri Nov
 4 11:20:03 2011
> @@ -123,15 +123,15 @@ public class MatrixWritable implements W
>     }
> 
>     if (hasLabels) {
> -    	Map<String,Integer> columnLabelBindings = Maps.newHashMap();
> -    	Map<String,Integer> rowLabelBindings = Maps.newHashMap();
> -    	readLabels(in, columnLabelBindings, rowLabelBindings);
> -    	if (!columnLabelBindings.isEmpty()) {
> -    		r.setColumnLabelBindings(columnLabelBindings);
> -    	}
> -    	if (!rowLabelBindings.isEmpty()) {
> -    		r.setRowLabelBindings(rowLabelBindings);
> -    	}
> +      Map<String,Integer> columnLabelBindings = Maps.newHashMap();
> +      Map<String,Integer> rowLabelBindings = Maps.newHashMap();
> +      readLabels(in, columnLabelBindings, rowLabelBindings);
> +      if (!columnLabelBindings.isEmpty()) {
> +        r.setColumnLabelBindings(columnLabelBindings);
> +      }
> +      if (!rowLabelBindings.isEmpty()) {
> +        r.setRowLabelBindings(rowLabelBindings);
> +      }
>     }
> 
>     return r;
> @@ -159,7 +159,7 @@ public class MatrixWritable implements W
>       VectorWritable.writeVector(out, matrix.viewRow(i), false);
>     }
>     if ((flags & FLAG_LABELS) != 0) {
> -    	writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
> +      writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
>     }
>   }
> }
> 
> Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
> URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1197510&r1=1197509&r2=1197510&view=diff
> ==============================================================================
> --- mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java (original)
> +++ mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java Fri
Nov  4 11:20:03 2011
> @@ -17,6 +17,7 @@
> 
> package org.apache.mahout.math;
> 
> +import com.google.common.collect.Maps;
> import com.google.common.io.Closeables;
> import org.apache.hadoop.io.Writable;
> import org.junit.Test;
> @@ -26,7 +27,6 @@ import java.io.ByteArrayOutputStream;
> import java.io.DataInputStream;
> import java.io.DataOutputStream;
> import java.io.IOException;
> -import java.util.HashMap;
> import java.util.Map;
> 
> public final class MatrixWritableTest extends MahoutTestCase {
> @@ -36,13 +36,14 @@ public final class MatrixWritableTest ex
> 		Matrix m = new SparseMatrix(5, 5);
> 		m.set(1, 2, 3.0);
> 		m.set(3, 4, 5.0);
> -		Map<String, Integer> bindings = new HashMap<String, Integer>();
> +		Map<String, Integer> bindings = Maps.newHashMap();
> 		bindings.put("A", 0);
> 		bindings.put("B", 1);
> 		bindings.put("C", 2);
> 		bindings.put("D", 3);
> 		bindings.put("default", 4);
> 		m.setRowLabelBindings(bindings);
> +    m.setColumnLabelBindings(bindings);
> 		doTestMatrixWritableEquals(m);
> 	}
> 
> @@ -51,12 +52,13 @@ public final class MatrixWritableTest ex
> 		Matrix m = new DenseMatrix(5,5);
> 		m.set(1, 2, 3.0);
> 		m.set(3, 4, 5.0);
> -		Map<String, Integer> bindings = new HashMap<String, Integer>();
> +		Map<String, Integer> bindings = Maps.newHashMap();
> 		bindings.put("A", 0);
> 		bindings.put("B", 1);
> 		bindings.put("C", 2);
> 		bindings.put("D", 3);
> 		bindings.put("default", 4);
> +    m.setRowLabelBindings(bindings);
> 		m.setColumnLabelBindings(bindings);
> 		doTestMatrixWritableEquals(m);
> 	}
> @@ -66,7 +68,9 @@ public final class MatrixWritableTest ex
> 		MatrixWritable matrixWritable2 = new MatrixWritable();
> 		writeAndRead(matrixWritable, matrixWritable2);
> 		Matrix m2 = matrixWritable2.get();
> -		compareMatrices(m, m2);  // not sure this works?
> +		compareMatrices(m, m2); 
> +    doCheckBindings(m2.getRowLabelBindings());
> +    doCheckBindings(m2.getColumnLabelBindings());    
> 	}
> 
> 	private static void compareMatrices(Matrix m, Matrix m2) {
> @@ -98,6 +102,14 @@ public final class MatrixWritableTest ex
> 		}
> 	}
> 
> +  private static void doCheckBindings(Map<String,Integer> labels) {
> +    assertTrue("Missing label", labels.keySet().contains("A"));
> +    assertTrue("Missing label", labels.keySet().contains("B"));
> +    assertTrue("Missing label", labels.keySet().contains("C"));
> +    assertTrue("Missing label", labels.keySet().contains("D"));
> +    assertTrue("Missing label", labels.keySet().contains("default"));
> +  }
> +
> 	private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException
{
> 		ByteArrayOutputStream baos = new ByteArrayOutputStream();
> 		DataOutputStream dos = new DataOutputStream(baos);
> 
> Added: mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
> URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1197510&view=auto
> ==============================================================================
> --- mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
(added)
> +++ mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
Fri Nov  4 11:20:03 2011
> @@ -0,0 +1,423 @@
> +/*
> + * Licensed to the Apache Software Foundation (ASF) under one or more
> + * contributor license agreements.  See the NOTICE file distributed with
> + * this work for additional information regarding copyright ownership.
> + * The ASF licenses this file to You under the Apache License, Version 2.0
> + * (the "License"); you may not use this file except in compliance with
> + * the License.  You may obtain a copy of the License at
> + *
> + *     http://www.apache.org/licenses/LICENSE-2.0
> + *
> + * Unless required by applicable law or agreed to in writing, software
> + * distributed under the License is distributed on an "AS IS" BASIS,
> + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> + * See the License for the specific language governing permissions and
> + * limitations under the License.
> + */
> +
> +package org.apache.mahout.classifier;
> +
> +import java.io.File;
> +import java.io.FileOutputStream;
> +import java.io.IOException;
> +import java.io.OutputStream;
> +import java.io.PrintStream;
> +import java.util.Iterator;
> +import java.util.List;
> +import java.util.Map;
> +
> +import org.apache.hadoop.conf.Configuration;
> +import org.apache.hadoop.fs.FileSystem;
> +import org.apache.hadoop.fs.Path;
> +import org.apache.hadoop.io.SequenceFile;
> +import org.apache.hadoop.io.Text;
> +import org.apache.hadoop.util.ToolRunner;
> +import org.apache.mahout.common.AbstractJob;
> +import org.apache.mahout.common.commandline.DefaultOptionCreator;
> +import org.apache.mahout.math.Matrix;
> +import org.apache.mahout.math.MatrixWritable;
> +
> +import com.google.common.collect.Lists;
> +
> +/**
> + * Export a ConfusionMatrix in various text formats: 
> + *   ToString version
> + *   Grayscale HTML table
> + *   Summary HTML table 
> + *   Table of counts
> + *   all with optional HTML wrappers
> + * 
> + * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair
> + * 
> + * Intended to consume ConfusionMatrix SequenceFile output by Bayes
> + * TestClassifier class
> + */
> +public final class ConfusionMatrixDumper extends AbstractJob {
> +
> +  // HTML wrapper - default CSS
> +  private static final String HEADER = "<html>" +
> +      "<head>\n" +
> +      "<title>TITLE</title>\n" +
> +      "</head>" +
> +      "<body>\n" +
> +      "<style type='text/css'> \n" +
> +      "table\n" +
> +      "{\n" +
> +      "border:3px solid black; text-align:left;\n" +
> +      "}\n" +
> +      "th.normalHeader\n" +
> +      "{\n" +
> +      "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n"
+
> +      "}\n" +
> +      "th.tallHeader\n" +
> +      "{\n" +
> +      "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white;
height:6em\n" +
> +      "}\n" +
> +      "tr.label\n" +
> +      "{\n" +
> +      "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n"
+
> +      "}\n" +
> +      "tr.row\n" +
> +      "{\n" +
> +      "border:1px solid gray;text-align:center;background-color:snow\n" +
> +      "}\n" +
> +      "td\n" +
> +      "{\n" +
> +      "min-width:2em\n" +
> +      "}\n" +
> +      "td.cell\n" +
> +      "{\n" +
> +      "border:1px solid black;text-align:right;background-color:snow\n" +
> +      "}\n" +
> +      "td.empty\n" +
> +      "{\n" +
> +      "border:0px;text-align:right;background-color:snow\n" +
> +      "}\n" +
> +      "td.white\n" +
> +      "{\n" +
> +      "border:0px solid black;text-align:right;background-color:white\n" +
> +      "}\n" +
> +      "td.black\n" +
> +      "{\n" +
> +      "border:0px solid red;text-align:right;background-color:black\n" +
> +      "}\n" +
> +      "td.gray1\n" +
> +      "{\n" +
> +      "border:0px solid green;text-align:right; background-color:LightGray\n" +
> +      "}\n" +
> +      "td.gray2\n" +
> +      "{\n" +
> +      "border:0px solid blue;text-align:right;background-color:gray\n" +
> +      "}\n" +
> +      "td.gray3\n" +
> +      "{\n" +
> +      "border:0px solid red;text-align:right;background-color:DarkGray\n" +
> +      "}\n" +
> +      "th" +
> +      "{\n" +
> +      "        text-align: center;\n" +
> +      "        vertical-align: bottom;\n" +
> +      "        padding-bottom: 3px;\n" +
> +      "        padding-left: 5px;\n" +
> +      "        padding-right: 5px;\n" +
> +      "}\n" +
> +      "     .verticalText\n" +
> +      "      {\n" +
> +      "        text-align: center;\n" +
> +      "        vertical-align: middle;\n" +
> +      "        width: 20px;\n" +
> +      "        margin: 0px;\n" +
> +      "        padding: 0px;\n" +
> +      "        padding-left: 3px;\n" +
> +      "        padding-right: 3px;\n" +
> +      "        padding-top: 10px;\n" +
> +      "        white-space: nowrap;\n" +
> +      "        -webkit-transform: rotate(-90deg); \n" +
> +      "        -moz-transform: rotate(-90deg);         \n" +
> +      "      };\n" +
> +      "</style>\n";
> +  private static final String FOOTER = "</html></body>";
> +  
> +  // CSS style names. 
> +  private static final String CSS_TABLE = "table";
> +  private static final String CSS_LABEL = "label";
> +  private static final String CSS_TALL_HEADER = "tall";
> +  private static final String CSS_VERTICAL = "verticalText";
> +  private static final String CSS_CELL = "cell";
> +  private static final String CSS_EMPTY = "empty";
> +  private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3",
"black"};
> +  
> +  private ConfusionMatrixDumper() {}
> +  
> +  public static void main(String[] args) throws Exception {
> +    ToolRunner.run(new ConfusionMatrixDumper(), args);
> +  }
> +  
> +  @Override
> +  public int run(String[] args) throws IOException {
> +    addInputOption();
> +    addOption("output", "o", "Output path", null); // AbstractJob output feature requires
param
> +    addOption(DefaultOptionCreator.overwriteOption().create());
> +    addFlag("html", null, "Create complete HTML page");
> +    addFlag("text", null, "Dump simple text");
> +    Map<String, String> parsedArgs = parseArguments(args);
> +    if (parsedArgs == null) {
> +      return -1;
> +    }
> +    
> +    Path inputPath = getInputPath();
> +    String outputFile = parsedArgs.containsKey("--output") ? parsedArgs.get("--output")
: null;
> +    boolean text = parsedArgs.containsKey("--text");
> +    boolean wrapHtml = parsedArgs.containsKey("--html");
> +    PrintStream out = getPrintStream(outputFile);
> +    if (text) {
> +      exportText(inputPath, out);
> +    } else {
> +      exportTable(inputPath, out, wrapHtml);
> +    }
> +    out.flush();
> +    if (out != System.out) {
> +      out.close();
> +    }
> +    return 0;
> +  }
> +  
> +  private static void exportText(Path inputPath, PrintStream out) throws IOException
{
> +    MatrixWritable mw = new MatrixWritable();
> +    Text key = new Text();
> +    readSeqFile(inputPath, key, mw);
> +    Matrix m = mw.get();
> +    ConfusionMatrix cm = new ConfusionMatrix(m);
> +    out.println(cm.toString());
> +  }
> +  
> +  private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml)
throws IOException {
> +    MatrixWritable mw = new MatrixWritable();
> +    Text key = new Text();
> +    readSeqFile(inputPath, key, mw);
> +    String fileName = inputPath.getName();
> +    fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length());
> +    Matrix m = mw.get();
> +    ConfusionMatrix cm = new ConfusionMatrix(m);
> +    if (wrapHtml) {
> +      printHeader(out, fileName);
> +    }
> +    out.println("<p/>");
> +    printSummaryTable(cm, out);
> +    out.println("<p/>");
> +    printGrayTable(cm, out);
> +    out.println("<p/>");
> +    printCountsTable(cm, out);
> +    out.println("<p/>");
> +    printTextInBox(cm, out);
> +    out.println("<p/>");
> +    if (wrapHtml) {
> +      printFooter(out);
> +    }
> +  }
> +  
> +  private static List<String> stripDefault(ConfusionMatrix cm) {
> +    List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); 
> +    String defaultLabel = cm.getDefaultLabel();
> +    int unclassified = cm.getTotal(defaultLabel);
> +    if (unclassified > 0) {
> +      return stripped;
> +    }
> +    stripped.remove(defaultLabel);
> +    return stripped;
> +  }
> +  
> +  // TODO: test - this should work with HDFS files
> +  private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException
{
> +    Configuration conf = new Configuration();
> +    FileSystem fs = FileSystem.get(conf);
> +    SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
> +    reader.next(key, m);
> +  }
> +  
> +  // TODO: test - this might not work with HDFS files?
> +  //     after all, it does no seeks
> +  private static PrintStream getPrintStream(String outputFilename) throws IOException
{
> +    if (outputFilename != null) {
> +      File outputFile = new File(outputFilename);
> +      if (outputFile.exists()) { 
> +        outputFile.delete();
> +      }
> +      outputFile.createNewFile();
> +      OutputStream os = new FileOutputStream(outputFile);
> +      return new PrintStream(os);
> +    } else {
> +      return System.out;
> +    }
> +  }
> +  
> +  private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
> +    Iterator<String> iter = cm.getLabels().iterator();
> +    int count = 0;
> +    while(iter.hasNext()) {
> +      count += cm.getCount(rowLabel, iter.next());
> +    }
> +    return count;
> +  }
> +  
> +  // HTML generator code
> +  
> +  private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
> +    out.println("<div style='width:90%;overflow:scroll;'>");
> +    out.println("<pre>");
> +    out.println(cm.toString());
> +    out.println("</pre>");
> +    out.println("</div>");
> +  }
> +  
> +  public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) {
> +    format("<table class='%s'>\n", out, CSS_TABLE);
> +    format("<tr class='%s'>", out, CSS_LABEL);
> +    out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>");
> +    out.println("</tr>");
> +    List<String> labels = stripDefault(cm);
> +    for(String label: labels) {
> +      printSummaryRow(cm, out, label);
> +    }
> +    out.println("</table>");
> +  }
> +  
> +  private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label)
{
> +    format("<tr class='%s'>", out, CSS_CELL);
> +    int correct = cm.getCorrect(label);
> +    double accuracy = cm.getAccuracy(label);
> +    int count = getCount(cm, label);
> +    format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>",
> +           out, CSS_CELL, label, count, correct, (int) Math.round(accuracy));
> +    out.println("</tr>");
> +  }
> +  
> +  private static int getCount(ConfusionMatrix cm, String label) {
> +    int count = 0;
> +    for (String s : cm.getLabels()) {
> +      count += cm.getCount(label, s);
> +    }
> +    return count;
> +  }
> +  
> +  public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
> +    format("<table class='%s'>\n", out, CSS_TABLE);
> +    printCountsHeader(cm, out, true);
> +    printGrayRows(cm, out);
> +    out.println("</table>");
> +  }
> +  
> +  /**
> +   * Print each value in a four-value grayscale based on count/max. 
> +   * Gives a mostly white matrix with grays in misclassified, and black in diagonal.
> +   * TODO: Using the sqrt(count/max) as the rating is more stringent 
> +   */
> +  private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
> +    List<String> labels = stripDefault(cm);
> +    for (String label: labels) {
> +      printGrayRow(cm, out, labels, label);
> +    }
> +  }
> +  
> +  private static void printGrayRow(ConfusionMatrix cm, PrintStream out, Iterable<String>
labels, String rowLabel) {
> +    format("<tr class='%s'>", out, CSS_LABEL);
> +    format("<td>%s</td>", out, rowLabel);
> +    int total = getLabelTotal(cm, rowLabel);
> +    for (String columnLabel: labels) {
> +      printGrayCell(cm, out, total, rowLabel, columnLabel);
> +    }
> +    out.println("</tr>");
> +  }
> +  
> +  // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs
> +  // assign black to count = total, meaning complete success
> +  // alternative rating is to use sqrt(total) instead of total - this is more drastic

> +  private static void printGrayCell(ConfusionMatrix cm,
> +                                    PrintStream out,
> +                                    int total,
> +                                    String rowLabel,
> +                                    String columnLabel) {
> +    
> +    int count = cm.getCount(rowLabel, columnLabel);
> +    if (count == 0) {
> +      out.format("<td class='%s'/>", CSS_EMPTY);
> +    } else {
> +      // 0 is white, full is black, everything else gray
> +      int rating = (int) ((count/ (double) total) * 4);
> +      String css = CSS_GRAY_CELLS[rating];
> +      format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel,
count);
> +    }
> +  }
> +  
> +  public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
> +    int length = cm.getLabels().size();
> +    format("<table class='%s'>\n", out, CSS_TABLE);
> +    printCountsHeader(cm, out, false);
> +    printCountsRows(cm, out);
> +    out.println("</table>");
> +  }
> +  
> +  private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
> +    List<String> labels = stripDefault(cm);
> +    for(String label: labels) {
> +      printCountsRow(cm, out, labels, label);
> +    }
> +  }
> +  
> +  private static void printCountsRow(ConfusionMatrix cm, PrintStream out, Iterable<String>
labels, String rowLabel) {
> +    out.println("<tr>");
> +    format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel);
> +    for(String columnLabel: labels) {
> +      printCountsCell(cm, out, rowLabel, columnLabel);
> +    }
> +    out.println("</tr>");
> +  }
> +  
> +  private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel,
String columnLabel) {
> +    int count = cm.getCount(rowLabel, columnLabel);
> +    String s = count == 0 ? "" : Integer.toString(count);
> +    format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel,
s);
> +  }
> +  
> +  private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean
vertical) {
> +    List<String> labels = stripDefault(cm);
> +    int longest = getLongestHeader(labels);
> +    if (vertical) {
> +      // do vertical - rotation is a bitch
> +      out.format("<tr class='%s' style='height:%dem'><th>&nbsp;</th>\n",
CSS_TALL_HEADER, longest/2);
> +      for(String label: labels) {
> +        out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL,
label);
> +      }
> +      out.println("</tr>");
> +    } else {
> +      // header - empty cell in upper left
> +      out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE,
CSS_LABEL);
> +      for(String label: labels) {
> +        out.format("<td>%s</td>", label);
> +      }
> +      out.format("</tr>");
> +    }
> +  }
> +  
> +  private static int getLongestHeader(Iterable<String> labels) {
> +    int max = 0;
> +    for (String label: labels) {
> +      max = Math.max(label.length(), max);
> +    }
> +    return max;
> +  }
> +  
> +  private static void format(String format, PrintStream out, Object ... args) {
> +    String format2 = String.format(format, args);
> +    out.println(format2);
> +  }
> +  
> +  public static void printHeader(PrintStream out, CharSequence title) {
> +    out.println(HEADER.replace("TITLE", title));
> +  }
> +  
> +  public static void printFooter(PrintStream out) {
> +    out.println(FOOTER);
> +  }
> +  
> +}
> 
> Modified: mahout/trunk/src/conf/driver.classes.props
> URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1197510&r1=1197509&r2=1197510&view=diff
> ==============================================================================
> --- mahout/trunk/src/conf/driver.classes.props (original)
> +++ mahout/trunk/src/conf/driver.classes.props Fri Nov  4 11:20:03 2011
> @@ -46,4 +46,6 @@ org.apache.mahout.classifier.sequencelea
> org.apache.mahout.classifier.sequencelearning.hmm.RandomSequenceGenerator = hmmpredict
: Generate random sequence of observations by given HMM
> org.apache.mahout.utils.SplitInput = split : Split Input data into test and train sets
> org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob = trainnb : Train
the Vector-based Bayes classifier
> -org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb : Test the
Vector-based Bayes classifier
> \ No newline at end of file
> +org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb : Test the
Vector-based Bayes classifier
> +org.apache.mahout.classifier.ConfusionMatrixDumper = cmdump : Dump confusion matrix
in HTML or text formats
> +org.apache.mahout.utils.MatrixDumper = matrixdump : Dump matrix in CSV format
> \ No newline at end of file
> 
> 

--------------------------------------------
Grant Ingersoll
http://www.lucidimagination.com




Mime
  • Unnamed multipart/alternative (inline, None, 0 bytes)
View raw message