Return-Path: Delivered-To: apmail-lucene-mahout-commits-archive@minotaur.apache.org Received: (qmail 28005 invoked from network); 29 Sep 2009 04:57:18 -0000 Received: from hermes.apache.org (HELO mail.apache.org) (140.211.11.3) by minotaur.apache.org with SMTP; 29 Sep 2009 04:57:18 -0000 Received: (qmail 24267 invoked by uid 500); 29 Sep 2009 04:57:18 -0000 Delivered-To: apmail-lucene-mahout-commits-archive@lucene.apache.org Received: (qmail 24228 invoked by uid 500); 29 Sep 2009 04:57:18 -0000 Mailing-List: contact mahout-commits-help@lucene.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: mahout-dev@lucene.apache.org Delivered-To: mailing list mahout-commits@lucene.apache.org Received: (qmail 24219 invoked by uid 99); 29 Sep 2009 04:57:18 -0000 Received: from nike.apache.org (HELO nike.apache.org) (192.87.106.230) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 29 Sep 2009 04:57:18 +0000 X-ASF-Spam-Status: No, hits=-2000.0 required=10.0 tests=ALL_TRUSTED X-Spam-Check-By: apache.org Received: from [140.211.11.4] (HELO eris.apache.org) (140.211.11.4) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 29 Sep 2009 04:57:04 +0000 Received: by eris.apache.org (Postfix, from userid 65534) id 90C7A23888D6; Tue, 29 Sep 2009 04:56:41 +0000 (UTC) Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit Subject: svn commit: r819830 [1/2] - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/df/mapred/ main/java/org/apache/mahout/df/mapred/inmem/ main/java/org/apache/mahout/df/mapreduce/ main/java/org/apache/mahout/df/mapreduce/inmem/ test/java/org/ap... Date: Tue, 29 Sep 2009 04:56:39 -0000 To: mahout-commits@lucene.apache.org From: adeneche@apache.org X-Mailer: svnmailer-1.0.8 Message-Id: <20090929045641.90C7A23888D6@eris.apache.org> X-Virus-Checked: Checked by ClamAV on apache.org Author: adeneche Date: Tue Sep 29 04:56:25 2009 New Revision: 819830 URL: http://svn.apache.org/viewvc?rev=819830&view=rev Log: MAHOUT-140 In-memory mapreduce Decision Forests Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemBuilder.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemInputFormat.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/package.html lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredOutput.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormat.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/package.html lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputFormatTest.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputSplitTest.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormatTest.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/InMemInputSplitTest.java Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/Builder.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,311 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred; + +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Comparator; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.mapred.JobClient; +import org.apache.hadoop.mapred.JobConf; +import org.apache.mahout.df.DecisionForest; +import org.apache.mahout.df.builder.TreeBuilder; +import org.apache.mahout.df.callback.PredictionCallback; +import org.apache.mahout.df.data.Dataset; +import org.apache.mahout.common.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for Mapred DecisionForest builders. Takes care of storing the + * parameters common to the mapred implementations.
+ * The child classes must implement at least : + *
    + *
  • void configureJob(JobConf) : to further configure the job before its + * launch; and
  • + *
  • DecisionForest parseOutput(JobConf, PredictionCallback) : in order to + * convert the job outputs into a DecisionForest and its corresponding oob + * predictions
  • + *
