mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From adene...@apache.org
Subject svn commit: r820181 [3/3] - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/df/mapred/partial/ core/src/main/java/org/apache/mahout/df/mapreduce/partial/ core/src/test/java/org/apache/mahout/df/mapred/partial/ core/src/test/java/org/apach...
Date Wed, 30 Sep 2009 05:29:24 GMT
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step2MapperTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step2MapperTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step2MapperTest.java
(added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step2MapperTest.java
Wed Sep 30 05:29:22 2009
@@ -0,0 +1,171 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce.partial;
+
+import static org.apache.mahout.df.data.Utils.double2String;
+import static org.apache.mahout.df.data.Utils.randomDescriptor;
+import static org.apache.mahout.df.data.Utils.randomDoubles;
+
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.apache.mahout.df.mapreduce.partial.InterResults;
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+import org.apache.mahout.df.node.Leaf;
+import org.apache.mahout.df.node.Node;
+
+public class Step2MapperTest extends TestCase {
+
+  /**
+   * Special Step2Mapper that can be configured without using a Configuration
+   * 
+   */
+  private static class MockStep2Mapper extends Step2Mapper {
+    public MockStep2Mapper(int partition, Dataset dataset, TreeID[] keys,
+        Node[] trees, int numInstances) {
+      configure(partition, dataset, keys, trees, numInstances);
+    }
+
+  }
+
+  /** nb attributes per generated data instance */
+  protected final int nbAttributes = 4;
+
+  /** nb generated data instances */
+  protected final int nbInstances = 100;
+
+  /** nb trees to build */
+  protected final int nbTrees = 11;
+
+  /** nb mappers to use */
+  protected final int nbMappers = 5;
+
+  @SuppressWarnings("unchecked")
+  public void testMapper() throws Exception {
+    Random rng = new Random();
+
+    // prepare the data
+    String descriptor = randomDescriptor(rng, nbAttributes);
+    double[][] source = randomDoubles(rng, descriptor, nbInstances);
+    String[] sData = double2String(source);
+    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    String[][] splits = Utils.splitData(sData, nbMappers);
+
+    // prepare first step output
+    TreeID[] keys = new TreeID[nbTrees];
+    Node[] trees = new Node[nbTrees];
+    int[] sizes = new int[nbMappers];
+
+    int treeIndex = 0;
+    for (int partition = 0; partition < nbMappers; partition++) {
+      int nbMapTrees = Step1Mapper.nbTrees(nbMappers, nbTrees, partition);
+
+      for (int tree = 0; tree < nbMapTrees; tree++, treeIndex++) {
+        keys[treeIndex] = new TreeID(partition, treeIndex);
+        // put the partition in the leaf's label
+        // this way we can track the outputs
+        trees[treeIndex] = new Leaf(partition);
+      }
+
+      sizes[partition] = splits[partition].length;
+    }
+
+    // store the first step outputs in a file
+    FileSystem fs = FileSystem.getLocal(new Configuration());
+    Path forestPath = new Path("testdata/Step2MapperTest.forest");
+    InterResults.store(fs, forestPath, keys, trees, sizes);
+
+    LongWritable key = new LongWritable();
+    Text value = new Text();
+
+    for (int partition = 0; partition < nbMappers; partition++) {
+      String[] split = splits[partition];
+
+      // number of trees that will be handled by the mapper
+      int nbConcerned = Step2Mapper.nbConcerned(nbMappers, nbTrees, partition);
+
+      MockContext context = new MockContext(new Step2Mapper(),
+          new Configuration(), new TaskAttemptID(), nbConcerned);
+
+      // load the current mapper's (key, tree) pairs
+      TreeID[] curKeys = new TreeID[nbConcerned];
+      Node[] curTrees = new Node[nbConcerned];
+      InterResults.load(fs, forestPath, nbMappers, nbTrees, partition, curKeys,
+          curTrees);
+
+      // simulate the job
+      MockStep2Mapper mapper = new MockStep2Mapper(partition, dataset, curKeys,
+          curTrees, split.length);
+
+      for (int index = 0; index < split.length; index++) {
+        key.set(index);
+        value.set(split[index]);
+        mapper.map(key, value, context);
+      }
+
+      mapper.cleanup(context);
+
+      // make sure the mapper did not return its own trees
+      assertEquals(nbConcerned, context.nbOutputs());
+
+      // check the returned results
+      int current = 0;
+      for (int index = 0; index < nbTrees; index++) {
+        if (keys[index].partition() == partition) {
+          // should not be part of the results
+          continue;
+        }
+
+        TreeID k = context.keys[current];
+
+        // the tree should receive the partition's index
+        assertEquals(partition, k.partition());
+
+        // make sure all the trees of the other partitions are handled in the
+        // correct order
+        assertEquals(index, k.treeId());
+
+        int[] predictions = context.values[current].getPredictions();
+
+        // all the instances of the partition should be classified
+        assertEquals(split.length, predictions.length);
+        assertEquals(
+            "at least one instance of the partition was not classified", -1,
+            ArrayUtils.indexOf(predictions, -1));
+
+        // the tree must not belong to the mapper's partition
+        int treePartition = predictions[0];
+        assertFalse("Step2Mapper returned a tree from its own partition",
+            partition == treePartition);
+
+        current++;
+      }
+    }
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/TreeIDTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/TreeIDTest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/TreeIDTest.java
(added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/TreeIDTest.java
Wed Sep 30 05:29:22 2009
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce.partial;
+
+import java.util.Random;
+
+import org.apache.mahout.df.mapreduce.partial.TreeID;
+
+import junit.framework.TestCase;
+
+public class TreeIDTest extends TestCase {
+
+  public void testTreeID() {
+    int n = 1000000;
+    Random rng = new Random();
+    
+    for (int nloop = 0; nloop < n; nloop++) {
+      int partition = Math.abs(rng.nextInt());
+      int treeId = rng.nextInt(TreeID.MAX_TREEID);
+      
+      TreeID t1 = new TreeID(partition, treeId);
+      
+      assertEquals(partition, t1.partition());
+      assertEquals(treeId, t1.treeId());
+      
+      TreeID t2 = new TreeID();
+      t2.set(partition, treeId);
+
+      assertEquals(partition, t2.partition());
+      assertEquals(treeId, t2.treeId());
+    }
+  }
+}

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java
(added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapred/BuildForest.java
Wed Sep 30 05:29:22 2009
@@ -0,0 +1,225 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapred;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.ErrorEstimate;
+import org.apache.mahout.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.df.callback.ForestPredictions;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.mapred.inmem.InMemBuilder;
+import org.apache.mahout.df.mapred.partial.PartialBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/**
+ * Tool to builds a Random Forest using any given dataset (in UCI format). Can
+ * use either the in-mem mapred or partial mapred implementations
+ */
+public class BuildForest extends Configured implements Tool {
+
+  private static final Logger log = LoggerFactory.getLogger(BuildForest.class);
+
+  protected Path dataPath; // Data path
+
+  protected Path datasetPath; // Dataset path
+
+  int m; // Number of variables to select at each tree-node
+
+  int nbTrees; // Number of trees to grow
+
+  Long seed = null; // Random seed
+
+  boolean isPartial; // use partial data implementation
+
+  boolean isOob; // estimate oob error;
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option oobOpt = obuilder.withShortName("oob").withRequired(false)
+        .withDescription("Optional, estimate the out-of-bag error").create();
+
+    Option dataOpt = obuilder.withLongName("data").withShortName("d")
+        .withRequired(true).withArgument(
+            abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+        .withDescription("Data path").create();
+
+    Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+        .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+        .withDescription("Dataset path").create();
+
+    Option selectionOpt = obuilder.withLongName("selection")
+        .withShortName("sl").withRequired(true).withArgument(
+            abuilder.withName("m").withMinimum(1).withMaximum(1).create())
+        .withDescription("Number of variables to select randomly at each tree-node")
+        .create();
+
+    Option seedOpt = obuilder.withLongName("seed").withShortName("sd").withRequired(false)
+        .withArgument(abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
+        .withDescription("Optional, seed value used to initialise the Random number generator")
+        .create();
+
+    Option partialOpt = obuilder.withLongName("partial").withShortName("p")
+        .withRequired(false).withDescription("Optional, use the Partial Data implementation").create();
+
+    Option nbtreesOpt = obuilder.withLongName("nbtrees").withShortName("t").withRequired(true)
+        .withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
+        .withDescription("Number of trees to grow").create();
+
+    Option helpOpt = obuilder.withLongName("help").withDescription(
+        "Print out help").withShortName("h").create();
+
+    Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(
+        dataOpt).withOption(datasetOpt).withOption(selectionOpt).withOption(
+        seedOpt).withOption(partialOpt).withOption(nbtreesOpt).withOption(
+        helpOpt).create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption("help")) {
+        CommandLineUtil.printHelp(group);
+        return -1;
+      }
+
+      isPartial = cmdLine.hasOption(partialOpt);
+      isOob = cmdLine.hasOption(oobOpt);
+      String dataName = cmdLine.getValue(dataOpt).toString();
+      String datasetName = cmdLine.getValue(datasetOpt).toString();
+      m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
+      nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
+
+      if (cmdLine.hasOption(seedOpt)) {
+        seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
+      }
+
+      log.debug("data : " + dataName);
+      log.debug("dataset : " + datasetName);
+      log.debug("m : " + m);
+      log.debug("seed : " + seed);
+      log.debug("nbtrees : " + nbTrees);
+      log.debug("isPartial : " + isPartial);
+      log.debug("isOob : " + isOob);
+
+      dataPath = new Path(dataName);
+      datasetPath = new Path(datasetName);
+
+    } catch (OptionException e) {
+      System.err.println("Exception : " + e);
+      CommandLineUtil.printHelp(group);
+      return -1;
+    }
+
+    buildForest();
+
+    return 0;
+  }
+
+  private DecisionForest buildForest() throws Exception {
+    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
+    treeBuilder.setM(m);
+
+    Dataset dataset = Dataset.load(getConf(), datasetPath);
+
+    ForestPredictions callback = (isOob) ? new ForestPredictions(dataset
+        .nbInstances(), dataset.nblabels()) : null;
+
+    Builder forestBuilder;
+
+    if (isPartial) {
+      log.info("Partial Mapred implementation");
+      forestBuilder = new PartialBuilder(treeBuilder, dataPath, datasetPath,
+          seed, getConf());
+    } else {
+      log.info("InMem Mapred implementation");
+      forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath,
+          seed, getConf());
+    }
+
+    log.info("Building the forest...");
+    long time = System.currentTimeMillis();
+
+    DecisionForest forest = forestBuilder.build(nbTrees, callback);
+
+    time = System.currentTimeMillis() - time;
+    log.info("Build Time: " + DFUtils.elapsedTime(time));
+
+    if (isOob) {
+      Random rng;
+      if (seed != null)
+        rng = new Random(seed);
+      else
+        rng = new Random();
+
+      FileSystem fs = dataPath.getFileSystem(getConf());
+      int[] labels = Data.extractLabels(dataset, fs, dataPath);
+      
+      log.info("oob error estimate : "
+          + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
+    }
+
+    return forest;
+  }
+
+  protected Data loadData(Configuration conf, Path dataPath, Dataset dataset)
+      throws Exception {
+    log.info("Loading the data...");
+    FileSystem fs = dataPath.getFileSystem(conf);
+    Data data = DataLoader.loadData(dataset, fs, dataPath);
+    log.info("Data Loaded");
+
+    return data;
+  }
+
+  /**
+   * @param args
+   * @throws Exception
+   */
+  public static void main(String[] args) throws Exception {
+    int res = ToolRunner.run(new Configuration(), new BuildForest(), args);
+    System.exit(res);
+  }
+
+}

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=820181&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
(added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
Wed Sep 30 05:29:22 2009
@@ -0,0 +1,239 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce;
+
+import java.util.Random;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.ErrorEstimate;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.mapreduce.inmem.InMemBuilder;
+import org.apache.mahout.df.mapreduce.partial.PartialBuilder;
+import org.apache.mahout.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.df.callback.ForestPredictions;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Tool to builds a Random Forest using any given dataset (in UCI format). Can
+ * use either the in-mem mapred or partial mapred implementations
+ */
+public class BuildForest extends Configured implements Tool {
+
+  private static final Logger log = LoggerFactory.getLogger(BuildForest.class);
+
+  protected Path dataPath; // Data path
+
+  protected Path datasetPath; // Dataset path
+
+  int m; // Number of variables to select at each tree-node
+
+  int nbTrees; // Number of trees to grow
+
+  Long seed = null; // Random seed
+
+  boolean isPartial; // use partial data implementation
+
+  boolean isOob; // estimate oob error;
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option oobOpt = obuilder.withShortName("oob").withRequired(false)
+        .withDescription("Optional, estimate the out-of-bag error").create();
+
+    Option dataOpt = obuilder.withLongName("data").withShortName("d")
+        .withRequired(true).withArgument(
+            abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+        .withDescription("Data path").create();
+
+    Option datasetOpt = obuilder
+        .withLongName("dataset")
+        .withShortName("ds")
+        .withRequired(true)
+        .withArgument(
+            abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+        .withDescription("Dataset path").create();
+
+    Option selectionOpt = obuilder.withLongName("selection")
+        .withShortName("sl").withRequired(true).withArgument(
+            abuilder.withName("m").withMinimum(1).withMaximum(1).create())
+        .withDescription(
+            "Number of variables to select randomly at each tree-node")
+        .create();
+
+    Option seedOpt = obuilder
+        .withLongName("seed")
+        .withShortName("sd")
+        .withRequired(false)
+        .withArgument(
+            abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
+        .withDescription(
+            "Optional, seed value used to initialise the Random number generator")
+        .create();
+
+    Option partialOpt = obuilder.withLongName("partial").withShortName("p")
+        .withRequired(false).withDescription(
+            "Optional, use the Partial Data implementation").create();
+
+    Option nbtreesOpt = obuilder
+        .withLongName("nbtrees")
+        .withShortName("t")
+        .withRequired(true)
+        .withArgument(
+            abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
+        .withDescription("Number of trees to grow").create();
+
+    Option helpOpt = obuilder.withLongName("help").withDescription(
+        "Print out help").withShortName("h").create();
+
+   Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(
+        dataOpt).withOption(datasetOpt).withOption(selectionOpt).withOption(
+        seedOpt).withOption(partialOpt).withOption(nbtreesOpt).withOption(
+        helpOpt).create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption("help")) {
+        CommandLineUtil.printHelp(group);
+        return -1;
+      }
+
+      isPartial = cmdLine.hasOption(partialOpt);
+      isOob = cmdLine.hasOption(oobOpt);
+      String dataName = cmdLine.getValue(dataOpt).toString();
+      String datasetName = cmdLine.getValue(datasetOpt).toString();
+      m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
+      nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
+
+      if (cmdLine.hasOption(seedOpt)) {
+        seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
+      }
+
+      log.debug("data : " + dataName);
+      log.debug("dataset : " + datasetName);
+      log.debug("m : " + m);
+      log.debug("seed : " + seed);
+      log.debug("nbtrees : " + nbTrees);
+      log.debug("isPartial : " + isPartial);
+      log.debug("isOob : " + isOob);
+
+      dataPath = new Path(dataName);
+      datasetPath = new Path(datasetName);
+
+    } catch (OptionException e) {
+      System.err.println("Exception : " + e);
+      CommandLineUtil.printHelp(group);
+      return -1;
+    }
+
+    buildForest();
+
+    return 0;
+  }
+
+  private DecisionForest buildForest() throws Exception {
+    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
+    treeBuilder.setM(m);
+
+   Dataset dataset = Dataset.load(getConf(), datasetPath);
+
+    ForestPredictions callback = (isOob) ? new ForestPredictions(dataset
+        .nbInstances(), dataset.nblabels()) : null;
+
+    Builder forestBuilder;
+
+    if (isPartial) {
+      log.info("Partial Mapred implementation");
+      forestBuilder = new PartialBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
+    } else {
+      log.info("InMem Mapred implementation");
+      forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath,
+          seed, getConf());
+    }
+    log.info("Building the forest...");
+    long time = System.currentTimeMillis();
+
+    DecisionForest forest = forestBuilder.build(nbTrees, callback);
+
+    time = System.currentTimeMillis() - time;
+    log.info("Build Time: " + DFUtils.elapsedTime(time));
+
+    if (isOob) {
+      Random rng;
+      if (seed != null)
+        rng = new Random(seed);
+      else
+        rng = new Random();
+
+      FileSystem fs = dataPath.getFileSystem(getConf());
+      int[] labels = Data.extractLabels(dataset, fs, dataPath);
+
+      log.info("oob error estimate : "
+          + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
+    }
+
+    return forest;
+  }
+
+  protected Data loadData(Configuration conf, Path dataPath, Dataset dataset)
+      throws Exception {
+    log.info("Loading the data...");
+    FileSystem fs = dataPath.getFileSystem(conf);
+    Data data = DataLoader.loadData(dataset, fs, dataPath);
+    log.info("Data Loaded");
+
+    return data;
+  }
+
+  /**
+   * @param args
+   * @throws Exception
+   */
+  public static void main(String[] args) throws Exception {
+    int res = ToolRunner.run(new Configuration(), new BuildForest(), args);
+    System.exit(res);
+  }
+
+}
+



Mime
View raw message