mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1589027 - in /mahout/trunk: ./ mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/ mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/ mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ mr...
Date Tue, 22 Apr 2014 06:36:04 GMT
Author: ssc
Date: Tue Apr 22 06:36:03 2014
New Revision: 1589027

URL: http://svn.apache.org/r1589027
Log:
MAHOUT-1519 Remove StandardThetaTrainer

Removed:
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
Modified:
    mahout/trunk/CHANGELOG
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
    mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
    mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
    mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
    mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
    mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Tue Apr 22 06:36:03 2014
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 1.0 - unreleased
 
+  MAHOUT-1519: Remove StandardThetaTrainer (Andrew Palumbo via ssc)
+
   MAHOUT-1496: Create a website describing the distributed ALS recommender (Jian Wang via
ssc)
 
   MAHOUT-1502: Update Naive Bayes Webpage to Current Implementation (Andrew Palumbo via ssc)

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
Tue Apr 22 06:36:03 2014
@@ -58,6 +58,7 @@ public final class BayesUtils {
   public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {
 
     float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+    boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
 
     // read feature sums and label sums
     Vector scoresPerLabel = null;
@@ -81,19 +82,22 @@ public final class BayesUtils {
         new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(),
conf)) {
       scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
     }
-
-    Vector perLabelThetaNormalizer = scoresPerLabel.like();
-     for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
-        new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(),
conf)) {
-      if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER))
{
-        perLabelThetaNormalizer = entry.getSecond().get();
+    
+    // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate
it for the standard model
+    Vector perLabelThetaNormalizer = null;
+    if (isComplementary) {
+      perLabelThetaNormalizer=scoresPerLabel.like();    
+      for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+          new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(),
conf)) {
+        if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER))
{
+          perLabelThetaNormalizer = entry.getSecond().get();
+        }
       }
+      Preconditions.checkNotNull(perLabelThetaNormalizer);
     }
-
-    Preconditions.checkNotNull(perLabelThetaNormalizer);
-    
+     
     return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel,
perLabelThetaNormalizer,
-        alphaI);
+        alphaI, isComplementary);
   }
 
   /** Write the list of labels into a map file */

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
Tue Apr 22 06:36:03 2014
@@ -43,12 +43,12 @@ public class NaiveBayesModel {
   private final float alphaI;
   private final double numFeatures;
   private final double totalWeightSum;
+  private final boolean isComplementary;  
+   
+  public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL";
 
-  public NaiveBayesModel(Matrix weightMatrix,
-                         Vector weightsPerFeature,
-                         Vector weightsPerLabel,
-                         Vector thetaNormalizer,
-                         float alphaI) {
+  public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel,
Vector thetaNormalizer,
+                         float alphaI, boolean isComplementary) {
     this.weightsPerLabelAndFeature = weightMatrix;
     this.weightsPerFeature = weightsPerFeature;
     this.weightsPerLabel = weightsPerLabel;
@@ -56,6 +56,7 @@ public class NaiveBayesModel {
     this.numFeatures = weightsPerFeature.getNumNondefaultElements();
     this.totalWeightSum = weightsPerLabel.zSum();
     this.alphaI = alphaI;
+    this.isComplementary=isComplementary;
   }
 
   public double labelWeight(int label) {
@@ -93,7 +94,11 @@ public class NaiveBayesModel {
   public Vector createScoringVector() {
     return weightsPerLabel.like();
   }
-
+  
+  public boolean isComplemtary(){
+      return isComplementary;
+  }
+  
   public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException
{
     FileSystem fs = output.getFileSystem(conf);
 
@@ -102,14 +107,17 @@ public class NaiveBayesModel {
     Vector weightsPerFeature = null;
     Matrix weightsPerLabelAndFeature;
     float alphaI;
+    boolean isComplementary;
 
     FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
     try {
       alphaI = in.readFloat();
+      isComplementary = in.readBoolean();
       weightsPerFeature = VectorWritable.readVector(in);
       weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
-      perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
-
+      if (isComplementary){
+        perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
+      }
       weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size());
       for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
         weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
@@ -118,7 +126,7 @@ public class NaiveBayesModel {
       Closeables.close(in, true);
     }
     NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature,
weightsPerLabel,
-        perLabelThetaNormalizer, alphaI);
+        perLabelThetaNormalizer, alphaI, isComplementary);
     model.validate();
     return model;
   }
@@ -128,9 +136,12 @@ public class NaiveBayesModel {
     FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));
     try {
       out.writeFloat(alphaI);
+      out.writeBoolean(isComplementary);
       VectorWritable.writeVector(out, weightsPerFeature);
-      VectorWritable.writeVector(out, weightsPerLabel);
-      VectorWritable.writeVector(out, perlabelThetaNormalizer);
+      VectorWritable.writeVector(out, weightsPerLabel); 
+      if (isComplementary){
+        VectorWritable.writeVector(out, perlabelThetaNormalizer);
+      }
       for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
         VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row));
       }
