hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yxji...@apache.org
Subject svn commit: r1562074 - in /hama/trunk: ./ core/ examples/src/main/java/org/apache/hama/examples/ examples/src/main/java/org/apache/hama/examples/util/ examples/src/test/java/org/apache/hama/examples/
Date Tue, 28 Jan 2014 14:05:15 GMT
Author: yxjiang
Date: Tue Jan 28 14:05:15 2014
New Revision: 1562074

URL: http://svn.apache.org/r1562074
Log:
HAMA-859: Leverage commons cli2 to parse the input argument for NeuralNetwork Example

Added:
    hama/trunk/examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java
Modified:
    hama/trunk/CHANGES.txt
    hama/trunk/core/pom.xml
    hama/trunk/examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java
    hama/trunk/examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java
    hama/trunk/pom.xml

Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1562074&r1=1562073&r2=1562074&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Tue Jan 28 14:05:15 2014
@@ -25,6 +25,7 @@ Release 0.7.0 (unreleased changes)
 
   IMPROVEMENTS
 
+   HAMA-859: Leverage commons cli2 to parse the input argument for NeuralNetwork Example
(Yexi Jiang)
    HAMA-853: Refactor Outgoing message manager (edwardyoon)
    HAMA-852: Add MessageClass property in BSPJob (Martin Illecker)
    HAMA-843: Message communication overhead between master aggregation and vertex computation
supersteps (edwardyoon)

Modified: hama/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/hama/trunk/core/pom.xml?rev=1562074&r1=1562073&r2=1562074&view=diff
==============================================================================
--- hama/trunk/core/pom.xml (original)
+++ hama/trunk/core/pom.xml Tue Jan 28 14:05:15 2014
@@ -51,10 +51,6 @@
       <artifactId>commons-logging</artifactId>
     </dependency>
     <dependency>
-      <groupId>commons-cli</groupId>
-      <artifactId>commons-cli</artifactId>
-    </dependency>
-    <dependency>
       <groupId>commons-configuration</groupId>
       <artifactId>commons-configuration</artifactId>
     </dependency>
@@ -135,6 +131,10 @@
       <groupId>org.apache.zookeeper</groupId>
       <artifactId>zookeeper</artifactId>
     </dependency>
+    <dependency>
+    	<groupId>org.apache.mahout.commons</groupId>
+    	<artifactId>commons-cli</artifactId>
+    </dependency>
   </dependencies>
 
   <build>

Modified: hama/trunk/examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java?rev=1562074&r1=1562073&r2=1562074&view=diff
==============================================================================
--- hama/trunk/examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java (original)
+++ hama/trunk/examples/src/main/java/org/apache/hama/examples/NeuralNetwork.java Tue Jan
28 14:05:15 2014
@@ -23,194 +23,288 @@ import java.io.InputStreamReader;
 import java.io.OutputStreamWriter;
 import java.net.URI;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hama.HamaConfiguration;
 import org.apache.hama.commons.math.DenseDoubleVector;
 import org.apache.hama.commons.math.DoubleVector;
 import org.apache.hama.commons.math.FunctionFactory;
+import org.apache.hama.examples.util.ParserUtil;
 import org.apache.hama.ml.ann.SmallLayeredNeuralNetwork;
 
