mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From adene...@apache.org
Subject svn commit: r820181 [2/3] - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/df/mapred/partial/ core/src/main/java/org/apache/mahout/df/mapreduce/partial/ core/src/test/java/org/apache/mahout/df/mapred/partial/ core/src/test/java/org/apach...
Date Wed, 30 Sep 2009 05:29:24 GMT
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialBuilderTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialBuilderTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialBuilderTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialBuilderTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,226 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapred.partial;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.callback.PredictionCallback;
+import org.apache.mahout.df.mapred.partial.PartialBuilder;
+import org.apache.mahout.df.mapred.partial.Step1Mapper;
+import org.apache.mahout.df.mapreduce.MapredOutput;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class PartialBuilderTest extends TestCase {
+
+  protected static final int numMaps = 5;
+
+  protected static final int numTrees = 32;
+
+  /** instances per partition */
+  protected static final int numInstances = 20;
+
+  public void testProcessOutput() throws Exception {
+    JobConf job = new JobConf();
+    job.setNumMapTasks(numMaps);
+
+    Random rng = new Random();
+
+    // prepare the output
+    TreeID[] keys = new TreeID[numTrees];
+    MapredOutput[] values = new MapredOutput[numTrees];
+    int[] firstIds = new int[numMaps];
+    randomKeyValues(rng, keys, values, firstIds);
+
+    // store the output in a sequence file
+    Path base = new Path("testdata");
+    FileSystem fs = base.getFileSystem(job);
+    if (fs.exists(base))
+      fs.delete(base, true);
+
+    Path outputFile = new Path(base, "PartialBuilderTest.seq");
+    Writer writer = SequenceFile.createWriter(fs, job, outputFile,
+        TreeID.class, MapredOutput.class);
+
+    try {
+      for (int index = 0; index < numTrees; index++) {
+        writer.append(keys[index], values[index]);
+      }
+    } finally {
+    }
+    writer.close();
+
+    // load the output and make sure its valid
+    TreeID[] newKeys = new TreeID[numTrees];
+    Node[] newTrees = new Node[numTrees];
+    
+    PartialBuilder.processOutput(job, base, firstIds, newKeys, newTrees, 
+        new TestCallback(keys, values));
+
+    // check the forest
+    for (int tree = 0; tree < numTrees; tree++) {
+      assertEquals(values[tree].getTree(), newTrees[tree]);
+    }
+
+    assertTrue("keys not equal", Arrays.deepEquals(keys, newKeys));
+  }
+
+  /**
+   * Make sure that the builder passes the good parameters to the job
+   * 
+   */
+  public void testConfigure() {
+    TreeBuilder treeBuilder = new DefaultTreeBuilder();
+    Path dataPath = new Path("notUsedDataPath");
+    Path datasetPath = new Path("notUsedDatasetPath");
+    Long seed = 5L;
+
+    new PartialBuilderChecker(treeBuilder, dataPath, datasetPath, seed);
+  }
+
+  /**
+   * Generates random (key, value) pairs. Shuffles the partition's order
+   * 
+   * @param rng
+   * @param keys
+   * @param values
+   * @param firstIds partitions's first ids in hadoop's order
+   */
+  protected void randomKeyValues(Random rng, TreeID[] keys,
+      MapredOutput[] values, int[] firstIds) {
+    int index = 0;
+    int firstId = 0;
+    List<Integer> partitions = new ArrayList<Integer>();
+    int partition;
+
+    for (int p = 0; p < numMaps; p++) {
+      // select a random partition, not yet selected
+      do {
+        partition = rng.nextInt(numMaps);
+      } while (partitions.contains(partition));
+
+      partitions.add(partition);
+
+      int nbTrees = Step1Mapper.nbTrees(numMaps, numTrees, partition);
+
+      for (int treeId = 0; treeId < nbTrees; treeId++) {
+        Node tree = new Leaf(rng.nextInt(100));
+
+        keys[index] = new TreeID(partition, treeId);
+        values[index] = new MapredOutput(tree, nextIntArray(rng, numInstances));
+
+        index++;
+      }
+      
+      firstIds[p] = firstId;
+      firstId += numInstances;
+    }
+
+  }
+
+  protected int[] nextIntArray(Random rng, int size) {
+    int[] array = new int[size];
+    for (int index = 0; index < size; index++) {
+      array[index] = rng.nextInt(101) - 1;
+    }
+
+    return array;
+  }
+
+  protected static class PartialBuilderChecker extends PartialBuilder {
+
+    protected Long _seed;
+
+    protected TreeBuilder _treeBuilder;
+
+    protected Path _datasetPath;
+
+    public PartialBuilderChecker(TreeBuilder treeBuilder, Path dataPath,
+        Path datasetPath, Long seed) {
+      super(treeBuilder, dataPath, datasetPath, seed);
+
+      _seed = seed;
+      _treeBuilder = treeBuilder;
+      _datasetPath = datasetPath;
+    }
+
+    @Override
+    protected void runJob(JobConf job) throws Exception {
+      // no need to run the job, just check if the params are correct
+
+      assertEquals(_seed, getRandomSeed(job));
+
+      // PartialBuilder should detect the 'local' mode and overrides the number
+      // of map tasks
+      assertEquals(1, job.getNumMapTasks());
+
+      assertEquals(numTrees, getNbTrees(job));
+
+      assertFalse(isOutput(job));
+      assertTrue(isOobEstimate(job));
+
+      assertEquals(_treeBuilder, getTreeBuilder(job));
+
+      assertEquals(_datasetPath, getDistributedCacheFile(job, 0));
+    }
+
+  }
+
+  /**
+   * Mock Callback. Make sure that the callback receives the correct predictions
+   * 
+   */
+  protected static class TestCallback extends PredictionCallback {
+
+    protected final TreeID[] keys;
+
+    protected final MapredOutput[] values;
+
+    public TestCallback(TreeID[] keys, MapredOutput[] values) {
+      this.keys = keys;
+      this.values = values;
+    }
+
+    @Override
+    public void prediction(int treeId, int instanceId, int prediction) {
+      int partition = instanceId / numInstances;
+
+      TreeID key = new TreeID(partition, treeId);
+      int index = ArrayUtils.indexOf(keys, key);
+      assertTrue("key not found", index >= 0);
+
+      assertEquals(values[index].getPredictions()[instanceId % numInstances],
+          prediction);
+    }
+
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialOutputCollector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialOutputCollector.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialOutputCollector.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialOutputCollector.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapred.partial;
+
+import java.io.IOException;
+
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.mahout.df.mapreduce.MapredOutput;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+
+
+public class PartialOutputCollector implements OutputCollector<TreeID, MapredOutput> {
+
+  public final TreeID[] keys;
+
+  public final MapredOutput[] values;
+
+  private int index = 0;
+
+  public PartialOutputCollector(int nbTrees) {
+    keys = new TreeID[nbTrees];
+    values = new MapredOutput[nbTrees];
+  }
+
+  public void collect(TreeID key, MapredOutput value) throws IOException {
+    if (index == keys.length) {
+      throw new IOException("Received more output than expected : " + index);
+    }
+
+    keys[index] = key.clone();
+    values[index] = value.clone();
+
+    index++;
+  }
+
+  /**
+   * Number of outputs collected
+   * 
+   * @return
+   */
+  public int nbOutputs() {
+    return index;
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialSequentialBuilder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialSequentialBuilder.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialSequentialBuilder.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartialSequentialBuilder.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,285 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapred.partial;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.RecordReader;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.TextInputFormat;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.callback.PredictionCallback;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.mapred.Builder;
+import org.apache.mahout.df.mapred.partial.PartialBuilder;
+import org.apache.mahout.df.mapred.partial.Step1Mapper;
+import org.apache.mahout.df.mapred.partial.Step2Mapper;
+import org.apache.mahout.df.mapreduce.MapredOutput;
+import org.apache.mahout.df.mapreduce.partial.InterResults;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Simulates the Partial mapreduce implementation in a sequential manner. Must
+ * receive a seed
+ */
+public class PartialSequentialBuilder extends PartialBuilder {
+
+  private static final Logger log = LoggerFactory.getLogger(PartialSequentialBuilder.class);
+
+  protected PartialOutputCollector firstOutput;
+
+  protected PartialOutputCollector secondOutput;
+
+  protected final Dataset dataset;
+
+  /** first instance id in hadoop's order */
+  protected int[] firstIds;
+  
+  /** partitions' sizes in hadoop order */
+  protected int[] sizes;
+
+  public PartialSequentialBuilder(TreeBuilder treeBuilder, Path dataPath,
+      Dataset dataset, long seed, Configuration conf) {
+    super(treeBuilder, dataPath, new Path("notUsed"), seed, conf);
+    this.dataset = dataset;
+  }
+
+  public PartialSequentialBuilder(TreeBuilder treeBuilder, Path dataPath,
+      Dataset dataset, long seed) {
+    this(treeBuilder, dataPath, dataset, seed, new Configuration());
+  }
+
+  @Override
+  protected void configureJob(JobConf job, int nbTrees, boolean oobEstimate)
+      throws IOException {
+    
+    int numMaps = job.getNumMapTasks();
+
+    super.configureJob(job, nbTrees, oobEstimate);
+
+    // PartialBuilder sets the number of maps to 1 if we are running in 'local'
+    job.setNumMapTasks(numMaps);
+  }
+
+  @Override
+  protected void runJob(JobConf job) throws IOException {
+    // retrieve the splits
+    TextInputFormat input = (TextInputFormat) job.getInputFormat();
+    InputSplit[] splits = input.getSplits(job, job.getNumMapTasks());
+    log.debug("Nb splits : " + splits.length);
+
+    InputSplit[] sorted = Arrays.copyOf(splits, splits.length);
+    Builder.sortSplits(sorted);
+
+    int numTrees = Builder.getNbTrees(job); // total number of trees
+
+    firstOutput = new PartialOutputCollector(numTrees);
+    Reporter reporter = Reporter.NULL;
+    long slowest = 0; // duration of slowest map
+
+    int firstId = 0;
+    firstIds = new int[splits.length];
+    sizes = new int[splits.length];
+    
+    // to compute firstIds, process the splits in file order
+    for (int p = 0; p < splits.length; p++) {
+      InputSplit split = splits[p];
+      int hp = ArrayUtils.indexOf(sorted, split); // hadoop's partition
+      
+      RecordReader<LongWritable, Text> reader = input.getRecordReader(split, job, reporter);
+
+      LongWritable key = reader.createKey();
+      Text value = reader.createValue();
+
+      Step1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, seed,
+          hp, splits.length, numTrees);
+
+      long time = System.currentTimeMillis();
+
+      firstIds[hp] = firstId;
+
+      while (reader.next(key, value)) {
+        mapper.map(key, value, firstOutput, reporter);
+        firstId++;
+        sizes[hp]++;
+      }
+
+      mapper.close();
+
+      time = System.currentTimeMillis() - time;
+      log.info("Duration : " + DFUtils.elapsedTime(time));
+
+      if (time > slowest) {
+        slowest = time;
+      }
+    }
+
+    log.info("Longest duration : " + DFUtils.elapsedTime(slowest));
+  }
+
+  @Override
+  protected DecisionForest parseOutput(JobConf job, PredictionCallback callback)
+      throws IOException {
+    int numMaps = job.getNumMapTasks();
+
+    DecisionForest forest = processOutput(firstOutput.keys, firstOutput.values, callback);
+
+    if (isStep2(job)) {
+      Path forestPath = new Path(getOutputPath(job), "step1.inter");
+      FileSystem fs = forestPath.getFileSystem(job);
+      
+      Node[] trees = new Node[forest.getTrees().size()];
+      forest.getTrees().toArray(trees);
+      InterResults.store(fs, forestPath, firstOutput.keys, trees, sizes);
+
+      log.info("***********");
+      log.info("Second Step");
+      log.info("***********");
+      secondStep(job, forestPath, callback);
+
+      processOutput(secondOutput.keys, secondOutput.values, callback);
+    }
+
+    return forest;
+  }
+
+  /**
+   * extract the decision forest and call the callback after correcting the instance ids
+   * 
+   * @param keys
+   * @param values
+   * @param callback
+   * @return
+   */
+  protected DecisionForest processOutput(TreeID[] keys, MapredOutput[] values, PredictionCallback callback) {
+    List<Node> trees = new ArrayList<Node>();
+
+    for (int index = 0; index < keys.length; index++) {
+      TreeID key = keys[index];
+      MapredOutput value = values[index];
+
+      trees.add(value.getTree());
+
+      int[] predictions = value.getPredictions();
+      for (int id = 0; id < predictions.length; id++) {
+        callback.prediction(key.treeId(), firstIds[key.partition()] + id,
+            predictions[id]);
+      }
+    }
+    
+    return new DecisionForest(trees);
+  }
+
+  /**
+   * The second step uses the trees to predict the rest of the instances outside
+   * their own partition
+   * 
+   * @throws IOException
+   * 
+   */
+  protected void secondStep(JobConf job, Path forestPath,
+      PredictionCallback callback) throws IOException {
+    // retrieve the splits
+    TextInputFormat input = (TextInputFormat) job.getInputFormat();
+    InputSplit[] splits = input.getSplits(job, job.getNumMapTasks());
+    log.debug("Nb splits : " + splits.length);
+
+    Builder.sortSplits(splits);
+
+    int numTrees = Builder.getNbTrees(job); // total number of trees
+
+    // compute the expected number of outputs
+    int total = 0;
+    for (int p = 0; p < splits.length; p++) {
+      total += Step2Mapper.nbConcerned(splits.length, numTrees, p);
+    }
+
+    secondOutput = new PartialOutputCollector(total);
+    Reporter reporter = Reporter.NULL;
+    long slowest = 0; // duration of slowest map
+
+    for (int partition = 0; partition < splits.length; partition++) {
+      InputSplit split = splits[partition];
+      RecordReader<LongWritable, Text> reader = input.getRecordReader(split,
+          job, reporter);
+
+      LongWritable key = reader.createKey();
+      Text value = reader.createValue();
+
+      // load the output of the 1st step
+      int nbConcerned = Step2Mapper.nbConcerned(splits.length, numTrees,
+          partition);
+      TreeID[] fsKeys = new TreeID[nbConcerned];
+      Node[] fsTrees = new Node[nbConcerned];
+
+      FileSystem fs = forestPath.getFileSystem(job);
+      int numInstances = InterResults.load(fs, forestPath, splits.length,
+          numTrees, partition, fsKeys, fsTrees);
+
+      Step2Mapper mapper = new Step2Mapper();
+      mapper.configure(partition, dataset, fsKeys, fsTrees, numInstances);
+
+      long time = System.currentTimeMillis();
+
+      while (reader.next(key, value)) {
+        mapper.map(key, value, secondOutput, reporter);
+      }
+
+      mapper.close();
+
+      time = System.currentTimeMillis() - time;
+      log.info("Duration : " + DFUtils.elapsedTime(time));
+
+      if (time > slowest) {
+        slowest = time;
+      }
+    }
+
+    log.info("Longest duration : " + DFUtils.elapsedTime(slowest));
+  }
+
+  /**
+   * Special Step1Mapper that can be configured without using a Configuration
+   * 
+   */
+  protected static class MockStep1Mapper extends Step1Mapper {
+    public MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed,
+        int partition, int numMapTasks, int numTrees) {
+      configure(false, true, treeBuilder, dataset);
+      configure(seed, partition, numMapTasks, numTrees);
+    }
+
+  }
+
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartitionBugTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartitionBugTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartitionBugTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/PartitionBugTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,169 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapred.partial;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.callback.PredictionCallback;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Instance;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.node.Node;
+
+public class PartitionBugTest extends TestCase {
+  int numAttributes = 40;
+
+  int numInstances = 200;
+
+  int numTrees = 10;
+
+  int numMaps = 5;
+
+  /**
+   * Make sure that the correct instance ids are being computed
+   * 
+   * @throws Exception
+   * 
+   */
+  public void testProcessOutput() throws Exception {
+    Random rng = new Random();
+    //long seed = rng.nextLong();
+    long seed = 1L;
+
+    // create a dataset large enough to be split up
+    String descriptor = Utils.randomDescriptor(rng, numAttributes);
+    double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);
+
+    // each instance label is its index in the dataset
+    int labelId = Utils.findLabel(descriptor);
+    for (int index = 0; index < numInstances; index++) {
+      source[index][labelId] = index;
+    }
+
+    // store the data into a file
+    String[] sData = Utils.double2String(source);
+    Path dataPath = Utils.writeDataToTestFile(sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Data data = DataLoader.loadData(dataset, sData);
+
+    JobConf jobConf = new JobConf();
+    jobConf.setNumMapTasks(numMaps);
+
+    // prepare a custom TreeBuilder that will classify each
+    // instance with its own label (in this case its index in the dataset)
+    TreeBuilder treeBuilder = new MockTreeBuilder();
+    
+    // disable the second step because we can test without it
+    // and we won't be able to serialize the MockNode
+    PartialBuilder.setStep2(jobConf, false);
+    PartialSequentialBuilder builder = new PartialSequentialBuilder(
+        treeBuilder, dataPath, dataset, seed, jobConf);
+
+    // remove the output path (its only used for testing)
+    Path outputPath = builder.getOutputPath(jobConf);
+    FileSystem fs = outputPath.getFileSystem(jobConf);
+    if (fs.exists(outputPath)) {
+      fs.delete(outputPath, true);
+    }
+    
+    builder.build(numTrees, new MockCallback(data));
+  }
+
+  /**
+   * Assets that the instanceId are correct
+   *
+   */
+  private static class MockCallback extends PredictionCallback {
+    private final Data data;
+
+    public MockCallback(Data data) {
+      this.data = data;
+    }
+
+    @Override
+    public void prediction(int treeId, int instanceId, int prediction) {
+      // because of the bagging, prediction can be -1
+      if (prediction == -1) {
+        return;
+      }
+
+      assertEquals(String.format("treeId: %d, InstanceId: %d, Prediction: %d",
+          treeId, instanceId, prediction), data.get(instanceId).label, prediction);
+    }
+
+  }
+
+  /**
+   * Custom Leaf node that returns for each instance its own label
+   * 
+   */
+  private static class MockLeaf extends Node {
+
+    @Override
+    public int classify(Instance instance) {
+      return instance.label;
+    }
+
+    @Override
+    protected String getString() {
+      return "[MockLeaf]";
+    }
+
+    @Override
+    public long maxDepth() {
+      // TODO Auto-generated method stub
+      return 0;
+    }
+
+    @Override
+    public long nbNodes() {
+      // TODO Auto-generated method stub
+      return 0;
+    }
+
+    @Override
+    protected void writeNode(DataOutput out) throws IOException {
+    }
+
+    public void readFields(DataInput in) throws IOException {
+    }
+
+    
+  }
+
+  private static class MockTreeBuilder extends TreeBuilder {
+
+    @Override
+    public Node build(Random rng, Data data) {
+      // TODO Auto-generated method stub
+      return new MockLeaf();
+    }
+
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step0JobTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step0JobTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step0JobTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step0JobTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,203 @@
+package org.apache.mahout.df.mapred.partial;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.RecordReader;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.TextInputFormat;
+import org.apache.mahout.df.data.DataConverter;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.mapred.Builder;
+import org.apache.mahout.df.mapred.partial.Step0Job.Step0Mapper;
+import org.apache.mahout.df.mapred.partial.Step0Job.Step0Output;
+
+public class Step0JobTest extends TestCase {
+
+  // the generated data must be big enough to be splited by FileInputFormat
+
+  int numAttributes = 40;
+
+  int numInstances = 200;
+
+  int numTrees = 10;
+
+  int numMaps = 5;
+
+  public void testStep0Mapper() throws Exception {
+    Random rng = new Random();
+
+    // create a dataset large enough to be split up
+    String descriptor = Utils.randomDescriptor(rng, numAttributes);
+    double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);
+    String[] sData = Utils.double2String(source);
+
+    // write the data to a file
+    Path dataPath = Utils.writeDataToTestFile(sData);
+    
+    JobConf job = new JobConf();
+    job.setNumMapTasks(numMaps);
+
+    TextInputFormat.setInputPaths(job, dataPath);
+
+    // retrieve the splits
+    TextInputFormat input = (TextInputFormat) job.getInputFormat();
+    InputSplit[] splits = input.getSplits(job, numMaps);
+
+    InputSplit[] sorted = Arrays.copyOf(splits, splits.length);
+    Builder.sortSplits(sorted);
+
+    Step0OutputCollector collector = new Step0OutputCollector(numMaps);
+    Reporter reporter = Reporter.NULL;
+
+    for (int p = 0; p < numMaps; p++) {
+      InputSplit split = sorted[p];
+      RecordReader<LongWritable, Text> reader = input.getRecordReader(split, job, reporter);
+
+      LongWritable key = reader.createKey();
+      Text value = reader.createValue();
+
+      Step0Mapper mapper = new Step0Mapper();
+      mapper.configure(p);
+
+      Long firstKey = null;
+      int size = 0;
+
+      while (reader.next(key, value)) {
+        if (firstKey == null) {
+          firstKey = key.get();
+        }
+
+        mapper.map(key, value, collector, reporter);
+
+        size++;
+      }
+
+      mapper.close();
+
+      // validate the mapper's output
+      assertEquals(p, collector.keys[p]);
+      assertEquals(firstKey.longValue(), collector.values[p].firstId);
+      assertEquals(size, collector.values[p].size);
+    }
+
+  }
+
+  public void testProcessOutput() throws Exception {
+    Random rng = new Random();
+
+    // create a dataset large enough to be split up
+    String descriptor = Utils.randomDescriptor(rng, numAttributes);
+    double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);
+
+    // each instance label is its index in the dataset
+    int labelId = Utils.findLabel(descriptor);
+    for (int index = 0; index < numInstances; index++) {
+      source[index][labelId] = index;
+    }
+
+    String[] sData = Utils.double2String(source);
+
+    // write the data to a file
+    Path dataPath = Utils.writeDataToTestFile(sData);
+    
+    // prepare a data converter
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    DataConverter converter = new DataConverter(dataset);
+    
+    JobConf job = new JobConf();
+    job.setNumMapTasks(numMaps);
+    TextInputFormat.setInputPaths(job, dataPath);
+
+    // retrieve the splits
+    TextInputFormat input = (TextInputFormat) job.getInputFormat();
+    InputSplit[] splits = input.getSplits(job, numMaps);
+
+    InputSplit[] sorted = Arrays.copyOf(splits, splits.length);
+    Builder.sortSplits(sorted);
+
+    Reporter reporter = Reporter.NULL;
+
+    int[] keys = new int[numMaps];
+    Step0Output[] values = new Step0Output[numMaps];
+    
+    int[] expectedIds = new int[numMaps];
+    
+    for (int p = 0; p < numMaps; p++) {
+      InputSplit split = sorted[p];
+      RecordReader<LongWritable, Text> reader = input.getRecordReader(split, job, reporter);
+
+      LongWritable key = reader.createKey();
+      Text value = reader.createValue();
+
+      Long firstKey = null;
+      int size = 0;
+      
+      while (reader.next(key, value)) {
+        if (firstKey == null) {
+          firstKey = key.get();
+          expectedIds[p] = converter.convert(0, value.toString()).label;
+        }
+
+        size++;
+      }
+      
+      keys[p] = p;
+      values[p] = new Step0Output(firstKey, size);
+    }
+
+    Step0Output[] partitions = Step0Job.processOutput(keys, values);
+    
+    int[] actualIds = Step0Output.extractFirstIds(partitions);
+    
+    assertTrue("Expected: " + Arrays.toString(expectedIds) + " But was: "
+        + Arrays.toString(actualIds), Arrays.equals(expectedIds, actualIds));
+  }
+
+  protected static class Step0OutputCollector implements
+      OutputCollector<IntWritable, Step0Output> {
+
+    public final int[] keys;
+
+    public final Step0Output[] values;
+
+    private int index = 0;
+
+    public Step0OutputCollector(int numMaps) {
+      keys = new int[numMaps];
+      values = new Step0Output[numMaps];
+    }
+
+    public void collect(IntWritable key, Step0Output value) throws IOException {
+      if (index == keys.length) {
+        throw new IOException("Received more output than expected : " + index);
+      }
+
+      keys[index] = key.get();
+      values[index] = value.clone();
+
+      index++;
+    }
+
+    /**
+     * Number of outputs collected
+     * 
+     * @return
+     */
+    public int nbOutputs() {
+      return index;
+    }
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step1MapperTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step1MapperTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step1MapperTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step1MapperTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,131 @@
+package org.apache.mahout.df.mapred.partial;
+
+import static org.apache.mahout.df.data.Utils.double2String;
+import static org.apache.mahout.df.data.Utils.randomDescriptor;
+import static org.apache.mahout.df.data.Utils.randomDoubles;
+
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.mapred.partial.Step1Mapper;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class Step1MapperTest extends TestCase {
+
+  /**
+   * Make sure that the data used to build the trees is from the mapper's
+   * partition
+   * 
+   */
+  private static class MockTreeBuilder extends TreeBuilder {
+
+    protected Data expected;
+
+    public void setExpected(Data data) {
+      expected = data;
+    }
+
+    @Override
+    public Node build(Random rng, Data data) {
+      for (int index = 0; index < data.size(); index++) {
+        assertTrue(expected.contains(data.get(index)));
+      }
+
+      return new Leaf(-1);
+    }
+  }
+
+  /**
+   * Special Step1Mapper that can be configured without using a Configuration
+   * 
+   */
+  protected static class MockStep1Mapper extends Step1Mapper {
+    public MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed,
+        int partition, int numMapTasks, int numTrees) {
+      configure(false, true, treeBuilder, dataset);
+      configure(seed, partition, numMapTasks, numTrees);
+    }
+
+    public int getFirstTreeId() {
+      return firstTreeId;
+    }
+
+  }
+
+  /** nb attributes per generated data instance */
+  protected final int nbAttributes = 4;
+
+  /** nb generated data instances */
+  protected final int nbInstances = 100;
+
+  /** nb trees to build */
+  protected final int nbTrees = 10;
+
+  /** nb mappers to use */
+  protected final int nbMappers = 2;
+
+  public void testMapper() throws Exception {
+    Long seed = null;
+    Random rng = new Random();
+
+    // prepare the data
+    String descriptor = randomDescriptor(rng, nbAttributes);
+    double[][] source = randomDoubles(rng, descriptor, nbInstances);
+    String[] sData = double2String(source);
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    String[][] splits = Utils.splitData(sData, nbMappers);
+
+    MockTreeBuilder treeBuilder = new MockTreeBuilder();
+
+    LongWritable key = new LongWritable();
+    Text value = new Text();
+
+    int treeIndex = 0;
+    
+    for (int partition = 0; partition < nbMappers; partition++) {
+      String[] split = splits[partition];
+      treeBuilder.setExpected(DataLoader.loadData(dataset, split));
+
+      // expected number of trees that this mapper will build
+      int mapNbTrees = Step1Mapper.nbTrees(nbMappers, nbTrees, partition);
+
+      PartialOutputCollector output = new PartialOutputCollector(mapNbTrees);
+
+      MockStep1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, seed,
+          partition, nbMappers, nbTrees);
+
+      // make sure the mapper computed firstTreeId correctly
+      assertEquals(treeIndex, mapper.getFirstTreeId());
+
+      for (int index = 0; index < split.length; index++) {
+        key.set(index);
+        value.set(split[index]);
+        mapper.map(key, value, output, Reporter.NULL);
+      }
+
+      mapper.close();
+
+      // make sure the mapper built all its trees
+      assertEquals(mapNbTrees, output.nbOutputs());
+
+      // check the returned keys
+      for (TreeID k : output.keys) {
+        assertEquals(partition, k.partition());
+        assertEquals(treeIndex, k.treeId());
+
+        treeIndex++;
+      }
+    }
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step2MapperTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step2MapperTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step2MapperTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/partial/Step2MapperTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapred.partial;
+
+import static org.apache.mahout.df.data.Utils.double2String;
+import static org.apache.mahout.df.data.Utils.randomDescriptor;
+import static org.apache.mahout.df.data.Utils.randomDoubles;
+
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.mapreduce.partial.InterResults;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class Step2MapperTest extends TestCase {
+
+  /**
+   * Special Step2Mapper that can be configured without using a Configuration
+   * 
+   */
+  private static class MockStep2Mapper extends Step2Mapper {
+    public MockStep2Mapper(int partition, Dataset dataset, TreeID[] keys,
+        Node[] trees, int numInstances) {
+      configure(partition, dataset, keys, trees, numInstances);
+    }
+
+  }
+
+  /** nb attributes per generated data instance */
+  protected final int nbAttributes = 4;
+
+  /** nb generated data instances */
+  protected final int nbInstances = 100;
+
+  /** nb trees to build */
+  protected final int nbTrees = 11;
+
+  /** nb mappers to use */
+  protected final int nbMappers = 5;
+
+  public void testMapper() throws Exception {
+    Random rng = new Random();
+
+    // prepare the data
+    String descriptor = randomDescriptor(rng, nbAttributes);
+    double[][] source = randomDoubles(rng, descriptor, nbInstances);
+    String[] sData = double2String(source);
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    String[][] splits = Utils.splitData(sData, nbMappers);
+
+    // prepare first step output
+    TreeID[] keys = new TreeID[nbTrees];
+    Node[] trees = new Node[nbTrees];
+    int[] sizes = new int[nbMappers];
+    
+    int treeIndex = 0;
+    for (int partition = 0; partition < nbMappers; partition++) {
+      int nbMapTrees = Step1Mapper.nbTrees(nbMappers, nbTrees, partition);
+
+      for (int tree = 0; tree < nbMapTrees; tree++, treeIndex++) {
+        keys[treeIndex] = new TreeID(partition, treeIndex);
+        // put the partition in the leaf's label
+        // this way we can track the outputs
+        trees[treeIndex] = new Leaf(partition);
+      }
+      
+      sizes[partition] = splits[partition].length;
+    }
+
+    // store the first step outputs in a file
+    FileSystem fs = FileSystem.getLocal(new Configuration());
+    Path forestPath = new Path("testdata/Step2MapperTest.forest");
+    InterResults.store(fs, forestPath, keys, trees, sizes);
+
+    LongWritable key = new LongWritable();
+    Text value = new Text();
+
+    for (int partition = 0; partition < nbMappers; partition++) {
+      String[] split = splits[partition];
+
+      // number of trees that will be handled by the mapper
+      int nbConcerned = Step2Mapper.nbConcerned(nbMappers, nbTrees, partition);
+
+      PartialOutputCollector output = new PartialOutputCollector(nbConcerned);
+
+      // load the current mapper's (key, tree) pairs
+      TreeID[] curKeys = new TreeID[nbConcerned];
+      Node[] curTrees = new Node[nbConcerned];
+      InterResults.load(fs, forestPath, nbMappers, nbTrees, partition, curKeys, curTrees);
+
+      // simulate the job
+      MockStep2Mapper mapper = new MockStep2Mapper(partition, dataset, curKeys, curTrees, split.length);
+
+      for (int index = 0; index < split.length; index++) {
+        key.set(index);
+        value.set(split[index]);
+        mapper.map(key, value, output, Reporter.NULL);
+      }
+
+      mapper.close();
+
+      // make sure the mapper did not return its own trees
+      assertEquals(nbConcerned, output.nbOutputs());
+
+      // check the returned results
+      int current = 0;
+      for (int index = 0; index < nbTrees; index++) {
+        if (keys[index].partition() == partition) {
+          // should not be part of the results
+          continue;
+        }
+
+        TreeID k = output.keys[current];
+
+        // the tree should receive the partition's index
+        assertEquals(partition, k.partition());
+
+        // make sure all the trees of the other partitions are handled in the
+        // correct order
+        assertEquals(index, k.treeId());
+
+        int[] predictions = output.values[current].getPredictions();
+
+        // all the instances of the partition should be classified
+        assertEquals(split.length, predictions.length);
+        assertEquals(
+            "at least one instance of the partition was not classified", -1,
+            ArrayUtils.indexOf(predictions, -1));
+
+        // the tree must not belong to the mapper's partition
+        int treePartition = predictions[0];
+        assertFalse("Step2Mapper returned a tree from its own partition",
+            partition == treePartition);
+
+        current++;
+      }
+    }
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/InterResultsTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/InterResultsTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/InterResultsTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/InterResultsTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,159 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.df.mapreduce.partial;
+
+import static org.apache.mahout.df.data.Utils.double2String;
+import static org.apache.mahout.df.data.Utils.randomDoubles;
+
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class InterResultsTest extends TestCase {
+
+  /** nb attributes per generated data instance */
+  protected final int nbAttributes = 4;
+
+  /** nb generated data instances */
+  protected final int nbInstances = 100;
+
+  /** nb trees to build */
+  protected final int nbTrees = 11;
+
+  /** nb mappers to use */
+  protected final int nbMappers = 5;
+
+  protected String[][] splits;
+
+  TreeID[] keys;
+
+  Node[] trees;
+  
+  int[] sizes;
+
+  protected void setUp() throws Exception {
+    Random rng = new Random();
+
+    // prepare the data
+    double[][] source = randomDoubles(rng, nbAttributes, nbInstances);
+    String[] sData = double2String(source);
+
+    splits = Utils.splitData(sData, nbMappers);
+
+    sizes = new int[nbMappers];
+    for (int p = 0; p < nbMappers; p++) {
+      sizes[p] = splits[p].length;
+    }
+
+    // prepare first step output
+    keys = new TreeID[nbTrees];
+    trees = new Node[nbTrees];
+    
+    int treeIndex = 0;
+    for (int partition = 0; partition < nbMappers; partition++) {
+      int nbMapTrees = Step1Mapper.nbTrees(nbMappers, nbTrees, partition);
+
+      for (int index = 0; index < nbMapTrees; index++, treeIndex++) {
+        keys[treeIndex] = new TreeID(partition, treeIndex);
+
+        // put the tree index in the leaf's label
+        // this way we can check the stored data
+        trees[treeIndex] = new Leaf(treeIndex);
+      }
+    }
+  }
+
+  public void testLoad() throws Exception {
+    // store the intermediate results
+    Path forestPath = new Path("testdata/InterResultsTest/test.forest");
+    FileSystem fs = forestPath.getFileSystem(new Configuration());
+
+    InterResults.store(fs, forestPath, keys, trees, sizes);
+
+    for (int partition = 0; partition < nbMappers; partition++) {
+      int nbConcerned = Step2Mapper.nbConcerned(nbMappers, nbTrees, partition);
+
+      TreeID[] newKeys = new TreeID[nbConcerned];
+      Node[] newValues = new Node[nbConcerned];
+
+      int numInstances = InterResults.load(fs, forestPath, nbMappers,
+          nbTrees, partition, newKeys, newValues);
+
+      // verify the partition's size
+      assertEquals(splits[partition].length, numInstances);
+
+      // verify (key, tree)
+      int current = 0;
+      for (int index = 0; index < nbTrees; index++) {
+        // the trees of the current partition should not be loaded
+        if (current < nbConcerned) {
+          assertFalse("A tree from the current partition has been loaded",
+              newKeys[current].partition() == partition);
+        }
+        if (keys[index].partition() == partition) {
+          continue;
+        }
+
+        assertEquals("index: " + index, keys[index], newKeys[current]);
+        assertEquals("index: " + index, trees[index], newValues[current]);
+
+        current++;
+      }
+    }
+  }
+
+  public void testStore() throws Exception {
+    // store the intermediate results
+    Path forestPath = new Path("testdata/InterResultsTest/test.forest");
+    FileSystem fs = forestPath.getFileSystem(new Configuration());
+    
+    InterResults.store(fs, forestPath, keys, trees, sizes);
+
+    // load the file and check the stored values
+
+    FSDataInputStream in = fs.open(forestPath);
+
+    try {
+      // partitions' sizes
+      for (int p = 0; p < nbMappers; p++) {
+        assertEquals(splits[p].length, in.readInt());
+      }
+
+      // load (key, tree)
+      TreeID key = new TreeID();
+      Node value;
+      for (int index = 0; index < nbTrees; index++) {
+        key.readFields(in);
+        value = Node.read(in);
+
+        assertEquals("index: " + index, keys[index], key);
+        assertEquals("index: " + index, trees[index], value);
+      }
+    } finally {
+      in.close();
+    }
+  }
+
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/MockContext.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/MockContext.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/MockContext.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/MockContext.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.df.mapreduce.partial;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+import org.apache.hadoop.mapreduce.Mapper.Context;
+import org.apache.mahout.df.mapreduce.MapredOutput;
+
+/**
+ * Special implementation that collects the output of the mappers
+ */
+public class MockContext extends Context {
+
+  public final TreeID[] keys;
+
+  public final MapredOutput[] values;
+
+  private int index = 0;
+
+  @SuppressWarnings("unchecked")
+  public MockContext(Mapper mapper, Configuration conf, TaskAttemptID taskid,
+      int nbTrees) throws IOException, InterruptedException {
+    mapper.super(conf, taskid, null, null, null, null, null);
+
+    keys = new TreeID[nbTrees];
+    values = new MapredOutput[nbTrees];
+  }
+
+  @Override
+  public void write(Object key, Object value) throws IOException,
+      InterruptedException {
+    if (index == keys.length) {
+      throw new IOException("Received more output than expected : " + index);
+    }
+
+    keys[index] = ((TreeID) key).clone();
+    values[index] = ((MapredOutput) value).clone();
+
+    index++;
+  }
+
+  /**
+   * Number of outputs collected
+   * 
+   * @return
+   */
+  public int nbOutputs() {
+    return index;
+  }
+
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,229 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce.partial;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.callback.PredictionCallback;
+import org.apache.mahout.df.mapreduce.MapredOutput;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class PartialBuilderTest extends TestCase {
+
+  protected static final int numMaps = 5;
+
+  protected static final int numTrees = 32;
+
+  /** instances per partition */
+  protected static final int numInstances = 20;
+
+  public void testProcessOutput() throws Exception {
+    Configuration conf = new Configuration();
+    conf.setInt("mapred.map.tasks", numMaps);
+
+    Random rng = new Random();
+
+    // prepare the output
+    TreeID[] keys = new TreeID[numTrees];
+    MapredOutput[] values = new MapredOutput[numTrees];
+    int[] firstIds = new int[numMaps];
+    randomKeyValues(rng, keys, values, firstIds);
+
+    // store the output in a sequence file
+    Path base = new Path("testdata");
+    FileSystem fs = base.getFileSystem(conf);
+    if (fs.exists(base))
+      fs.delete(base, true);
+
+    Path outputFile = new Path(base, "PartialBuilderTest.seq");
+    Writer writer = SequenceFile.createWriter(fs, conf, outputFile,
+        TreeID.class, MapredOutput.class);
+
+    try {
+      for (int index = 0; index < numTrees; index++) {
+        writer.append(keys[index], values[index]);
+      }
+    } finally {
+    }
+    writer.close();
+
+    // load the output and make sure its valid
+    TreeID[] newKeys = new TreeID[numTrees];
+    Node[] newTrees = new Node[numTrees];
+    
+    PartialBuilder.processOutput(new Job(conf), base, firstIds, newKeys, newTrees, 
+        new TestCallback(keys, values));
+
+    // check the forest
+    for (int tree = 0; tree < numTrees; tree++) {
+      assertEquals(values[tree].getTree(), newTrees[tree]);
+    }
+
+    assertTrue("keys not equal", Arrays.deepEquals(keys, newKeys));
+  }
+
+  /**
+   * Make sure that the builder passes the good parameters to the job
+   * 
+   */
+  public void testConfigure() {
+    TreeBuilder treeBuilder = new DefaultTreeBuilder();
+    Path dataPath = new Path("notUsedDataPath");
+    Path datasetPath = new Path("notUsedDatasetPath");
+    Long seed = 5L;
+
+    new PartialBuilderChecker(treeBuilder, dataPath, datasetPath, seed);
+  }
+
+  /**
+   * Generates random (key, value) pairs. Shuffles the partition's order
+   * 
+   * @param rng
+   * @param keys
+   * @param values
+   * @param firstIds partitions's first ids in hadoop's order
+   */
+  protected void randomKeyValues(Random rng, TreeID[] keys,
+      MapredOutput[] values, int[] firstIds) {
+    int index = 0;
+    int firstId = 0;
+    List<Integer> partitions = new ArrayList<Integer>();
+    int partition;
+
+    for (int p = 0; p < numMaps; p++) {
+      // select a random partition, not yet selected
+      do {
+        partition = rng.nextInt(numMaps);
+      } while (partitions.contains(partition));
+
+      partitions.add(partition);
+
+      int nbTrees = Step1Mapper.nbTrees(numMaps, numTrees, partition);
+
+      for (int treeId = 0; treeId < nbTrees; treeId++) {
+        Node tree = new Leaf(rng.nextInt(100));
+
+        keys[index] = new TreeID(partition, treeId);
+        values[index] = new MapredOutput(tree, nextIntArray(rng, numInstances));
+
+        index++;
+      }
+      
+      firstIds[p] = firstId;
+      firstId += numInstances;
+    }
+
+  }
+
+  protected int[] nextIntArray(Random rng, int size) {
+    int[] array = new int[size];
+    for (int index = 0; index < size; index++) {
+      array[index] = rng.nextInt(101) - 1;
+    }
+
+    return array;
+  }
+
+  protected static class PartialBuilderChecker extends PartialBuilder {
+
+    protected Long _seed;
+
+    protected TreeBuilder _treeBuilder;
+
+    protected Path _datasetPath;
+
+    public PartialBuilderChecker(TreeBuilder treeBuilder, Path dataPath,
+        Path datasetPath, Long seed) {
+      super(treeBuilder, dataPath, datasetPath, seed);
+
+      _seed = seed;
+      _treeBuilder = treeBuilder;
+      _datasetPath = datasetPath;
+    }
+
+    @Override
+    protected boolean runJob(Job job) throws Exception {
+      // no need to run the job, just check if the params are correct
+
+      Configuration conf = job.getConfiguration();
+      
+      assertEquals(_seed, getRandomSeed(conf));
+
+      // PartialBuilder should detect the 'local' mode and overrides the number
+      // of map tasks
+      assertEquals(1, conf.getInt("mapred.map.tasks", -1));
+
+      assertEquals(numTrees, getNbTrees(conf));
+
+      assertFalse(isOutput(conf));
+      assertTrue(isOobEstimate(conf));
+
+      assertEquals(_treeBuilder, getTreeBuilder(conf));
+
+      assertEquals(_datasetPath, getDistributedCacheFile(conf, 0));
+      
+      return true;
+    }
+
+  }
+
+  /**
+   * Mock Callback. Make sure that the callback receives the correct predictions
+   * 
+   */
+  protected static class TestCallback extends PredictionCallback {
+
+    protected final TreeID[] keys;
+
+    protected final MapredOutput[] values;
+
+    public TestCallback(TreeID[] keys, MapredOutput[] values) {
+      this.keys = keys;
+      this.values = values;
+    }
+
+    @Override
+    public void prediction(int treeId, int instanceId, int prediction) {
+      int partition = instanceId / numInstances;
+
+      TreeID key = new TreeID(partition, treeId);
+      int index = ArrayUtils.indexOf(keys, key);
+      assertTrue("key not found", index >= 0);
+
+      assertEquals(values[index].getPredictions()[instanceId % numInstances],
+          prediction);
+    }
+
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,292 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce.partial;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.JobID;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.callback.PredictionCallback;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.mapreduce.Builder;
+import org.apache.mahout.df.mapreduce.MapredOutput;
+import org.apache.mahout.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Simulates the Partial mapreduce implementation in a sequential manner. Must
+ * receive a seed
+ */
+public class PartialSequentialBuilder extends PartialBuilder {
+
+  private static final Logger log = LoggerFactory.getLogger(PartialSequentialBuilder.class);
+
+  protected MockContext firstOutput;
+
+  protected MockContext secondOutput;
+
+  protected final Dataset dataset;
+
+  /** first instance id in hadoop's order */
+  protected int[] firstIds;
+  
+  /** partitions' sizes in hadoop order */
+  protected int[] sizes;
+
+  public PartialSequentialBuilder(TreeBuilder treeBuilder, Path dataPath,
+      Dataset dataset, long seed, Configuration conf) {
+    super(treeBuilder, dataPath, new Path("notUsed"), seed, conf);
+    this.dataset = dataset;
+  }
+
+  public PartialSequentialBuilder(TreeBuilder treeBuilder, Path dataPath,
+      Dataset dataset, long seed) {
+    this(treeBuilder, dataPath, dataset, seed, new Configuration());
+  }
+
+  @Override
+  protected void configureJob(Job job, int nbTrees, boolean oobEstimate)
+      throws IOException {
+    Configuration conf = job.getConfiguration();
+    
+    int num = conf.getInt("mapred.map.tasks", -1);
+
+    super.configureJob(job, nbTrees, oobEstimate);
+
+    // PartialBuilder sets the number of maps to 1 if we are running in 'local'
+    conf.setInt("mapred.map.tasks", num);
+  }
+
+  @SuppressWarnings("unchecked")
+  @Override
+  protected boolean runJob(Job job) throws Exception {
+    Configuration conf = job.getConfiguration();
+    
+    // retrieve the splits
+    TextInputFormat input = new TextInputFormat();
+    List<InputSplit> splits = input.getSplits(job);
+    
+    int nbSplits = splits.size();
+    log.debug("Nb splits : " + nbSplits);
+
+    InputSplit[] sorted = new InputSplit[nbSplits];
+    splits.toArray(sorted);
+    Builder.sortSplits(sorted);
+
+    int numTrees = Builder.getNbTrees(conf); // total number of trees
+
+    TaskAttemptContext task = new TaskAttemptContext(conf, new TaskAttemptID());
+
+    firstOutput = new MockContext(new Step1Mapper(), conf, task.getTaskAttemptID(), numTrees);
+    long slowest = 0; // duration of slowest map
+
+    int firstId = 0;
+    firstIds = new int[nbSplits];
+    sizes = new int[nbSplits];
+    
+    // to compute firstIds, process the splits in file order
+    for (int p = 0; p < nbSplits; p++) {
+      InputSplit split = splits.get(p);
+      int hp = ArrayUtils.indexOf(sorted, split); // hadoop's partition
+      
+      RecordReader<LongWritable, Text> reader = input.createRecordReader(split, task);
+      reader.initialize(split, task);
+      
+      Step1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, seed,
+          hp, nbSplits, numTrees);
+
+      long time = System.currentTimeMillis();
+
+      firstIds[hp] = firstId;
+
+      while (reader.nextKeyValue()) {
+        mapper.map(reader.getCurrentKey(), reader.getCurrentValue(), firstOutput);
+        firstId++;
+        sizes[hp]++;
+      }
+
+      mapper.cleanup(firstOutput);
+
+      time = System.currentTimeMillis() - time;
+      log.info("Duration : " + DFUtils.elapsedTime(time));
+
+      if (time > slowest) {
+        slowest = time;
+      }
+    }
+
+    log.info("Longest duration : " + DFUtils.elapsedTime(slowest));
+    return true;
+  }
+
+  @Override
+  protected DecisionForest parseOutput(Job job, PredictionCallback callback)
+      throws Exception {
+    Configuration conf = job.getConfiguration();
+    
+    DecisionForest forest = processOutput(firstOutput.keys, firstOutput.values, callback);
+
+    if (isStep2(conf)) {
+      Path forestPath = new Path(getOutputPath(conf), "step1.inter");
+      FileSystem fs = forestPath.getFileSystem(conf);
+      
+      Node[] trees = new Node[forest.getTrees().size()];
+      forest.getTrees().toArray(trees);
+      InterResults.store(fs, forestPath, firstOutput.keys, trees, sizes);
+
+      log.info("***********");
+      log.info("Second Step");
+      log.info("***********");
+      secondStep(conf, forestPath, callback);
+
+      processOutput(secondOutput.keys, secondOutput.values, callback);
+    }
+
+    return forest;
+  }
+
+  /**
+   * extract the decision forest and call the callback after correcting the instance ids
+   * 
+   * @param keys
+   * @param values
+   * @param callback
+   * @return
+   */
+  protected DecisionForest processOutput(TreeID[] keys, MapredOutput[] values, PredictionCallback callback) {
+    List<Node> trees = new ArrayList<Node>();
+
+    for (int index = 0; index < keys.length; index++) {
+      TreeID key = keys[index];
+      MapredOutput value = values[index];
+
+      trees.add(value.getTree());
+
+      int[] predictions = value.getPredictions();
+      for (int id = 0; id < predictions.length; id++) {
+        callback.prediction(key.treeId(), firstIds[key.partition()] + id,
+            predictions[id]);
+      }
+    }
+    
+    return new DecisionForest(trees);
+  }
+
+  /**
+   * The second step uses the trees to predict the rest of the instances outside
+   * their own partition
+   * 
+   * @throws Exception
+   * 
+   */
+  @SuppressWarnings("unchecked")
+  protected void secondStep(Configuration conf, Path forestPath, PredictionCallback callback) throws Exception {
+    JobContext jobContext = new JobContext(conf, new JobID());
+    
+    // retrieve the splits
+    TextInputFormat input = new TextInputFormat();
+    List<InputSplit> splits = input.getSplits(jobContext);
+    
+    int nbSplits = splits.size();
+    log.debug("Nb splits : " + nbSplits);
+
+    InputSplit[] sorted = new InputSplit[nbSplits];
+    splits.toArray(sorted);
+    Builder.sortSplits(sorted);
+
+    int numTrees = Builder.getNbTrees(conf); // total number of trees
+
+    // compute the expected number of outputs
+    int total = 0;
+    for (int p = 0; p < nbSplits; p++) {
+      total += Step2Mapper.nbConcerned(nbSplits, numTrees, p);
+    }
+
+    TaskAttemptContext task = new TaskAttemptContext(conf, new TaskAttemptID());
+
+    secondOutput = new MockContext(new Step2Mapper(), conf, task.getTaskAttemptID(), numTrees);
+    long slowest = 0; // duration of slowest map
+
+    for (int partition = 0; partition < nbSplits; partition++) {
+      
+      InputSplit split = sorted[partition];
+      RecordReader<LongWritable, Text> reader = input.createRecordReader(split, task);
+
+      // load the output of the 1st step
+      int nbConcerned = Step2Mapper.nbConcerned(nbSplits, numTrees, partition);
+      TreeID[] fsKeys = new TreeID[nbConcerned];
+      Node[] fsTrees = new Node[nbConcerned];
+
+      FileSystem fs = forestPath.getFileSystem(conf);
+      int numInstances = InterResults.load(fs, forestPath, nbSplits,
+          numTrees, partition, fsKeys, fsTrees);
+
+      Step2Mapper mapper = new Step2Mapper();
+      mapper.configure(partition, dataset, fsKeys, fsTrees, numInstances);
+
+      long time = System.currentTimeMillis();
+
+      while (reader.nextKeyValue()) {
+        mapper.map(reader.getCurrentKey(), reader.getCurrentValue(), secondOutput);
+      }
+
+      mapper.cleanup(secondOutput);
+
+      time = System.currentTimeMillis() - time;
+      log.info("Duration : " + DFUtils.elapsedTime(time));
+
+      if (time > slowest) {
+        slowest = time;
+      }
+    }
+
+    log.info("Longest duration : " + DFUtils.elapsedTime(slowest));
+  }
+
+  /**
+   * Special Step1Mapper that can be configured without using a Configuration
+   * 
+   */
+  protected static class MockStep1Mapper extends Step1Mapper {
+    public MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed,
+        int partition, int numMapTasks, int numTrees) {
+      configure(false, true, treeBuilder, dataset);
+      configure(seed, partition, numMapTasks, numTrees);
+    }
+
+  }
+
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartitionBugTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartitionBugTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartitionBugTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartitionBugTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,169 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce.partial;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.callback.PredictionCallback;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Instance;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.node.Node;
+
+public class PartitionBugTest extends TestCase {
+  int numAttributes = 40;
+
+  int numInstances = 200;
+
+  int numTrees = 10;
+
+  int numMaps = 5;
+
+  /**
+   * Make sure that the correct instance ids are being computed
+   * 
+   * @throws Exception
+   * 
+   */
+  public void testProcessOutput() throws Exception {
+    Random rng = new Random();
+    //long seed = rng.nextLong();
+    long seed = 1L;
+
+    // create a dataset large enough to be split up
+    String descriptor = Utils.randomDescriptor(rng, numAttributes);
+    double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);
+
+    // each instance label is its index in the dataset
+    int labelId = Utils.findLabel(descriptor);
+    for (int index = 0; index < numInstances; index++) {
+      source[index][labelId] = index;
+    }
+
+    // store the data into a file
+    String[] sData = Utils.double2String(source);
+    Path dataPath = Utils.writeDataToTestFile(sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Data data = DataLoader.loadData(dataset, sData);
+
+    Configuration conf = new Configuration();
+    Step0JobTest.setMaxSplitSize(conf, dataPath, numMaps);
+
+    // prepare a custom TreeBuilder that will classify each
+    // instance with its own label (in this case its index in the dataset)
+    TreeBuilder treeBuilder = new MockTreeBuilder();
+    
+    // disable the second step because we can test without it
+    // and we won't be able to serialize the MockNode
+    PartialBuilder.setStep2(conf, false);
+    PartialSequentialBuilder builder = new PartialSequentialBuilder(
+        treeBuilder, dataPath, dataset, seed, conf);
+
+    // remove the output path (its only used for testing)
+    Path outputPath = builder.getOutputPath(conf);
+    FileSystem fs = outputPath.getFileSystem(conf);
+    if (fs.exists(outputPath)) {
+      fs.delete(outputPath, true);
+    }
+    
+    builder.build(numTrees, new MockCallback(data));
+  }
+
+  /**
+   * Assets that the instanceId are correct
+   *
+   */
+  private static class MockCallback extends PredictionCallback {
+    private final Data data;
+
+    public MockCallback(Data data) {
+      this.data = data;
+    }
+
+    @Override
+    public void prediction(int treeId, int instanceId, int prediction) {
+      // because of the bagging, prediction can be -1
+      if (prediction == -1) {
+        return;
+      }
+
+      assertEquals(String.format("treeId: %d, InstanceId: %d, Prediction: %d",
+          treeId, instanceId, prediction), data.get(instanceId).label, prediction);
+    }
+
+  }
+
+  /**
+   * Custom Leaf node that returns for each instance its own label
+   * 
+   */
+  private static class MockLeaf extends Node {
+
+    @Override
+    public int classify(Instance instance) {
+      return instance.label;
+    }
+
+    @Override
+    protected String getString() {
+      return "[MockLeaf]";
+    }
+
+    @Override
+    public long maxDepth() {
+      // TODO Auto-generated method stub
+      return 0;
+    }
+
+    @Override
+    public long nbNodes() {
+      // TODO Auto-generated method stub
+      return 0;
+    }
+
+    @Override
+    protected void writeNode(DataOutput out) throws IOException {
+    }
+
+    public void readFields(DataInput in) throws IOException {
+    }
+
+    
+  }
+
+  private static class MockTreeBuilder extends TreeBuilder {
+
+    @Override
+    public Node build(Random rng, Data data) {
+      // TODO Auto-generated method stub
+      return new MockLeaf();
+    }
+
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step0JobTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step0JobTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step0JobTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step0JobTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,248 @@
+package org.apache.mahout.df.mapreduce.partial;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+import org.apache.hadoop.mapreduce.Mapper.Context;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.mahout.df.data.DataConverter;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.mapreduce.Builder;
+import org.apache.mahout.df.mapreduce.partial.Step0Job.Step0Mapper;
+import org.apache.mahout.df.mapreduce.partial.Step0Job.Step0Output;
+
+public class Step0JobTest extends TestCase {
+
+  // the generated data must be big enough to be splited by FileInputFormat
+
+  int numAttributes = 40;
+
+  int numInstances = 2000;
+
+  int numTrees = 10;
+
+  int numMaps = 5;
+
+  Step0Context context;
+
+  /**
+   * Computes the "mapred.max.split.size" that will generate the desired number
+   * of input splits
+   * 
+   * @param conf
+   * @param inputPath
+   * @param numMaps desired number of input splits
+   * @throws Exception
+   */
+  public static void setMaxSplitSize(Configuration conf, Path inputPath,
+      int numMaps) throws Exception {
+    FileSystem fs = inputPath.getFileSystem(conf);
+    FileStatus status = fs.getFileStatus(inputPath);
+    long goalSize = status.getLen() / numMaps;
+    conf.setLong("mapred.max.split.size", goalSize);
+  }
+
+  @SuppressWarnings("unchecked")
+  public void testStep0Mapper() throws Exception {
+    Random rng = new Random();
+
+    // create a dataset large enough to be split up
+    String descriptor = Utils.randomDescriptor(rng, numAttributes);
+    double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);
+    String[] sData = Utils.double2String(source);
+
+    // write the data to a file
+    Path dataPath = Utils.writeDataToTestFile(sData);
+
+    Job job = new Job();
+    job.setInputFormatClass(TextInputFormat.class);
+    TextInputFormat.setInputPaths(job, dataPath);
+
+    setMaxSplitSize(job.getConfiguration(), dataPath, numMaps);
+
+    // retrieve the splits
+    TextInputFormat input = new TextInputFormat();
+    List<InputSplit> splits = input.getSplits(job);
+    assertEquals(numMaps, splits.size());
+
+    InputSplit[] sorted = new InputSplit[numMaps];
+    splits.toArray(sorted);
+    Builder.sortSplits(sorted);
+
+    context = new Step0Context(new Step0Mapper(), job.getConfiguration(),
+        new TaskAttemptID(), numMaps);
+
+    for (int p = 0; p < numMaps; p++) {
+      InputSplit split = sorted[p];
+
+      RecordReader<LongWritable, Text> reader = input.createRecordReader(split,
+          context);
+      reader.initialize(split, context);
+
+      Step0Mapper mapper = new Step0Mapper();
+      mapper.configure(p);
+
+      Long firstKey = null;
+      int size = 0;
+
+      while (reader.nextKeyValue()) {
+        LongWritable key = reader.getCurrentKey();
+
+        if (firstKey == null) {
+          firstKey = key.get();
+        }
+
+        mapper.map(key, reader.getCurrentValue(), context);
+
+        size++;
+      }
+
+      mapper.cleanup(context);
+
+      // validate the mapper's output
+      assertEquals(p, context.keys[p]);
+      assertEquals(firstKey.longValue(), context.values[p].firstId);
+      assertEquals(size, context.values[p].size);
+    }
+
+  }
+
+  public void testProcessOutput() throws Exception {
+    Random rng = new Random();
+
+    // create a dataset large enough to be split up
+    String descriptor = Utils.randomDescriptor(rng, numAttributes);
+    double[][] source = Utils.randomDoubles(rng, descriptor, numInstances);
+
+    // each instance label is its index in the dataset
+    int labelId = Utils.findLabel(descriptor);
+    for (int index = 0; index < numInstances; index++) {
+      source[index][labelId] = index;
+    }
+
+    String[] sData = Utils.double2String(source);
+
+    // write the data to a file
+    Path dataPath = Utils.writeDataToTestFile(sData);
+
+    // prepare a data converter
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    DataConverter converter = new DataConverter(dataset);
+
+    Job job = new Job();
+    job.setInputFormatClass(TextInputFormat.class);
+    TextInputFormat.setInputPaths(job, dataPath);
+
+    setMaxSplitSize(job.getConfiguration(), dataPath, numMaps);
+
+    // retrieve the splits
+    TextInputFormat input = new TextInputFormat();
+    List<InputSplit> splits = input.getSplits(job);
+    assertEquals(numMaps, splits.size());
+
+    InputSplit[] sorted = new InputSplit[numMaps];
+    splits.toArray(sorted);
+    Builder.sortSplits(sorted);
+
+    List<Integer> keys = new ArrayList<Integer>();
+    List<Step0Output> values = new ArrayList<Step0Output>();
+
+    int[] expectedIds = new int[numMaps];
+
+    TaskAttemptContext context = new TaskAttemptContext(job.getConfiguration(),
+        new TaskAttemptID());
+
+    for (int p = 0; p < numMaps; p++) {
+      InputSplit split = sorted[p];
+      RecordReader<LongWritable, Text> reader = input.createRecordReader(split,
+          context);
+      reader.initialize(split, context);
+
+      Long firstKey = null;
+      int size = 0;
+
+      while (reader.nextKeyValue()) {
+        LongWritable key = reader.getCurrentKey();
+        Text value = reader.getCurrentValue();
+
+        if (firstKey == null) {
+          firstKey = key.get();
+          expectedIds[p] = converter.convert(0, value.toString()).label;
+        }
+
+        size++;
+      }
+
+      keys.add(p);
+      values.add(new Step0Output(firstKey, size));
+    }
+
+    Step0Output[] partitions = Step0Job.processOutput(keys, values);
+
+    int[] actualIds = Step0Output.extractFirstIds(partitions);
+
+    assertTrue("Expected: " + Arrays.toString(expectedIds) + " But was: "
+        + Arrays.toString(actualIds), Arrays.equals(expectedIds, actualIds));
+  }
+
+  public class Step0Context extends Context {
+
+    public final int[] keys;
+
+    public final Step0Output[] values;
+
+    private int index = 0;
+
+    @SuppressWarnings("unchecked")
+    public Step0Context(Mapper mapper, Configuration conf,
+        TaskAttemptID taskid, int numMaps) throws IOException,
+        InterruptedException {
+      mapper.super(conf, taskid, null, null, null, null, null);
+
+      keys = new int[numMaps];
+      values = new Step0Output[numMaps];
+    }
+
+    @Override
+    public void write(Object key, Object value) throws IOException,
+        InterruptedException {
+      if (index == keys.length) {
+        throw new IOException("Received more output than expected : " + index);
+      }
+
+      keys[index] = ((IntWritable) key).get();
+      values[index] = ((Step0Output) value).clone();
+
+      index++;
+    }
+
+    /**
+     * Number of outputs collected
+     * 
+     * @return
+     */
+    public int nbOutputs() {
+      return index;
+    }
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java Wed Sep 30 05:29:22 2009
@@ -0,0 +1,133 @@
+package org.apache.mahout.df.mapreduce.partial;
+
+import static org.apache.mahout.df.data.Utils.double2String;
+import static org.apache.mahout.df.data.Utils.randomDescriptor;
+import static org.apache.mahout.df.data.Utils.randomDoubles;
+
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+import org.apache.mahout.df.builder.TreeBuilder;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class Step1MapperTest extends TestCase {
+
+  /**
+   * Make sure that the data used to build the trees is from the mapper's
+   * partition
+   * 
+   */
+  private static class MockTreeBuilder extends TreeBuilder {
+
+    protected Data expected;
+
+    public void setExpected(Data data) {
+      expected = data;
+    }
+
+    @Override
+    public Node build(Random rng, Data data) {
+      for (int index = 0; index < data.size(); index++) {
+        assertTrue(expected.contains(data.get(index)));
+      }
+
+      return new Leaf(-1);
+    }
+  }
+
+  /**
+   * Special Step1Mapper that can be configured without using a Configuration
+   * 
+   */
+  protected static class MockStep1Mapper extends Step1Mapper {
+    public MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed,
+        int partition, int numMapTasks, int numTrees) {
+      configure(false, true, treeBuilder, dataset);
+      configure(seed, partition, numMapTasks, numTrees);
+    }
+
+    public int getFirstTreeId() {
+      return firstTreeId;
+    }
+
+  }
+
+  /** nb attributes per generated data instance */
+  protected final int nbAttributes = 4;
+
+  /** nb generated data instances */
+  protected final int nbInstances = 100;
+
+  /** nb trees to build */
+  protected final int nbTrees = 10;
+
+  /** nb mappers to use */
+  protected final int nbMappers = 2;
+
+  @SuppressWarnings("unchecked")
+  public void testMapper() throws Exception {
+    Long seed = null;
+    Random rng = new Random();
+
+    // prepare the data
+    String descriptor = randomDescriptor(rng, nbAttributes);
+    double[][] source = randomDoubles(rng, descriptor, nbInstances);
+    String[] sData = double2String(source);
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    String[][] splits = Utils.splitData(sData, nbMappers);
+
+    MockTreeBuilder treeBuilder = new MockTreeBuilder();
+
+    LongWritable key = new LongWritable();
+    Text value = new Text();
+
+    int treeIndex = 0;
+
+    for (int partition = 0; partition < nbMappers; partition++) {
+      String[] split = splits[partition];
+      treeBuilder.setExpected(DataLoader.loadData(dataset, split));
+
+      // expected number of trees that this mapper will build
+      int mapNbTrees = Step1Mapper.nbTrees(nbMappers, nbTrees, partition);
+
+      MockContext context = new MockContext(new Step1Mapper(),
+          new Configuration(), new TaskAttemptID(), mapNbTrees);
+
+      MockStep1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, seed,
+          partition, nbMappers, nbTrees);
+
+      // make sure the mapper computed firstTreeId correctly
+      assertEquals(treeIndex, mapper.getFirstTreeId());
+
+      for (int index = 0; index < split.length; index++) {
+        key.set(index);
+        value.set(split[index]);
+        mapper.map(key, value, context);
+      }
+
+      mapper.cleanup(context);
+
+      // make sure the mapper built all its trees
+      assertEquals(mapNbTrees, context.nbOutputs());
+
+      // check the returned keys
+      for (TreeID k : context.keys) {
+        assertEquals(partition, k.partition());
+        assertEquals(treeIndex, k.treeId());
+
+        treeIndex++;
+      }
+    }
+  }
+}



Mime
View raw message