@@ -149,15 +160,17 @@ public class NaiveBayesModel {
     Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined");
     Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0,
         "the feature sums have to be greater than 0!");
-    Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have
to be defined");
-    Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0,
-        "the number of theta normalizers has to be greater than 0!");    
-    Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) 
-            == Math.signum(perlabelThetaNormalizer.maxValue()), 
-       "Theta normalizers do not all have the same sign");            
-    Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() 
-            == perlabelThetaNormalizer.size(), 
-       "Theta normalizers can not have zero value.");
+    if (isComplementary){
+        Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers
have to be defined");
+        Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() >
0,
+            "the number of theta normalizers has to be greater than 0!");    
+        Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) 
+                == Math.signum(perlabelThetaNormalizer.maxValue()), 
+           "Theta normalizers do not all have the same sign");            
+        Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() 
+                == perlabelThetaNormalizer.size(), 
+           "Theta normalizers can not have zero value.");
+    }
     
   }
 }

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
Tue Apr 22 06:36:03 2014
@@ -28,8 +28,7 @@ public class StandardNaiveBayesClassifie
   @Override
   public double getScoreForLabelFeature(int label, int feature) {
     NaiveBayesModel model = getModel();
-    // Standard Naive Bayes does not use weight normalization, uncomment following line for
weight normalized NB
-    // weight = weight / model.thetaNormalizer(label);
+    // Standard Naive Bayes does not use weight normalization
     return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(),
model.numFeatures());
   }
 

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
Tue Apr 22 06:36:03 2014
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.classifier.naivebayes.test;
 
+import com.google.common.base.Preconditions;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.Text;
@@ -49,8 +50,17 @@ public class BayesTestMapper extends Map
     Configuration conf = context.getConfiguration();
     Path modelPath = HadoopUtil.getSingleCachedFile(conf);
     NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
-    boolean compl = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY));
-    if (compl) {
+    boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY));
+    
+    // ensure that if we are testing in complementary mode, the model has been
+    // trained complementary. a complementarty model will work for standard classification
+    // a standard model will not work for complementary classification
+    if (isComplementary) {
+      Preconditions.checkArgument((model.isComplemtary() == isComplementary),
+          "Complementary mode in model is different than test mode");
+    }
+    
+    if (isComplementary) {
       classifier = new ComplementaryNaiveBayesClassifier(model);
     } else {
       classifier = new StandardNaiveBayesClassifier(model);

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
Tue Apr 22 06:36:03 2014
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.classifier.naivebayes.test;
 
+import com.google.common.base.Preconditions;
 import java.io.IOException;
 import java.util.List;
 import java.util.Map;
@@ -111,14 +112,23 @@ public class TestNaiveBayesDriver extend
     boolean complementary = hasOption("testComplementary");
     FileSystem fs = FileSystem.get(getConf());
     NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
+    
+    // Ensure that if we are testing in complementary mode, the model has been
+    // trained complementary. a complementarty model will work for standard classification
+    // a standard model will not work for complementary classification
+    if (complementary){
+        Preconditions.checkArgument((model.isComplemtary() == complementary),
+            "Complementary mode in model is different from test mode");
+    }
+    
     AbstractNaiveBayesClassifier classifier;
     if (complementary) {
       classifier = new ComplementaryNaiveBayesClassifier(model);
     } else {
       classifier = new StandardNaiveBayesClassifier(model);
     }
-    SequenceFile.Writer writer =
-        SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"),
Text.class, VectorWritable.class);
+    SequenceFile.Writer writer = SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(),
"part-r-00000"),
+        Text.class, VectorWritable.class);
 
     try {
       SequenceFileDirIterable<Text, VectorWritable> dirIterable =
@@ -142,8 +152,8 @@ public class TestNaiveBayesDriver extend
         Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
     //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
 
-    //boolean complementary = parsedArgs.containsKey("testComplementary"); //always result
to false as key in hash map is "--testComplementary"
-    boolean complementary = hasOption("testComplementary"); //or  complementary = parsedArgs.containsKey("--testComplementary");
+
+    boolean complementary = hasOption("testComplementary");
     testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
     return testJob.waitForCompletion(true);
   }

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
Tue Apr 22 06:36:03 2014
@@ -17,24 +17,67 @@
 
 package org.apache.mahout.classifier.naivebayes.training;
 
+import com.google.common.base.Preconditions;
 import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
 import org.apache.mahout.math.Vector;
 
