mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From adene...@apache.org
Subject svn commit: r1213034 [2/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/df/ core/src/main/java/org/apache/mahout/cla...
Date Sun, 11 Dec 2011 17:53:51 GMT
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java Sun Dec
11 17:53:50 2011
@@ -38,17 +38,23 @@ public final class Split {
     this(attr, ig, Double.NaN);
   }
 
-  /** attribute to split for */
+  /**
+   * @return attribute to split for
+   */
   public int getAttr() {
     return attr;
   }
 
-  /** Information Gain of the split */
+  /**
+   * @return Information Gain of the split
+   */
   public double getIg() {
     return ig;
   }
 
-  /** split value for NUMERICAL attributes */
+  /**
+   * @return split value for NUMERICAL attributes
+   */
   public double getSplit() {
     return split;
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
Sun Dec 11 17:53:50 2011
@@ -39,7 +39,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Compute the frequency distribution of the "class label"
+ * Compute the frequency distribution of the "class label"<br>
+ * This class can be used when the criterion variable is the categorical attribute.
  */
 public final class Frequencies extends Configured implements Tool {
   

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
Sun Dec 11 17:53:50 2011
@@ -50,7 +50,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Temporary class used to compute the frequency distribution of the "class attribute".
+ * Temporary class used to compute the frequency distribution of the "class attribute".<br>
+ * This class can be used when the criterion variable is the categorical attribute.
  */
 public class FrequenciesJob {
   
@@ -124,7 +125,7 @@ public class FrequenciesJob {
    * 
    * @return counts[partition][label] = num tuples from 'partition' with class == label
    */
-  protected int[][] parseOutput(JobContext job) throws IOException {
+  int[][] parseOutput(JobContext job) throws IOException {
     Configuration conf = job.getConfiguration();
     
     int numMaps = conf.getInt("mapred.map.tasks", -1);
@@ -176,7 +177,7 @@ public class FrequenciesJob {
     /**
      * Useful when testing
      */
-    protected void setup(Dataset dataset) {
+    void setup(Dataset dataset) {
       converter = new DataConverter(dataset);
     }
     
@@ -189,7 +190,7 @@ public class FrequenciesJob {
       
       Instance instance = converter.convert(value.toString());
       
-      context.write(firstId, new IntWritable(dataset.getLabel(instance)));
+      context.write(firstId, new IntWritable((int) dataset.getLabel(instance)));
     }
     
   }
@@ -208,7 +209,7 @@ public class FrequenciesJob {
     /**
      * Useful when testing
      */
-    protected void setup(int nblabels) {
+    void setup(int nblabels) {
       this.nblabels = nblabels;
     }
     
@@ -236,7 +237,9 @@ public class FrequenciesJob {
     /** counts[c] = num tuples from the partition with label == c */
     private int[] counts;
     
-    protected Frequencies(long firstId, int[] counts) {
+    public Frequencies() { }
+    
+    Frequencies(long firstId, int[] counts) {
       this.firstId = firstId;
       this.counts = Arrays.copyOf(counts, counts.length);
     }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java Sun
Dec 11 17:53:50 2011
@@ -50,7 +50,8 @@ import com.google.common.base.Preconditi
 
 /**
  * This tool is used to uniformly distribute the class of all the tuples of the dataset over
a given number of
- * partitions.
+ * partitions.<br>
+ * This class can be used when the criterion variable is the categorical attribute.
  */
 public final class UDistrib {
   
@@ -63,7 +64,8 @@ public final class UDistrib {
    * Launch the uniform distribution tool. Requires the following command line arguments:<br>
    * 
    * data : data path dataset : dataset path numpartitions : num partitions output : output
path
-   * 
+   *
+   * @throws java.io.IOException
    */
   public static void main(String[] args) throws IOException {
     
@@ -175,7 +177,7 @@ public final class UDistrib {
       
       // write the tuple in files[tuple.label]
       Instance instance = converter.convert(line);
-      int label = dataset.getLabel(instance);
+      int label = (int) dataset.getLabel(instance);
       files[currents[label]].writeBytes(line);
       files[currents[label]].writeChar('\n');
       

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
Sun Dec 11 17:53:50 2011
@@ -44,5 +44,17 @@ public final class DataConverterTest ext
     for (int index = 0; index < data.size(); index++) {
       assertEquals(data.get(index), converter.convert(sData[index]));
     }
+
+    // regression
+    source = Utils.randomDoubles(rng, descriptor, true, INSTANCE_COUNT);
+    sData = Utils.double2String(source);
+    dataset = DataLoader.generateDataset(descriptor, true, sData);
+    data = DataLoader.loadData(dataset, sData);
+    
+    converter = new DataConverter(dataset);
+    
+    for (int index = 0; index < data.size(); index++) {
+      assertEquals(data.get(index), converter.convert(sData[index]));
+    }
   }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
Sun Dec 11 17:53:50 2011
@@ -57,6 +57,16 @@ public final class DataLoaderTest extend
 
     testLoadedData(data, attrs, missings, loaded);
     testLoadedDataset(data, attrs, missings, loaded);
+
+    // regression
+    data = Utils.randomDoubles(rng, descriptor, true, datasize);
+    missings = Lists.newArrayList();
+    sData = prepareData(data, attrs, missings);
+    dataset = DataLoader.generateDataset(descriptor, true, sData);
+    loaded = DataLoader.loadData(dataset, sData);
+
+    testLoadedData(data, attrs, missings, loaded);
+    testLoadedDataset(data, attrs, missings, loaded);
   }
 
   /**
@@ -81,7 +91,17 @@ public final class DataLoaderTest extend
     Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     
     assertEquals(expected, dataset);
-  }
+
+    // regression
+    data = Utils.randomDoubles(rng, descriptor, true, datasize);
+    missings = Lists.newArrayList();
+    sData = prepareData(data, attrs, missings);
+    expected = DataLoader.generateDataset(descriptor, true, sData);
+
+    dataset = DataLoader.generateDataset(descriptor, true, sData);
+    
+    assertEquals(expected, dataset);
+}
 
   /**
    * Converts the data to an array of comma-separated strings and adds some
@@ -153,14 +173,21 @@ public final class DataLoaderTest extend
         }
 
         if (attrs[attr].isNumerical()) {
-          assertEquals(vector[attr], instance.get(aId++), EPSILON);
-        } else if (attrs[attr].isCategorical()||attrs[attr].isLabel()) {
+          assertEquals(vector[attr], instance.get(aId), EPSILON);
+          aId++;
+        } else if (attrs[attr].isCategorical()) {
           checkCategorical(data, missings, loaded, attr, aId, vector[attr],
               instance.get(aId));
           aId++;
-        } /*else if (attrs[attr].isLabel()) {
-          checkLabel(data, missings, loaded, attr, vector[attr]);
-        }*/
+        } else if (attrs[attr].isLabel()) {
+          if (loaded.getDataset().isNumerical(aId)) {
+            assertEquals(vector[attr], instance.get(aId), EPSILON);
+          } else {
+            checkCategorical(data, missings, loaded, attr, aId, vector[attr],
+              instance.get(aId));
+          }
+          aId++;
+        }
       }
       
       lind++;
@@ -193,14 +220,21 @@ public final class DataLoaderTest extend
           continue;
         }
 
-        assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId));
-        
-        if (attrs[attr].isCategorical()) {
-          double nValue = instance.get(aId);
-          String oValue = Double.toString(data[index][attr]);
-          assertEquals((double) loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+        if (attrs[attr].isLabel()) {
+          if (!loaded.getDataset().isNumerical(aId)) {
+            double nValue = instance.get(aId);
+            String oValue = Double.toString(data[index][attr]);
+            assertEquals((double) loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+          }
+        } else {
+          assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId));
+          
+          if (attrs[attr].isCategorical()) {
+            double nValue = instance.get(aId);
+            String oValue = Double.toString(data[index][attr]);
+            assertEquals((double) loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+          }
         }
-
         aId++;
       }
     }
@@ -227,7 +261,19 @@ public final class DataLoaderTest extend
     Data loaded = DataLoader.loadData(dataset, fs, dataPath);
 
     testLoadedData(source, attrs, missings, loaded);
-  }
+
+    // regression
+    source = Utils.randomDoubles(rng, descriptor, true, datasize);
+    missings = Lists.newArrayList();
+    sData = prepareData(source, attrs, missings);
+    dataset = DataLoader.generateDataset(descriptor, true, sData);
+
+    dataPath = Utils.writeDataToTestFile(sData);
+    fs = dataPath.getFileSystem(new Configuration());
+    loaded = DataLoader.loadData(dataset, fs, dataPath);
+
+    testLoadedData(source, attrs, missings, loaded);
+}
 
   /**
    * Test method for
@@ -254,6 +300,19 @@ public final class DataLoaderTest extend
     Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path);
     
     assertEquals(expected, dataset);
+
+    // regression
+    source = Utils.randomDoubles(rng, descriptor, false, datasize);
+    missings = Lists.newArrayList();
+    sData = prepareData(source, attrs, missings);
+    expected = DataLoader.generateDataset(descriptor, false, sData);
+
+    path = Utils.writeDataToTestFile(sData);
+    fs = path.getFileSystem(new Configuration());
+    
+    dataset = DataLoader.generateDataset(descriptor, false, fs, path);
+    
+    assertEquals(expected, dataset);
   }
 
   /**
@@ -288,38 +347,4 @@ public final class DataLoaderTest extend
       lind++;
     }
   }
-
-  /**
-   * each time value appears in data as a label, its corresponding code must
-   * appear in all the instances with the same label.
-   *
-   * @param labelInd label's index in source
-   * @param value source label's value
-   */
-  static void checkLabel(double[][] source,
-                         Collection<Integer> missings,
-                         Data loaded,
-                         int labelInd,
-                         double value) {
-  	Dataset dataset = loaded.getDataset();
-  	
-    // label's code that corresponds to the value
-    int code = loaded.getDataset().labelCode(Double.toString(value));
-
-    int lind = 0;
-
-    for (int index = 0; index < source.length; index++) {
-      if (missings.contains(index)) {
-        continue;
-      }
-
-      if (source[index][labelInd] == value) {
-        assertEquals(code, dataset.getLabel(loaded.get(lind)));
-      } else {
-        assertFalse(code == dataset.getLabel(loaded.get(lind)));
-      }
-
-      lind++;
-    }
-  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java Sun
Dec 11 17:53:50 2011
@@ -33,13 +33,16 @@ public class DataTest extends MahoutTest
 
   private Random rng;
 
-  private Data data;
+  private Data classifierData;
+
+  private Data regressionData;
 
   @Override
   public void setUp() throws Exception {
     super.setUp();
     rng = RandomUtils.getRandom();
-    data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+    classifierData = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+    regressionData = Utils.randomData(rng, ATTRIBUTE_COUNT, true, DATA_SIZE);
   }
 
   /**
@@ -51,17 +54,45 @@ public class DataTest extends MahoutTest
     int n = 10;
 
     for (int nloop = 0; nloop < n; nloop++) {
-      int attr = rng.nextInt(data.getDataset().nbAttributes());
+      int attr = rng.nextInt(classifierData.getDataset().nbAttributes());
 
-      double[] values = data.values(attr);
+      double[] values = classifierData.values(attr);
       double value = values[rng.nextInt(values.length)];
 
-      Data eSubset = data.subset(Condition.equals(attr, value));
-      Data lSubset = data.subset(Condition.lesser(attr, value));
-      Data gSubset = data.subset(Condition.greaterOrEquals(attr, value));
+      Data eSubset = classifierData.subset(Condition.equals(attr, value));
+      Data lSubset = classifierData.subset(Condition.lesser(attr, value));
+      Data gSubset = classifierData.subset(Condition.greaterOrEquals(attr, value));
 
       for (int index = 0; index < DATA_SIZE; index++) {
-        Instance instance = data.get(index);
+        Instance instance = classifierData.get(index);
+
+        if (instance.get(attr) < value) {
+          assertTrue(lSubset.contains(instance));
+          assertFalse(eSubset.contains(instance));
+          assertFalse(gSubset.contains(instance));
+        } else if (instance.get(attr) == value) {
+          assertFalse(lSubset.contains(instance));
+          assertTrue(eSubset.contains(instance));
+          assertTrue(gSubset.contains(instance));
+        } else {
+          assertFalse(lSubset.contains(instance));
+          assertFalse(eSubset.contains(instance));
+          assertTrue(gSubset.contains(instance));
+        }
+      }
+
+      // regression
+      attr = rng.nextInt(regressionData.getDataset().nbAttributes());
+
+      values = regressionData.values(attr);
+      value = values[rng.nextInt(values.length)];
+
+      eSubset = regressionData.subset(Condition.equals(attr, value));
+      lSubset = regressionData.subset(Condition.lesser(attr, value));
+      gSubset = regressionData.subset(Condition.greaterOrEquals(attr, value));
+
+      for (int index = 0; index < DATA_SIZE; index++) {
+        Instance instance = regressionData.get(index);
 
         if (instance.get(attr) < value) {
           assertTrue(lSubset.contains(instance));
@@ -82,17 +113,23 @@ public class DataTest extends MahoutTest
 
   @Test
   public void testValues() throws Exception {
-    Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
-
-    for (int attr = 0; attr < data.getDataset().nbAttributes(); attr++) {
-      double[] values = data.values(attr);
+    for (int attr = 0; attr < classifierData.getDataset().nbAttributes(); attr++) {
+      double[] values = classifierData.values(attr);
 
       // each value of the attribute should appear exactly one time in values
       for (int index = 0; index < DATA_SIZE; index++) {
-        assertEquals(1, count(values, data.get(index).get(attr)));
+        assertEquals(1, count(values, classifierData.get(index).get(attr)));
       }
     }
 
+    for (int attr = 0; attr < regressionData.getDataset().nbAttributes(); attr++) {
+      double[] values = regressionData.values(attr);
+
+      // each value of the attribute should appear exactly one time in values
+      for (int index = 0; index < DATA_SIZE; index++) {
+        assertEquals(1, count(values, regressionData.get(index).get(attr)));
+      }
+    }
   }
 
   private static int count(double[] values, double value) {
@@ -194,19 +231,33 @@ public class DataTest extends MahoutTest
    */
   @Test
   public void testBagging() {
-    Data bag = data.bagging(rng);
+    Data bag = classifierData.bagging(rng);
 
     // the bag should have the same size as the data
-    assertEquals(data.size(), bag.size());
+    assertEquals(classifierData.size(), bag.size());
 
     // at least one element from the data should not be in the bag
     boolean found = false;
-    for (int index = 0; index < data.size() && !found; index++) {
-      found = !bag.contains(data.get(index));
+    for (int index = 0; index < classifierData.size() && !found; index++) {
+      found = !bag.contains(classifierData.get(index));
     }
     
     assertTrue("some instances from data should not be in the bag", found);
-  }
+
+    // regression
+    bag = regressionData.bagging(rng);
+
+    // the bag should have the same size as the data
+    assertEquals(regressionData.size(), bag.size());
+
+    // at least one element from the data should not be in the bag
+    found = false;
+    for (int index = 0; index < regressionData.size() && !found; index++) {
+      found = !bag.contains(regressionData.get(index));
+    }
+    
+    assertTrue("some instances from data should not be in the bag", found);
+}
 
   /**
    * Test method for
@@ -216,42 +267,61 @@ public class DataTest extends MahoutTest
   public void testRsplit() {
 
     // rsplit should handle empty subsets
-    Data source = data.clone();
+    Data source = classifierData.clone();
     Data subset = source.rsplit(rng, 0);
     assertTrue("subset should be empty", subset.isEmpty());
     assertEquals("source.size is incorrect", DATA_SIZE, source.size());
 
     // rsplit should handle full size subsets
-    source = data.clone();
+    source = classifierData.clone();
     subset = source.rsplit(rng, DATA_SIZE);
     assertEquals("subset.size is incorrect", DATA_SIZE, subset.size());
     assertTrue("source should be empty", source.isEmpty());
 
     // random case
     int subsize = rng.nextInt(DATA_SIZE);
-    source = data.clone();
+    source = classifierData.clone();
     subset = source.rsplit(rng, subsize);
     assertEquals("subset.size is incorrect", subsize, subset.size());
     assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size());
-  }
+
+    // regression
+    // rsplit should handle empty subsets
+    source = regressionData.clone();
+    subset = source.rsplit(rng, 0);
+    assertTrue("subset should be empty", subset.isEmpty());
+    assertEquals("source.size is incorrect", DATA_SIZE, source.size());
+
+    // rsplit should handle full size subsets
+    source = regressionData.clone();
+    subset = source.rsplit(rng, DATA_SIZE);
+    assertEquals("subset.size is incorrect", DATA_SIZE, subset.size());
+    assertTrue("source should be empty", source.isEmpty());
+
+    // random case
+    subsize = rng.nextInt(DATA_SIZE);
+    source = regressionData.clone();
+    subset = source.rsplit(rng, subsize);
+    assertEquals("subset.size is incorrect", subsize, subset.size());
+    assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size());
+}
 
   @Test
   public void testCountLabel() throws Exception {
-    Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
-    Dataset dataset = data.getDataset();
+    Dataset dataset = classifierData.getDataset();
     int[] counts = new int[dataset.nblabels()];
 
     int n = 10;
 
     for (int nloop = 0; nloop < n; nloop++) {
       Arrays.fill(counts, 0);
-      data.countLabels(counts);
+      classifierData.countLabels(counts);
       
-      for (int index=0;index<data.size();index++) {
-        counts[dataset.getLabel(data.get(index))]--;
+      for (int index = 0; index < classifierData.size(); index++) {
+        counts[(int) dataset.getLabel(classifierData.get(index))]--;
       }
       
-      for (int label = 0; label < data.getDataset().nblabels(); label++) {
+      for (int label = 0; label < classifierData.getDataset().nblabels(); label++) {
         assertEquals("Wrong label 'equals' count", 0, counts[0]);
       }
     }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
Sun Dec 11 17:53:50 2011
@@ -56,6 +56,15 @@ public final class DatasetTest extends M
       dataset.write(out);
       
       assertEquals(dataset, readDataset(byteOutStream.toByteArray()));
+
+      // regression
+      byteOutStream.reset();
+      
+      dataset = Utils.randomData(rng, NUM_ATTRIBUTES, true, 1).getDataset();
+      
+      dataset.write(out);
+      
+      assertEquals(dataset, readDataset(byteOutStream.toByteArray()));
     }
   }
   

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java
Sun Dec 11 17:53:50 2011
@@ -67,13 +67,13 @@ public class PartialSequentialBuilder ex
   }
 
   @Override
-  protected void configureJob(Job job, int nbTrees)
+  protected void configureJob(Job job)
       throws IOException {
     Configuration conf = job.getConfiguration();
     
     int num = conf.getInt("mapred.map.tasks", -1);
 
-    super.configureJob(job, nbTrees);
+    super.configureJob(job);
 
     // PartialBuilder sets the number of maps to 1 if we are running in 'local'
     conf.setInt("mapred.map.tasks", num);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java Sun
Dec 11 17:53:50 2011
@@ -56,13 +56,13 @@ public final class NodeTest extends Maho
   public void testReadTree() throws Exception {
     Node node1 = new CategoricalNode(rng.nextInt(), 
         new double[] { rng.nextDouble(), rng.nextDouble() }, 
-        new Node[] { new Leaf(rng.nextInt()), new Leaf(rng.nextInt()) });
+        new Node[] { new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()) });
     Node node2 = new NumericalNode(rng.nextInt(), rng.nextDouble(), 
-        new Leaf(rng.nextInt()), new Leaf(rng.nextInt()));
+        new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()));
     
     Node root = new CategoricalNode(rng.nextInt(), 
         new double[] { rng.nextDouble(), rng.nextDouble(), rng.nextDouble() }, 
-        new Node[] { node1, node2, new Leaf(rng.nextInt()) });
+        new Node[] { node1, node2, new Leaf(rng.nextDouble()) });
 
     // write the node to a DataOutput
     root.write(out);
@@ -80,7 +80,7 @@ public final class NodeTest extends Maho
   @Test
   public void testReadLeaf() throws Exception {
 
-    Node leaf = new Leaf(rng.nextInt());
+    Node leaf = new Leaf(rng.nextDouble());
     leaf.write(out);
     assertEquals(leaf, readNode());
   }
@@ -89,7 +89,7 @@ public final class NodeTest extends Maho
   public void testParseNumerical() throws Exception {
 
     Node node = new NumericalNode(rng.nextInt(), rng.nextDouble(), new Leaf(rng
-        .nextInt()), new Leaf(rng.nextInt()));
+        .nextInt()), new Leaf(rng.nextDouble()));
     node.write(out);
     assertEquals(node, readNode());
   }
@@ -98,8 +98,8 @@ public final class NodeTest extends Maho
 
     Node node = new CategoricalNode(rng.nextInt(), new double[]{rng.nextDouble(),
         rng.nextDouble(), rng.nextDouble()}, new Node[]{
-        new Leaf(rng.nextInt()), new Leaf(rng.nextInt()),
-        new Leaf(rng.nextInt())});
+        new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()),
+        new Leaf(rng.nextDouble())});
 
     node.write(out);
     assertEquals(node, readNode());

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java
Sun Dec 11 17:53:50 2011
@@ -113,8 +113,8 @@ public class BreimanExample extends Conf
     numNodesOne += forestOne.nbNodes();
     
     // compute the test set error (Selection Error), and mean tree error (One Tree Error),
-    int[] testLabels = test.extractLabels();
-    int[] predictions = new int[test.size()];
+    double[] testLabels = test.extractLabels();
+    double[] predictions = new double[test.size()];
     
     forestM.classify(test, predictions);
     sumTestErrM += ErrorEstimate.errorRate(testLabels, predictions);

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
Sun Dec 11 17:53:50 2011
@@ -37,6 +37,7 @@ import org.apache.mahout.common.CommandL
 import org.apache.mahout.classifier.df.DFUtils;
 import org.apache.mahout.classifier.df.DecisionForest;
 import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
 import org.apache.mahout.classifier.df.data.Data;
 import org.apache.mahout.classifier.df.data.DataLoader;
 import org.apache.mahout.classifier.df.data.Dataset;
@@ -65,9 +66,12 @@ public class BuildForest extends Configu
   private Long seed; // Random seed
   
   private boolean isPartial; // use partial data implementation
+  
+  private String builderName; // Tree builder class name
 
   @Override
-  public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException
{
+  public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException,
+    InstantiationException, IllegalAccessException {
     
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
     ArgumentBuilder abuilder = new ArgumentBuilder();
@@ -99,12 +103,16 @@ public class BuildForest extends Configu
         abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
         withDescription("Output path, will contain the Decision Forest").create();
 
+    Option builderOpt = obuilder.withLongName("builder").withShortName("b").withRequired(false)
+      .withArgument(abuilder.withName("builder").withMinimum(1).withMaximum(1).create()).
+      withDescription("Tree builder class name").create();
+
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
         .create();
     
     Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt)
         .withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
-        .withOption(outputOpt).withOption(helpOpt).create();
+        .withOption(outputOpt).withOption(builderOpt).withOption(helpOpt).create();
     
     try {
       Parser parser = new Parser();
@@ -127,6 +135,10 @@ public class BuildForest extends Configu
         seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
       }
 
+      if (cmdLine.hasOption(builderOpt)) {
+        builderName = cmdLine.getValue(builderOpt).toString();
+      }
+
       if (log.isDebugEnabled()) {
         log.debug("data : {}", dataName);
         log.debug("dataset : {}", datasetName);
@@ -135,6 +147,7 @@ public class BuildForest extends Configu
         log.debug("seed : {}", seed);
         log.debug("nbtrees : {}", nbTrees);
         log.debug("isPartial : {}", isPartial);
+        log.debug("builder : {}", builderName);
       }
      
       dataPath = new Path(dataName);
@@ -152,7 +165,8 @@ public class BuildForest extends Configu
     return 0;
   }
   
-  private void buildForest() throws IOException, ClassNotFoundException, InterruptedException
{
+  private void buildForest() throws IOException, ClassNotFoundException, InterruptedException,
+    InstantiationException, IllegalAccessException {
     // make sure the output path does not exist
     FileSystem ofs = outputPath.getFileSystem(getConf());
     if (ofs.exists(outputPath)) {
@@ -160,8 +174,14 @@ public class BuildForest extends Configu
       return;
     }
 
-    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
-    treeBuilder.setM(m);
+    TreeBuilder treeBuilder;
+    if (builderName == null) {
+      treeBuilder = new DefaultTreeBuilder();
+      ((DefaultTreeBuilder) treeBuilder).setM(m);
+    } else {
+      Class<?> clazz = Class.forName(builderName);
+      treeBuilder = (TreeBuilder) clazz.newInstance();
+    }
     
     Builder forestBuilder;
     

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java?rev=1213034&r1=1213033&r2=1213034&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
Sun Dec 11 17:53:50 2011
@@ -18,6 +18,8 @@
 package org.apache.mahout.classifier.df.mapreduce;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Random;
 import java.util.Scanner;
 import java.util.Arrays;
@@ -44,6 +46,7 @@ import org.apache.mahout.common.RandomUt
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.classifier.df.DFUtils;
 import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.RegressionResultAnalyzer;
 import org.apache.mahout.classifier.ResultAnalyzer;
 import org.apache.mahout.classifier.ClassifierResult;
 import org.apache.mahout.classifier.df.data.DataConverter;
@@ -179,12 +182,27 @@ public class TestForest extends Configur
       throw new IllegalArgumentException("You must specify the ouputPath when using the mapreduce
implementation");
     }
 
-    Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath,
getConf(), analyze);
+    Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath,
getConf());
 
     classifier.run();
 
     if (analyze) {
-      log.info("{}", classifier.getAnalyzer());
+      double[][] results = classifier.getResults();
+      if (results != null) {
+        Dataset dataset = Dataset.load(getConf(), datasetPath);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          RegressionResultAnalyzer regressionAnalyzer = new RegressionResultAnalyzer();
+          regressionAnalyzer.setInstances(results);
+          log.info("{}", regressionAnalyzer);
+        } else {
+          ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+          for (double[] res : results) {
+            analyzer.addInstance(dataset.getLabelString(res[0]),
+              new ClassifierResult(dataset.getLabelString(res[1]), 1.0));
+          }
+          log.info("{}", analyzer);
+        }
+      }
     }
   }
 
@@ -206,37 +224,49 @@ public class TestForest extends Configur
     long time = System.currentTimeMillis();
 
     Random rng = RandomUtils.getRandom();
-    ResultAnalyzer analyzer = analyze ? new ResultAnalyzer(Arrays.asList(dataset.labels()),
"unknown") : null;
 
+    List<double[]> resList = new ArrayList<double[]>();
     if (dataFS.getFileStatus(dataPath).isDir()) {
       //the input is a directory of files
-      testDirectory(outputPath, converter, forest, dataset, analyzer, rng);
+      testDirectory(outputPath, converter, forest, dataset, resList, rng);
     }  else {
       // the input is one single file
-      testFile(dataPath, outputPath, converter, forest, dataset, analyzer, rng);
+      testFile(dataPath, outputPath, converter, forest, dataset, resList, rng);
     }
 
     time = System.currentTimeMillis() - time;
     log.info("Classification Time: {}", DFUtils.elapsedTime(time));
 
-    if (analyzer != null) {
-      log.info("{}", analyzer);
+    if (analyze) {
+      if (dataset.isNumerical(dataset.getLabelId())) {
+        RegressionResultAnalyzer regressionAnalyzer = new RegressionResultAnalyzer();
+        double[][] results = new double[resList.size()][2];
+        regressionAnalyzer.setInstances(resList.toArray(results));
+        log.info("{}", regressionAnalyzer);
+      } else {
+        ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+        for (double[] r : resList) {
+          analyzer.addInstance(dataset.getLabelString(r[0]),
+            new ClassifierResult(dataset.getLabelString(r[1]), 1.0));
+        }
+        log.info("{}", analyzer);
+      }
     }
   }
 
-  private void testDirectory(Path outPath, DataConverter converter, DecisionForest forest,
Dataset dataset,
-                        ResultAnalyzer analyzer, Random rng) throws IOException {
+  private void testDirectory(Path outPath, DataConverter converter, DecisionForest forest,
+    Dataset dataset, List<double[]> results, Random rng) throws IOException {
     Path[] infiles = DFUtils.listOutputFiles(dataFS, dataPath);
 
     for (Path path : infiles) {
       log.info("Classifying : {}", path);
       Path outfile = outPath != null ? new Path(outPath, path.getName()).suffix(".out") :
null;
-      testFile(path, outfile, converter, forest, dataset, analyzer, rng);
+      testFile(path, outfile, converter, forest, dataset, results, rng);
     }
   }
 
-  private void testFile(Path inPath, Path outPath, DataConverter converter, DecisionForest
forest, Dataset dataset,
-                        ResultAnalyzer analyzer, Random rng) throws IOException {
+  private void testFile(Path inPath, Path outPath, DataConverter converter, DecisionForest
forest,
+    Dataset dataset, List<double[]> results, Random rng) throws IOException {
     // create the predictions file
     FSDataOutputStream ofile = null;
 
@@ -255,17 +285,14 @@ public class TestForest extends Configur
         }
 
         Instance instance = converter.convert(line);
-        int prediction = forest.classify(rng, instance);
+        double prediction = forest.classify(dataset, rng, instance);
 
         if (outputPath != null) {
-          ofile.writeChars(Integer.toString(prediction)); // write the prediction
+          ofile.writeChars(Double.toString(prediction)); // write the prediction
           ofile.writeChar('\n');
         }
-
-        if (analyzer != null) {
-          analyzer.addInstance(dataset.getLabelString(dataset.getLabel(instance)),
-                               new ClassifierResult(dataset.getLabelString(prediction), 1.0));
-        }
+        
+        results.add(new double[] {dataset.getLabel(instance), prediction});
       }
 
       scanner.close();



Mime
View raw message