+ * + */ +public abstract class Builder { + + private static final Logger log = LoggerFactory.getLogger(Builder.class); + + /** Tree Builder Component */ + protected final TreeBuilder treeBuilder; + + protected final Path dataPath; + + protected final Path datasetPath; + + protected final Long seed; + + protected final Configuration conf; + + protected String outputDirName = "output"; + + /** + * Used only for DEBUG purposes. if false, the mappers doesn't output anything, + * so the builder has nothing to process + * + * @param conf + * @return + */ + protected static boolean isOutput(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.output", true); + } + + protected static boolean isOobEstimate(Configuration conf) { + return conf.getBoolean("mahout.rf.oob", false); + } + + protected static void setOobEstimate(Configuration conf, boolean value) { + conf.setBoolean("mahout.rf.oob", value); + } + + /** + * Returns the random seed + * + * @param conf + * @return null if no seed is available + */ + public static Long getRandomSeed(Configuration conf) { + String seed = conf.get("mahout.rf.random.seed"); + if (seed == null) + return null; + + return Long.valueOf(seed); + } + + /** + * Sets the random seed value + * + * @param conf + * @param seed + */ + protected static void setRandomSeed(Configuration conf, long seed) { + conf.setLong("mahout.rf.random.seed", seed); + } + + public static TreeBuilder getTreeBuilder(Configuration conf) { + String string = conf.get("mahout.rf.treebuilder"); + if (string == null) + return null; + + return (TreeBuilder) StringUtils.fromString(string); + } + + protected static void setTreeBuilder(Configuration conf, TreeBuilder treeBuilder) { + conf.set("mahout.rf.treebuilder", StringUtils.toString(treeBuilder)); + } + + /** + * Get the number of trees for the map-reduce job. The default value is 100 + * + * @param conf + * @return + */ + public static int getNbTrees(Configuration conf) { + return conf.getInt("mahout.rf.nbtrees", -1); + } + + /** + * Set the number of trees to grow for the map-reduce job + * + * @param conf + * @param nbTrees + * @throws IllegalArgumentException if (nbTrees <= 0) + */ + public static void setNbTrees(Configuration conf, int nbTrees) { + if (nbTrees <= 0) + throw new IllegalArgumentException("nbTrees should be greater than 0"); + + conf.setInt("mahout.rf.nbtrees", nbTrees); + } + + /** + * Sets the Output directory name, will be creating in the working directory + * @param name + */ + public void setOutputDirName(String name) { + outputDirName = name; + } + + /** + * Output Directory name + * @param conf + * @return + * @throws IOException + */ + public Path getOutputPath(Configuration conf) throws IOException { + // the output directory is accessed only by this class, so use the default + // file system + FileSystem fs = FileSystem.get(conf); + return new Path(fs.getWorkingDirectory(), outputDirName); + } + + public Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, + Long seed, Configuration conf) { + this.treeBuilder = treeBuilder; + this.dataPath = dataPath; + this.datasetPath = datasetPath; + this.seed = seed; + this.conf = conf; + } + + /** + * Helper method. Get a path from the DistributedCache + * + * @param job + * @param index index of the path in the DistributedCache files + * @return + * @throws IOException + */ + public static Path getDistributedCacheFile(JobConf job, int index) + throws IOException { + URI[] files = new URI[0]; + files = DistributedCache.getCacheFiles(job); + + if (files == null || files.length < index) { + throw new IOException("path not found in the DistributedCache"); + } + + return new Path(files[index].getPath()); + } + + /** + * Helper method. Load a Dataset stored in the DistributedCache + * + * @param job + * @return + * @throws IOException + */ + public static Dataset loadDataset(JobConf job) throws IOException { + Path datasetPath = getDistributedCacheFile(job, 0); + + return Dataset.load(job, datasetPath); + } + + /** + * Used by the inheriting classes to configure the job + * + * @param conf + * @param nbTrees number of trees to grow + * @param oobEstimate true, if oob error should be estimated + * @throws IOException + */ + protected abstract void configureJob(JobConf conf, int nbTrees, + boolean oobEstimate) throws IOException; + + /** + * Sequential implementation should override this method to simulate the job + * execution + * + * @param job + * @throws Exception + */ + protected void runJob(JobConf job) throws Exception { + JobClient.runJob(job); + } + + /** + * Parse the output files to extract the trees and pass the predictions to the + * callback + * + * @param job + * @param callback can be null + * @return + * @throws IOException + */ + protected abstract DecisionForest parseOutput(JobConf job, + PredictionCallback callback) throws IOException; + + public DecisionForest build(int nbTrees, PredictionCallback callback) + throws Exception { + JobConf job = new JobConf(conf, Builder.class); + + Path outputPath = getOutputPath(job); + FileSystem fs = outputPath.getFileSystem(job); + + // check the output + if (fs.exists(outputPath)) + throw new RuntimeException("Ouput path already exists : " + outputPath); + + if (seed != null) + setRandomSeed(job, seed); + setNbTrees(job, nbTrees); + setTreeBuilder(job, treeBuilder); + setOobEstimate(job, callback != null); + + // put the dataset into the DistributedCache + DistributedCache.addCacheFile(datasetPath.toUri(), job); + + log.debug("Configuring the job..."); + configureJob(job, nbTrees, callback != null); + + log.debug("Running the job..."); + runJob(job); + + if (isOutput(job)) { + log.debug("Parsing the output..."); + DecisionForest forest = parseOutput(job, callback); + + // delete the output path + fs.delete(outputPath, true); + + return forest; + } + + return null; + } + + /** + * sort the splits into order based on size, so that the biggest go first.
+ * This is the same code used by Hadoop's JobClient. + * + * @param splits + */ + public static void sortSplits(InputSplit[] splits) { + Arrays.sort(splits, new Comparator() { + public int compare(InputSplit a, InputSplit b) { + try { + long left = a.getLength(); + long right = b.getLength(); + if (left == right) { + return 0; + } else if (left < right) { + return 1; + } else { + return -1; + } + } catch (IOException ie) { + throw new RuntimeException("Problem getting input split size", ie); + } + } + }); + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/MapredMapper.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,101 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred; + +import java.io.IOException; + +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.MapReduceBase; +import org.apache.mahout.df.builder.TreeBuilder; +import org.apache.mahout.df.data.Dataset; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for Mapred mappers. Loads common parameters from the job + */ +public class MapredMapper extends MapReduceBase { + + protected final static Logger log = LoggerFactory.getLogger(MapredMapper.class); + + private boolean noOutput; + + private boolean oobEstimate; + + private TreeBuilder treeBuilder; + + private Dataset dataset; + + /** + * + * @return if false, the mapper does not output + */ + protected boolean isOobEstimate() { + return oobEstimate; + } + + /** + * + * @return if false, the mapper does not estimate and output predictions + */ + protected boolean isNoOutput() { + return noOutput; + } + + protected TreeBuilder getTreeBuilder() { + return treeBuilder; + } + + protected Dataset getDataset() { + return dataset; + } + + @Override + public void configure(JobConf conf) { + super.configure(conf); + + try { + configure(!Builder.isOutput(conf), Builder.isOobEstimate(conf), Builder + .getTreeBuilder(conf), Builder.loadDataset(conf)); + } catch (IOException e) { + throw new RuntimeException( + "Exception caught while configuring the mapper: " + e.getMessage()); + } + } + + /** + * Useful for testing + * + * @param noOutput + * @param oobEstimate + * @param treeBuilder + * @param dataset + */ + protected void configure(boolean noOutput, boolean oobEstimate, + TreeBuilder treeBuilder, Dataset dataset) { + this.noOutput = noOutput; + this.oobEstimate = oobEstimate; + + if (treeBuilder == null) { + throw new RuntimeException("TreeBuilder not found in the Job parameters"); + } + this.treeBuilder = treeBuilder; + + this.dataset = dataset; + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemBuilder.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemBuilder.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemBuilder.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemBuilder.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred.inmem; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile.Reader; +import org.apache.hadoop.mapred.FileOutputFormat; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.mahout.df.DFUtils; +import org.apache.mahout.df.DecisionForest; +import org.apache.mahout.df.builder.TreeBuilder; +import org.apache.mahout.df.callback.PredictionCallback; +import org.apache.mahout.df.mapred.Builder; +import org.apache.mahout.df.mapreduce.MapredOutput; +import org.apache.mahout.df.node.Node; + +/** + * MapReduce implementation where each mapper loads a full copy of the data + * in-memory. The forest trees are splitted across all the mappers + */ +public class InMemBuilder extends Builder { + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, + Long seed, Configuration conf) { + super(treeBuilder, dataPath, datasetPath, seed, conf); + } + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath) { + this(treeBuilder, dataPath, datasetPath, null, new Configuration()); + } + + @Override + protected void configureJob(JobConf conf, int nbTrees, boolean oobEstimate) + throws IOException { + FileOutputFormat.setOutputPath(conf, getOutputPath(conf)); + + // put the data in the DistributedCache + DistributedCache.addCacheFile(dataPath.toUri(), conf); + + conf.setOutputKeyClass(IntWritable.class); + conf.setOutputValueClass(MapredOutput.class); + + conf.setMapperClass(InMemMapper.class); + conf.setNumReduceTasks(0); // no reducers + + conf.setInputFormat(InMemInputFormat.class); + conf.setOutputFormat(SequenceFileOutputFormat.class); + } + + @Override + protected DecisionForest parseOutput(JobConf conf, PredictionCallback callback) + throws IOException { + Map output = new HashMap(); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + // import the InMemOutputs + IntWritable key = new IntWritable(); + MapredOutput value = new MapredOutput(); + + for (Path path : outfiles) { + Reader reader = new Reader(fs, path, conf); + + try { + while (reader.next(key, value)) { + output.put(key.get(), value.clone()); + } + } finally { + reader.close(); + } + } + + return processOutput(output, callback); + } + + /** + * Process the output, extracting the trees and passing the predictions to the + * callback + * + * @param output + * @param callback + * @return + */ + protected DecisionForest processOutput(Map output, + PredictionCallback callback) { + List trees = new ArrayList(); + + for (Integer key : output.keySet()) { + MapredOutput value = output.get(key); + + trees.add(value.getTree()); + + if (callback != null) { + int[] predictions = value.getPredictions(); + for (int index = 0; index < predictions.length; index++) { + callback.prediction(key, index, predictions[index]); + } + } + } + + return new DecisionForest(trees); + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemInputFormat.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemInputFormat.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemInputFormat.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemInputFormat.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,271 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred.inmem; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Random; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.InputFormat; +import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.RecordReader; +import org.apache.hadoop.mapred.Reporter; +import org.apache.mahout.df.mapred.Builder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Custom InputFormat that generates InputSplits given the desired number of + * trees.
+ * each input split contains a subset of the trees.
+ * The number of splits is equal to the number of requested splits + */ +public class InMemInputFormat implements InputFormat { + + protected static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class); + + protected Random rng; + + protected Long seed; + + protected boolean isSingleSeed; + + /** + * Used for DEBUG purposes only. if true and a seed is available, all the + * mappers use the same seed, thus all the mapper should take the same time to + * build their trees. + * + * @param conf + * @return + */ + public static boolean isSingleSeed(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.single.seed", false); + } + + @Override + public RecordReader getRecordReader( + InputSplit split, JobConf conf, Reporter reporter) throws IOException { + return new InMemRecordReader((InMemInputSplit) split); + } + + @Override + public InputSplit[] getSplits(JobConf conf, int numSplits) throws IOException { + int nbTrees = Builder.getNbTrees(conf); + int splitSize = nbTrees / numSplits; + + seed = Builder.getRandomSeed(conf); + isSingleSeed = isSingleSeed(conf); + + if (rng != null && seed != null) { + log.warn("getSplits() was called more than once and the 'seed' is set, " + + "this can lead to no-repeatable behavior"); + } + + rng = (seed == null || isSingleSeed) ? null : new Random(seed); + + int id = 0; + + InputSplit[] splits = new InputSplit[numSplits]; + + for (int index = 0; index < numSplits - 1; index++) { + splits[index] = new InMemInputSplit(id, splitSize, nextSeed()); + id += splitSize; + } + + // take care of the remainder + splits[numSplits - 1] = new InMemInputSplit(id, nbTrees - id, nextSeed()); + + return splits; + } + + /** + * Return the seed for the next InputSplit + * + * @return + */ + private Long nextSeed() { + if (seed == null) + return null; + else if (isSingleSeed) + return seed; + else + return rng.nextLong(); + } + + public static class InMemRecordReader implements + RecordReader { + + protected final InMemInputSplit split; + + protected int pos; + + public InMemRecordReader(InMemInputSplit split) { + this.split = split; + } + + @Override + public void close() throws IOException { + } + + @Override + public IntWritable createKey() { + return new IntWritable(); + } + + @Override + public NullWritable createValue() { + return NullWritable.get(); + } + + @Override + public long getPos() throws IOException { + return pos; + } + + @Override + public float getProgress() throws IOException { + if (pos == 0) + return 0f; + else + return (float) (pos - 1) / split.nbTrees; + } + + @Override + public boolean next(IntWritable key, NullWritable value) throws IOException { + if (pos < split.nbTrees) { + key.set(split.firstId + pos); + pos++; + return true; + } else { + return false; + } + } + + } + + /** + * Custom InputSplit that indicates how many trees are built by each mapper + */ + public static class InMemInputSplit implements InputSplit { + + /** Id of the first tree of this split */ + private int firstId; + + private int nbTrees; + + private Long seed; + + public InMemInputSplit() { + } + + public InMemInputSplit(int firstId, int nbTrees, Long seed) { + this.firstId = firstId; + this.nbTrees = nbTrees; + this.seed = seed; + } + + /** + * Return the Id of the first tree of this split + * + * @return + */ + public int getFirstId() { + return firstId; + } + + /** + * Return the number of trees + * + * @return + */ + public int getNbTrees() { + return nbTrees; + } + + /** + * Return the random seed + * + * @return null if no seed is available + */ + public Long getSeed() { + return seed; + } + + @Override + public long getLength() throws IOException { + return nbTrees; + } + + @Override + public String[] getLocations() throws IOException { + return new String[0]; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || !(obj instanceof InMemInputSplit)) + return false; + + InMemInputSplit split = (InMemInputSplit) obj; + + if (seed == null && split.seed != null) + return false; + + return firstId == split.firstId && nbTrees == split.nbTrees + && (seed == null || seed.equals(split.seed)); + } + + @Override + public String toString() { + return String.format("[firstId:%d, nbTrees:%d, seed:%d]", firstId, + nbTrees, seed); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readInt(); + nbTrees = in.readInt(); + boolean isSeed = in.readBoolean(); + seed = (isSeed) ? in.readLong() : null; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(firstId); + out.writeInt(nbTrees); + out.writeBoolean(seed != null); + if (seed != null) { + out.writeLong(seed); + } + } + + public static InMemInputSplit read(DataInput in) throws IOException { + InMemInputSplit split = new InMemInputSplit(); + split.readFields(in); + return split; + } + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemMapper.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemMapper.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/InMemMapper.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred.inmem; + +import java.io.IOException; +import java.util.Random; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.Mapper; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapred.Reporter; +import org.apache.hadoop.util.StringUtils; +import org.apache.mahout.df.Bagging; +import org.apache.mahout.df.callback.SingleTreePredictions; +import org.apache.mahout.df.data.Data; +import org.apache.mahout.df.data.DataLoader; +import org.apache.mahout.df.data.Dataset; +import org.apache.mahout.df.mapred.Builder; +import org.apache.mahout.df.mapred.MapredMapper; +import org.apache.mahout.df.mapred.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.df.mapreduce.MapredOutput; +import org.apache.mahout.df.node.Node; + +/** + * In-memory mapper that grows the trees using a full copy of the data loaded + * in-memory. The number of trees to grow is determined by the current + * InMemInputSplit. + */ +public class InMemMapper extends MapredMapper implements + Mapper { + + protected Bagging bagging; + + protected Random rng; + + protected Data data; + + /** + * Load the training data + * + * @param conf + * @return + * @throws RuntimeException if the data could not be loaded + */ + protected Data loadData(JobConf conf, Dataset dataset) { + try { + Path dataPath = Builder.getDistributedCacheFile(conf, 1); + FileSystem fs = FileSystem.get(dataPath.toUri(), conf); + + return DataLoader.loadData(dataset, fs, dataPath); + } catch (Exception e) { + throw new RuntimeException("Exception caught while loading the data: " + + StringUtils.stringifyException(e)); + } + } + + @Override + public void configure(JobConf conf) { + super.configure(conf); + + log.info("Loading the data..."); + data = loadData(conf, getDataset()); + log.info("Data loaded : " + data.size() + " instances"); + + bagging = new Bagging(getTreeBuilder(), data); + } + + @Override + public void map(IntWritable key, NullWritable value, + OutputCollector output, Reporter reporter) + throws IOException { + map(key, output, (InMemInputSplit) reporter.getInputSplit()); + } + + public void map(IntWritable key, + OutputCollector output, InMemInputSplit split) + throws IOException { + + SingleTreePredictions callback = null; + int[] predictions = null; + + if (isOobEstimate() && !isNoOutput()) { + callback = new SingleTreePredictions(data.size()); + predictions = callback.predictions; + } + + initRandom(split); + + log.debug("Building..."); + Node tree = bagging.build(key.get(), rng, callback); + + if (!isNoOutput()) { + log.debug("Outputing..."); + MapredOutput mrOut = new MapredOutput(tree, predictions); + + output.collect(key, mrOut); + } + } + + protected void initRandom(InMemInputSplit split) { + if (rng == null) { // first execution of this mapper + Long seed = split.getSeed(); + log.debug("Initialising rng with seed : " + seed); + + if (seed == null) + rng = new Random(); + else + rng = new Random(seed); + } + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/package.html URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/package.html?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/package.html (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapred/inmem/package.html Tue Sep 29 04:56:25 2009 @@ -0,0 +1,40 @@ + + + + org.apache.mahout.df.mapred.inmem + + +
+In-memory mapreduce implementation of Random Decision Forests +
+
 
+
+Each mapper is responsible for growing a number of trees with a whole copy of the dataset loaded in memory, it uses the reference implementation's code to build each tree and estimate the oob error.

+The dataset is distributed to the slave nodes using the DistributedCache. A custom InputFormat (InMemInputFormat) is configured with the desired number of trees and generates a number of InputSplits (InMemInputSplit) equal to the configured number of maps (mapred.map.tasks).

+There is no need for reducers, each map outputs (MapredOutput) the trees it built and, for each tree, the labels the tree predicted for each out-of-bag instance. This step has to be done in the mapper because only there we know which instances are o-o-b.

+The Forest builder (InMemBuilder) is responsible for configuring and launching the job. At the end of the job it parses the output files and builds the corresponding DecisionForest, and for each tree prediction it calls (if available) a PredictionCallback that allows the caller to compute any error needed. +
+
 
+ + + Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,339 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.Job; +import org.apache.mahout.common.StringUtils; +import org.apache.mahout.df.DecisionForest; +import org.apache.mahout.df.builder.TreeBuilder; +import org.apache.mahout.df.callback.PredictionCallback; +import org.apache.mahout.df.data.Dataset; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.Comparator; + +/** + * Base class for Mapred DecisionForest builders. Takes care of storing the + * parameters common to the mapred implementations.
+ * The child classes must implement at least : + *
    + *
  • void configureJob(Job) : to further configure the job before its + * launch; and
  • + *
  • DecisionForest parseOutput(Job, PredictionCallback) : in order to + * convert the job outputs into a DecisionForest and its corresponding oob + * predictions
  • + *
+ * + */ +public abstract class Builder { + + private static final Logger log = LoggerFactory.getLogger(Builder.class); + + /** Tree Builder Component */ + protected final TreeBuilder treeBuilder; + + protected final Path dataPath; + + protected final Path datasetPath; + + protected final Long seed; + + private final Configuration conf; + + protected String outputDirName = "output"; + + protected int numTrees; + + /** + * Return the value of "mapred.map.tasks". In case the 'local' runner is + * detected, returns 1 + * + * @param conf configuration + * @return number of map tasks + */ + public static int getNumMaps(Configuration conf) { + // if we are in 'local' mode, correct the number of maps + // or the mappers won't be able to compute the right indexes + String tracker = conf.get("mapred.job.tracker", "local"); + if ("local".equals(tracker)) { + log + .warn("Hadoop running in 'local' mode, only one map task will be launched"); + return 1; + } + + return conf.getInt("mapred.map.tasks", -1); + } + + /** + * Used only for DEBUG purposes. if false, the mappers doesn't output + * anything, so the builder has nothing to process + * + * @param conf configuration + * @return true if the builder has to return output. false otherwise + */ + protected static boolean isOutput(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.output", true); + } + + protected static boolean isOobEstimate(Configuration conf) { + return conf.getBoolean("mahout.rf.oob", false); + } + + protected static void setOobEstimate(Configuration conf, boolean value) { + conf.setBoolean("mahout.rf.oob", value); + } + + /** + * Returns the random seed + * + * @param conf configuration + * @return null if no seed is available + */ + public static Long getRandomSeed(Configuration conf) { + String seed = conf.get("mahout.rf.random.seed"); + if (seed == null) + return null; + + return Long.valueOf(seed); + } + + /** + * Sets the random seed value + * + * @param conf configuration + * @param seed random seed + */ + protected static void setRandomSeed(Configuration conf, long seed) { + conf.setLong("mahout.rf.random.seed", seed); + } + + public static TreeBuilder getTreeBuilder(Configuration conf) { + String string = conf.get("mahout.rf.treebuilder"); + if (string == null) + return null; + + return (TreeBuilder) StringUtils.fromString(string); + } + + protected static void setTreeBuilder(Configuration conf, + TreeBuilder treeBuilder) { + conf.set("mahout.rf.treebuilder", StringUtils.toString(treeBuilder)); + } + + /** + * Get the number of trees for the map-reduce job. + * + * @param conf configuration + * @return number of trees to build + */ + public static int getNbTrees(Configuration conf) { + return conf.getInt("mahout.rf.nbtrees", -1); + } + + /** + * Set the number of trees to grow for the map-reduce job + * + * @param conf configuration + * @param nbTrees number of trees to build + * @throws IllegalArgumentException if (nbTrees <= 0) + */ + public static void setNbTrees(Configuration conf, int nbTrees) { + if (nbTrees <= 0) + throw new IllegalArgumentException("nbTrees should be greater than 0"); + + conf.setInt("mahout.rf.nbtrees", nbTrees); + } + + /** + * Sets the Output directory name, will be creating in the working directory + * + * @param name output dir. name + */ + public void setOutputDirName(String name) { + outputDirName = name; + } + + /** + * Output Directory name + * + * @param conf configuration + * @return output dir. path (%WORKING_DIRECTORY%/OUTPUT_DIR_NAME%) + * @throws IOException if we cannot get the default FileSystem + */ + public Path getOutputPath(Configuration conf) throws IOException { + // the output directory is accessed only by this class, so use the default + // file system + FileSystem fs = FileSystem.get(conf); + return new Path(fs.getWorkingDirectory(), outputDirName); + } + + /** + * Helper method. Get a path from the DistributedCache + * + * @param conf configuration + * @param index index of the path in the DistributedCache files + * @return path from the DistributedCache + * @throws IOException if no path is found + */ + public static Path getDistributedCacheFile(Configuration conf, int index) + throws IOException { + URI[] files = DistributedCache.getCacheFiles(conf); + + if (files == null || files.length <= index) { + throw new IOException("path not found in the DistributedCache"); + } + + return new Path(files[index].getPath()); + } + + /** + * Helper method. Load a Dataset stored in the DistributedCache + * + * @param conf configuration + * @return loaded Dataset + * @throws IOException if we cannot retrieve the Dataset path from the DistributedCache, or the Dataset could not be loaded + */ + public static Dataset loadDataset(Configuration conf) throws IOException { + Path datasetPath = getDistributedCacheFile(conf, 0); + + return Dataset.load(conf, datasetPath); + } + + public Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, + Long seed, Configuration conf) { + this.treeBuilder = treeBuilder; + this.dataPath = dataPath; + this.datasetPath = datasetPath; + this.seed = seed; + this.conf = new Configuration(conf); + } + + /** + * Used by the inheriting classes to configure the job + * + * @param job Hadoop's Job + * @param nbTrees number of trees to grow + * @param oobEstimate true, if oob error should be estimated + * @throws IOException if anything goes wrong while configuring the job + */ + protected abstract void configureJob(Job job, int nbTrees, boolean oobEstimate) + throws IOException; + + /** + * Sequential implementation should override this method to simulate the job + * execution + * + * @param job Hadoop's job + * @return true is the job succeeded + * @throws Exception if the job encounters an error + */ + protected boolean runJob(Job job) throws Exception { + return job.waitForCompletion(true); + } + + /** + * Parse the output files to extract the trees and pass the predictions to the + * callback + * + * @param job Hadoop's job + * @param callback can be null + * @return Built DecisionForest + * @throws Exception if anything goes wrong while parsing the output + */ + protected abstract DecisionForest parseOutput(Job job, PredictionCallback callback) throws Exception; + + public DecisionForest build(int nbTrees, PredictionCallback callback) + throws Exception { + numTrees = getNbTrees(conf); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + // check the output + if (fs.exists(outputPath)) + throw new RuntimeException("Ouput path already exists : " + outputPath); + + if (seed != null) + setRandomSeed(conf, seed); + setNbTrees(conf, nbTrees); + setTreeBuilder(conf, treeBuilder); + setOobEstimate(conf, callback != null); + + // put the dataset into the DistributedCache + DistributedCache.addCacheFile(datasetPath.toUri(), conf); + + Job job = new Job(conf, "decision forest builder"); + + log.debug("Configuring the job..."); + configureJob(job, nbTrees, callback != null); + + log.debug("Running the job..."); + if (!runJob(job)) { + log.error("Job failed!"); + return null; + } + + if (isOutput(conf)) { + log.debug("Parsing the output..."); + DecisionForest forest = parseOutput(job, callback); + + // delete the output path + fs.delete(outputPath, true); + + return forest; + } + + return null; + } + + /** + * sort the splits into order based on size, so that the biggest go first.
+ * This is the same code used by Hadoop's JobClient. + * + * @param splits input splits + */ + public static void sortSplits(InputSplit[] splits) { + Arrays.sort(splits, new Comparator() { + public int compare(InputSplit a, InputSplit b) { + try { + long left = a.getLength(); + long right = b.getLength(); + if (left == right) { + return 0; + } else if (left < right) { + return 1; + } else { + return -1; + } + } catch (Exception ie) { + throw new RuntimeException("Problem getting input split size", ie); + } + } + }); + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.df.builder.TreeBuilder; +import org.apache.mahout.df.data.Dataset; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for Mapred mappers. Loads common parameters from the job + */ +public class MapredMapper extends + Mapper { + + protected final static Logger log = LoggerFactory.getLogger(MapredMapper.class); + + private boolean noOutput; + + private boolean oobEstimate; + + private TreeBuilder treeBuilder; + + private Dataset dataset; + + /** + * + * @return if false, the mapper does not output + */ + protected boolean isOobEstimate() { + return oobEstimate; + } + + /** + * + * @return if false, the mapper does not estimate and output predictions + */ + protected boolean isNoOutput() { + return noOutput; + } + + protected TreeBuilder getTreeBuilder() { + return treeBuilder; + } + + protected Dataset getDataset() { + return dataset; + } + + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + try { + configure(!Builder.isOutput(conf), Builder.isOobEstimate(conf), Builder + .getTreeBuilder(conf), Builder.loadDataset(conf)); + } catch (IOException e) { + throw new RuntimeException( + "Exception caught while configuring the mapper: " + e.getMessage()); + } + } + + /** + * Useful for testing + * + * @param noOutput + * @param oobEstimate + * @param treeBuilder + * @param dataset + */ + protected void configure(boolean noOutput, boolean oobEstimate, + TreeBuilder treeBuilder, Dataset dataset) { + this.noOutput = noOutput; + this.oobEstimate = oobEstimate; + + if (treeBuilder == null) { + throw new RuntimeException("TreeBuilder not found in the Job parameters"); + } + this.treeBuilder = treeBuilder; + + this.dataset = dataset; + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredOutput.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredOutput.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredOutput.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredOutput.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.df.DFUtils; +import org.apache.mahout.df.node.Node; + +/** + * Used by various implementation to return the results of a build.
+ * Contains a grown tree and and its oob predictions. + */ +public class MapredOutput implements Writable { + + protected Node tree; + + protected int[] predictions; + + public Node getTree() { + return tree; + } + + public int[] getPredictions() { + return predictions; + } + + public MapredOutput() { + } + + public MapredOutput(Node tree, int[] predictions) { + this.tree = tree; + this.predictions = predictions; + } + + public MapredOutput(Node tree) { + this(tree, null); + } + + public MapredOutput(int[] predictions) { + this(null, predictions); + } + + public static MapredOutput read(DataInput in) throws IOException { + MapredOutput rfOutput = new MapredOutput(); + rfOutput.readFields(in); + return rfOutput; + } + + @Override + public void readFields(DataInput in) throws IOException { + boolean readTree = in.readBoolean(); + if (readTree) { + tree = Node.read(in); + } + + boolean readPredictions = in.readBoolean(); + if (readPredictions) { + predictions = DFUtils.readIntArray(in); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeBoolean(tree != null); + if (tree != null) { + tree.write(out); + } + + out.writeBoolean(predictions != null); + if (predictions != null) { + DFUtils.writeArray(out, predictions); + } + } + + @Override + public MapredOutput clone() { + return new MapredOutput(tree, predictions); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || !(obj instanceof MapredOutput)) + return false; + + MapredOutput mo = (MapredOutput) obj; + + if (tree != null && tree.equals(mo.tree) == false) + return false; + + return Arrays.equals(predictions, mo.predictions); + } + + @Override + public String toString() { + return "{" + tree + " | " + Arrays.toString(predictions) + "}"; + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce.inmem; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile.Reader; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.df.DFUtils; +import org.apache.mahout.df.DecisionForest; +import org.apache.mahout.df.builder.TreeBuilder; +import org.apache.mahout.df.callback.PredictionCallback; +import org.apache.mahout.df.mapreduce.Builder; +import org.apache.mahout.df.mapreduce.MapredOutput; +import org.apache.mahout.df.node.Node; + +/** + * MapReduce implementation where each mapper loads a full copy of the data + * in-memory. The forest trees are splitted across all the mappers + */ +public class InMemBuilder extends Builder { + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, + Long seed, Configuration conf) { + super(treeBuilder, dataPath, datasetPath, seed, conf); + } + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath) { + this(treeBuilder, dataPath, datasetPath, null, new Configuration()); + } + + @Override + protected void configureJob(Job job, int nbTrees, boolean oobEstimate) + throws IOException { + Configuration conf = job.getConfiguration(); + + job.setJarByClass(InMemBuilder.class); + + FileOutputFormat.setOutputPath(job, getOutputPath(conf)); + + // put the data in the DistributedCache + DistributedCache.addCacheFile(dataPath.toUri(), conf); + + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(MapredOutput.class); + + job.setMapperClass(InMemMapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(InMemInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + } + + @Override + protected DecisionForest parseOutput(Job job, PredictionCallback callback) + throws IOException { + Configuration conf = job.getConfiguration(); + + Map output = new HashMap(); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + // import the InMemOutputs + IntWritable key = new IntWritable(); + MapredOutput value = new MapredOutput(); + + for (Path path : outfiles) { + Reader reader = new Reader(fs, path, conf); + + try { + while (reader.next(key, value)) { + output.put(key.get(), value.clone()); + } + } finally { + reader.close(); + } + } + + return processOutput(output, callback); + } + + /** + * Process the output, extracting the trees and passing the predictions to the + * callback + * + * @param output + * @param callback + * @return + */ + protected DecisionForest processOutput(Map output, + PredictionCallback callback) { + List trees = new ArrayList(); + + for (Integer key : output.keySet()) { + MapredOutput value = output.get(key); + + trees.add(value.getTree()); + + if (callback != null) { + int[] predictions = value.getPredictions(); + for (int index = 0; index < predictions.length; index++) { + callback.prediction(key, index, predictions[index]); + } + } + } + + return new DecisionForest(trees); + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormat.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormat.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormat.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormat.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,290 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce.inmem; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.InputFormat; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.mahout.df.mapreduce.Builder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Custom InputFormat that generates InputSplits given the desired number of + * trees.
+ * each input split contains a subset of the trees.
+ * The number of splits is equal to the number of requested splits + */ +public class InMemInputFormat extends InputFormat { + + protected static final Logger log = LoggerFactory + .getLogger(InMemInputSplit.class); + + protected Random rng; + + protected Long seed; + + protected boolean isSingleSeed; + + /** + * Used for DEBUG purposes only. if true and a seed is available, all the + * mappers use the same seed, thus all the mapper should take the same time to + * build their trees. + * + * @param conf + * @return + */ + public static boolean isSingleSeed(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.single.seed", false); + } + + @Override + public RecordReader createRecordReader( + InputSplit split, TaskAttemptContext context) throws IOException, + InterruptedException { + return new InMemRecordReader((InMemInputSplit) split); + } + + @Override + public List getSplits(JobContext context) throws IOException, + InterruptedException { + Configuration conf = context.getConfiguration(); + int numSplits = conf.getInt("mapred.map.tasks", -1); + + return getSplits(conf, numSplits); + } + + public List getSplits(Configuration conf, int numSplits) throws IOException, + InterruptedException { + int nbTrees = Builder.getNbTrees(conf); + int splitSize = nbTrees / numSplits; + + seed = Builder.getRandomSeed(conf); + isSingleSeed = isSingleSeed(conf); + + if (rng != null && seed != null) { + log.warn("getSplits() was called more than once and the 'seed' is set, " + + "this can lead to no-repeatable behavior"); + } + + rng = (seed == null || isSingleSeed) ? null : new Random(seed); + + int id = 0; + + List splits = new ArrayList(numSplits); + + for (int index = 0; index < numSplits - 1; index++) { + splits.add(new InMemInputSplit(id, splitSize, nextSeed())); + id += splitSize; + } + + // take care of the remainder + splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed())); + + return splits; + } + + /** + * Return the seed for the next InputSplit + * + * @return + */ + private Long nextSeed() { + if (seed == null) + return null; + else if (isSingleSeed) + return seed; + else + return rng.nextLong(); + } + + public static class InMemRecordReader extends + RecordReader { + + protected final InMemInputSplit split; + + protected int pos; + + protected IntWritable key; + protected NullWritable value; + + public InMemRecordReader(InMemInputSplit split) { + this.split = split; + } + + @Override + public float getProgress() throws IOException { + if (pos == 0) + return 0f; + else + return (float) (pos - 1) / split.nbTrees; + } + + @Override + public IntWritable getCurrentKey() throws IOException, InterruptedException { + return key; + } + + @Override + public NullWritable getCurrentValue() throws IOException, InterruptedException { + return value; + } + + @Override + public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException { + key = new IntWritable(); + value = NullWritable.get(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (pos < split.nbTrees) { + key.set(split.firstId + pos); + pos++; + return true; + } else { + return false; + } + } + + @Override + public void close() throws IOException { + } + + } + + /** + * Custom InputSplit that indicates how many trees are built by each mapper + */ + public static class InMemInputSplit extends InputSplit implements Writable { + + /** Id of the first tree of this split */ + private int firstId; + + private int nbTrees; + + private Long seed; + + public InMemInputSplit() { + } + + public InMemInputSplit(int firstId, int nbTrees, Long seed) { + this.firstId = firstId; + this.nbTrees = nbTrees; + this.seed = seed; + } + + /** + * Return the Id of the first tree of this split + * + * @return + */ + public int getFirstId() { + return firstId; + } + + /** + * Return the number of trees + * + * @return + */ + public int getNbTrees() { + return nbTrees; + } + + /** + * Return the random seed + * + * @return null if no seed is available + */ + public Long getSeed() { + return seed; + } + + @Override + public long getLength() throws IOException { + return nbTrees; + } + + @Override + public String[] getLocations() throws IOException { + return new String[0]; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || !(obj instanceof InMemInputSplit)) + return false; + + InMemInputSplit split = (InMemInputSplit) obj; + + if (seed == null && split.seed != null) + return false; + + return firstId == split.firstId && nbTrees == split.nbTrees + && (seed == null || seed.equals(split.seed)); + } + + @Override + public String toString() { + return String.format("[firstId:%d, nbTrees:%d, seed:%d]", firstId, + nbTrees, seed); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readInt(); + nbTrees = in.readInt(); + boolean isSeed = in.readBoolean(); + seed = (isSeed) ? in.readLong() : null; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(firstId); + out.writeInt(nbTrees); + out.writeBoolean(seed != null); + if (seed != null) { + out.writeLong(seed); + } + } + + public static InMemInputSplit read(DataInput in) throws IOException { + InMemInputSplit split = new InMemInputSplit(); + split.readFields(in); + return split; + } + + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,128 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce.inmem; + +import java.io.IOException; +import java.util.Random; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.util.StringUtils; +import org.apache.mahout.df.Bagging; +import org.apache.mahout.df.callback.SingleTreePredictions; +import org.apache.mahout.df.data.Data; +import org.apache.mahout.df.data.DataLoader; +import org.apache.mahout.df.data.Dataset; +import org.apache.mahout.df.mapreduce.Builder; +import org.apache.mahout.df.mapreduce.MapredMapper; +import org.apache.mahout.df.mapreduce.MapredOutput; +import org.apache.mahout.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.df.node.Node; + +/** + * In-memory mapper that grows the trees using a full copy of the data loaded + * in-memory. The number of trees to grow is determined by the current + * InMemInputSplit. + */ +public class InMemMapper extends + MapredMapper { + + protected Bagging bagging; + + protected Random rng; + + protected Data data; + + /** + * Load the training data + * + * @param conf + * @return + * @throws RuntimeException if the data could not be loaded + */ + protected Data loadData(Configuration conf, Dataset dataset) { + try { + Path dataPath = Builder.getDistributedCacheFile(conf, 1); + FileSystem fs = FileSystem.get(dataPath.toUri(), conf); + + return DataLoader.loadData(dataset, fs, dataPath); + } catch (Exception e) { + throw new RuntimeException("Exception caught while loading the data: " + + StringUtils.stringifyException(e)); + } + } + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + log.info("Loading the data..."); + data = loadData(conf, getDataset()); + log.info("Data loaded : " + data.size() + " instances"); + + bagging = new Bagging(getTreeBuilder(), data); + } + + @Override + protected void map(IntWritable key, NullWritable value, Context context) + throws IOException, InterruptedException { + map(key, context); + } + + public void map(IntWritable key, Context context) throws IOException, InterruptedException { + + SingleTreePredictions callback = null; + int[] predictions = null; + + if (isOobEstimate() && !isNoOutput()) { + callback = new SingleTreePredictions(data.size()); + predictions = callback.predictions; + } + + initRandom((InMemInputSplit)context.getInputSplit()); + + log.debug("Building..."); + Node tree = bagging.build(key.get(), rng, callback); + + if (!isNoOutput()) { + log.debug("Outputing..."); + MapredOutput mrOut = new MapredOutput(tree, predictions); + + context.write(key, mrOut); + } + } + + protected void initRandom(InMemInputSplit split) { + if (rng == null) { // first execution of this mapper + Long seed = split.getSeed(); + log.debug("Initialising rng with seed : " + seed); + + if (seed == null) + rng = new Random(); + else + rng = new Random(seed); + } + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/package.html URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/package.html?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/package.html (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/package.html Tue Sep 29 04:56:25 2009 @@ -0,0 +1,40 @@ + + + + org.apache.mahout.df.mapred.inmem + + +
+In-memory mapreduce implementation of Random Decision Forests +
+
 
+
+Each mapper is responsible for growing a number of trees with a whole copy of the dataset loaded in memory, it uses the reference implementation's code to build each tree and estimate the oob error.

+The dataset is distributed to the slave nodes using the DistributedCache. A custom InputFormat (InMemInputFormat) is configured with the desired number of trees and generates a number of InputSplits (InMemInputSplit) equal to the configured number of maps (mapred.map.tasks).

+There is no need for reducers, each map outputs (MapredOutput) the trees it built and, for each tree, the labels the tree predicted for each out-of-bag instance. This step has to be done in the mapper because only there we know which instances are o-o-b.

+The Forest builder (InMemBuilder) is responsible for configuring and launching the job. At the end of the job it parses the output files and builds the corresponding DecisionForest, and for each tree prediction it calls (if available) a PredictionCallback that allows the caller to compute any error needed. +
+
 
+ + + Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputFormatTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputFormatTest.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputFormatTest.java (added) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputFormatTest.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,112 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred.inmem; + +import java.util.Random; + +import junit.framework.TestCase; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.mapred.JobConf; +import org.apache.mahout.df.mapred.Builder; +import org.apache.mahout.df.mapred.inmem.InMemInputFormat; +import org.apache.mahout.df.mapred.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.df.mapred.inmem.InMemInputFormat.InMemRecordReader; + +public class InMemInputFormatTest extends TestCase { + + public void testSplits() throws Exception { + int n = 1; + int maxNumSplits = 100; + int maxNbTrees = 1000; + + Random rng = new Random(); + + for (int nloop = 0; nloop < n; nloop++) { + int numSplits = rng.nextInt(maxNumSplits) + 1; + int nbTrees = rng.nextInt(maxNbTrees) + 1; + + JobConf conf = new JobConf(); + Builder.setNbTrees(conf, nbTrees); + + InMemInputFormat inputFormat = new InMemInputFormat(); + + InputSplit[] splits = inputFormat.getSplits(conf, numSplits); + + assertEquals(numSplits, splits.length); + + int nbTreesPerSplit = nbTrees / numSplits; + int totalTrees = 0; + int expectedId = 0; + + for (int index = 0; index < numSplits; index++) { + assertTrue(splits[index] instanceof InMemInputSplit); + + InMemInputSplit split = (InMemInputSplit) splits[index]; + + assertEquals(expectedId, split.getFirstId()); + + if (index < numSplits - 1) + assertEquals(nbTreesPerSplit, split.getNbTrees()); + else + assertEquals(nbTrees - totalTrees, split.getNbTrees()); + + totalTrees += split.getNbTrees(); + expectedId += split.getNbTrees(); + } + } + } + + public void testRecordReader() throws Exception { + int n = 1; + int maxNumSplits = 100; + int maxNbTrees = 1000; + + Random rng = new Random(); + + for (int nloop = 0; nloop < n; nloop++) { + int numSplits = rng.nextInt(maxNumSplits) + 1; + int nbTrees = rng.nextInt(maxNbTrees) + 1; + + JobConf conf = new JobConf(); + Builder.setNbTrees(conf, nbTrees); + + InMemInputFormat inputFormat = new InMemInputFormat(); + + InputSplit[] splits = inputFormat.getSplits(conf, numSplits); + + for (int index = 0; index < numSplits; index++) { + InMemInputSplit split = (InMemInputSplit) splits[index]; + InMemRecordReader reader = (InMemRecordReader) inputFormat.getRecordReader( + split, conf, null); + + for (int tree = 0; tree < split.getNbTrees(); tree++) { + IntWritable key = reader.createKey(); + NullWritable value = reader.createValue(); + + // reader.next() should return true until there is no tree left + assertEquals(tree < split.getNbTrees(), reader.next(key, value)); + + assertEquals(split.getFirstId() + tree, key.get()); + } + } + } + } +} Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputSplitTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputSplitTest.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputSplitTest.java (added) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapred/inmem/InMemInputSplitTest.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,77 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapred.inmem; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Random; + +import junit.framework.TestCase; + +import org.apache.mahout.df.mapred.inmem.InMemInputFormat.InMemInputSplit; + +public class InMemInputSplitTest extends TestCase { + + protected Random rng; + + protected ByteArrayOutputStream byteOutStream; + protected DataOutput out; + + @Override + protected void setUp() throws Exception { + rng = new Random(); + + byteOutStream = new ByteArrayOutputStream(); + out = new DataOutputStream(byteOutStream); + } + + /** + * Make sure that all the fields are processed correctly + * @throws IOException + * + */ + public void testWritable() throws IOException { + InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), rng.nextLong()); + + split.write(out); + assertEquals(split, readSplit()); + } + + /** + * test the case seed == null + * @throws IOException + * + */ + public void testNullSeed() throws IOException { + InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), null); + + split.write(out); + assertEquals(split, readSplit()); + } + + protected InMemInputSplit readSplit() throws IOException { + ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray()); + DataInput in = new DataInputStream(byteInStream); + return InMemInputSplit.read(in); + } +} Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormatTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormatTest.java?rev=819830&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormatTest.java (added) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/inmem/InMemInputFormatTest.java Tue Sep 29 04:56:25 2009 @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.df.mapreduce.inmem; + +import java.util.List; +import java.util.Random; + +import junit.framework.TestCase; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.mahout.df.mapreduce.Builder; +import org.apache.mahout.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.df.mapreduce.inmem.InMemInputFormat.InMemRecordReader; + +public class InMemInputFormatTest extends TestCase { + + public void testSplits() throws Exception { + int n = 1; + int maxNumSplits = 100; + int maxNbTrees = 1000; + + Random rng = new Random(); + + for (int nloop = 0; nloop < n; nloop++) { + int numSplits = rng.nextInt(maxNumSplits) + 1; + int nbTrees = rng.nextInt(maxNbTrees) + 1; + + Configuration conf = new Configuration(); + Builder.setNbTrees(conf, nbTrees); + + InMemInputFormat inputFormat = new InMemInputFormat(); + List splits = inputFormat.getSplits(conf, numSplits); + + assertEquals(numSplits, splits.size()); + + int nbTreesPerSplit = nbTrees / numSplits; + int totalTrees = 0; + int expectedId = 0; + + for (int index = 0; index < numSplits; index++) { + assertTrue(splits.get(index) instanceof InMemInputSplit); + + InMemInputSplit split = (InMemInputSplit) splits.get(index); + + assertEquals(expectedId, split.getFirstId()); + + if (index < numSplits - 1) + assertEquals(nbTreesPerSplit, split.getNbTrees()); + else + assertEquals(nbTrees - totalTrees, split.getNbTrees()); + + totalTrees += split.getNbTrees(); + expectedId += split.getNbTrees(); + } + } + } + + public void testRecordReader() throws Exception { + int n = 1; + int maxNumSplits = 100; + int maxNbTrees = 1000; + + Random rng = new Random(); + + for (int nloop = 0; nloop < n; nloop++) { + int numSplits = rng.nextInt(maxNumSplits) + 1; + int nbTrees = rng.nextInt(maxNbTrees) + 1; + + Configuration conf = new Configuration(); + Builder.setNbTrees(conf, nbTrees); + + InMemInputFormat inputFormat = new InMemInputFormat(); + List splits = inputFormat.getSplits(conf, numSplits); + + for (int index = 0; index < numSplits; index++) { + InMemInputSplit split = (InMemInputSplit) splits.get(index); + InMemRecordReader reader = new InMemRecordReader(split); + + reader.initialize(split, null); + + for (int tree = 0; tree < split.getNbTrees(); tree++) { + // reader.next() should return true until there is no tree left + assertEquals(tree < split.getNbTrees(), reader.nextKeyValue()); + assertEquals(split.getFirstId() + tree, reader.getCurrentKey().get()); + } + } + } + } +}