-public class ComplementaryThetaTrainer extends AbstractThetaTrainer {
+public class ComplementaryThetaTrainer {
+
+  private final Vector weightsPerFeature;
+  private final Vector weightsPerLabel;
+  private final Vector perLabelThetaNormalizer;
+  private final double alphaI;
+  private final double totalWeightSum;
+  private final double numFeatures;
 
   public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double
alphaI) {
-    super(weightsPerFeature, weightsPerLabel, alphaI);
+    Preconditions.checkNotNull(weightsPerFeature);
+    Preconditions.checkNotNull(weightsPerLabel);
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
+    this.alphaI = alphaI;
+    perLabelThetaNormalizer = weightsPerLabel.like();
+    totalWeightSum = weightsPerLabel.zSum();
+    numFeatures = weightsPerFeature.getNumNondefaultElements();
   }
 
-  @Override
   public void train(int label, Vector perLabelWeight) {
     double labelWeight = labelWeight(label);
     // sum weights for each label including those with zero word counts
-    for(int i=0; i < perLabelWeight.size(); i++){
+    for(int i = 0; i < perLabelWeight.size(); i++){
       Vector.Element perLabelWeightElement = perLabelWeight.getElement(i);
       updatePerLabelThetaNormalizer(label,
-          ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()),
perLabelWeightElement.get(),
-              totalWeightSum(), labelWeight, alphaI(), numFeatures()));
+          ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()),
+              perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures()));
     }
   }
+
+  protected double alphaI() {
+    return alphaI;
+  }
+
+  protected double numFeatures() {
+    return numFeatures;
+  }
+
+  protected double labelWeight(int label) {
+    return weightsPerLabel.get(label);
+  }
+
+  protected double totalWeightSum() {
+    return totalWeightSum;
+  }
+
+  protected double featureWeight(int feature) {
+    return weightsPerFeature.get(feature);
+  }
+
+  // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude
Errors
+  protected void updatePerLabelThetaNormalizer(int label, double weight) {
+    perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight));
+  }
+
+  public Vector retrievePerLabelThetaNormalizer() {
+    return perLabelThetaNormalizer.clone();
+  }
 }

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
Tue Apr 22 06:36:03 2014
@@ -33,7 +33,7 @@ public class ThetaMapper extends Mapper<
   public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI";
   static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary";
 
-  private AbstractThetaTrainer trainer;
+  private ComplementaryThetaTrainer trainer;
 
   @Override
   protected void setup(Context ctx) throws IOException, InterruptedException {
@@ -41,15 +41,10 @@ public class ThetaMapper extends Mapper<
     Configuration conf = ctx.getConfiguration();
 
     float alphaI = conf.getFloat(ALPHA_I, 1.0f);
-    Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf);
-
-    if (conf.getBoolean(TRAIN_COMPLEMENTARY, false)) {
-      trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
-                                              scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL),
alphaI);
-    } else {
-      trainer = new StandardThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
-                                         scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL),
alphaI);
-    }
+    Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf);    
+    
+    trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+                                            scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL),
alphaI);
   }
 
   @Override

Modified: mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
(original)
+++ mahout/trunk/mrlegacy/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
Tue Apr 22 06:36:03 2014
@@ -133,24 +133,26 @@ public final class TrainNaiveBayesJob ex
     // Put the per label and per feature vectors into the cache
     HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());
 
-    // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector
-    // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight
Magnitude Errors
-    Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
-                                 getTempPath(THETAS),
-                                 SequenceFileInputFormat.class,
-                                 ThetaMapper.class,
-                                 Text.class,
-                                 VectorWritable.class,
-                                 VectorSumReducer.class,
-                                 Text.class,
-                                 VectorWritable.class,
-                                 SequenceFileOutputFormat.class);
-    thetaSummer.setCombinerClass(VectorSumReducer.class);
-    thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
-    thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
-    succeeded = thetaSummer.waitForCompletion(true);
-    if (!succeeded) {
-      return -1;
+    if (trainComplementary){
+      // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector
+      // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight
Magnitude Errors
+      Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+                                   getTempPath(THETAS),
+                                   SequenceFileInputFormat.class,
+                                   ThetaMapper.class,
+                                   Text.class,
+                                   VectorWritable.class,
+                                   VectorSumReducer.class,
+                                   Text.class,
+                                   VectorWritable.class,
+                                   SequenceFileOutputFormat.class);
+      thetaSummer.setCombinerClass(VectorSumReducer.class);
+      thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
+      thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
+      succeeded = thetaSummer.waitForCompletion(true);
+      if (!succeeded) {
+        return -1;
+      }
     }
     
     // Put the per label theta normalizers into the cache
@@ -158,6 +160,7 @@ public final class TrainNaiveBayesJob ex
     
     // Validate our model and then write it out to the official output
     getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);
