mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1142566 [1/2] - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/naivebayes/ main/java/org/apache/mahout/classifier/naivebayes/trainer/ main/java/org/apache/mahout/classifier/naivebayes/training/ main/java/org/apache/mahou...
Date Mon, 04 Jul 2011 06:49:19 GMT
Author: ssc
Date: Mon Jul  4 06:49:17 2011
New Revision: 1142566

URL: http://svn.apache.org/viewvc?rev=1142566&view=rev
Log:
MAHOUT-746 Refactoring of the parallel Naive Bayes implementation in org.apache.mahout.classifier.naivebayes

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/
      - copied from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
      - copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
      - copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
      - copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
      - copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java
Removed:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesInstanceMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesSumReducer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesThetaComplementaryMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesThetaMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesWeightsMapper.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Mon Jul  4 06:49:17 2011
@@ -23,11 +23,8 @@ import org.apache.mahout.classifier.Abst
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.Vector.Element;
 
-/**
- * Class implementing the Naive Bayes Classifier Algorithm
- * 
- */
-public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+/** Class implementing the Naive Bayes Classifier Algorithm */
+abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
 
   private final NaiveBayesModel model;
   
@@ -39,27 +36,27 @@ public abstract class AbstractNaiveBayes
     return model;
   }
   
