Author: ssc
Date: Mon Jul 4 06:49:17 2011
New Revision: 1142566
URL: http://svn.apache.org/viewvc?rev=1142566&view=rev
Log:
MAHOUT-746 Refactoring of the parallel Naive Bayes implementation in org.apache.mahout.classifier.naivebayes
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/
- copied from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
- copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
- copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
- copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
- copied, changed from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesInstanceMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesSumReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesThetaComplementaryMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesThetaMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/NaiveBayesWeightsMapper.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Mon Jul 4 06:49:17 2011
@@ -23,11 +23,8 @@ import org.apache.mahout.classifier.Abst
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
-/**
- * Class implementing the Naive Bayes Classifier Algorithm
- *
- */
-public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+/** Class implementing the Naive Bayes Classifier Algorithm */
+abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
private final NaiveBayesModel model;
@@ -39,27 +36,27 @@ public abstract class AbstractNaiveBayes
return model;
}
- public abstract double getScoreForLabelFeature(int label, int feature);
-
- public double getScoreForLabelInstance(int label, Vector instance) {
+ protected abstract double getScoreForLabelFeature(int label, int feature);
+
+ protected double getScoreForLabelInstance(int label, Vector instance) {
double result = 0.0;
- Iterator<Element> it = instance.iterateNonZero();
- while (it.hasNext()) {
- result += getScoreForLabelFeature(label, it.next().index());
+ Iterator<Element> elements = instance.iterateNonZero();
+ while (elements.hasNext()) {
+ result += getScoreForLabelFeature(label, elements.next().index());
}
- return result;
+ return result / model.thetaNormalizer(label);
}
@Override
public int numCategories() {
- return model.getNumLabels();
+ return model.numLabels();
}
@Override
public Vector classify(Vector instance) {
- Vector score = model.getLabelSum().like();
- for (int i = 0; i < score.size(); i++) {
- score.set(i, getScoreForLabelInstance(i, instance));
+ Vector score = model.createScoringVector();
+ for (int label = 0; label < model.numLabels(); label++) {
+ score.set(label, getScoreForLabelInstance(label, instance));
}
return score;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Mon Jul 4 06:49:17 2011
@@ -31,16 +31,10 @@ public class ComplementaryNaiveBayesClas
@Override
public double getScoreForLabelFeature(int label, int feature) {
NaiveBayesModel model = getModel();
- double result = model.getWeightMatrix().get(label, feature);
- double vocabCount = model.getVocabCount();
- double featureSum = model.getFeatureSum().get(feature);
- double totalSum = model.getTotalSum();
- double labelSum = model.getLabelSum().get(label);
- double numerator = featureSum - result + model.getAlphaI();
- double denominator = totalSum - labelSum + vocabCount;
- double weight = Math.log(numerator / denominator);
- result = weight / model.getPerlabelThetaNormalizer().get(label);
- return result;
+ double numerator = model.featureWeight(feature) - model.weight(label, feature) + model.alphaI();
+ double denominator = model.totalWeightSum() - model.labelWeight(label) + model.alphaI() * model.numFeatures();
+
+ return Math.log(numerator / denominator);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java Mon Jul 4 06:49:17 2011
@@ -18,182 +18,134 @@
package org.apache.mahout.classifier.naivebayes;
import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.mahout.classifier.naivebayes.trainer.NaiveBayesTrainer;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-/**
- * NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.
- */
-public class NaiveBayesModel {
+import java.io.IOException;
- private static final String MODEL = "NaiveBayesModel";
+/** NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.*/
+public class NaiveBayesModel {
- private Vector labelSum;
+ private Vector weightsPerLabel;
private Vector perlabelThetaNormalizer;
- private Vector featureSum;
- private Matrix weightMatrix;
+ private Vector weightsPerFeature;
+ private Matrix weightsPerLabelAndFeature;
private float alphaI;
- private double vocabCount;
- private double totalSum;
-
- private NaiveBayesModel() {
- // do nothing
- }
-
- public NaiveBayesModel(Matrix matrix, Vector featureSum, Vector labelSum, Vector thetaNormalizer, float alphaI) {
- this.weightMatrix = matrix;
- this.featureSum = featureSum;
- this.labelSum = labelSum;
+ private double numFeatures;
+ private double totalWeightSum;
+
+ public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer,
+ float alphaI) {
+ this.weightsPerLabelAndFeature = weightMatrix;
+ this.weightsPerFeature = weightsPerFeature;
+ this.weightsPerLabel = weightsPerLabel;
this.perlabelThetaNormalizer = thetaNormalizer;
- this.vocabCount = featureSum.getNumNondefaultElements();
- this.totalSum = labelSum.zSum();
+ this.numFeatures = weightsPerFeature.getNumNondefaultElements();
+ this.totalWeightSum = weightsPerLabel.zSum();
this.alphaI = alphaI;
}
- private void setLabelSum(Vector labelSum) {
- this.labelSum = labelSum;
+ public double labelWeight(int label) {
+ return weightsPerLabel.getQuick(label);
}
-
- public void setPerlabelThetaNormalizer(Vector perlabelThetaNormalizer) {
- this.perlabelThetaNormalizer = perlabelThetaNormalizer;
+ public double thetaNormalizer(int label) {
+ return perlabelThetaNormalizer.get(label);
}
-
- public void setFeatureSum(Vector featureSum) {
- this.featureSum = featureSum;
+ public double featureWeight(int feature) {
+ return weightsPerFeature.getQuick(feature);
}
-
- public void setWeightMatrix(Matrix weightMatrix) {
- this.weightMatrix = weightMatrix;
+ public double weight(int label, int feature) {
+ return weightsPerLabelAndFeature.getQuick(label, feature);
}
-
- public void setAlphaI(float alphaI) {
- this.alphaI = alphaI;
+ public float alphaI() {
+ return alphaI;
}
-
- public void setVocabCount(double vocabCount) {
- this.vocabCount = vocabCount;
+ public double numFeatures() {
+ return numFeatures;
}
-
- public void setTotalSum(double totalSum) {
- this.totalSum = totalSum;
+ public double totalWeightSum() {
+ return totalWeightSum;
}
- public Vector getLabelSum() {
- return labelSum;
- }
-
- public Vector getPerlabelThetaNormalizer() {
- return perlabelThetaNormalizer;
- }
-
- public Vector getFeatureSum() {
- return featureSum;
+ public int numLabels() {
+ return weightsPerLabel.size();
}
- public Matrix getWeightMatrix() {
- return weightMatrix;
+ public Vector createScoringVector() {
+ return weightsPerLabel.like();
}
- public float getAlphaI() {
- return alphaI;
- }
+ public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException {
+ FileSystem fs = output.getFileSystem(conf);
- public double getVocabCount() {
- return vocabCount;
- }
+ Vector weightsPerLabel = null;
+ Vector perLabelThetaNormalizer = null;
+ Vector weightsPerFeature = null;
+ Matrix weightsPerLabelAndFeature;
+ float alphaI;
- public double getTotalSum() {
- return totalSum;
- }
-
- public int getNumLabels() {
- return labelSum.size();
- }
+ FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
+ try {
+ alphaI = in.readFloat();
+ weightsPerFeature = VectorWritable.readVector(in);
+ weightsPerLabel = VectorWritable.readVector(in);
+ perLabelThetaNormalizer = VectorWritable.readVector(in);
- public static String getModelName() {
- return MODEL;
- }
-
- // CODE USED FOR SERIALIZATION
- public static NaiveBayesModel fromMRTrainerOutput(Path output, Configuration conf) {
- Path classVectorPath = new Path(output, NaiveBayesTrainer.CLASS_VECTORS);
- Path sumVectorPath = new Path(output, NaiveBayesTrainer.SUM_VECTORS);
- Path thetaSumPath = new Path(output, NaiveBayesTrainer.THETA_SUM);
-
- NaiveBayesModel model = new NaiveBayesModel();
- model.setAlphaI(conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f));
-
- int featureCount = 0;
- int labelCount = 0;
- // read feature sums and label sums
- for (Pair<Text,VectorWritable> record
- : new SequenceFileIterable<Text, VectorWritable>(sumVectorPath, true, conf)) {
- Text key = record.getFirst();
- VectorWritable value = record.getSecond();
- if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
- model.setFeatureSum(value.get());
- featureCount = value.get().getNumNondefaultElements();
- model.setVocabCount(featureCount);
- } else if (key.toString().equals(BayesConstants.LABEL_SUM)) {
- model.setLabelSum(value.get());
- model.setTotalSum(value.get().zSum());
- labelCount = value.get().size();
+ weightsPerLabelAndFeature = new SparseMatrix(new int[] { weightsPerLabel.size(), weightsPerFeature.size() });
+ for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
+ weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
}
+ } finally {
+ Closeables.closeQuietly(in);
}
+ NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel,
+ perLabelThetaNormalizer, alphaI);
+ model.validate();
+ return model;
+ }
- // read the class matrix
- Matrix matrix = new SparseMatrix(new int[] {labelCount, featureCount});
- for (Pair<IntWritable,VectorWritable> record
- : new SequenceFileIterable<IntWritable,VectorWritable>(classVectorPath, true, conf)) {
- IntWritable label = record.getFirst();
- VectorWritable value = record.getSecond();
- matrix.assignRow(label.get(), value.get());
- }
-
- model.setWeightMatrix(matrix);
-
- // read theta normalizer
- for (Pair<Text,VectorWritable> record
- : new SequenceFileIterable<Text,VectorWritable>(thetaSumPath, true, conf)) {
- Text key = record.getFirst();
- VectorWritable value = record.getSecond();
- if (key.toString().equals(BayesConstants.LABEL_THETA_NORMALIZER)) {
- model.setPerlabelThetaNormalizer(value.get());
+ public void serialize(Path output, Configuration conf) throws IOException {
+ FileSystem fs = output.getFileSystem(conf);
+ FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));
+ try {
+ out.writeFloat(alphaI);
+ VectorWritable.writeVector(out, weightsPerFeature);
+ VectorWritable.writeVector(out, weightsPerLabel);
+ VectorWritable.writeVector(out, perlabelThetaNormalizer);
+ for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
+ VectorWritable.writeVector(out, weightsPerLabelAndFeature.getRow(row));
}
+ } finally {
+ Closeables.closeQuietly(out);
}
-
- return model;
}
- public static void validate(NaiveBayesModel model) {
- if (model == null) {
- return; // empty models are valid
- }
-
- Preconditions.checkArgument(model.getAlphaI() > 0, "Error: AlphaI has to be greater than 0!");
- Preconditions.checkArgument(model.getVocabCount() > 0, "Error: The vocab count has to be greater than 0!");
- Preconditions.checkArgument(model.getTotalSum() > 0, "Error: The vocab count has to be greater than 0!");
- Preconditions.checkArgument(model.getLabelSum() != null && model.getLabelSum().getNumNondefaultElements() > 0,
- "Error: The number of labels has to be greater than 0 and defined!");
- Preconditions.checkArgument(model.getPerlabelThetaNormalizer() != null &&
- model.getPerlabelThetaNormalizer().getNumNondefaultElements() > 0,
- "Error: The number of theta normalizers has to be greater than 0 or defined!");
- Preconditions.checkArgument(model.getFeatureSum() != null && model.getFeatureSum().getNumNondefaultElements() > 0,
- "Error: The number of features has to be greater than 0 or defined!");
+ public void validate() {
+ Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!");
+ Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!");
+ Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!");
+ Preconditions.checkArgument(weightsPerLabel != null, "the number of labels has to be defined!");
+ Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0,
+ "the number of labels has to be greater than 0!");
+ Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined");
+ Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0,
+ "the number of theta normalizers has to be greater than 0!");
+ Preconditions.checkArgument(weightsPerFeature != null, "the feature sums have to be defined");
+ Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0,
+ "the feature sums have to be greater than 0!");
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java Mon Jul 4 06:49:17 2011
@@ -18,10 +18,7 @@
package org.apache.mahout.classifier.naivebayes;
-/**
- * Class implementing the Naive Bayes Classifier Algorithm
- *
- */
+/** Class implementing the Naive Bayes Classifier Algorithm */
public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
public StandardNaiveBayesClassifier(NaiveBayesModel model) {
@@ -31,14 +28,11 @@ public class StandardNaiveBayesClassifie
@Override
public double getScoreForLabelFeature(int label, int feature) {
NaiveBayesModel model = getModel();
- double result = model.getWeightMatrix().get(label, feature);
- double vocabCount = model.getVocabCount();
- double sumLabelWeight = model.getLabelSum().get(label);
- double numerator = result + model.getAlphaI();
- double denominator = sumLabelWeight + vocabCount;
- double weight = -Math.log(numerator / denominator);
- result = weight / model.getPerlabelThetaNormalizer().get(label);
- return result;
+
+ double numerator = model.weight(label, feature) + model.alphaI();
+ double denominator = model.labelWeight(label) + model.alphaI() * model.numFeatures();
+
+ return -Math.log(numerator / denominator);
}
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.Vector;
+
+public abstract class AbstractThetaTrainer {
+
+ private Vector weightsPerFeature;
+ private Vector weightsPerLabel;
+ private Vector perLabelThetaNormalizer;
+ private double alphaI;
+ private double totalWeightSum;
+ private double numFeatures;
+
+ public AbstractThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+ Preconditions.checkNotNull(weightsPerFeature);
+ Preconditions.checkNotNull(weightsPerLabel);
+ this.weightsPerFeature = weightsPerFeature;
+ this.weightsPerLabel = weightsPerLabel;
+ this.alphaI = alphaI;
+ perLabelThetaNormalizer = weightsPerLabel.like();
+ totalWeightSum = weightsPerLabel.zSum();
+ numFeatures = weightsPerFeature.getNumNondefaultElements();
+ }
+
+ public abstract void train(int label, Vector instance);
+
+ protected double alphaI() {
+ return alphaI;
+ }
+
+ protected double numFeatures() {
+ return numFeatures;
+ }
+
+ protected double labelWeight(int label) {
+ return weightsPerLabel.get(label);
+ }
+
+ protected double totalWeightSum() {
+ return totalWeightSum;
+ }
+
+ protected double featureWeight(int feature) {
+ return weightsPerFeature.get(feature);
+ }
+
+ protected void updatePerLabelThetaNormalizer(int label, double weight) {
+ perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
+ }
+
+ public Vector retrievePerLabelThetaNormalizer() {
+ return perLabelThetaNormalizer.clone();
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.mahout.math.Vector;
+
+import java.util.Iterator;
+
+public class ComplementaryThetaTrainer extends AbstractThetaTrainer {
+
+ public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+ super(weightsPerFeature, weightsPerLabel, alphaI);
+ }
+
+ @Override
+ public void train(int label, Vector instance) {
+ double sigmaK = labelWeight(label);
+ Iterator<Vector.Element> it = instance.iterateNonZero();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ double numerator = featureWeight(e.index()) - e.get() + alphaI();
+ double denominator = totalWeightSum() - sigmaK + alphaI() * numFeatures();
+ double weight = Math.log(numerator / denominator);
+ updatePerLabelThetaNormalizer(label, weight);
+ }
+ }
+}
Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java Mon Jul 4 06:49:17 2011
@@ -15,50 +15,34 @@
* limitations under the License.
*/
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
import java.io.IOException;
-import java.net.URI;
-import com.google.common.base.Preconditions;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
-public class NaiveBayesInstanceMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
-
- private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
-
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+ public enum Counter { SKIPPED_INSTANCES }
+
+ private OpenObjectIntHashMap<String> labelIndex;
+
@Override
- protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
- if (labelMap.containsKey(key.toString())) {
- int label = labelMap.get(key.toString());
- context.write(new IntWritable(label), value);
- } else {
- context.getCounter("NaiveBayes", "Skipped instance: not in label list").increment(1);
- }
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ labelIndex = TrainUtils.readIndexFromCache(ctx.getConfiguration());
}
-
+
@Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Configuration conf = context.getConfiguration();
- URI[] localFiles = DistributedCache.getCacheFiles(conf);
- Preconditions.checkArgument(localFiles != null && localFiles.length >= 1,
- "missing paths from the DistributedCache");
- Path labelMapFile = new Path(localFiles[0].getPath());
- // key is word value is id
- for (Pair<Writable,IntWritable> record
- : new SequenceFileIterable<Writable,IntWritable>(labelMapFile, true, conf)) {
- labelMap.put(record.getFirst().toString(), record.getSecond().get());
+ protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+ String label = labelText.toString();
+ if (labelIndex.containsKey(label)) {
+ ctx.write(new IntWritable(labelIndex.get(label)), instance);
+ } else {
+ ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
}
}
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.mahout.math.Vector;
+
+public class StandardThetaTrainer extends AbstractThetaTrainer {
+
+ public StandardThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+ super(weightsPerFeature, weightsPerLabel, alphaI);
+ }
+
+ @Override
+ public void train(int label, Vector instance) {
+ double weight = Math.log((instance.zSum() + alphaI()) / (labelWeight(label) + alphaI() * numFeatures()));
+ updatePerLabelThetaNormalizer(label, weight);
+ }
+}
Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java Mon Jul 4 06:49:17 2011
@@ -15,79 +15,50 @@
* limitations under the License.
*/
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
import java.io.IOException;
-import java.net.URI;
+import java.util.Map;
-import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.classifier.naivebayes.BayesConstants;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
-public class NaiveBayesThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
-
- private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
- private Vector featureSum;
- private Vector labelSum;
- private Vector perLabelThetaNormalizer;
- private double alphaI = 1.0;
- private double vocabCount;
-
- @Override
- protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException {
- Vector vector = value.get();
- int label = key.get();
- double weight = Math.log((vector.zSum() + alphaI) / (labelSum.get(label) + vocabCount));
- perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
- }
-
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Configuration conf = context.getConfiguration();
- URI[] localFiles = DistributedCache.getCacheFiles(conf);
- Preconditions.checkArgument(localFiles != null && localFiles.length >= 2,
- "missing paths from the DistributedCache");
-
- alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f);
- Path weightFile = new Path(localFiles[0].getPath());
-
- for (Pair<Text,VectorWritable> record
- : new SequenceFileIterable<Text,VectorWritable>(weightFile, true, conf)) {
- Text key = record.getFirst();
- VectorWritable value = record.getSecond();
- if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
- featureSum = value.get();
- } else if (key.toString().equals(BayesConstants.LABEL_SUM)) {
- labelSum = value.get();
- }
- }
- perLabelThetaNormalizer = labelSum.like();
- vocabCount = featureSum.getNumNondefaultElements();
+public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI";
+ static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary";
- Path labelMapFile = new Path(localFiles[1].getPath());
+ private AbstractThetaTrainer trainer;
- // key is word value is id
- for (Pair<Writable,IntWritable> record
- : new SequenceFileIterable<Writable,IntWritable>(labelMapFile, true, conf)) {
- labelMap.put(record.getFirst().toString(), record.getSecond().get());
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ Configuration conf = ctx.getConfiguration();
+
+ float alphaI = conf.getFloat(ALPHA_I, 1.0f);
+ boolean trainComplemenary = conf.getBoolean(TRAIN_COMPLEMENTARY, false);
+ Map<String,Vector> scores = TrainUtils.readScoresFromCache(conf);
+
+ if (!trainComplemenary) {
+ trainer = new StandardThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+ scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
+ } else {
+ trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+ scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
}
}
-
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+ trainer.train(key.get(), value.get());
+ }
+
@Override
- protected void cleanup(Context context) throws IOException, InterruptedException {
- context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
- super.cleanup(context);
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
+ new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
}
}
Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java Mon Jul 4 06:49:17 2011
@@ -15,201 +15,83 @@
* limitations under the License.
*/
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
-import java.io.IOException;
-import java.net.URI;
+import java.util.Map;
-import com.google.common.io.Closeables;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
-import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
-/**
- * This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes)
- */
-public final class NaiveBayesTrainer {
-
- public static final String THETA_SUM = "thetaSum";
- public static final String SUM_VECTORS = "sumVectors";
- public static final String CLASS_VECTORS = "classVectors";
- public static final String LABEL_MAP = "labelMap";
- public static final String ALPHA_I = "alphaI";
-
- private NaiveBayesTrainer() {
- }
+/** This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes) */
+public final class TrainNaiveBayesJob extends AbstractJob {
- public static void trainNaiveBayes(Path input,
- Configuration conf,
- Iterable<String> inputLabels,
- Path output,
- int numReducers,
- float alphaI,
- boolean trainComplementary)
- throws IOException, InterruptedException, ClassNotFoundException {
- conf.setFloat(ALPHA_I, alphaI);
- Path labelMapPath = createLabelMapFile(inputLabels, conf, new Path(output, LABEL_MAP));
- Path classVectorPath = new Path(output, CLASS_VECTORS);
- runNaiveBayesByLabelSummer(input, conf, labelMapPath, classVectorPath, numReducers);
- Path weightFilePath = new Path(output, SUM_VECTORS);
- runNaiveBayesWeightSummer(classVectorPath, conf, labelMapPath, weightFilePath, numReducers);
- Path thetaFilePath = new Path(output, THETA_SUM);
- if (trainComplementary) {
- runNaiveBayesThetaComplementarySummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
- } else {
- runNaiveBayesThetaSummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
+ public static final String WEIGHTS_PER_FEATURE = "__SPF";
+ public static final String WEIGHTS_PER_LABEL = "__SPL";
+ public static final String LABEL_THETA_NORMALIZER = "_LTN";
+
+ public static final String SUMMED_OBSERVATIONS = "summedObservations";
+ public static final String WEIGHTS = "weights";
+ public static final String THETAS = "thetas";
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("labels", "l", "comma-separated list of labels to include in training", true);
+ addOption("alphaI", "a", "smoothing parameter", String.valueOf(1.0f));
+ addOption("trainComplementary", "c", "train complementary?", String.valueOf(false));
+
+ Map<String,String> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
}
- }
-
- private static void runNaiveBayesByLabelSummer(Path input,
- Configuration conf,
- Path labelMapPath,
- Path output, int numReducers)
- throws IOException, InterruptedException, ClassNotFoundException {
-
- // this conf parameter needs to be set enable serialisation of conf values
- conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
- + "org.apache.hadoop.io.serializer.WritableSerialization");
- DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
-
- Job job = new Job(conf);
- job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
- + labelMapPath.toString());
- job.setJarByClass(NaiveBayesTrainer.class);
- FileInputFormat.setInputPaths(job, input);
- FileOutputFormat.setOutputPath(job, output);
- job.setMapperClass(NaiveBayesInstanceMapper.class);
- job.setCombinerClass(VectorSumReducer.class);
- job.setReducerClass(VectorSumReducer.class);
- job.setInputFormatClass(SequenceFileInputFormat.class);
- job.setOutputFormatClass(SequenceFileOutputFormat.class);
- job.setOutputKeyClass(IntWritable.class);
- job.setOutputValueClass(VectorWritable.class);
- job.setNumReduceTasks(numReducers);
- HadoopUtil.delete(conf, output);
- job.waitForCompletion(true);
- }
- private static void runNaiveBayesWeightSummer(Path input,
- Configuration conf,
- Path labelMapPath,
- Path output,
- int numReducers)
- throws IOException, InterruptedException, ClassNotFoundException {
-
- // this conf parameter needs to be set enable serialisation of conf values
- conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
- + "org.apache.hadoop.io.serializer.WritableSerialization");
- DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
-
- Job job = new Job(conf);
- job.setJobName("Train Naive Bayes: input-folder: " + input);
- job.setJarByClass(NaiveBayesTrainer.class);
- FileInputFormat.setInputPaths(job, input);
- FileOutputFormat.setOutputPath(job, output);
- job.setMapperClass(NaiveBayesWeightsMapper.class);
- job.setReducerClass(VectorSumReducer.class);
- job.setInputFormatClass(SequenceFileInputFormat.class);
- job.setOutputFormatClass(SequenceFileOutputFormat.class);
- job.setOutputKeyClass(Text.class);
- job.setOutputValueClass(VectorWritable.class);
- job.setNumReduceTasks(numReducers);
- HadoopUtil.delete(conf, output);
- job.waitForCompletion(true);
- }
-
- private static void runNaiveBayesThetaSummer(Path input,
- Configuration conf,
- Path weightFilePath,
- Path output,
- int numReducers)
- throws IOException, InterruptedException, ClassNotFoundException {
-
- // this conf parameter needs to be set enable serialisation of conf values
- conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
- + "org.apache.hadoop.io.serializer.WritableSerialization");
- DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
-
- Job job = new Job(conf);
- job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
- + weightFilePath.toString());
- job.setJarByClass(NaiveBayesTrainer.class);
- FileInputFormat.setInputPaths(job, input);
- FileOutputFormat.setOutputPath(job, output);
- job.setMapperClass(NaiveBayesThetaMapper.class);
- job.setReducerClass(VectorSumReducer.class);
- job.setInputFormatClass(SequenceFileInputFormat.class);
- job.setOutputFormatClass(SequenceFileOutputFormat.class);
- job.setOutputKeyClass(IntWritable.class);
- job.setOutputValueClass(VectorWritable.class);
- job.setNumReduceTasks(numReducers);
- HadoopUtil.delete(conf, output);
- job.waitForCompletion(true);
- }
+ Iterable<String> labels = Splitter.on(",").split(parsedArgs.get("--labels"));
+ float alphaI = Float.parseFloat(parsedArgs.get("--alphaI"));
+ boolean trainComplementary = Boolean.parseBoolean(parsedArgs.get("--trainComplementary"));
+
+ TrainUtils.writeLabelIndex(getConf(), labels, getTempPath("labelIndex"));
+ TrainUtils.setSerializations(getConf());
+ TrainUtils.cacheFiles(getTempPath("labelIndex"), getConf());
+
+ Job indexInstances = prepareJob(getInputPath(), getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class,
+ IndexInstancesMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class,
+ VectorWritable.class, SequenceFileOutputFormat.class);
+ indexInstances.setCombinerClass(VectorSumReducer.class);
+ indexInstances.waitForCompletion(true);
+
+ Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(WEIGHTS),
+ SequenceFileInputFormat.class, WeightsMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class,
+ Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(Iterables.size(labels)));
+ weightSummer.setCombinerClass(VectorSumReducer.class);
+ weightSummer.waitForCompletion(true);
+
+ TrainUtils.cacheFiles(getTempPath(WEIGHTS), getConf());
+
+ Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(THETAS),
+ SequenceFileInputFormat.class, ThetaMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class,
+ Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ thetaSummer.setCombinerClass(VectorSumReducer.class);
+ thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
+ thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
+ thetaSummer.waitForCompletion(true);
+
+ NaiveBayesModel naiveBayesModel = TrainUtils.readModelFromTempDir(getTempPath(), getConf());
+ naiveBayesModel.validate();
+ naiveBayesModel.serialize(getOutputPath(), getConf());
- private static void runNaiveBayesThetaComplementarySummer(Path input,
- Configuration conf,
- Path weightFilePath,
- Path output,
- int numReducers)
- throws IOException, InterruptedException, ClassNotFoundException {
-
- // this conf parameter needs to be set enable serialisation of conf values
- conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
- + "org.apache.hadoop.io.serializer.WritableSerialization");
- DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
-
- Job job = new Job(conf);
- job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
- + weightFilePath.toString());
- job.setJarByClass(NaiveBayesTrainer.class);
- FileInputFormat.setInputPaths(job, input);
- FileOutputFormat.setOutputPath(job, output);
- job.setMapperClass(NaiveBayesThetaComplementaryMapper.class);
- job.setReducerClass(VectorSumReducer.class);
- job.setInputFormatClass(SequenceFileInputFormat.class);
- job.setOutputFormatClass(SequenceFileOutputFormat.class);
- job.setOutputKeyClass(IntWritable.class);
- job.setOutputValueClass(VectorWritable.class);
- job.setNumReduceTasks(numReducers);
- HadoopUtil.delete(conf, output);
- job.waitForCompletion(true);
+ return 0;
}
-
-
- /**
- * Write the list of labels into a map file
- */
- public static Path createLabelMapFile(Iterable<String> labels,
- Configuration conf,
- Path labelMapPathBase) throws IOException {
- FileSystem fs = FileSystem.get(labelMapPathBase.toUri(), conf);
- Path labelMapPath = new Path(labelMapPathBase, LABEL_MAP);
-
- SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, labelMapPath, Text.class, IntWritable.class);
- try {
- int i = 0;
- for (String label : labels) {
- Writable key = new Text(label);
- dictWriter.append(key, new IntWritable(i++));
- }
- } finally {
- Closeables.closeQuietly(dictWriter);
- }
- return labelMapPath;
- }
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainUtils.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+
+public class TrainUtils {
+
+ private TrainUtils() {}
+
+ static NaiveBayesModel readModelFromTempDir(Path base, Configuration conf) {
+
+ Vector scoresPerLabel = null;
+ Vector perlabelThetaNormalizer = null;
+ Vector scoresPerFeature = null;
+ Matrix scoresPerLabelAndFeature;
+ float alphaI;
+
+ alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+
+ // read feature sums and label sums
+ for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ String key = record.getFirst().toString();
+ VectorWritable value = record.getSecond();
+ if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+ scoresPerFeature = value.get();
+ } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+ scoresPerLabel = value.get();
+ }
+ }
+
+ Preconditions.checkNotNull(scoresPerFeature);
+ Preconditions.checkNotNull(scoresPerLabel);
+
+ scoresPerLabelAndFeature = new SparseMatrix(new int[] { scoresPerLabel.size(), scoresPerFeature.size() });
+ for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+ }
+
+ for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+ perlabelThetaNormalizer = entry.getSecond().get();
+ }
+ }
+
+ Preconditions.checkNotNull(perlabelThetaNormalizer);
+
+ return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perlabelThetaNormalizer,
+ alphaI);
+ }
+
+ protected static void setSerializations(Configuration conf) {
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ }
+
+ protected static void cacheFiles(Path fileToCache, Configuration conf) {
+ DistributedCache.setCacheFiles(new URI[] { fileToCache.toUri() }, conf);
+ }
+
+ /** Write the list of labels into a map file */
+ protected static void writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+ throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+ try {
+ int i = 0;
+ for (String label : labels) {
+ writer.append(new Text(label), new IntWritable(i++));
+ }
+ } finally {
+ Closeables.closeQuietly(writer);
+ }
+ }
+
+ private static Path cachedFile(Configuration conf) throws IOException {
+ return new Path(DistributedCache.getCacheFiles(conf)[0].getPath());
+ }
+
+ protected static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+ OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<String>();
+ for (Pair<Writable,IntWritable> entry : new SequenceFileIterable<Writable,IntWritable>(cachedFile(conf), conf)) {
+ index.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return index;
+ }
+
+ protected static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+ Map<String,Vector> sumVectors = Maps.newHashMap();
+ for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(cachedFile(conf),
+ PathType.LIST, PathFilters.partFilter(), conf)) {
+ sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return sumVectors;
+ }
+}
Copied: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java (from r1139275, mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java&r1=1139275&r2=1142566&rev=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java Mon Jul 4 06:49:17 2011
@@ -15,68 +15,50 @@
* limitations under the License.
*/
-package org.apache.mahout.classifier.naivebayes.trainer;
+package org.apache.mahout.classifier.naivebayes.training;
import java.io.IOException;
-import java.net.URI;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.filecache.DistributedCache;
-import org.apache.hadoop.fs.Path;
+import com.google.common.base.Preconditions;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.classifier.naivebayes.BayesConstants;
-import org.apache.mahout.common.Pair;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
-public class NaiveBayesWeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
-
- private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
- private Vector featureSum;
- private Vector labelSum;
-
+public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels";
+
+ private Vector weightsPerFeature;
+ private Vector weightsPerLabel;
+
@Override
- protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, InterruptedException {
- Vector vector = value.get();
- if (featureSum == null) {
- featureSum = new RandomAccessSparseVector(vector.size(), vector.getNumNondefaultElements());
- labelSum = new RandomAccessSparseVector(labelMap.size());
- }
-
- int label = key.get();
- vector.addTo(featureSum);
- labelSum.set(label, labelSum.get(label) + vector.zSum());
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
+ Preconditions.checkArgument(numLabels > 0);
+ weightsPerLabel = new RandomAccessSparseVector(numLabels);
}
-
+
@Override
- protected void setup(Context context) throws IOException, InterruptedException {
- super.setup(context);
- Configuration conf = context.getConfiguration();
- URI[] localFiles = DistributedCache.getCacheFiles(conf);
- if (localFiles == null || localFiles.length < 1) {
- throw new IllegalArgumentException("missing paths from the DistributedCache");
+ protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+ Vector instance = value.get();
+ if (weightsPerFeature == null) {
+ weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements());
}
- Path labelMapFile = new Path(localFiles[0].getPath());
- // key is word value is id
- for (Pair<Writable,IntWritable> record
- : new SequenceFileIterable<Writable,IntWritable>(labelMapFile, true, conf)) {
- labelMap.put(record.getFirst().toString(), record.getSecond().get());
- }
+ int label = index.get();
+ instance.addTo(weightsPerFeature);
+ weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
}
-
+
@Override
- protected void cleanup(Context context) throws IOException, InterruptedException {
- if (featureSum != null) {
- context.write(new Text(BayesConstants.FEATURE_SUM), new VectorWritable(featureSum));
- context.write(new Text(BayesConstants.LABEL_SUM), new VectorWritable(labelSum));
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ if (weightsPerFeature != null) {
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature));
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel));
}
- super.cleanup(context);
+ super.cleanup(ctx);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/VectorWritable.java Mon Jul 4 06:49:17 2011
@@ -187,6 +187,19 @@ public final class VectorWritable extend
}
@Override
+ public boolean equals(Object o) {
+ if (o instanceof VectorWritable) {
+ return vector.equals(((VectorWritable)o).get());
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return vector.hashCode();
+ }
+
+ @Override
public String toString() {
return vector.toString();
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java Mon Jul 4 06:49:17 2011
@@ -37,10 +37,10 @@ public final class ComplementaryNaiveBay
@Test
public void testNaiveBayes() throws Exception {
assertEquals(4, classifier.numCategories());
- assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
- assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
- assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
- assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+ assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+ assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+ assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+ assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java Mon Jul 4 06:49:17 2011
@@ -26,7 +26,7 @@ public class NaiveBayesModelTest extends
// make sure we generate a valid random model
NaiveBayesModel model = getModel();
// check whether the model is valid
- NaiveBayesModel.validate(model);
+ model.validate();
}
}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,117 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+
+public class NaiveBayesTest extends MahoutTestCase {
+
+ Configuration conf;
+ File inputFile;
+ File outputDir;
+ File tempDir;
+
+ static final Text LABEL_STOLEN = new Text("stolen");
+ static final Text LABEL_NOT_STOLEN = new Text("not_stolen");
+
+ static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
+ static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
+ static final Vector.Element TYPE_SPORTS = MathHelper.elem(2, 1);
+ static final Vector.Element TYPE_SUV = MathHelper.elem(3, 1);
+ static final Vector.Element ORIGIN_DOMESTIC = MathHelper.elem(4, 1);
+ static final Vector.Element ORIGIN_IMPORTED = MathHelper.elem(5, 1);
+
+
+ @Before
+ public void setup() throws Exception {
+ super.setUp();
+
+ conf = new Configuration();
+
+ inputFile = getTestTempFile("trainingInstances.seq");
+ outputDir = getTestTempDir("output");
+ outputDir.delete();
+ tempDir = getTestTempDir("tmp");
+
+ SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf,
+ new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class);
+
+ try {
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
+ } finally {
+ Closeables.closeQuietly(writer);
+ }
+ }
+
+ @Test
+ public void toyData() throws Exception {
+ TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
+ trainNaiveBayes.setConf(conf);
+ trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--labels", "stolen,not_stolen", "--tempDir", tempDir.getAbsolutePath() });
+
+ NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
+
+ AbstractVectorClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel);
+
+ assertEquals(2, classifier.numCategories());
+
+ Vector prediction = classifier.classify(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
+
+ // should be classified as not stolen
+ assertTrue(prediction.get(0) < prediction.get(1));
+ }
+
+ @Test
+ public void toyDataComplementary() throws Exception {
+ TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
+ trainNaiveBayes.setConf(conf);
+ trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--labels", "stolen,not_stolen", "--trainComplementary", String.valueOf(true),
+ "--tempDir", tempDir.getAbsolutePath() });
+
+ NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
+
+ AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);
+
+ assertEquals(2, classifier.numCategories());
+
+ Vector prediction = classifier.classify(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
+
+ // should be classified as not stolen
+ assertTrue(prediction.get(0) < prediction.get(1));
+ }
+
+ VectorWritable trainingInstance(Vector.Element... elems) {
+ DenseVector trainingInstance = new DenseVector(6);
+ for (Vector.Element elem : elems) {
+ trainingInstance.set(elem.index(), elem.get());
+ }
+ return new VectorWritable(trainingInstance);
+ }
+
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java Mon Jul 4 06:49:17 2011
@@ -17,8 +17,6 @@
package org.apache.mahout.classifier.naivebayes;
-import java.util.Iterator;
-
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
@@ -34,19 +32,14 @@ public class NaiveBayesTestBase extends
public void setUp() throws Exception {
super.setUp();
model = createNaiveBayesModel();
-
- // make sure the model is valid :)
- NaiveBayesModel.validate(model);
+ model.validate();
}
protected NaiveBayesModel getModel() {
return model;
}
- public double complementaryNaiveBayesThetaWeight(int label,
- Matrix weightMatrix,
- Vector labelSum,
- Vector featureSum) {
+ public double complementaryNaiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) {
double weight = 0.0;
double alpha = 1.0;
for (int i = 0; i < featureSum.size(); i++) {
@@ -61,14 +54,11 @@ public class NaiveBayesTestBase extends
return weight;
}
- public double naiveBayesThetaWeight(int label,
- Matrix weightMatrix,
- Vector labelSum,
- Vector featureSum) {
+ public double naiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) {
double weight = 0.0;
double alpha = 1.0;
- for (int i = 0; i < featureSum.size(); i++) {
- double score = weightMatrix.get(i, label);
+ for (int feature = 0; feature < featureSum.size(); feature++) {
+ double score = weightMatrix.get(feature, label);
double lSum = labelSum.get(label);
double numerator = score + alpha;
double denominator = lSum + featureSum.size();
@@ -78,52 +68,60 @@ public class NaiveBayesTestBase extends
}
public NaiveBayesModel createNaiveBayesModel() {
- double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
- {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
- double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
- double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+ double[][] matrix = {
+ { 0.7, 0.1, 0.1, 0.3 },
+ { 0.4, 0.4, 0.1, 0.1 },
+ { 0.1, 0.0, 0.8, 0.1 },
+ { 0.1, 0.1, 0.1, 0.7 } };
+
+ double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 };
+ double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 };
DenseMatrix weightMatrix = new DenseMatrix(matrix);
DenseVector labelSum = new DenseVector(labelSumArray);
DenseVector featureSum = new DenseVector(featureSumArray);
- double[] thetaNormalizerSum = {naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
- naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
- naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
- naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+ double[] thetaNormalizerSum = {
+ naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+ naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+ naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+ naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
+
// now generate the model
- return new NaiveBayesModel(weightMatrix, featureSum,
- labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+ return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
}
public NaiveBayesModel createComplementaryNaiveBayesModel() {
- double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
- {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
- double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
- double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+ double[][] matrix = {
+ { 0.7, 0.1, 0.1, 0.3 },
+ { 0.4, 0.4, 0.1, 0.1 },
+ { 0.1, 0.0, 0.8, 0.1 },
+ { 0.1, 0.1, 0.1, 0.7 } };
+
+ double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 };
+ double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 };
DenseMatrix weightMatrix = new DenseMatrix(matrix);
DenseVector labelSum = new DenseVector(labelSumArray);
DenseVector featureSum = new DenseVector(featureSumArray);
- double[] thetaNormalizerSum = {complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
- complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
- complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
- complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+ double[] thetaNormalizerSum = {
+ complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
+
// now generate the model
- return new NaiveBayesModel(weightMatrix, featureSum,
- labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+ return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
}
public int maxIndex(Vector instance) {
- Iterator<Element> it = instance.iterator();
int maxIndex = -1;
- double val = Integer.MIN_VALUE;
- while (it.hasNext()) {
- Element e = it.next();
- if (val <= e.get()) {
- maxIndex = e.index();
- val = e.get();
+ double maxScore = Integer.MIN_VALUE;
+ for (Element label : instance) {
+ if (label.get() >= maxScore) {
+ maxIndex = label.index();
+ maxScore = label.get();
}
}
return maxIndex;
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1142566&r1=1142565&r2=1142566&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java Mon Jul 4 06:49:17 2011
@@ -38,10 +38,10 @@ public final class StandardNaiveBayesCla
@Test
public void testNaiveBayes() throws Exception {
assertEquals(4, classifier.numCategories());
- assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
- assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
- assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
- assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+ assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+ assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+ assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+ assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Counter;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+public class IndexInstancesMapperTest extends MahoutTestCase {
+
+ Mapper.Context ctx;
+ OpenObjectIntHashMap<String> labelIndex;
+ VectorWritable instance;
+
+ @Before
+ public void setup() throws Exception {
+ super.setUp();
+
+ ctx = EasyMock.createMock(Mapper.Context.class);
+ instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 }));
+
+ labelIndex = new OpenObjectIntHashMap<String>();
+ labelIndex.put("bird", 0);
+ labelIndex.put("cat", 1);
+ }
+
+
+ @Test
+ public void index() throws Exception {
+
+ ctx.write(new IntWritable(0), instance);
+
+ EasyMock.replay(ctx);
+
+ IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+ setField(indexInstances, "labelIndex", labelIndex);
+
+ indexInstances.map(new Text("bird"), instance, ctx);
+
+ EasyMock.verify(ctx);
+ }
+
+ @Test
+ public void skip() throws Exception {
+
+ Counter skippedInstances = EasyMock.createMock(Counter.class);
+
+ EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
+ skippedInstances.increment(1);
+
+ EasyMock.replay(ctx, skippedInstances);
+
+ IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+ setField(indexInstances, "labelIndex", labelIndex);
+
+ indexInstances.map(new Text("fish"), instance, ctx);
+
+ EasyMock.verify(ctx, skippedInstances);
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java?rev=1142566&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java Mon Jul 4 06:49:17 2011
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+public class ThetaMapperTest extends MahoutTestCase {
+
+ @Test
+ public void standard() throws Exception {
+
+ Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class);
+ AbstractThetaTrainer trainer = EasyMock.createMock(AbstractThetaTrainer.class);
+
+ Vector instance1 = new DenseVector(new double[] { 1, 2, 3 });
+ Vector instance2 = new DenseVector(new double[] { 4, 5, 6 });
+
+ Vector perLabelThetaNormalizer = new DenseVector(new double[] { 7, 8 });
+
+ ThetaMapper thetaMapper = new ThetaMapper();
+ setField(thetaMapper, "trainer", trainer);
+
+ trainer.train(0, instance1);
+ trainer.train(1, instance2);
+ EasyMock.expect(trainer.retrievePerLabelThetaNormalizer()).andReturn(perLabelThetaNormalizer);
+ ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+
+ EasyMock.replay(ctx, trainer);
+
+ thetaMapper.map(new IntWritable(0), new VectorWritable(instance1), ctx);
+ thetaMapper.map(new IntWritable(1), new VectorWritable(instance2), ctx);
+ thetaMapper.cleanup(ctx);
+
+ EasyMock.verify(ctx, trainer);
+ }
+
+
+}
|