+    getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary);
     NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf());
     naiveBayesModel.validate();
     naiveBayesModel.serialize(getOutputPath(), getConf());

Modified: mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
(original)
+++ mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
Tue Apr 22 06:36:03 2014
@@ -24,9 +24,13 @@ public class NaiveBayesModelTest extends
   @Test
   public void testRandomModelGeneration() {
     // make sure we generate a valid random model
-    NaiveBayesModel model = getModel();
+    NaiveBayesModel standardModel = getStandardModel();
     // check whether the model is valid
-    model.validate();
+    standardModel.validate();
+    
+    // same for Complementary
+    NaiveBayesModel complementaryModel = getComplementaryModel();
+    complementaryModel.validate();
   }
 
 }

Modified: mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
(original)
+++ mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
Tue Apr 22 06:36:03 2014
@@ -26,17 +26,23 @@ import org.apache.mahout.math.Vector.Ele
 
 public abstract class NaiveBayesTestBase extends MahoutTestCase {
   
-  private NaiveBayesModel model;
+  private NaiveBayesModel standardModel;
+  private NaiveBayesModel complementaryModel;
   
   @Override
   public void setUp() throws Exception {
     super.setUp();
-    model = createNaiveBayesModel();
-    model.validate();
+    standardModel = createStandardNaiveBayesModel();
+    standardModel.validate();
+    complementaryModel = createComplementaryNaiveBayesModel();
+    complementaryModel.validate();
   }
   
-  protected NaiveBayesModel getModel() {
-    return model;
+  protected NaiveBayesModel getStandardModel() {
+    return standardModel;
+  }
+  protected NaiveBayesModel getComplementaryModel() {
+    return complementaryModel;
   }
   
   protected static double complementaryNaiveBayesThetaWeight(int label,
@@ -73,7 +79,7 @@ public abstract class NaiveBayesTestBase
     return weight;
   }
 
-  protected static NaiveBayesModel createNaiveBayesModel() {
+  protected static NaiveBayesModel createStandardNaiveBayesModel() {
     double[][] matrix = {
         { 0.7, 0.1, 0.1, 0.3 },
         { 0.4, 0.4, 0.1, 0.1 },
@@ -85,16 +91,10 @@ public abstract class NaiveBayesTestBase
     
     DenseMatrix weightMatrix = new DenseMatrix(matrix);
     DenseVector labelSum = new DenseVector(labelSumArray);
-    DenseVector featureSum = new DenseVector(featureSumArray);
+    DenseVector featureSum = new DenseVector(featureSumArray);    
     
-    double[] thetaNormalizerSum = {
-        naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
-        naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
-        naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
-        naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
-
     // now generate the model
-    return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum),
1.0f);
+    return new NaiveBayesModel(weightMatrix, featureSum, labelSum, null, 1.0f, false);
   }
   
   protected static NaiveBayesModel createComplementaryNaiveBayesModel() {
@@ -118,7 +118,7 @@ public abstract class NaiveBayesTestBase
         complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
 
     // now generate the model
-    return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum),
1.0f);
+    return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum),
1.0f, true);
   }
   
   protected static int maxIndex(Vector instance) {

Modified: mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
(original)
+++ mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
Tue Apr 22 06:36:03 2014
@@ -31,7 +31,7 @@ public final class StandardNaiveBayesCla
   @Before
   public void setUp() throws Exception {
     super.setUp();
-    NaiveBayesModel model = createNaiveBayesModel();
+    NaiveBayesModel model = createStandardNaiveBayesModel();
     classifier = new StandardNaiveBayesClassifier(model);
   }
   
@@ -42,7 +42,6 @@ public final class StandardNaiveBayesCla
     assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0,
1.0, 0.0, 0.0 }))));
     assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0,
0.0, 1.0, 0.0 }))));
     assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0,
0.0, 0.0, 1.0 }))));
-    
   }
   
 }

Modified: mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java?rev=1589027&r1=1589026&r2=1589027&view=diff
==============================================================================
--- mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
(original)
+++ mahout/trunk/mrlegacy/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
Tue Apr 22 06:36:03 2014
@@ -33,7 +33,7 @@ public class ThetaMapperTest extends Mah
   public void standard() throws Exception {
 
     Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class);
-    AbstractThetaTrainer trainer = EasyMock.createMock(AbstractThetaTrainer.class);
+    ComplementaryThetaTrainer trainer = EasyMock.createMock(ComplementaryThetaTrainer.class);
 
     Vector instance1 = new DenseVector(new double[] { 1, 2, 3 });
     Vector instance2 = new DenseVector(new double[] { 4, 5, 6 });



Mime
View raw message