-  public abstract double getScoreForLabelFeature(int label, int feature);
-  
-  public double getScoreForLabelInstance(int label, Vector instance) {
+  protected abstract double getScoreForLabelFeature(int label, int feature);
+
+  protected double getScoreForLabelInstance(int label, Vector instance) {
     double result = 0.0;
-    Iterator<Element> it = instance.iterateNonZero();
-    while (it.hasNext()) {
-      result +=  getScoreForLabelFeature(label, it.next().index());
+    Iterator<Element> elements = instance.iterateNonZero();
+    while (elements.hasNext()) {
+      result += getScoreForLabelFeature(label, elements.next().index());
     }
-    return result;
+    return result / model.thetaNormalizer(label);
   }
   
   @Override
   public int numCategories() {
-    return model.getNumLabels();
+    return model.numLabels();
   }
 
   @Override
   public Vector classify(Vector instance) {
-    Vector score = model.getLabelSum().like();
-    for (int i = 0; i < score.size(); i++) {
-      score.set(i, getScoreForLabelInstance(i, instance));
+    Vector score = model.createScoringVector();
+    for (int label = 0; label < model.numLabels(); label++) {
+      score.set(label, getScoreForLabelInstance(label, instance));
     }
     return score;
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Mon Jul  4 06:49:17 2011
@@ -31,16 +31,10 @@ public class ComplementaryNaiveBayesClas
   @Override
   public double getScoreForLabelFeature(int label, int feature) {
     NaiveBayesModel model = getModel();
-    double result = model.getWeightMatrix().get(label, feature);
-    double vocabCount = model.getVocabCount();
-    double featureSum = model.getFeatureSum().get(feature);
-    double totalSum = model.getTotalSum();
-    double labelSum = model.getLabelSum().get(label);
-    double numerator = featureSum - result + model.getAlphaI();
-    double denominator =  totalSum - labelSum + vocabCount;
-    double weight = Math.log(numerator / denominator);
-    result = weight / model.getPerlabelThetaNormalizer().get(label);
-    return result;
+    double numerator = model.featureWeight(feature) - model.weight(label, feature) + model.alphaI();
+    double denominator =  model.totalWeightSum() - model.labelWeight(label) + model.alphaI() * model.numFeatures();
+
+    return Math.log(numerator / denominator);
   }
 
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java Mon Jul  4 06:49:17 2011
@@ -18,182 +18,134 @@
 package org.apache.mahout.classifier.naivebayes;
 
 import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.classifier.naivebayes.trainer.NaiveBayesTrainer;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.SparseMatrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
-/**
- * NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.
- */
-public class NaiveBayesModel {
+import java.io.IOException;
 
-  private static final String MODEL = "NaiveBayesModel";
+/** NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.*/
+public class NaiveBayesModel {
 
-  private Vector labelSum;
+  private Vector weightsPerLabel;
   private Vector perlabelThetaNormalizer;
-  private Vector featureSum;
-  private Matrix weightMatrix;
+  private Vector weightsPerFeature;
+  private Matrix weightsPerLabelAndFeature;
   private float alphaI;
-  private double vocabCount;
-  private double totalSum;
-  
-  private NaiveBayesModel() { 
-    // do nothing
-  }
-  
-  public NaiveBayesModel(Matrix matrix, Vector featureSum, Vector labelSum, Vector thetaNormalizer, float alphaI) {
-    this.weightMatrix = matrix;
-    this.featureSum = featureSum;
-    this.labelSum = labelSum;
+  private double numFeatures;
+  private double totalWeightSum;
+
+  public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer,
+      float alphaI) {
+    this.weightsPerLabelAndFeature = weightMatrix;
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
     this.perlabelThetaNormalizer = thetaNormalizer;
-    this.vocabCount = featureSum.getNumNondefaultElements();
-    this.totalSum = labelSum.zSum();
+    this.numFeatures = weightsPerFeature.getNumNondefaultElements();
+    this.totalWeightSum = weightsPerLabel.zSum();
     this.alphaI = alphaI;
   }
 
-  private void setLabelSum(Vector labelSum) {
-    this.labelSum = labelSum;
+  public double labelWeight(int label) {
+    return weightsPerLabel.getQuick(label);
   }
 
-
-  public void setPerlabelThetaNormalizer(Vector perlabelThetaNormalizer) {
-    this.perlabelThetaNormalizer = perlabelThetaNormalizer;
+  public double thetaNormalizer(int label) {
+    return perlabelThetaNormalizer.get(label);
   }
 
-
-  public void setFeatureSum(Vector featureSum) {
-    this.featureSum = featureSum;
+  public double featureWeight(int feature) {
+    return weightsPerFeature.getQuick(feature);
   }
 
-
-  public void setWeightMatrix(Matrix weightMatrix) {
-    this.weightMatrix = weightMatrix;
+  public double weight(int label, int feature) {
+    return weightsPerLabelAndFeature.getQuick(label, feature);
   }
 
-
-  public void setAlphaI(float alphaI) {
-    this.alphaI = alphaI;
+  public float alphaI() {
+    return alphaI;
   }
 
-
-  public void setVocabCount(double vocabCount) {
-    this.vocabCount = vocabCount;
+  public double numFeatures() {
+    return numFeatures;
   }
 
-
-  public void setTotalSum(double totalSum) {
-    this.totalSum = totalSum;
+  public double totalWeightSum() {
+    return totalWeightSum;
   }
   
-  public Vector getLabelSum() {
-    return labelSum;
-  }
-
-  public Vector getPerlabelThetaNormalizer() {
-    return perlabelThetaNormalizer;
-  }
-
-  public Vector getFeatureSum() {
-    return featureSum;
+  public int numLabels() {
+    return weightsPerLabel.size();
   }
 
-  public Matrix getWeightMatrix() {
-    return weightMatrix;
+  public Vector createScoringVector() {
+    return weightsPerLabel.like();
   }
 
-  public float getAlphaI() {
-    return alphaI;
-  }
+  public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException {
+    FileSystem fs = output.getFileSystem(conf);
 
-  public double getVocabCount() {
-    return vocabCount;
-  }
+    Vector weightsPerLabel = null;
+    Vector perLabelThetaNormalizer = null;
+    Vector weightsPerFeature = null;
+    Matrix weightsPerLabelAndFeature;
+    float alphaI;
 
-  public double getTotalSum() {
-    return totalSum;
-  }
-  
-  public int getNumLabels() {
-    return labelSum.size();
-  }
+    FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
+    try {
+      alphaI = in.readFloat();
+      weightsPerFeature = VectorWritable.readVector(in);
+      weightsPerLabel = VectorWritable.readVector(in);
+      perLabelThetaNormalizer = VectorWritable.readVector(in);
 
-  public static String getModelName() {
-    return MODEL;
-  }
-  
-  // CODE USED FOR SERIALIZATION
-  public static NaiveBayesModel fromMRTrainerOutput(Path output, Configuration conf) {
-    Path classVectorPath = new Path(output, NaiveBayesTrainer.CLASS_VECTORS);
-    Path sumVectorPath = new Path(output, NaiveBayesTrainer.SUM_VECTORS);
-    Path thetaSumPath = new Path(output, NaiveBayesTrainer.THETA_SUM);
-
-    NaiveBayesModel model = new NaiveBayesModel();
-    model.setAlphaI(conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f));
-
-    int featureCount = 0;
-    int labelCount = 0;
-    // read feature sums and label sums
-    for (Pair<Text,VectorWritable> record
-         : new SequenceFileIterable<Text, VectorWritable>(sumVectorPath, true, conf)) {
-      Text key = record.getFirst();
-      VectorWritable value = record.getSecond();
-      if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
-        model.setFeatureSum(value.get());
-        featureCount = value.get().getNumNondefaultElements();
-        model.setVocabCount(featureCount);       
-      } else  if (key.toString().equals(BayesConstants.LABEL_SUM)) {
-        model.setLabelSum(value.get());
-        model.setTotalSum(value.get().zSum());
-        labelCount = value.get().size();
+      weightsPerLabelAndFeature = new SparseMatrix(new int[] { weightsPerLabel.size(), weightsPerFeature.size() });
+      for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
+        weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
       }
+    } finally {
+      Closeables.closeQuietly(in);
     }
+    NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel,
+        perLabelThetaNormalizer, alphaI);
+    model.validate();
+    return model;
+  }
 
-    // read the class matrix
-    Matrix matrix = new SparseMatrix(new int[] {labelCount, featureCount});
-    for (Pair<IntWritable,VectorWritable> record
-         : new SequenceFileIterable<IntWritable,VectorWritable>(classVectorPath, true, conf)) {
-      IntWritable label = record.getFirst();
-      VectorWritable value = record.getSecond();
-      matrix.assignRow(label.get(), value.get());
-    }
-    
-    model.setWeightMatrix(matrix);
-
-    // read theta normalizer
-    for (Pair<Text,VectorWritable> record
-         : new SequenceFileIterable<Text,VectorWritable>(thetaSumPath, true, conf)) {
-      Text key = record.getFirst();
-      VectorWritable value = record.getSecond();
-      if (key.toString().equals(BayesConstants.LABEL_THETA_NORMALIZER)) {
-        model.setPerlabelThetaNormalizer(value.get());
+  public void serialize(Path output, Configuration conf) throws IOException {
+    FileSystem fs = output.getFileSystem(conf);
+    FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));
+    try {
+      out.writeFloat(alphaI);
+      VectorWritable.writeVector(out, weightsPerFeature);
+      VectorWritable.writeVector(out, weightsPerLabel);
+      VectorWritable.writeVector(out, perlabelThetaNormalizer);
+      for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
+        VectorWritable.writeVector(out, weightsPerLabelAndFeature.getRow(row));
       }
+    } finally {
+      Closeables.closeQuietly(out);
     }
-
-    return model;
   }
   
-  public static void validate(NaiveBayesModel model) {
-    if (model == null) {
-      return; // empty models are valid
-    }
-
-    Preconditions.checkArgument(model.getAlphaI() > 0, "Error: AlphaI has to be greater than 0!");
-    Preconditions.checkArgument(model.getVocabCount() > 0, "Error: The vocab count has to be greater than 0!");
-    Preconditions.checkArgument(model.getTotalSum() > 0, "Error: The vocab count has to be greater than 0!");
-    Preconditions.checkArgument(model.getLabelSum() != null && model.getLabelSum().getNumNondefaultElements() > 0,
-        "Error: The number of labels has to be greater than 0 and defined!");
-    Preconditions.checkArgument(model.getPerlabelThetaNormalizer() != null &&
-        model.getPerlabelThetaNormalizer().getNumNondefaultElements() > 0,
-        "Error: The number of theta normalizers has to be greater than 0 or defined!");
-    Preconditions.checkArgument(model.getFeatureSum() != null && model.getFeatureSum().getNumNondefaultElements() > 0,
-        "Error: The number of features has to be greater than 0 or defined!");
+  public void validate() {
+    Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!");
+    Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!");
+    Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!");
+    Preconditions.checkArgument(weightsPerLabel != null, "the number of labels has to be defined!");
+    Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0,
+        "the number of labels has to be greater than 0!");
+    Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined");
+    Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0,
+        "the number of theta normalizers has to be greater than 0!");
+    Preconditions.checkArgument(weightsPerFeature != null, "the feature sums have to be defined");
+    Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0,
+        "the feature sums have to be greater than 0!");
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java Mon Jul  4 06:49:17 2011
@@ -18,10 +18,7 @@
 package org.apache.mahout.classifier.naivebayes;
 
 
-/**
- * Class implementing the Naive Bayes Classifier Algorithm
- * 
- */
+/** Class implementing the Naive Bayes Classifier Algorithm */
 public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier { 
  
   public StandardNaiveBayesClassifier(NaiveBayesModel model) {
@@ -31,14 +28,11 @@ public class StandardNaiveBayesClassifie
   @Override
   public double getScoreForLabelFeature(int label, int feature) {
     NaiveBayesModel model = getModel();
-    double result = model.getWeightMatrix().get(label, feature);
-    double vocabCount = model.getVocabCount();
-    double sumLabelWeight = model.getLabelSum().get(label);
-    double numerator = result + model.getAlphaI();
-    double denominator = sumLabelWeight + vocabCount;
-    double weight = -Math.log(numerator / denominator);
-    result = weight / model.getPerlabelThetaNormalizer().get(label);
-    return result;
+
+    double numerator = model.weight(label, feature) + model.alphaI();
+    double denominator = model.labelWeight(label) + model.alphaI() * model.numFeatures();
+
+    return -Math.log(numerator / denominator);
   }
   
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.Vector;
+
+public abstract class AbstractThetaTrainer {
+
+  private Vector weightsPerFeature;
+  private Vector weightsPerLabel;
+  private Vector perLabelThetaNormalizer;
+  private double alphaI;
+  private double totalWeightSum;
+  private double numFeatures;
+
+  public AbstractThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+    Preconditions.checkNotNull(weightsPerFeature);
+    Preconditions.checkNotNull(weightsPerLabel);
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
+    this.alphaI = alphaI;
+    perLabelThetaNormalizer = weightsPerLabel.like();
+    totalWeightSum = weightsPerLabel.zSum();
+    numFeatures = weightsPerFeature.getNumNondefaultElements();
+  }
+
+  public abstract void train(int label, Vector instance);
+
+  protected double alphaI() {
+    return alphaI;
+  }
+
+  protected double numFeatures() {
+    return numFeatures;
+  }
+
+  protected double labelWeight(int label) {
+    return weightsPerLabel.get(label);
+  }
+
+  protected double totalWeightSum() {
+    return totalWeightSum;
+  }
+
+  protected double featureWeight(int feature) {
+    return weightsPerFeature.get(feature);
+  }
+
+  protected void updatePerLabelThetaNormalizer(int label, double weight) {
+    perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
+  }
+
+  public Vector retrievePerLabelThetaNormalizer() {
+    return perLabelThetaNormalizer.clone();
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.mahout.math.Vector;
+
+import java.util.Iterator;
+
+public class ComplementaryThetaTrainer extends AbstractThetaTrainer {
+
+  public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+    super(weightsPerFeature, weightsPerLabel, alphaI);
+  }
+
+  @Override
+  public void train(int label, Vector instance) {
+    double sigmaK = labelWeight(label);
+    Iterator<Vector.Element> it = instance.iterateNonZero();
+    while (it.hasNext()) {
+      Vector.Element e = it.next();
+      double numerator = featureWeight(e.index()) - e.get() + alphaI();
+      double denominator = totalWeightSum() - sigmaK + alphaI() * numFeatures();
+      double weight = Math.log(numerator / denominator);
+      updatePerLabelThetaNormalizer(label, weight);
+    }
+  }
+}

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java Mon Jul  4 06:49:17 2011
@@ -15,50 +15,34 @@
  * limitations under the License.
  */
 
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
 
 import java.io.IOException;
-import java.net.URI;
 
-import com.google.common.base.Preconditions;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.map.OpenObjectIntHashMap;
 
-public class NaiveBayesInstanceMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
-  
-  private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
-  
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+  public enum Counter { SKIPPED_INSTANCES }
+
+  private OpenObjectIntHashMap<String> labelIndex;
+
   @Override
-  protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
-    if (labelMap.containsKey(key.toString())) {
-      int label = labelMap.get(key.toString());
-      context.write(new IntWritable(label), value);
-    } else {
-      context.getCounter("NaiveBayes", "Skipped instance: not in label list").increment(1);
-    }
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    labelIndex = TrainUtils.readIndexFromCache(ctx.getConfiguration());
   }
-  
+
   @Override
-  protected void setup(Context context) throws IOException, InterruptedException {
-    super.setup(context);
-    Configuration conf = context.getConfiguration();
-    URI[] localFiles = DistributedCache.getCacheFiles(conf);
-    Preconditions.checkArgument(localFiles != null && localFiles.length >= 1,
-        "missing paths from the DistributedCache");
-    Path labelMapFile = new Path(localFiles[0].getPath());
-    // key is word value is id
-    for (Pair<Writable,IntWritable> record
-         : new SequenceFileIterable<Writable,IntWritable>(labelMapFile, true, conf)) {
-      labelMap.put(record.getFirst().toString(), record.getSecond().get());
+  protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+    String label = labelText.toString();
+    if (labelIndex.containsKey(label)) {
+      ctx.write(new IntWritable(labelIndex.get(label)), instance);
+    } else {
+      ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
     }
   }
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.mahout.math.Vector;
+
+public class StandardThetaTrainer extends AbstractThetaTrainer {
+
+  public StandardThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+    super(weightsPerFeature, weightsPerLabel, alphaI);
+  }
+
+  @Override
+  public void train(int label, Vector instance) {
+    double weight = Math.log((instance.zSum() + alphaI()) / (labelWeight(label) + alphaI() * numFeatures()));
+    updatePerLabelThetaNormalizer(label, weight);
+  }
+}

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java Mon Jul  4 06:49:17 2011
@@ -15,79 +15,50 @@
  * limitations under the License.
  */
 
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
 
 import java.io.IOException;
-import java.net.URI;
+import java.util.Map;
 
-import com.google.common.base.Preconditions;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.classifier.naivebayes.BayesConstants;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
 
-public class NaiveBayesThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
-  
-  private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
-  private Vector featureSum;
-  private Vector labelSum;
-  private Vector perLabelThetaNormalizer;
-  private double alphaI = 1.0;
-  private double vocabCount;
-  
-  @Override
-  protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException {
-    Vector vector = value.get();
-    int label = key.get();
-    double weight = Math.log((vector.zSum() + alphaI) / (labelSum.get(label) + vocabCount));
-    perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
-  }
-  
-  @Override
-  protected void setup(Context context) throws IOException, InterruptedException {
-    super.setup(context);
-    Configuration conf = context.getConfiguration();
-    URI[] localFiles = DistributedCache.getCacheFiles(conf);
-    Preconditions.checkArgument(localFiles != null && localFiles.length >= 2,
-        "missing paths from the DistributedCache");
-
-    alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f);
-    Path weightFile = new Path(localFiles[0].getPath());
-
-    for (Pair<Text,VectorWritable> record
-         : new SequenceFileIterable<Text,VectorWritable>(weightFile, true, conf)) {
-      Text key = record.getFirst();
-      VectorWritable value = record.getSecond();
-      if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
-        featureSum = value.get();
-      } else  if (key.toString().equals(BayesConstants.LABEL_SUM)) {
-        labelSum = value.get();
-      }
-    }
-    perLabelThetaNormalizer = labelSum.like();
-    vocabCount = featureSum.getNumNondefaultElements();
+public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+  public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI";
+  static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary";
 
-    Path labelMapFile = new Path(localFiles[1].getPath());
+  private AbstractThetaTrainer trainer;
 
-    // key is word value is id
-    for (Pair<Writable,IntWritable> record 
-         : new SequenceFileIterable<Writable,IntWritable>(labelMapFile, true, conf)) {
-      labelMap.put(record.getFirst().toString(), record.getSecond().get());
+  @Override
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    Configuration conf = ctx.getConfiguration();
+
+    float alphaI = conf.getFloat(ALPHA_I, 1.0f);
+    boolean trainComplemenary = conf.getBoolean(TRAIN_COMPLEMENTARY, false);
+    Map<String,Vector> scores = TrainUtils.readScoresFromCache(conf);
+
+    if (!trainComplemenary) {
+      trainer = new StandardThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+          scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
+    } else {
+      trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+          scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
     }
   }
-  
+
+  @Override
+  protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+    trainer.train(key.get(), value.get());
+  }
+
   @Override
-  protected void cleanup(Context context) throws IOException, InterruptedException {
-    context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
-    super.cleanup(context);
+  protected void cleanup(Context ctx) throws IOException, InterruptedException {
+    ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
+        new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
   }
 }

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java Mon Jul  4 06:49:17 2011
@@ -15,201 +15,83 @@
  * limitations under the License.
  */
 
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
 
-import java.io.IOException;
-import java.net.URI;
+import java.util.Map;
 
-import com.google.common.io.Closeables;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
 import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
-import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.mapreduce.VectorSumReducer;
 import org.apache.mahout.math.VectorWritable;
 
-/**
- * This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes)
- */
-public final class NaiveBayesTrainer {
-  
-  public static final String THETA_SUM = "thetaSum";
-  public static final String SUM_VECTORS = "sumVectors";
-  public static final String CLASS_VECTORS = "classVectors";
-  public static final String LABEL_MAP = "labelMap";
-  public static final String ALPHA_I = "alphaI";
-
-  private NaiveBayesTrainer() {
-  }
+/** This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes) */
+public final class TrainNaiveBayesJob extends AbstractJob {
 
-  public static void trainNaiveBayes(Path input,
-                                      Configuration conf,
-                                      Iterable<String> inputLabels,
-                                      Path output,
-                                      int numReducers,
-                                      float alphaI,
-                                      boolean trainComplementary)
-      throws IOException, InterruptedException, ClassNotFoundException {
-    conf.setFloat(ALPHA_I, alphaI);
-    Path labelMapPath = createLabelMapFile(inputLabels, conf, new Path(output, LABEL_MAP));
-    Path classVectorPath =  new Path(output, CLASS_VECTORS);
-    runNaiveBayesByLabelSummer(input, conf, labelMapPath, classVectorPath, numReducers);
-    Path weightFilePath = new Path(output, SUM_VECTORS);
-    runNaiveBayesWeightSummer(classVectorPath, conf, labelMapPath, weightFilePath, numReducers);
-    Path thetaFilePath = new Path(output, THETA_SUM);
-    if (trainComplementary) {
-      runNaiveBayesThetaComplementarySummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
-    } else {
-      runNaiveBayesThetaSummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
+  public static final String WEIGHTS_PER_FEATURE = "__SPF";
+  public static final String WEIGHTS_PER_LABEL = "__SPL";
+  public static final String LABEL_THETA_NORMALIZER = "_LTN";
+
+  public static final String SUMMED_OBSERVATIONS = "summedObservations";
+  public static final String WEIGHTS = "weights";
+  public static final String THETAS = "thetas";
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption("labels", "l", "comma-separated list of labels to include in training", true);
+    addOption("alphaI", "a", "smoothing parameter", String.valueOf(1.0f));
+    addOption("trainComplementary", "c", "train complementary?", String.valueOf(false));
+
+    Map<String,String> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
     }
-  }
-
-  private static void runNaiveBayesByLabelSummer(Path input,
-                                                 Configuration conf,
-                                                 Path labelMapPath,
-                                                 Path output, int numReducers)
-    throws IOException, InterruptedException, ClassNotFoundException {
-    
-    // this conf parameter needs to be set enable serialisation of conf values
-    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
-        + "org.apache.hadoop.io.serializer.WritableSerialization");
-    DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
-  
-    Job job = new Job(conf);
-    job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
-        + labelMapPath.toString());
-    job.setJarByClass(NaiveBayesTrainer.class);
-    FileInputFormat.setInputPaths(job, input);
-    FileOutputFormat.setOutputPath(job, output);
-    job.setMapperClass(NaiveBayesInstanceMapper.class);
-    job.setCombinerClass(VectorSumReducer.class);
-    job.setReducerClass(VectorSumReducer.class);
-    job.setInputFormatClass(SequenceFileInputFormat.class);
-    job.setOutputFormatClass(SequenceFileOutputFormat.class);
-    job.setOutputKeyClass(IntWritable.class);
-    job.setOutputValueClass(VectorWritable.class);
-    job.setNumReduceTasks(numReducers);
-    HadoopUtil.delete(conf, output);
-    job.waitForCompletion(true);
-  }
 
-  private static void runNaiveBayesWeightSummer(Path input,
-                                                Configuration conf,
-                                                Path labelMapPath,
-                                                Path output,
-                                                int numReducers)
-    throws IOException, InterruptedException, ClassNotFoundException {
-    
-    // this conf parameter needs to be set enable serialisation of conf values
-    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
-        + "org.apache.hadoop.io.serializer.WritableSerialization");
-    DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
-    
-    Job job = new Job(conf);
-    job.setJobName("Train Naive Bayes: input-folder: " + input);
-    job.setJarByClass(NaiveBayesTrainer.class);
-    FileInputFormat.setInputPaths(job, input);
-    FileOutputFormat.setOutputPath(job, output);
-    job.setMapperClass(NaiveBayesWeightsMapper.class);
-    job.setReducerClass(VectorSumReducer.class);
-    job.setInputFormatClass(SequenceFileInputFormat.class);
-    job.setOutputFormatClass(SequenceFileOutputFormat.class);
-    job.setOutputKeyClass(Text.class);
-    job.setOutputValueClass(VectorWritable.class);
-    job.setNumReduceTasks(numReducers);
-    HadoopUtil.delete(conf, output);
-    job.waitForCompletion(true);
-  }
-  
-  private static void runNaiveBayesThetaSummer(Path input,
-                                               Configuration conf,
-                                               Path weightFilePath,
-                                               Path output,
-                                               int numReducers)
-    throws IOException, InterruptedException, ClassNotFoundException {
-    
-    // this conf parameter needs to be set enable serialisation of conf values
-    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
-        + "org.apache.hadoop.io.serializer.WritableSerialization");
-    DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
-  
-    Job job = new Job(conf);
-    job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
-        + weightFilePath.toString());
-    job.setJarByClass(NaiveBayesTrainer.class);
-    FileInputFormat.setInputPaths(job, input);
-    FileOutputFormat.setOutputPath(job, output);
-    job.setMapperClass(NaiveBayesThetaMapper.class);
-    job.setReducerClass(VectorSumReducer.class);
-    job.setInputFormatClass(SequenceFileInputFormat.class);
-    job.setOutputFormatClass(SequenceFileOutputFormat.class);
-    job.setOutputKeyClass(IntWritable.class);
-    job.setOutputValueClass(VectorWritable.class);
-    job.setNumReduceTasks(numReducers);
-    HadoopUtil.delete(conf, output);
-    job.waitForCompletion(true);
-  }
+    Iterable<String> labels = Splitter.on(",").split(parsedArgs.get("--labels"));
+    float alphaI = Float.parseFloat(parsedArgs.get("--alphaI"));
+    boolean trainComplementary = Boolean.parseBoolean(parsedArgs.get("--trainComplementary"));
+
+    TrainUtils.writeLabelIndex(getConf(), labels, getTempPath("labelIndex"));
+    TrainUtils.setSerializations(getConf());
+    TrainUtils.cacheFiles(getTempPath("labelIndex"), getConf());
+
+    Job indexInstances = prepareJob(getInputPath(), getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class,
+        IndexInstancesMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class,
+        VectorWritable.class, SequenceFileOutputFormat.class);
+    indexInstances.setCombinerClass(VectorSumReducer.class);
+    indexInstances.waitForCompletion(true);
+
+    Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(WEIGHTS),
+        SequenceFileInputFormat.class, WeightsMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class,
+        Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+    weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(Iterables.size(labels)));
+    weightSummer.setCombinerClass(VectorSumReducer.class);
+    weightSummer.waitForCompletion(true);
+
+    TrainUtils.cacheFiles(getTempPath(WEIGHTS), getConf());
+
+    Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(THETAS),
+        SequenceFileInputFormat.class, ThetaMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class,
+        Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+    thetaSummer.setCombinerClass(VectorSumReducer.class);
+    thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
+    thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
+    thetaSummer.waitForCompletion(true);
+
+    NaiveBayesModel naiveBayesModel = TrainUtils.readModelFromTempDir(getTempPath(), getConf());
+    naiveBayesModel.validate();
+    naiveBayesModel.serialize(getOutputPath(), getConf());
 
-  private static void runNaiveBayesThetaComplementarySummer(Path input,
-                                                            Configuration conf,
-                                                            Path weightFilePath,
-                                                            Path output,
-                                                            int numReducers)
-    throws IOException, InterruptedException, ClassNotFoundException {
-    
-    // this conf parameter needs to be set enable serialisation of conf values
-    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
-        + "org.apache.hadoop.io.serializer.WritableSerialization");
-    DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
-  
-    Job job = new Job(conf);
-    job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
-        + weightFilePath.toString());
-    job.setJarByClass(NaiveBayesTrainer.class);
-    FileInputFormat.setInputPaths(job, input);
-    FileOutputFormat.setOutputPath(job, output);
-    job.setMapperClass(NaiveBayesThetaComplementaryMapper.class);
-    job.setReducerClass(VectorSumReducer.class);
-    job.setInputFormatClass(SequenceFileInputFormat.class);
-    job.setOutputFormatClass(SequenceFileOutputFormat.class);
-    job.setOutputKeyClass(IntWritable.class);
-    job.setOutputValueClass(VectorWritable.class);
-    job.setNumReduceTasks(numReducers);
-    HadoopUtil.delete(conf, output);
-    job.waitForCompletion(true);
+    return 0;
   }
 
-  
-  
-  /**
-   * Write the list of labels into a map file
-   */
-  public static Path createLabelMapFile(Iterable<String> labels,
-                                        Configuration conf,
-                                        Path labelMapPathBase) throws IOException {
-    FileSystem fs = FileSystem.get(labelMapPathBase.toUri(), conf);
-    Path labelMapPath = new Path(labelMapPathBase, LABEL_MAP);
-    
-    SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, labelMapPath, Text.class, IntWritable.class);
-    try {
-      int i = 0;
-      for (String label : labels) {
-        Writable key = new Text(label);
-        dictWriter.append(key, new IntWritable(i++));
-      }
-    } finally {
-      Closeables.closeQuietly(dictWriter);
-    }
-    return labelMapPath;
-  }
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+
+public class TrainUtils {
+
+  private TrainUtils() {}
+
+  static NaiveBayesModel readModelFromTempDir(Path base, Configuration conf) {
+
+    Vector scoresPerLabel = null;
+    Vector perlabelThetaNormalizer = null;
+    Vector scoresPerFeature = null;
+    Matrix scoresPerLabelAndFeature;
+    float alphaI;
+
+    alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+
+    // read feature sums and label sums
+    for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      String key = record.getFirst().toString();
+      VectorWritable value = record.getSecond();
+      if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+        scoresPerFeature = value.get();
+      } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+        scoresPerLabel = value.get();
+      }
+    }
+
+    Preconditions.checkNotNull(scoresPerFeature);
+    Preconditions.checkNotNull(scoresPerLabel);
+
+    scoresPerLabelAndFeature = new SparseMatrix(new int[] { scoresPerLabel.size(), scoresPerFeature.size() });
+    for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+    }
+
+    for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+        perlabelThetaNormalizer = entry.getSecond().get();
+      }
+    }
+
+    Preconditions.checkNotNull(perlabelThetaNormalizer);
+
+    return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perlabelThetaNormalizer,
+        alphaI);
+  }
+
+  protected static void setSerializations(Configuration conf) {
+    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+        + "org.apache.hadoop.io.serializer.WritableSerialization");
+  }
+
+  protected static void cacheFiles(Path fileToCache, Configuration conf) {
+    DistributedCache.setCacheFiles(new URI[] { fileToCache.toUri() }, conf);
+  }
+
+  /** Write the list of labels into a map file */
+  protected static void writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+      throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+    try {
+      int i = 0;
+      for (String label : labels) {
+        writer.append(new Text(label), new IntWritable(i++));
+      }
+    } finally {
+      Closeables.closeQuietly(writer);
+    }
+  }
+
+  private static Path cachedFile(Configuration conf) throws IOException {
+    return new Path(DistributedCache.getCacheFiles(conf)[0].getPath());
+  }
+
+  protected static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+    OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<String>();
+    for (Pair<Writable,IntWritable> entry : new SequenceFileIterable<Writable,IntWritable>(cachedFile(conf), conf)) {
+      index.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return index;
+  }
+
+  protected static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+    Map<String,Vector> sumVectors = Maps.newHashMap();
+    for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(cachedFile(conf),
+        PathType.LIST, PathFilters.partFilter(), conf)) {
+      sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return sumVectors;
+  }
+}

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java Mon Jul  4 06:49:17 2011
@@ -15,68 +15,50 @@
  * limitations under the License.
  */
 
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
 
 import java.io.IOException;