+import com.google.common.io.Closeables;
+
 /**
  * The example of using {@link SmallLayeredNeuralNetwork}, including the
  * training phase and labeling phase.
  */
 public class NeuralNetwork {
+  // either train or label
+  private static String mode;
 
-  public static void main(String[] args) throws Exception {
-    if (args.length < 3) {
-      printUsage();
-      return;
+  // arguments for labeling
+  private static String featureDataPath;
+  private static String resultDataPath;
+  private static String modelPath;
+
+  // arguments for training
+  private static String trainingDataPath;
+  private static int featureDimension;
+  private static int labelDimension;
+  private static List<Integer> hiddenLayerDimension;
+  private static int iterations;
+  private static double learningRate;
+  private static double momemtumWeight;
+  private static double regularizationWeight;
+
+  public static boolean parseArgs(String[] args) {
+    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+    GroupBuilder groupBuilder = new GroupBuilder();
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+    // the feature data (unlabeled data) path argument
+    Option featureDataPathOption = optionBuilder
+        .withLongName("feature-data-path")
+        .withShortName("fp")
+        .withDescription("the path of the feature data (unlabeled data).")
+        .withArgument(
+            argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+                .create()).withRequired(true).create();
+
+    // the result data path argument
+    Option resultDataPathOption = optionBuilder
+        .withLongName("result-data-path")
+        .withShortName("rp")
+        .withDescription("the path to store the result.")
+        .withArgument(
+            argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+                .create()).withRequired(true).create();
+
+    // the path to store the model
+    Option modelPathOption = optionBuilder
+        .withLongName("model-data-path")
+        .withShortName("mp")
+        .withDescription("the path to store the trained model.")
+        .withArgument(
+            argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+                .create()).withRequired(true).create();
+
+    // the path of the training data
+    Option trainingDataPathOption = optionBuilder
+        .withLongName("training-data-path")
+        .withShortName("tp")
+        .withDescription("the path to store the trained model.")
+        .withArgument(
+            argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+                .create()).withRequired(true).create();
+
+    // the dimension of the features
+    Option featureDimensionOption = optionBuilder
+        .withLongName("feature dimension")
+        .withShortName("fd")
+        .withDescription("the dimension of the features.")
+        .withArgument(
+            argumentBuilder.withName("dimension").withMinimum(1).withMaximum(1)
+                .create()).withRequired(true).create();
+
+    // the dimension of the hidden layers, at most two hidden layers
+    Option hiddenLayerOption = optionBuilder
+        .withLongName("hidden layer dimension(s)")
+        .withShortName("hd")
+        .withDescription("the dimension of the hidden layer(s).")
+        .withArgument(
+            argumentBuilder.withName("dimension").withMinimum(0).withMaximum(2)
+                .create()).withRequired(true).create();
+
+    // the dimension of the labels
+    Option labelDimensionOption = optionBuilder
+        .withLongName("label dimension")
+        .withShortName("ld")
+        .withDescription("the dimension of the label(s).")
+        .withArgument(
+            argumentBuilder.withName("dimension").withMinimum(1).withMaximum(1)
+                .create()).withRequired(true).create();
+
+    // the number of iterations for training
+    Option iterationOption = optionBuilder
+        .withLongName("iterations")
+        .withShortName("itr")
+        .withDescription("the iterations for training.")
+        .withArgument(
+            argumentBuilder.withName("iterations").withMinimum(1)
+                .withMaximum(1).withDefault(1000).create()).create();
+
+    // the learning rate
+    Option learningRateOption = optionBuilder
+        .withLongName("learning-rate")
+        .withShortName("l")
+        .withDescription("the learning rate for training, default 0.1.")
+        .withArgument(
+            argumentBuilder.withName("learning-rate").withMinimum(1)
+                .withMaximum(1).withDefault(0.1).create()).create();
+
+    // the momemtum weight
+    Option momentumWeightOption = optionBuilder
+        .withLongName("momemtum-weight")
+        .withShortName("m")
+        .withDescription("the momemtum weight for training, default 0.1.")
+        .withArgument(
+            argumentBuilder.withName("momemtum weight").withMinimum(1)
+                .withMaximum(1).withDefault(0.1).create()).create();
+
+    // the regularization weight
+    Option regularizationWeightOption = optionBuilder
+        .withLongName("regularization-weight")
+        .withShortName("r")
+        .withDescription("the regularization weight for training, default 0.")
+        .withArgument(
+            argumentBuilder.withName("regularization weight").withMinimum(1)
+                .withMaximum(1).withDefault(0).create()).create();
+
+    // the parameters related to train mode
+    Group trainModeGroup = groupBuilder.withOption(trainingDataPathOption)
+        .withOption(modelPathOption).withOption(featureDimensionOption)
+        .withOption(labelDimensionOption).withOption(hiddenLayerOption)
+        .withOption(iterationOption).withOption(learningRateOption)
+        .withOption(momentumWeightOption)
+        .withOption(regularizationWeightOption).create();
+
+    // the parameters related to label mode
+    Group labelModeGroup = groupBuilder.withOption(modelPathOption)
+        .withOption(featureDataPathOption).withOption(resultDataPathOption)
+        .create();
+
+    Option trainModeOption = optionBuilder.withLongName("train")
+        .withShortName("train").withDescription("the train mode")
+        .withChildren(trainModeGroup).create();
+
+    Option labelModeOption = optionBuilder.withLongName("label")
+        .withShortName("label").withChildren(labelModeGroup)
+        .withDescription("the label mode").create();
+
+    Group normalGroup = groupBuilder.withOption(trainModeOption)
+        .withOption(labelModeOption).create();
+
+    Parser parser = new Parser();
+    parser.setGroup(normalGroup);
+    parser.setHelpFormatter(new HelpFormatter());
+    parser.setHelpTrigger("--help");
+    CommandLine cli = parser.parseAndHelp(args);
+    if (cli == null) {
+      return false;
     }
-    String mode = args[0];
-    if (mode.equalsIgnoreCase("label")) {
-      if (args.length < 4) {
-        printUsage();
-        return;
-      }
-      HamaConfiguration conf = new HamaConfiguration();
-
-      String featureDataPath = args[1];
-      String resultDataPath = args[2];
-      String modelPath = args[3];
-
-      SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(modelPath);
-
-      // process data in streaming approach
-      FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
-      BufferedReader br = new BufferedReader(new InputStreamReader(
-          fs.open(new Path(featureDataPath))));
-      Path outputPath = new Path(resultDataPath);
-      if (fs.exists(outputPath)) {
-        fs.delete(outputPath, true);
-      }
-      BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
-          fs.create(outputPath)));
 
-      String line = null;
+    // get the arguments
+    boolean hasTrainMode = cli.hasOption(trainModeOption);
+    boolean hasLabelMode = cli.hasOption(labelModeOption);
+    if (hasTrainMode && hasLabelMode) {
+      return false;
+    }
 
-      while ((line = br.readLine()) != null) {
-        if (line.trim().length() == 0) {
-          continue;
-        }
-        String[] tokens = line.trim().split(",");
-        double[] vals = new double[tokens.length];
-        for (int i = 0; i < tokens.length; ++i) {
-          vals[i] = Double.parseDouble(tokens[i]);
-        }
-        DoubleVector instance = new DenseDoubleVector(vals);
-        DoubleVector result = ann.getOutput(instance);
-        double[] arrResult = result.toArray();
-        StringBuilder sb = new StringBuilder();
-        for (int i = 0; i < arrResult.length; ++i) {
-          sb.append(arrResult[i]);
-          if (i != arrResult.length - 1) {
-            sb.append(",");
-          } else {
-            sb.append("\n");
-          }
-        }
-        bw.write(sb.toString());
-      }
+    mode = hasTrainMode ? "train" : "label";
+    if (mode.equals("train")) {
+      trainingDataPath = ParserUtil.getString(cli, trainingDataPathOption);
+      modelPath = ParserUtil.getString(cli, modelPathOption);
+      featureDimension = ParserUtil.getInteger(cli, featureDimensionOption);
+      labelDimension = ParserUtil.getInteger(cli, labelDimensionOption);
+      hiddenLayerDimension = ParserUtil.getInts(cli, hiddenLayerOption);
+      iterations = ParserUtil.getInteger(cli, iterationOption);
+      learningRate = ParserUtil.getDouble(cli, learningRateOption);
+      momemtumWeight = ParserUtil.getDouble(cli, momentumWeightOption);
+      regularizationWeight = ParserUtil.getDouble(cli,
+          regularizationWeightOption);
+    } else {
+      featureDataPath = ParserUtil.getString(cli, featureDataPathOption);
+      modelPath = ParserUtil.getString(cli, modelPathOption);
+      resultDataPath = ParserUtil.getString(cli, resultDataPathOption);
+    }
 
-      br.close();
-      bw.close();
-    } else if (mode.equals("train")) {
-      if (args.length < 5) {
-        printUsage();
-        return;
-      }
+    return true;
+  }
 
-      String trainingDataPath = args[1];
-      String trainedModelPath = args[2];
+  public static void main(String[] args) throws Exception {
+    if (parseArgs(args)) {
+      if (mode.equals("label")) {
+        HamaConfiguration conf = new HamaConfiguration();
+        SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(modelPath);
+
+        // process data in streaming approach
+        FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
+        BufferedReader br = new BufferedReader(new InputStreamReader(
+            fs.open(new Path(featureDataPath))));
+        Path outputPath = new Path(resultDataPath);
+        if (fs.exists(outputPath)) {
+          fs.delete(outputPath, true);
+        }
+        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
+            fs.create(outputPath)));
 
-      int featureDimension = Integer.parseInt(args[3]);
-      int labelDimension = Integer.parseInt(args[4]);
+        String line = null;
 
-      int iteration = 1000;
-      double learningRate = 0.4;
-      double momemtumWeight = 0.2;
-      double regularizationWeight = 0.01;
-
-      // parse parameters
-      if (args.length >= 6) {
-        try {
-          iteration = Integer.parseInt(args[5]);
-          System.out.printf("Iteration: %d\n", iteration);
-        } catch (NumberFormatException e) {
-          System.err
-              .println("MAX_ITERATION format invalid. It should be a positive number.");
-          return;
-        }
-      }
-      if (args.length >= 7) {
-        try {
-          learningRate = Double.parseDouble(args[6]);
-          System.out.printf("Learning rate: %f\n", learningRate);
-        } catch (NumberFormatException e) {
-          System.err
-              .println("LEARNING_RATE format invalid. It should be a positive double in range
(0, 1.0)");
-          return;
-        }
-      }
-      if (args.length >= 8) {
-        try {
-          momemtumWeight = Double.parseDouble(args[7]);
-          System.out.printf("Momemtum weight: %f\n", momemtumWeight);
-        } catch (NumberFormatException e) {
-          System.err
-              .println("MOMEMTUM_WEIGHT format invalid. It should be a positive double in
range (0, 1.0)");
-          return;
+        while ((line = br.readLine()) != null) {
+          if (line.trim().length() == 0) {
+            continue;
+          }
+          String[] tokens = line.trim().split(",");
+          double[] vals = new double[tokens.length];
+          for (int i = 0; i < tokens.length; ++i) {
+            vals[i] = Double.parseDouble(tokens[i]);
+          }
+          DoubleVector instance = new DenseDoubleVector(vals);
+          DoubleVector result = ann.getOutput(instance);
+          double[] arrResult = result.toArray();
+          StringBuilder sb = new StringBuilder();
+          for (int i = 0; i < arrResult.length; ++i) {
+            sb.append(arrResult[i]);
+            if (i != arrResult.length - 1) {
+              sb.append(",");
+            } else {
+              sb.append("\n");
+            }
+          }
+          bw.write(sb.toString());
         }
-      }
-      if (args.length >= 9) {
-        try {
-          regularizationWeight = Double.parseDouble(args[8]);
-          System.out
-              .printf("Regularization weight: %f\n", regularizationWeight);
-        } catch (NumberFormatException e) {
-          System.err
-              .println("REGULARIZATION_WEIGHT format invalid. It should be a positive double
in range (0, 1.0)");
-          return;
+
+        Closeables.close(br, true);
+        Closeables.close(bw, true);
+      } else { // train the model
+        SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
+        ann.setLearningRate(learningRate);
+        ann.setMomemtumWeight(momemtumWeight);
+        ann.setRegularizationWeight(regularizationWeight);
+        ann.addLayer(featureDimension, false,
+            FunctionFactory.createDoubleFunction("Sigmoid"));
+        if (hiddenLayerDimension != null) {
+          for (int dimension : hiddenLayerDimension) {
+            ann.addLayer(dimension, false,
+                FunctionFactory.createDoubleFunction("Sigmoid"));
+          }
         }
+        ann.addLayer(labelDimension, true,
+            FunctionFactory.createDoubleFunction("Sigmoid"));
+        ann.setCostFunction(FunctionFactory
+            .createDoubleDoubleFunction("CrossEntropy"));
+        ann.setModelPath(modelPath);
+
+        Map<String, String> trainingParameters = new HashMap<String, String>();
+        trainingParameters.put("tasks", "5");
+        trainingParameters.put("training.max.iterations", "" + iterations);
+        trainingParameters.put("training.batch.size", "300");
+        trainingParameters.put("convergence.check.interval", "1000");
+        ann.train(new Path(trainingDataPath), trainingParameters);
       }
-
-      // train the model
-      SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
-      ann.setLearningRate(learningRate);
-      ann.setMomemtumWeight(momemtumWeight);
-      ann.setRegularizationWeight(regularizationWeight);
-      ann.addLayer(featureDimension, false,
-          FunctionFactory.createDoubleFunction("Sigmoid"));
-      ann.addLayer(featureDimension, false,
-          FunctionFactory.createDoubleFunction("Sigmoid"));
-      ann.addLayer(labelDimension, true,
-          FunctionFactory.createDoubleFunction("Sigmoid"));
-      ann.setCostFunction(FunctionFactory
-          .createDoubleDoubleFunction("CrossEntropy"));
-      ann.setModelPath(trainedModelPath);
-
-      Map<String, String> trainingParameters = new HashMap<String, String>();
-      trainingParameters.put("tasks", "5");
-      trainingParameters.put("training.max.iterations", "" + iteration);
-      trainingParameters.put("training.batch.size", "300");
-      trainingParameters.put("convergence.check.interval", "1000");
-      ann.train(new Path(trainingDataPath), trainingParameters);
     }
-
-  }
-
-  private static void printUsage() {
-    System.out
-        .println("USAGE: <MODE> <INPUT_PATH> <OUTPUT_PATH> <MODEL_PATH>|<FEATURE_DIMENSION>
<LABEL_DIMENSION> [<MAX_ITERATION> <LEARNING_RATE> <MOMEMTUM_WEIGHT>
<REGULARIZATION_WEIGHT>]");
-    System.out
-        .println("\tMODE\t- train: train the model with given training data.");
-    System.out
-        .println("\t\t- label: obtain the result by feeding the features to the neural network.");
-    System.out
-        .println("\tINPUT_PATH\tin 'train' mode, it is the path of the training data; in
'label' mode, it is the path of the to be evaluated data that lacks the label.");
-    System.out
-        .println("\tOUTPUT_PATH\tin 'train' mode, it is where the trained model is stored;
in 'label' mode, it is where the labeled data is stored.");
-    System.out.println("\n\tConditional Parameters:");
-    System.out
-        .println("\tMODEL_PATH\tonly required in 'label' mode. It specifies where to load
the trained neural network model.");
-    System.out
-        .println("\tMAX_ITERATION\tonly used in 'train' mode. It specifies how many iterations
for the neural network to run. Default is 0.01.");
-    System.out
-        .println("\tLEARNING_RATE\tonly used to 'train' mode. It specifies the degree of
aggregation for learning, usually in range (0, 1.0). Default is 0.1.");
-    System.out
-        .println("\tMOMEMTUM_WEIGHT\tonly used to 'train' mode. It specifies the weight of
momemtum. Default is 0.");
-    System.out
-        .println("\tREGULARIZATION_WEIGHT\tonly required in 'train' model. It specifies the
weight of reqularization.");
-    System.out.println("\nExample:");
-    System.out
-        .println("Train a neural network with with feature dimension 8, label dimension 1
and default setting:\n\tneuralnets train hdfs://localhost:30002/training_data hdfs://localhost:30002/model
8 1");
-    System.out
-        .println("Train a neural network with with feature dimension 8, label dimension 1
and specify learning rate as 0.1, momemtum rate as 0.2, and regularization weight as 0.01:\n\tneuralnets.train
hdfs://localhost:30002/training_data hdfs://localhost:30002/model 8 1 0.1 0.2 0.01");
-    System.out
-        .println("Label the data with trained model:\n\tneuralnets evaluate hdfs://localhost:30002/unlabeled_data
hdfs://localhost:30002/result hdfs://localhost:30002/model");
   }
 
 }

Added: hama/trunk/examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java?rev=1562074&view=auto
==============================================================================
--- hama/trunk/examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java (added)
+++ hama/trunk/examples/src/main/java/org/apache/hama/examples/util/ParserUtil.java Tue Jan
28 14:05:15 2014
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.examples.util;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Option;
+
+/**
+ * Facilitate the command line argument parsing.
+ * 
+ */
+public class ParserUtil {
+
+  /**
+   * Parse and return the string parameter.
+   * 
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static String getString(CommandLine cli, Option option) {
+    Object val = cli.getValue(option);
+    if (val != null) {
+      return val.toString();
+    }
+    return null;
+  }
+
+  /**
+   * Parse and return the integer parameter.
+   * 
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static Integer getInteger(CommandLine cli, Option option) {
+    Object val = cli.getValue(option);
+    if (val != null) {
+      return Integer.parseInt(val.toString());
+    }
+    return null;
+  }
+
+  /**
+   * Parse and return the long parameter.
+   * 
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static Long getLong(CommandLine cli, Option option) {
+    Object val = cli.getValue(option);
+    if (val != null) {
+      return Long.parseLong(val.toString());
+    }
+    return null;
+  }
+
+  /**
+   * Parse and return the double parameter.
+   * 
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static Double getDouble(CommandLine cli, Option option) {
+    Object val = cli.getValue(option);
+    if (val != null) {
+      return Double.parseDouble(val.toString());
+    }
+    return null;
+  }
+
+  /**
+   * Parse and return the boolean parameter. If the parameter is set, it is
+   * true, otherwise it is false.
+   * 
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static boolean getBoolean(CommandLine cli, Option option) {
+    return cli.hasOption(option);
+  }
+  
+  /**
+   * Parse and return the array parameters.
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static List<String> getStrings(CommandLine cli, Option option) {
+    List<String> list = new ArrayList<String>();
+    for (Object obj : cli.getValues(option)) {
+      list.add(obj.toString());
+    }
+    return list;
+  }
+
+  /**
+   * Parse and return the array parameters.
+   * @param cli
+   * @param option
+   * @return
+   */
+  public static List<Integer> getInts(CommandLine cli, Option option) {
+    List<Integer> list = new ArrayList<Integer>();
+    for (String str : getStrings(cli, option)) {
+      list.add(Integer.parseInt(str));
+    }
+    return list;
+  }
+}
+

Modified: hama/trunk/examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java?rev=1562074&r1=1562073&r2=1562074&view=diff
==============================================================================
--- hama/trunk/examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java (original)
+++ hama/trunk/examples/src/test/java/org/apache/hama/examples/NeuralNetworkTest.java Tue
Jan 28 14:05:15 2014
@@ -23,8 +23,6 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
-import junit.framework.TestCase;
-
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
@@ -33,32 +31,34 @@ import org.apache.hadoop.io.SequenceFile
 import org.apache.hama.HamaConfiguration;
 import org.apache.hama.commons.io.VectorWritable;
 import org.apache.hama.commons.math.DenseDoubleVector;
+import org.junit.Before;
+import org.junit.Test;
 
 /**
  * Test the functionality of NeuralNetwork Example.
  * 
  */
-public class NeuralNetworkTest extends TestCase {
+public class NeuralNetworkTest {
   private Configuration conf = new HamaConfiguration();
   private FileSystem fs;
   private String MODEL_PATH = "/tmp/neuralnets.model";
   private String RESULT_PATH = "/tmp/neuralnets.txt";
   private String SEQTRAIN_DATA = "/tmp/test-neuralnets.data";
-  
-  @Override
-  protected void setUp() throws Exception {
-    super.setUp();
+
+  @Before
+  public void setup() throws Exception {
     fs = FileSystem.get(conf);
   }
 
+  @Test
   public void testNeuralnetsLabeling() throws IOException {
     this.neuralNetworkTraining();
 
     String dataPath = "src/test/resources/neuralnets_classification_test.txt";
-    String mode = "label";
+    String mode = "-label";
     try {
       NeuralNetwork
-          .main(new String[] { mode, dataPath, RESULT_PATH, MODEL_PATH });
+          .main(new String[] { mode, "-fp", dataPath, "-rp", RESULT_PATH, "-mp", MODEL_PATH
});
 
       // compare results with ground-truth
       BufferedReader groundTruthReader = new BufferedReader(new FileReader(
@@ -98,7 +98,7 @@ public class NeuralNetworkTest extends T
   }
 
   private void neuralNetworkTraining() {
-    String mode = "train";
+    String mode = "-train";
     String strTrainingDataPath = "src/test/resources/neuralnets_classification_training.txt";
     int featureDimension = 8;
     int labelDimension = 1;
@@ -130,8 +130,9 @@ public class NeuralNetworkTest extends T
     }
 
     try {
-      NeuralNetwork.main(new String[] { mode, SEQTRAIN_DATA,
-          MODEL_PATH, "" + featureDimension, "" + labelDimension });
+      NeuralNetwork.main(new String[] { mode, "-tp", SEQTRAIN_DATA, "-mp",
+          MODEL_PATH, "-fd", "" + featureDimension, "-hd",
+          "" + featureDimension, "-ld", "" + labelDimension, "-itr", "3000", "-m", "0.2",
"-l", "0.2" });
     } catch (Exception e) {
       e.printStackTrace();
     }

Modified: hama/trunk/pom.xml
URL: http://svn.apache.org/viewvc/hama/trunk/pom.xml?rev=1562074&r1=1562073&r2=1562074&view=diff
==============================================================================
--- hama/trunk/pom.xml (original)
+++ hama/trunk/pom.xml Tue Jan 28 14:05:15 2014
@@ -88,6 +88,7 @@
     <!-- Dependencies -->
     <commons-logging.version>1.1.1</commons-logging.version>
     <commons-cli.version>1.2</commons-cli.version>
+    <commons-cli2.version>2.0-mahout</commons-cli2.version>
     <commons-configuration>1.7</commons-configuration>
     <commons-lang>2.6</commons-lang>
     <commons-httpclient>3.0.1</commons-httpclient>
@@ -276,7 +277,12 @@
         <artifactId>jackson-mapper-asl</artifactId>
         <version>1.9.2</version>
       </dependency>
-         
+
+      <dependency>
+      	<groupId>org.apache.mahout.commons</groupId>
+      	<artifactId>commons-cli</artifactId>
+      	<version>${commons-cli2.version}</version>
+      </dependency>
     </dependencies>
   </dependencyManagement>
 



Mime
View raw message