mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From robina...@apache.org
Subject svn commit: r1005262 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/naivebayes/ main/java/org/apache/mahout/classifier/naivebayes/trainer/ main/java/org/apache/mahout/common/ test/java/org/apache/mahout/classifier/naivebayes/
Date Wed, 06 Oct 2010 21:38:42 GMT
Author: robinanil
Date: Wed Oct  6 21:38:41 2010
New Revision: 1005262

URL: http://svn.apache.org/viewvc?rev=1005262&view=rev
Log:
MAHOUT-287 Vector input based NaiveBayes classifier

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.util.Iterator;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm
+ * 
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier { 
+  protected NaiveBayesModel model;
+  
+  public AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+    this.model = model;
+  }
+  
+  public abstract double getScoreForLabelFeature(int label, int feature);
+  
+  public double getScoreForLabelInstance(int label, Vector instance) {
+    double result = 0.0;
+    Iterator<Element> it = instance.iterateNonZero();
+    while (it.hasNext()) {
+      Element e = it.next();
+      result +=  getScoreForLabelFeature(label, e.index());
+    }
+    return result;
+  }
+  
+  @Override
+  public int numCategories() {
+    return model.getNumLabels();
+  }
+
+  @Override
+  public Vector classify(Vector instance) {
+    Vector score = model.getLabelSum().like();
+    for (int i = 0; i < score.size(); i++) {
+      score.set(i, getScoreForLabelInstance(i, instance));
+    }
+    return score;
+  }
+
+  @Override
+  public double classifyScalar(Vector instance) {
+    throw new UnsupportedOperationException("Not supported in Naive Bayes");
+  }
+  
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+/**
+ * Class containing Constants used by Naive Bayes classifier classes
+ * 
+ */
+public final class BayesConstants {
+  
+  // Ensure all the strings are unique
+  public static final String ALPHA_SMOOTHING_FACTOR = "__SF"; // -
+  
+  public static final String WEIGHT = "__WT";
+  
+  public static final String FEATURE_SUM = "__SJ";
+  
+  public static final String LABEL_SUM = "__SK";
+  
+  public static final String LABEL_THETA_NORMALIZER = "_LTN";
+  
+  private BayesConstants() { }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm
+ * 
+ */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier { 
+ 
+  public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    double result = model.getWeightMatrix().get(label, feature);
+    double vocabCount = model.getVocabCount();
+    double featureSum = model.getFeatureSum().get(feature);
+    double totalSum = model.getTotalSum();
+    double labelSum = model.getLabelSum().get(label);
+    double numerator = featureSum - result + model.getAlphaI();
+    double denominator =  totalSum - labelSum + vocabCount;
+    double weight = Math.log(numerator / denominator);
+    result = weight / model.getPerlabelThetaNormalizer().get(label);
+    return result;
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,304 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.naivebayes.trainer.NaiveBayesTrainer;
+import org.apache.mahout.math.JsonMatrixAdapter;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParseException;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+/**
+ * 
+ * NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.
+ *
+ */
+public class NaiveBayesModel implements JsonDeserializer<NaiveBayesModel>, JsonSerializer<NaiveBayesModel>, Cloneable {
+ 
+  private Vector labelSum;
+  private Vector perlabelThetaNormalizer;
+  private Vector featureSum;
+  private Matrix weightMatrix;
+  private float alphaI;
+  private double vocabCount;
+  private double totalSum;
+  
+  private NaiveBayesModel() { 
+    // do nothing
+  }
+  
+  public NaiveBayesModel(Matrix matrix, Vector featureSum, Vector labelSum, Vector thetaNormalizer, float alphaI) {
+    this.weightMatrix = matrix;
+    this.featureSum = featureSum;
+    this.labelSum = labelSum;
+    this.perlabelThetaNormalizer = thetaNormalizer;
+    this.vocabCount = featureSum.getNumNondefaultElements();
+    this.totalSum = labelSum.zSum();
+    this.alphaI = alphaI;
+  }
+
+  private void setLabelSum(Vector labelSum) {
+    this.labelSum = labelSum;
+  }
+
+
+  public void setPerlabelThetaNormalizer(Vector perlabelThetaNormalizer) {
+    this.perlabelThetaNormalizer = perlabelThetaNormalizer;
+  }
+
+
+  public void setFeatureSum(Vector featureSum) {
+    this.featureSum = featureSum;
+  }
+
+
+  public void setWeightMatrix(Matrix weightMatrix) {
+    this.weightMatrix = weightMatrix;
+  }
+
+
+  public void setAlphaI(float alphaI) {
+    this.alphaI = alphaI;
+  }
+
+
+  public void setVocabCount(double vocabCount) {
+    this.vocabCount = vocabCount;
+  }
+
+
+  public void setTotalSum(double totalSum) {
+    this.totalSum = totalSum;
+  }
+  
+  public Vector getLabelSum() {
+    return labelSum;
+  }
+
+  public Vector getPerlabelThetaNormalizer() {
+    return perlabelThetaNormalizer;
+  }
+
+  public Vector getFeatureSum() {
+    return featureSum;
+  }
+
+  public Matrix getWeightMatrix() {
+    return weightMatrix;
+  }
+
+  public float getAlphaI() {
+    return alphaI;
+  }
+
+  public double getVocabCount() {
+    return vocabCount;
+  }
+
+  public double getTotalSum() {
+    return totalSum;
+  }
+  
+  public int getNumLabels() {
+    return labelSum.size();
+  }
+
+  public static String getModelName() {
+    return MODEL;
+  }
+  
+  // CODE USED FOR SERIALIZATION
+  public static NaiveBayesModel fromMRTrainerOutput(Path output, Configuration conf) throws IOException {
+    Path classVectorPath = new Path(output, NaiveBayesTrainer.CLASS_VECTORS);
+    Path sumVectorPath = new Path(output, NaiveBayesTrainer.SUM_VECTORS);
+    Path thetaSumPath = new Path(output, NaiveBayesTrainer.THETA_SUM);
+
+    NaiveBayesModel model = new NaiveBayesModel();
+    model.setAlphaI(conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f));
+    
+    FileSystem fs = sumVectorPath.getFileSystem(conf);
+    SequenceFile.Reader reader = new SequenceFile.Reader(fs, sumVectorPath, conf);
+    Text key = new Text();
+    VectorWritable value = new VectorWritable();
+
+    int featureCount = 0;
+    int labelCount = 0;
+    // read feature sums and label sums
+    while (reader.next(key, value)) {
+      if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
+        model.setFeatureSum(value.get());
+        featureCount = value.get().getNumNondefaultElements();
+        model.setVocabCount(featureCount);       
+      } else  if (key.toString().equals(BayesConstants.LABEL_SUM)) {
+        model.setLabelSum(value.get());
+        model.setTotalSum(value.get().zSum());
+        labelCount = value.get().size();
+      }
+    }
+    reader.close();
+    
+    // read the class matrix
+    reader = new SequenceFile.Reader(fs, classVectorPath, conf);
+    IntWritable label = new IntWritable();
+    Matrix matrix = new SparseMatrix(new int[] {labelCount, featureCount});
+    while (reader.next(label, value)) {
+      matrix.assignRow(label.get(), value.get());
+    }
+    reader.close();
+    
+    model.setWeightMatrix(matrix);
+   
+    
+    
+    reader = new SequenceFile.Reader(fs, thetaSumPath, conf);
+    // read theta normalizer
+    while (reader.next(key, value)) {
+      if (key.toString().equals(BayesConstants.LABEL_THETA_NORMALIZER)) {
+        model.setPerlabelThetaNormalizer(value.get());
+      }
+    }
+    reader.close();
+    
+    return model;
+  }
+  
+  /**
+   * Encode this NaiveBayesModel as a JSON string
+   *
+   * @return String containing the JSON of this model
+   */
+  public String toJson() {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(NaiveBayesModel.class, this);
+    Gson gson = builder.create();
+    return gson.toJson(this);
+  }
+
+  /**
+   * Decode this NaiveBayesModel from a JSON string
+   *
+   * @param json String containing JSON representation of this model
+   * @return Initialized model
+   */
+  public static NaiveBayesModel fromJson(String json) {
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(NaiveBayesModel.class, new NaiveBayesModel());
+    Gson gson = builder.create();
+    return gson.fromJson(json, NaiveBayesModel.class);
+  }
+   
+  private static final String MODEL = "NaiveBayesModel";
+
+  @Override
+  public JsonElement serialize(NaiveBayesModel model,
+                               Type type,
+                               JsonSerializationContext context) {
+    // now register the builders for matrix / vector
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Matrix.class, new JsonMatrixAdapter());
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    // create a model
+    JsonObject json = new JsonObject();
+    // first, we add the model
+    json.add(MODEL, new JsonPrimitive(gson.toJson(model)));
+    return json;
+  }
+
+  @Override
+  public NaiveBayesModel deserialize(JsonElement json,
+                                     Type type,
+                                     JsonDeserializationContext context) throws JsonParseException {
+    // register the builders for matrix / vector
+    GsonBuilder builder = new GsonBuilder();
+    builder.registerTypeAdapter(Matrix.class, new JsonMatrixAdapter());
+    builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+    Gson gson = builder.create();
+    // now decode the original model
+    JsonObject obj = json.getAsJsonObject();
+    String modelString = obj.get(MODEL).getAsString();
+    NaiveBayesModel model = gson.fromJson(modelString, NaiveBayesModel.class);
+   
+    // return the model
+    return model;
+  }
+  
+  public static void validate(NaiveBayesModel model) {
+    if (model == null) {
+      return; // empty models are valid
+    }
+
+    if (model.getAlphaI() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: AlphaI has to be greater than 0!");
+    }
+
+    if (model.getVocabCount() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: The vocab count has to be greater than 0!");
+    }
+
+    if (model.getVocabCount() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: The vocab count has to be greater than 0!");
+    }
+    
+    if (model.getTotalSum() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: The vocab count has to be greater than 0!");
+    }    
+
+    if (model.getLabelSum() == null || model.getLabelSum().getNumNondefaultElements() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: The number of labels has to be greater than 0 or defined!");
+    }  
+    
+    if (model.getPerlabelThetaNormalizer() == null ||
+        model.getPerlabelThetaNormalizer().getNumNondefaultElements() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: The number of theta normalizers has to be greater than 0 or defined!");
+    }
+    
+    if (model.getFeatureSum() == null ||model.getFeatureSum().getNumNondefaultElements() <= 0) {
+      throw new IllegalArgumentException(
+          "Error: The number of features has to be greater than 0 or defined!");
+    }
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm
+ * 
+ */
+public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier { 
+ 
+  public StandardNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    double result = model.getWeightMatrix().get(label, feature);
+    double vocabCount = model.getVocabCount();
+    double sumLabelWeight = model.getLabelSum().get(label);
+    double numerator = result + model.getAlphaI();
+    double denominator = sumLabelWeight + vocabCount;
+    double weight = - Math.log(numerator / denominator);
+    result = weight / model.getPerlabelThetaNormalizer().get(label);
+    return result;
+  }
+  
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesInstanceMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+  
+  private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+  
+  @Override
+  protected void map(Text key, VectorWritable value, Context context)
+      throws IOException, InterruptedException {
+    if (!labelMap.containsKey(key.toString())) {
+      context.getCounter("NaiveBayes", "Skipped instance: not in label list");
+      return;
+    }  
+    int label = labelMap.get(key.toString());
+    context.write(new IntWritable(label), value);
+  }
+  
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    try {
+      URI[] localFiles = DistributedCache.getCacheFiles(conf);
+      if (localFiles == null || localFiles.length < 1) {
+        throw new IllegalArgumentException("missing paths from the DistributedCache");
+      }
+      Path labelMapFile = new Path(localFiles[0].getPath());
+      FileSystem fs = labelMapFile.getFileSystem(conf);
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+      Writable key = new Text();
+      IntWritable value = new IntWritable();
+
+      // key is word value is id
+      while (reader.next(key, value)) {
+        labelMap.put(key.toString(), value.get());
+      }
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
+    }
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Can also be used as a local Combiner. This accumulates all the features and the weights and sums them up.
+ */
+public class NaiveBayesSumReducer extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+  @Override
+  protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context)
+      throws IOException, InterruptedException {
+    Vector vector = null;
+    for (VectorWritable v : values) {
+      if (vector == null) {
+        vector = v.get();
+      } else {
+        v.get().addTo(vector);
+      }
+    }
+    context.write(key, new VectorWritable(vector));
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesConstants;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesThetaComplementaryMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+  
+  private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+  private Vector featureSum;
+  private Vector labelSum;
+  private Vector perLabelThetaNormalizer;
+  private double alphaI = 1.0;
+  private double vocabCount;
+  private double totalSum = 0;
+  
+  @Override
+  protected void map(IntWritable key, VectorWritable value, Context context)
+      throws IOException, InterruptedException {
+    Vector vector = value.get();
+    int label = key.get();
+    double sigmaK = labelSum.get(label);
+    Iterator<Element> it = vector.iterateNonZero();
+    while (it.hasNext()) {
+      Element e = it.next();
+      double numerator = featureSum.get(e.index()) - e.get() + alphaI;
+      double denominator = totalSum - sigmaK + vocabCount;
+      double weight = Math.log(numerator / denominator);
+      perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
+    }
+  }
+  
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    try {
+      URI[] localFiles = DistributedCache.getCacheFiles(conf);
+      if (localFiles == null || localFiles.length < 2) {
+        throw new IllegalArgumentException("missing paths from the DistributedCache");
+      }
+      alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f);
+      Path weightFile = new Path(localFiles[0].getPath());
+      FileSystem fs = weightFile.getFileSystem(conf);
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, weightFile, conf);
+      Text key = new Text();
+      VectorWritable value = new VectorWritable();
+
+      while (reader.next(key, value)) {
+        if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
+          featureSum = value.get();
+        } else  if (key.toString().equals(BayesConstants.LABEL_SUM)) {
+          labelSum = value.get();
+        }
+      }
+      perLabelThetaNormalizer = labelSum.like();
+      totalSum = labelSum.zSum();
+      vocabCount = featureSum.getNumNondefaultElements();
+      
+      Path labelMapFile = new Path(localFiles[1].getPath());
+      fs = labelMapFile.getFileSystem(conf);
+
+      reader.close();
+      reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+      IntWritable intValue = new IntWritable();
+
+      // key is word value is id
+      while (reader.next(key, intValue)) {
+        labelMap.put(key.toString(), intValue.get());
+      }
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
+    }
+  }
+  
+  @Override
+  protected void cleanup(Context context) throws IOException, InterruptedException {
+    context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+    super.cleanup(context);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,101 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesConstants;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+  
+  private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+  private Vector featureSum;
+  private Vector labelSum;
+  private Vector perLabelThetaNormalizer;
+  private double alphaI = 1.0;
+  private double vocabCount;
+  
+  @Override
+  protected void map(IntWritable key, VectorWritable value, Context context)
+      throws IOException, InterruptedException {
+    Vector vector = value.get();
+    int label = key.get();
+    double weight = Math.log((vector.zSum() + alphaI) / (labelSum.get(label) + vocabCount));
+    perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
+  }
+  
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    try {
+      URI[] localFiles = DistributedCache.getCacheFiles(conf);
+      if (localFiles == null || localFiles.length < 2) {
+        throw new IllegalArgumentException("missing paths from the DistributedCache");
+      }
+      alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f);
+      Path weightFile = new Path(localFiles[0].getPath());
+      FileSystem fs = weightFile.getFileSystem(conf);
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, weightFile, conf);
+      Text key = new Text();
+      VectorWritable value = new VectorWritable();
+
+      while (reader.next(key, value)) {
+        if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
+          featureSum = value.get();
+        } else  if (key.toString().equals(BayesConstants.LABEL_SUM)) {
+          labelSum = value.get();
+        }
+      }
+      perLabelThetaNormalizer = labelSum.like();
+      vocabCount = featureSum.getNumNondefaultElements();
+      
+      Path labelMapFile = new Path(localFiles[1].getPath());
+      fs = labelMapFile.getFileSystem(conf);
+
+      reader.close();
+      reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+      IntWritable intValue = new IntWritable();
+
+      // key is word value is id
+      while (reader.next(key, intValue)) {
+        labelMap.put(key.toString(), intValue.get());
+      }
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
+    }
+  }
+  
+  @Override
+  protected void cleanup(Context context) throws IOException, InterruptedException {
+    context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+    super.cleanup(context);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,202 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes)
+ * 
+ * 
+ */
+public final class NaiveBayesTrainer {
+  
+  public static final String THETA_SUM = "thetaSum";
+  public static final String SUM_VECTORS = "sumVectors";
+  public static final String CLASS_VECTORS = "classVectors";
+  public static final String LABEL_MAP = "labelMap";
+  public static final String ALPHA_I = "alphaI";
+
+  public static void trainNaiveBayes(Path input,
+                                      Configuration conf,
+                                      List<String> inputLabels,
+                                      Path output,
+                                      int numReducers,
+                                      float alphaI,
+                                      boolean trainComplementary)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    conf.setFloat(ALPHA_I, alphaI);
+    Path labelMapPath = createLabelMapFile(inputLabels, conf, new Path(output, LABEL_MAP));
+    Path classVectorPath =  new Path(output, CLASS_VECTORS);
+    runNaiveBayesByLabelSummer(input, conf, labelMapPath, classVectorPath, numReducers);
+    Path weightFilePath = new Path(output, SUM_VECTORS);
+    runNaiveBayesWeightSummer(classVectorPath, conf, labelMapPath, weightFilePath, numReducers);
+    Path thetaFilePath = new Path(output, THETA_SUM);
+    if (trainComplementary) {
+      runNaiveBayesThetaComplementarySummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
+    } else {
+      runNaiveBayesThetaSummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
+    }
+  }
+
+  private static void runNaiveBayesByLabelSummer(Path input, Configuration conf, Path labelMapPath,
+                                                 Path output, int numReducers)
+    throws IOException, InterruptedException, ClassNotFoundException {
+    
+    // this conf parameter needs to be set enable serialisation of conf values
+    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+        + "org.apache.hadoop.io.serializer.WritableSerialization");
+    DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
+  
+    Job job = new Job(conf);
+    job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
+        + labelMapPath.toString());
+    job.setJarByClass(NaiveBayesTrainer.class);
+    FileInputFormat.setInputPaths(job, input);
+    FileOutputFormat.setOutputPath(job, output);
+    job.setMapperClass(NaiveBayesInstanceMapper.class);
+    job.setCombinerClass(NaiveBayesSumReducer.class);
+    job.setReducerClass(NaiveBayesSumReducer.class);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(VectorWritable.class);
+    job.setNumReduceTasks(numReducers);
+    HadoopUtil.overwriteOutput(output);
+    job.waitForCompletion(true);
+  }
+
+  private static void runNaiveBayesWeightSummer(Path input, Configuration conf,
+                                                Path labelMapPath, Path output, int numReducers)
+    throws IOException, InterruptedException, ClassNotFoundException {
+    
+    // this conf parameter needs to be set enable serialisation of conf values
+    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+        + "org.apache.hadoop.io.serializer.WritableSerialization");
+    DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
+    
+    Job job = new Job(conf);
+    job.setJobName("Train Naive Bayes: input-folder: " + input);
+    job.setJarByClass(NaiveBayesTrainer.class);
+    FileInputFormat.setInputPaths(job, input);
+    FileOutputFormat.setOutputPath(job, output);
+    job.setMapperClass(NaiveBayesWeightsMapper.class);
+    job.setReducerClass(NaiveBayesSumReducer.class);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(Text.class);
+    job.setOutputValueClass(VectorWritable.class);
+    job.setNumReduceTasks(numReducers);
+    HadoopUtil.overwriteOutput(output);
+    job.waitForCompletion(true);
+  }
+  
+  private static void runNaiveBayesThetaSummer(Path input, Configuration conf,
+                                               Path weightFilePath, Path output, int numReducers)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    
+    // this conf parameter needs to be set enable serialisation of conf values
+    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+        + "org.apache.hadoop.io.serializer.WritableSerialization");
+    DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
+  
+    Job job = new Job(conf);
+    job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
+        + weightFilePath.toString());
+    job.setJarByClass(NaiveBayesTrainer.class);
+    FileInputFormat.setInputPaths(job, input);
+    FileOutputFormat.setOutputPath(job, output);
+    job.setMapperClass(NaiveBayesThetaMapper.class);
+    job.setReducerClass(NaiveBayesSumReducer.class);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(VectorWritable.class);
+    job.setNumReduceTasks(numReducers);
+    HadoopUtil.overwriteOutput(output);
+    job.waitForCompletion(true);
+  }
+
+  private static void runNaiveBayesThetaComplementarySummer(Path input, Configuration conf,
+                                                            Path weightFilePath, Path output, int numReducers)
+      throws IOException, InterruptedException, ClassNotFoundException {
+    
+    // this conf parameter needs to be set enable serialisation of conf values
+    conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+        + "org.apache.hadoop.io.serializer.WritableSerialization");
+    DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
+  
+    Job job = new Job(conf);
+    job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
+        + weightFilePath.toString());
+    job.setJarByClass(NaiveBayesTrainer.class);
+    FileInputFormat.setInputPaths(job, input);
+    FileOutputFormat.setOutputPath(job, output);
+    job.setMapperClass(NaiveBayesThetaComplementaryMapper.class);
+    job.setReducerClass(NaiveBayesSumReducer.class);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(VectorWritable.class);
+    job.setNumReduceTasks(numReducers);
+    HadoopUtil.overwriteOutput(output);
+    job.waitForCompletion(true);
+  }
+
+  
+  
+  /**
+   * Write the list of labels into a map file
+   * 
+   * @param wordCountPath
+   * @param dictionaryPathBase
+   * @throws IOException
+   */
+  public static Path createLabelMapFile(List<String> labels,
+                                         Configuration conf,
+                                         Path labelMapPathBase) throws IOException {
+    FileSystem fs = FileSystem.get(labelMapPathBase.toUri(), conf);
+    Path labelMapPath = new Path(labelMapPathBase, LABEL_MAP);
+    
+    SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, labelMapPath, Text.class, IntWritable.class);
+    int i = 0;
+    for (String label : labels) {
+      Writable key = new Text(label);
+      dictWriter.append(key, new IntWritable(i++));
+    }
+    return labelMapPath;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesConstants;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesWeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+  
+  private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+  Vector featureSum;
+  Vector labelSum;
+ 
+  @Override
+  protected void map(IntWritable key, VectorWritable value, Context context)
+      throws IOException, InterruptedException {
+    Vector vector = value.get();
+    if (featureSum == null) {
+      featureSum = new RandomAccessSparseVector(vector.size(), vector.getNumNondefaultElements());
+      labelSum = new RandomAccessSparseVector(labelMap.size());  
+    }
+    
+    int label = key.get();
+    vector.addTo(featureSum);
+    labelSum.set(label, labelSum.get(label) + vector.zSum());
+  }
+  
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    try {
+      URI[] localFiles = DistributedCache.getCacheFiles(conf);
+      if (localFiles == null || localFiles.length < 1) {
+        throw new IllegalArgumentException("missing paths from the DistributedCache");
+      }
+      Path labelMapFile = new Path(localFiles[0].getPath());
+      FileSystem fs = labelMapFile.getFileSystem(conf);
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+      Writable key = new Text();
+      IntWritable value = new IntWritable();
+
+      // key is word value is id
+      while (reader.next(key, value)) {
+        labelMap.put(key.toString(), value.get());
+      }
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
+    }
+  }
+  
+  @Override
+  protected void cleanup(Context context) throws IOException, InterruptedException {
+    context.write(new Text(BayesConstants.FEATURE_SUM), new VectorWritable(featureSum));
+    context.write(new Text(BayesConstants.LABEL_SUM), new VectorWritable(labelSum));
+    super.cleanup(context);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,170 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.mahout.math.list.IntArrayList;
+
+/**
+ * An Ordered List of Integers which can be used in a Hadoop Map/Reduce Job
+ */
+public final class IntTuple implements WritableComparable<IntTuple> {
+  
+  private IntArrayList tuple = new IntArrayList();
+  
+  public IntTuple() {}
+  
+  public IntTuple(int firstEntry) {
+    add(firstEntry);
+  }
+  
+  public IntTuple(Iterable<Integer> entries) {
+    for (Integer entry : entries) {
+      add(entry);
+    }
+  }
+  
+  public IntTuple(int[] entries) {
+    for (int entry : entries) {
+      add(entry);
+    }
+  }
+  
+  /**
+   * add an entry to the end of the list
+   * 
+   * @param entry
+   * @return true if the items get added
+   */
+  public void add(int entry) {
+    tuple.add(entry);
+  }
+  
+  /**
+   * Fetches the string at the given location
+   * 
+   * @param index
+   * @return Integer value at the given location in the tuple list
+   */
+  public int at(int index) {
+    return tuple.get(index);
+  }
+  
+  /**
+   * Replaces the string at the given index with the given newInteger
+   * 
+   * @param index
+   * @param newInteger
+   * @return The previous value at that location
+   */
+  public int replaceAt(int index, int newInteger) {
+    int old = tuple.get(index);
+    tuple.set(index, newInteger);
+    return old;
+  }
+  
+  /**
+   * Fetch the list of entries from the tuple
+   * 
+   * @return a List containing the strings in the order of insertion
+   */
+  public IntArrayList getEntries() {
+    return new IntArrayList(this.tuple.elements());
+  }
+  
+  /**
+   * Returns the length of the tuple
+   * 
+   * @return length
+   */
+  public int length() {
+    return this.tuple.size();
+  }
+  
+  @Override
+  public int hashCode() {
+    return tuple.hashCode();
+  }
+  
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj) {
+      return true;
+    }
+    if (obj == null) {
+      return false;
+    }
+    if (getClass() != obj.getClass()) {
+      return false;
+    }
+    IntTuple other = (IntTuple) obj;
+    if (tuple == null) {
+      if (other.tuple != null) {
+        return false;
+      }
+    } else if (!tuple.equals(other.tuple)) {
+      return false;
+    }
+    return true;
+  }
+  
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    int len = in.readInt();
+    tuple = new IntArrayList(len);
+    IntWritable value = new IntWritable();
+    for (int i = 0; i < len; i++) {
+      value.readFields(in);
+      tuple.add(value.get());
+    }
+  }
+  
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(tuple.size());
+    IntWritable value = new IntWritable();
+    for (int entry : tuple.elements()) {
+      value.set(entry);
+      value.write(out);
+    }
+  }
+  
+  @Override
+  public int compareTo(IntTuple otherTuple) {
+    int thisLength = length();
+    int otherLength = otherTuple.length();
+    int min = Math.min(thisLength, otherLength);
+    for (int i = 0; i < min; i++) {
+      if (this.tuple.get(i) == otherTuple.at(i)) return 0;
+      return this.tuple.get(i) - otherTuple.at(i);
+    }
+    if (thisLength < otherLength) {
+      return -1;
+    } else if (thisLength > otherLength) {
+      return 1;
+    } else {
+      return 0;
+    }
+  }
+  
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,31 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.math.DenseVector;
+import org.junit.Before;
+import org.junit.Test;
+
+
+public final class ComplementaryNaiveBayesClassifierTest extends NaiveBayesTestBase{
+
+  NaiveBayesModel model;
+  ComplementaryNaiveBayesClassifier classifier;
+  
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    model = createComplementaryNaiveBayesModel();
+    classifier = new ComplementaryNaiveBayesClassifier(model);
+  }
+  
+  @Test
+  public void testNaiveBayes() throws Exception {
+    assertEquals(classifier.numCategories(), 4);
+    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
+    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
+    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
+    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+    
+  }
+  
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.junit.Test;
+
+public class NaiveBayesModelTest extends NaiveBayesTestBase {
+  
+  @Test
+  public void testRandomModelGeneration() {
+    // make sure we generate a valid random model
+    NaiveBayesModel model = getModel();
+    // check whether the model is valid
+    NaiveBayesModel.validate(model);
+  }
+  
+  @Test
+  public void testSerialization() {
+    String serialized = getModel().toJson();
+    NaiveBayesModel model2 = NaiveBayesModel.fromJson(serialized);
+    String serialized2 = model2.toJson();
+    // since there are no equals methods for the underlying objects, we
+    // check identity via the serialization string
+    assertEquals(serialized, serialized2);
+  }  
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,116 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import java.util.Iterator;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+public class NaiveBayesTestBase extends MahoutTestCase {
+  
+  private NaiveBayesModel model;
+  
+  @Override
+  public void setUp() throws Exception {
+    super.setUp();
+    model = createNaiveBayesModel();
+    
+    // make sure the model is valid :)
+    NaiveBayesModel.validate(model);
+  }
+  
+  protected NaiveBayesModel getModel() {
+    return model;
+  }
+  
+  public double complementaryNaiveBayesThetaWeight(int label,
+                                                   Matrix weightMatrix,
+                                                   Vector labelSum,
+                                                   Vector featureSum) {
+    double weight = 0.0;
+    double alpha = 1.0d;
+    for (int i = 0; i < featureSum.size(); i++) {
+      double score = weightMatrix.get(i, label);
+      double lSum = labelSum.get(label);
+      double fSum = featureSum.get(i);
+      double totalSum = featureSum.zSum();
+      double numerator = fSum - score + alpha;
+      double denominator = totalSum - lSum + featureSum.size();
+      weight += Math.log(numerator / denominator);
+    }
+    return weight;
+  }
+  
+  public double naiveBayesThetaWeight(int label,
+                                      Matrix weightMatrix,
+                                      Vector labelSum,
+                                      Vector featureSum) {
+    double weight = 0.0;
+    double alpha = 1.0d;
+    for (int i = 0; i < featureSum.size(); i++) {
+      double score = weightMatrix.get(i, label);
+      double lSum = labelSum.get(label);
+      double numerator = score + alpha;
+      double denominator = lSum + featureSum.size();
+      weight += Math.log(numerator / denominator);
+    }
+    return weight;
+  }
+
+  public NaiveBayesModel createNaiveBayesModel() {
+    double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+                         {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+    double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
+    double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+    
+    DenseMatrix weightMatrix = new DenseMatrix(matrix);
+    DenseVector labelSum = new DenseVector(labelSumArray);
+    DenseVector featureSum = new DenseVector(featureSumArray);
+    
+    double[] thetaNormalizerSum = {naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), 
+                                   naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+                                   naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+                                   naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+    // now generate the model
+    NaiveBayesModel model = new NaiveBayesModel(weightMatrix, featureSum,
+        labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+    return model;
+  }
+  
+  public NaiveBayesModel createComplementaryNaiveBayesModel() {
+    double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+                         {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+    double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
+    double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+    
+    DenseMatrix weightMatrix = new DenseMatrix(matrix);
+    DenseVector labelSum = new DenseVector(labelSumArray);
+    DenseVector featureSum = new DenseVector(featureSumArray);
+    
+    double[] thetaNormalizerSum = {complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), 
+                                   complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+                                   complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+                                   complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+    // now generate the model
+    NaiveBayesModel model = new NaiveBayesModel(weightMatrix, featureSum,
+        labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+    return model;
+  }
+  
+  public int maxIndex(Vector instance) {
+    Iterator<Element> it = instance.iterator();
+    int maxIndex = -1;
+    double val = Integer.MIN_VALUE;
+    while (it.hasNext()) {
+      Element e = it.next();
+      if (val <= e.get()) {
+        maxIndex = e.index();
+        val = e.get();
+      }
+    }
+    return maxIndex;
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java Wed Oct  6 21:38:41 2010
@@ -0,0 +1,31 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.math.DenseVector;
+import org.junit.Before;
+import org.junit.Test;
+
+
+public final class StandardNaiveBayesClassifierTest extends NaiveBayesTestBase{
+
+  NaiveBayesModel model;
+  StandardNaiveBayesClassifier classifier;
+  
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    model = createNaiveBayesModel();
+    classifier = new StandardNaiveBayesClassifier(model);
+  }
+  
+  @Test
+  public void testNaiveBayes() throws Exception {
+    assertEquals(classifier.numCategories(), 4);
+    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
+    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
+    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
+    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+    
+  }
+  
+}



Mime
View raw message