-import java.net.URI;
 
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.Path;
+import com.google.common.base.Preconditions;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.classifier.naivebayes.BayesConstants;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
 
-public class NaiveBayesWeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
-  
-  private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
-  private Vector featureSum;
-  private Vector labelSum;
- 
+public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+  static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels";
+
+  private Vector weightsPerFeature;
+  private Vector weightsPerLabel;
+
   @Override
-  protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException {
-    Vector vector = value.get();
-    if (featureSum == null) {
-      featureSum = new RandomAccessSparseVector(vector.size(), vector.getNumNondefaultElements());
-      labelSum = new RandomAccessSparseVector(labelMap.size());  
-    }
-    
-    int label = key.get();
-    vector.addTo(featureSum);
-    labelSum.set(label, labelSum.get(label) + vector.zSum());
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
+    Preconditions.checkArgument(numLabels > 0);
+    weightsPerLabel = new RandomAccessSparseVector(numLabels);
   }
-  
+
   @Override
-  protected void setup(Context context) throws IOException, InterruptedException {
-    super.setup(context);
-    Configuration conf = context.getConfiguration();
-    URI[] localFiles = DistributedCache.getCacheFiles(conf);
-    if (localFiles == null || localFiles.length < 1) {
-      throw new IllegalArgumentException("missing paths from the DistributedCache");
+  protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+    Vector instance = value.get();
+    if (weightsPerFeature == null) {
+      weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements());
     }
-    Path labelMapFile = new Path(localFiles[0].getPath());
 
-    // key is word value is id
-    for (Pair<Writable,IntWritable> record 
-         : new SequenceFileIterable<Writable,IntWritable>(labelMapFile, true, conf)) {
-      labelMap.put(record.getFirst().toString(), record.getSecond().get());
-    }
+    int label = index.get();
+    instance.addTo(weightsPerFeature);
+    weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
   }
-  
+
   @Override
-  protected void cleanup(Context context) throws IOException, InterruptedException {
-    if (featureSum != null) {
-      context.write(new Text(BayesConstants.FEATURE_SUM), new VectorWritable(featureSum));
-      context.write(new Text(BayesConstants.LABEL_SUM), new VectorWritable(labelSum));
+  protected void cleanup(Context ctx) throws IOException, InterruptedException {
+    if (weightsPerFeature != null) {
+      ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature));
+      ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel));
     }
-    super.cleanup(context);
+    super.cleanup(ctx);
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java Mon Jul  4 06:49:17 2011
@@ -187,6 +187,19 @@ public final class VectorWritable extend
   }
 
   @Override
+  public boolean equals(Object o) {
+    if (o instanceof VectorWritable) {
+      return vector.equals(((VectorWritable)o).get());
+    }
+    return false;
+  }
+
+  @Override
+  public int hashCode() {
+    return vector.hashCode();
+  }
+
+  @Override
   public String toString() {
     return vector.toString();
   }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java Mon Jul  4 06:49:17 2011
@@ -37,10 +37,10 @@ public final class ComplementaryNaiveBay
   @Test
   public void testNaiveBayes() throws Exception {
     assertEquals(4, classifier.numCategories());
-    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
-    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
-    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
-    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
     
   }
   

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java Mon Jul  4 06:49:17 2011
@@ -26,7 +26,7 @@ public class NaiveBayesModelTest extends
     // make sure we generate a valid random model
     NaiveBayesModel model = getModel();
     // check whether the model is valid
-    NaiveBayesModel.validate(model);
+    model.validate();
   }
 
 }

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,117 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+
+public class NaiveBayesTest extends MahoutTestCase {
+
+  Configuration conf;
+  File inputFile;
+  File outputDir;
+  File tempDir;
+
+  static final Text LABEL_STOLEN = new Text("stolen");
+  static final Text LABEL_NOT_STOLEN = new Text("not_stolen");
+
+  static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
+  static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
+  static final Vector.Element TYPE_SPORTS = MathHelper.elem(2, 1);
+  static final Vector.Element TYPE_SUV = MathHelper.elem(3, 1);
+  static final Vector.Element ORIGIN_DOMESTIC = MathHelper.elem(4, 1);
+  static final Vector.Element ORIGIN_IMPORTED = MathHelper.elem(5, 1);
+
+
+  @Before
+  public void setup() throws Exception {
+    super.setUp();
+
+    conf = new Configuration();
+
+    inputFile = getTestTempFile("trainingInstances.seq");
+    outputDir = getTestTempDir("output");
+    outputDir.delete();
+    tempDir = getTestTempDir("tmp");
+
+    SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf,
+        new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class);
+
+    try {
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
+    } finally {
+      Closeables.closeQuietly(writer);
+    }
+  }
+
+  @Test
+  public void toyData() throws Exception {
+    TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
+    trainNaiveBayes.setConf(conf);
+    trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+        "--labels", "stolen,not_stolen", "--tempDir", tempDir.getAbsolutePath() });
+
+    NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
+
+    AbstractVectorClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel);
+
+    assertEquals(2, classifier.numCategories());
+
+    Vector prediction = classifier.classify(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
+
+    // should be classified as not stolen
+    assertTrue(prediction.get(0) < prediction.get(1));
+  }
+
+  @Test
+  public void toyDataComplementary() throws Exception {
+    TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
+    trainNaiveBayes.setConf(conf);
+    trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+        "--labels", "stolen,not_stolen", "--trainComplementary", String.valueOf(true),
+        "--tempDir", tempDir.getAbsolutePath() });
+
+    NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
+
+    AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);
+
+    assertEquals(2, classifier.numCategories());
+
+    Vector prediction = classifier.classify(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
+
+    // should be classified as not stolen
+    assertTrue(prediction.get(0) < prediction.get(1));
+  }
+
+  VectorWritable trainingInstance(Vector.Element... elems) {
+    DenseVector trainingInstance = new DenseVector(6);
+    for (Vector.Element elem : elems) {
+      trainingInstance.set(elem.index(), elem.get());
+    }
+    return new VectorWritable(trainingInstance);
+  }
+
+
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java Mon Jul  4 06:49:17 2011
@@ -17,8 +17,6 @@
 
 package org.apache.mahout.classifier.naivebayes;
 
-import java.util.Iterator;
-
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.DenseVector;
@@ -34,19 +32,14 @@ public class NaiveBayesTestBase extends 
   public void setUp() throws Exception {
     super.setUp();
     model = createNaiveBayesModel();
-    
-    // make sure the model is valid :)
-    NaiveBayesModel.validate(model);
+    model.validate();
   }
   
   protected NaiveBayesModel getModel() {
     return model;
   }
   
-  public double complementaryNaiveBayesThetaWeight(int label,
-                                                   Matrix weightMatrix,
-                                                   Vector labelSum,
-                                                   Vector featureSum) {
+  public double complementaryNaiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) {
     double weight = 0.0;
     double alpha = 1.0;
     for (int i = 0; i < featureSum.size(); i++) {
@@ -61,14 +54,11 @@ public class NaiveBayesTestBase extends 
     return weight;
   }
   
-  public double naiveBayesThetaWeight(int label,
-                                      Matrix weightMatrix,
-                                      Vector labelSum,
-                                      Vector featureSum) {
+  public double naiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) {
     double weight = 0.0;
     double alpha = 1.0;
-    for (int i = 0; i < featureSum.size(); i++) {
-      double score = weightMatrix.get(i, label);
+    for (int feature = 0; feature < featureSum.size(); feature++) {
+      double score = weightMatrix.get(feature, label);
       double lSum = labelSum.get(label);
       double numerator = score + alpha;
       double denominator = lSum + featureSum.size();
@@ -78,52 +68,60 @@ public class NaiveBayesTestBase extends 
   }
 
   public NaiveBayesModel createNaiveBayesModel() {
-    double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
-                         {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
-    double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
-    double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+    double[][] matrix = {
+        { 0.7, 0.1, 0.1, 0.3 },
+        { 0.4, 0.4, 0.1, 0.1 },
+        { 0.1, 0.0, 0.8, 0.1 },
+        { 0.1, 0.1, 0.1, 0.7 } };
+
+    double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 };
+    double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 };
     
     DenseMatrix weightMatrix = new DenseMatrix(matrix);
     DenseVector labelSum = new DenseVector(labelSumArray);
     DenseVector featureSum = new DenseVector(featureSumArray);
     
-    double[] thetaNormalizerSum = {naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), 
-                                   naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
-                                   naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
-                                   naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+    double[] thetaNormalizerSum = {
+        naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+        naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+        naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+        naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
+
     // now generate the model
-    return new NaiveBayesModel(weightMatrix, featureSum,
-        labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+    return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
   }
   
   public NaiveBayesModel createComplementaryNaiveBayesModel() {
-    double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
-                         {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
-    double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
-    double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+    double[][] matrix = {
+        { 0.7, 0.1, 0.1, 0.3 },
+        { 0.4, 0.4, 0.1, 0.1 },
+        { 0.1, 0.0, 0.8, 0.1 },
+        { 0.1, 0.1, 0.1, 0.7 } };
+
+    double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 };
+    double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 };
     
     DenseMatrix weightMatrix = new DenseMatrix(matrix);
     DenseVector labelSum = new DenseVector(labelSumArray);
     DenseVector featureSum = new DenseVector(featureSumArray);
     
-    double[] thetaNormalizerSum = {complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), 
-                                   complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
-                                   complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
-                                   complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+    double[] thetaNormalizerSum = {
+        complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+        complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+        complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+        complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
+
     // now generate the model
-    return new NaiveBayesModel(weightMatrix, featureSum,
-        labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+    return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
   }
   
   public int maxIndex(Vector instance) {
-    Iterator<Element> it = instance.iterator();
     int maxIndex = -1;
-    double val = Integer.MIN_VALUE;
-    while (it.hasNext()) {
-      Element e = it.next();
-      if (val <= e.get()) {
-        maxIndex = e.index();
-        val = e.get();
+    double maxScore = Integer.MIN_VALUE;
+    for (Element label : instance) {
+      if (label.get() >= maxScore) {
+        maxIndex = label.index();
+        maxScore = label.get();
       }
     }
     return maxIndex;

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java Mon Jul  4 06:49:17 2011
@@ -38,10 +38,10 @@ public final class StandardNaiveBayesCla
   @Test
   public void testNaiveBayes() throws Exception {
     assertEquals(4, classifier.numCategories());
-    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
-    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
-    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
-    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
     
   }
   

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Counter;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+public class IndexInstancesMapperTest extends MahoutTestCase {
+
+  Mapper.Context ctx;
+  OpenObjectIntHashMap<String> labelIndex;
+  VectorWritable instance;
+
+  @Before
+  public void setup() throws Exception {
+    super.setUp();
+
+    ctx = EasyMock.createMock(Mapper.Context.class);
+    instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 }));
+
+    labelIndex = new OpenObjectIntHashMap<String>();
+    labelIndex.put("bird", 0);
+    labelIndex.put("cat", 1);
+  }
+
+
+  @Test
+  public void index() throws Exception {
+
+    ctx.write(new IntWritable(0), instance);
+
+    EasyMock.replay(ctx);
+
+    IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+    setField(indexInstances, "labelIndex", labelIndex);
+
+    indexInstances.map(new Text("bird"), instance, ctx);
+
+    EasyMock.verify(ctx);
+  }
+
+  @Test
+  public void skip() throws Exception {
+
+    Counter skippedInstances = EasyMock.createMock(Counter.class);
+
+    EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
+    skippedInstances.increment(1);
+
+    EasyMock.replay(ctx, skippedInstances);
+
+    IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+    setField(indexInstances, "labelIndex", labelIndex);
+
+    indexInstances.map(new Text("fish"), instance, ctx);
+
+    EasyMock.verify(ctx, skippedInstances);
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java Mon Jul  4 06:49:17 2011
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+public class ThetaMapperTest extends MahoutTestCase {
+
+  @Test
+  public void standard() throws Exception {
+
+    Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class);
+    AbstractThetaTrainer trainer = EasyMock.createMock(AbstractThetaTrainer.class);
+
+    Vector instance1 = new DenseVector(new double[] { 1, 2, 3 });
+    Vector instance2 = new DenseVector(new double[] { 4, 5, 6 });
+
+    Vector perLabelThetaNormalizer = new DenseVector(new double[] { 7, 8 });
+
+    ThetaMapper thetaMapper = new ThetaMapper();
+    setField(thetaMapper, "trainer", trainer);
+
+    trainer.train(0, instance1);
+    trainer.train(1, instance2);
+    EasyMock.expect(trainer.retrievePerLabelThetaNormalizer()).andReturn(perLabelThetaNormalizer);
+    ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+
+    EasyMock.replay(ctx, trainer);
+
+    thetaMapper.map(new IntWritable(0), new VectorWritable(instance1), ctx);
+    thetaMapper.map(new IntWritable(1), new VectorWritable(instance2), ctx);
+    thetaMapper.cleanup(ctx);
+
+    EasyMock.verify(ctx, trainer);
+  }
+
+
+}



Mime
View raw message