mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject svn commit: r1143542 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org/apache/mahout/common/ examples/src/main/java/org/apache/mahout/classifier/sgd/ examples/src/test/java/org/apache/mahout/classifier/sgd/...
Date Wed, 06 Jul 2011 20:17:59 GMT
Author: srowen
Date: Wed Jul  6 20:17:58 2011
New Revision: 1143542

URL: http://svn.apache.org/viewvc?rev=1143542&view=rev
Log:
MAHOUT-696 add command lines for adaptive logistic

Added:
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
    mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
    mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
    mahout/trunk/src/conf/driver.classes.props

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java Wed Jul  6 20:17:58 2011
@@ -70,7 +70,7 @@ public class CsvRecordFactory implements
 
   // crude CSV value splitter.  This will fail if any double quoted strings have
   // commas inside.  Also, escaped quotes will not be unescaped.  Good enough for now.
-  private final Splitter COMMA = Splitter.on(',').trimResults(CharMatcher.is('"'));
+  private static final Splitter COMMA = Splitter.on(',').trimResults(CharMatcher.is('"'));
 
   private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY =
           ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder()
@@ -87,6 +87,10 @@ public class CsvRecordFactory implements
 
   private int target;
   private final Dictionary targetDictionary;
+  
+  //Which column is  used for identify a CSV file line 
+  private String idName;
+  private int id = -1;
 
   private List<Integer> predictors;
   private Map<Integer, FeatureVectorEncoder> predictorEncoders;
@@ -109,6 +113,11 @@ public class CsvRecordFactory implements
     targetDictionary = new Dictionary();
   }
 
+  public CsvRecordFactory(String targetName, String idName, Map<String, String> typeMap){
+    this(targetName, typeMap);
+    this.idName = idName;
+  }
+
   /**
    * Defines the values and thus the encoding of values of the target variables.  Note
    * that any values of the target variable not present in this list will be given the
@@ -165,6 +174,11 @@ public class CsvRecordFactory implements
 
     // record target column and establish dictionary for decoding target
     target = vars.get(targetName);
+    
+    // record id column
+    if (idName != null){
+      id = vars.get(idName);
+    }
 
     // create list of predictor column numbers
     predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {
@@ -244,6 +258,69 @@ public class CsvRecordFactory implements
     }
     return targetValue;
   }
+  
+  /***
+   * Decodes a single line of csv data and records the target(if retrunTarget is true)
+   * and predictor variables in a record. As a side effect, features are added into the featureVector.
+   * Returns the value of the target variable. When used during classify against production data without
+   * target value, the method will be called with returnTarget = false. 
+   * @param line The raw data.
+   * @param featureVector Where to fill in the features.  Should be zeroed before calling
+   *                      processLine.
+   * @param returnTarget whether process and return target value, -1 will be returned if false.
+   * @return The value of the target variable.
+   */
+  public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) {
+    List<String> values = Lists.newArrayList(COMMA.split(line));
+    int targetValue = -1;
+    if (returnTarget) {
+      targetValue = targetDictionary.intern(values.get(target));
+      if (targetValue >= maxTargetValue) {
+        targetValue = maxTargetValue - 1;
+      }
+    }
+
+    for (Integer predictor : predictors) {
+      String value = predictor >= 0 ? values.get(predictor) : null;
+      predictorEncoders.get(predictor).addToVector(value, featureVector);
+    }
+    return targetValue;
+  }
+  
+  /***
+   * Extract the raw target string from a line read from a CSV file.
+   * @param line the line of content read from CSV file
+   * @return the raw target value in the corresponding column of CSV line 
+   */
+  public String getTargetString(CharSequence line) {
+    List<String> values = Lists.newArrayList(COMMA.split(line));
+    return values.get(target);
+
+  }
+
+  /***
+   * Extract the corresponding raw target label according to a code 
+   * @param code the integer code encoded during training process
+   * @return the raw target label
+   */  
+  public String getTargetLabel(int code) {
+    for (String key: targetDictionary.values()) {
+      if (targetDictionary.intern(key) == code) {
+        return key;
+      }
+    }
+    return null;
+  }
+  
+  /***
+   * Extract the id column value from the CSV record
+   * @param line the line of content read from CSV file
+   * @return the id value of the CSV record
+   */
+  public String getIdString(CharSequence line){
+    List<String> values = Lists.newArrayList(COMMA.split(line));
+    return values.get(id);
+  }
 
   /**
    * Returns a list of the names of the predictor variables.
@@ -284,4 +361,12 @@ public class CsvRecordFactory implements
     return r;
   }
 
+  public String getIdName() {
+    return idName;
+  }
+
+  public void setIdName(String idName) {
+    this.idName = idName;
+  }
+
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java Wed Jul  6 20:17:58 2011
@@ -17,7 +17,12 @@
 
 package org.apache.mahout.common;
 
-import java.io.*;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
 import java.lang.reflect.Field;
 
 import com.google.common.base.Charsets;
@@ -106,15 +111,15 @@ public abstract class MahoutTestCase ext
    * Find a declared field in a class or one of it's super classes
    */
   private static Field findDeclaredField(Class<?> inClass, String fieldname) throws NoSuchFieldException {
-    if (Object.class.equals(inClass)) {
-      throw new NoSuchFieldException();
-    }
-    for (Field field : inClass.getDeclaredFields()) {
-      if (field.getName().equalsIgnoreCase(fieldname)) {
-        return field;
+    while (!Object.class.equals(inClass)) {
+      for (Field field : inClass.getDeclaredFields()) {
+        if (field.getName().equalsIgnoreCase(fieldname)) {
+          return field;
+        }
       }
+      inClass = inClass.getSuperclass();
     }
-    return findDeclaredField(inClass.getSuperclass(), fieldname);
+    throw new NoSuchFieldException();
   }
 
   /**

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java?rev=1143542&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java Wed Jul  6 20:17:58 2011
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.io.Closeables;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.GroupedOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+public class AdaptiveLogisticModelParameters extends LogisticModelParameters {
+
+  private AdaptiveLogisticRegression alr;
+  private int interval = 800;
+  private int averageWindow = 500;
+  private int threads = 4;
+  private String prior = "L1";
+  private double priorOption = Double.NaN;
+  private String auc = null;
+
+  public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
+
+    if (alr == null) {
+      alr = new AdaptiveLogisticRegression(getMaxTargetCategories(),
+                                           getNumFeatures(), createPrior(prior, priorOption));
+      alr.setInterval(interval);
+      alr.setAveragingWindow(averageWindow);
+      alr.setThreadCount(threads);
+      alr.setAucEvaluator(createAUC(auc));
+    }
+    return alr;
+  }
+
+  public void checkParameters() {
+    if (prior != null) {
+      if ("TP".equals(prior.toUpperCase().trim()) ||
+          "EBP".equals(prior.toUpperCase().trim())) {
+        if (Double.isNaN(priorOption)) {
+          throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior.");
+        }
+      }
+    }
+  }
+
+  private static PriorFunction createPrior(String cmd, double priorOption) {
+    if (cmd == null) {
+      return null;
+    }
+    if ("L1".equals(cmd.toUpperCase().trim())) {
+      return new L1();
+    }
+    if ("L2".equals(cmd.toUpperCase().trim())) {
+      return new L2();
+    }
+    if ("UP".equals(cmd.toUpperCase().trim())) {
+      return new UniformPrior();
+    }
+    if ("TP".equals(cmd.toUpperCase().trim())) {
+      return new TPrior(priorOption);
+    }
+    if ("EBP".equals(cmd.toUpperCase().trim())) {
+      return new ElasticBandPrior(priorOption);
+    }
+
+    return null;
+  }
+
+  private static OnlineAuc createAUC(String cmd) {
+    if (cmd == null) {
+      return null;
+    }
+    if ("GLOBAL".equals(cmd.toUpperCase().trim())) {
+      return new GlobalOnlineAuc();
+    }
+    if ("GROUPED".equals(cmd.toUpperCase().trim())) {
+      return new GroupedOnlineAuc();
+    }
+    return null;
+  }
+
+  @Override
+  public void saveTo(OutputStream out) throws IOException {
+    if (alr != null) {
+      alr.close();
+    }
+    setTargetCategories(getCsvRecordFactory().getTargetCategories());
+    write(new DataOutputStream(out));
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeUTF(getTargetVariable());
+    out.writeInt(getTypeMap().size());
+    for (Map.Entry<String, String> entry : getTypeMap().entrySet()) {
+      out.writeUTF(entry.getKey());
+      out.writeUTF(entry.getValue());
+    }
+    out.writeInt(getNumFeatures());
+    out.writeInt(getMaxTargetCategories());
+    out.writeInt(getTargetCategories().size());
+    for (String category : getTargetCategories()) {
+      out.writeUTF(category);
+    }
+
+    out.writeInt(interval);
+    out.writeInt(averageWindow);
+    out.writeInt(threads);
+    out.writeUTF(prior);
+    out.writeDouble(priorOption);
+    out.writeUTF(auc);
+
+    // skip csv
+    alr.write(out);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    setTargetVariable(in.readUTF());
+    int typeMapSize = in.readInt();
+    Map<String, String> typeMap = new HashMap<String, String>(typeMapSize);
+    for (int i = 0; i < typeMapSize; i++) {
+      String key = in.readUTF();
+      String value = in.readUTF();
+      typeMap.put(key, value);
+    }
+    setTypeMap(typeMap);
+
+    setNumFeatures(in.readInt());
+    setMaxTargetCategories(in.readInt());
+    int targetCategoriesSize = in.readInt();
+    List<String> targetCategories = new ArrayList<String>(targetCategoriesSize);
+    for (int i = 0; i < targetCategoriesSize; i++) {
+      targetCategories.add(in.readUTF());
+    }
+    setTargetCategories(targetCategories);
+
+    interval = in.readInt();
+    averageWindow = in.readInt();
+    threads = in.readInt();
+    prior = in.readUTF();
+    priorOption = in.readDouble();
+    auc = in.readUTF();
+
+    alr = new AdaptiveLogisticRegression();
+    alr.readFields(in);
+  }
+
+
+  private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException {
+    AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters();
+    result.readFields(new DataInputStream(in));
+    return result;
+  }
+
+  public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException {
+    InputStream input = new FileInputStream(in);
+    try {
+      return loadFromStream(input);
+    } finally {
+      Closeables.closeQuietly(input);
+    }
+  }
+
+  public int getInterval() {
+    return interval;
+  }
+
+  public void setInterval(int interval) {
+    this.interval = interval;
+  }
+
+  public int getAverageWindow() {
+    return averageWindow;
+  }
+
+  public void setAverageWindow(int averageWindow) {
+    this.averageWindow = averageWindow;
+  }
+
+  public int getThreads() {
+    return threads;
+  }
+
+  public void setThreads(int threads) {
+    this.threads = threads;
+  }
+
+  public String getPrior() {
+    return prior;
+  }
+
+  public void setPrior(String prior) {
+    this.prior = prior;
+  }
+
+  public String getAuc() {
+    return auc;
+  }
+
+  public void setAuc(String auc) {
+    this.auc = auc;
+  }
+
+  public double getPriorOption() {
+    return priorOption;
+  }
+
+  public void setPriorOption(double priorOption) {
+    this.priorOption = priorOption;
+  }
+
+
+}

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java Wed Jul  6 20:17:58 2011
@@ -216,6 +216,10 @@ public class LogisticModelParameters imp
     maxTargetCategories = targetCategories.size();
   }
 
+  public List<String> getTargetCategories() {
+    return this.targetCategories;
+  }
+
   public void setUseBias(boolean useBias) {
     this.useBias = useBias;
   }
@@ -232,6 +236,10 @@ public class LogisticModelParameters imp
     return typeMap;
   }
 
+  public void setTypeMap(Map<String, String> map) {
+    this.typeMap = map;
+  }
+
   public int getNumFeatures() {
     return numFeatures;
   }

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java?rev=1143542&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java Wed Jul  6 20:17:58 2011
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+
+public final class RunAdaptiveLogistic {
+
+  private static String inputFile;
+  private static String modelFile;
+  private static String outputFile;
+  private static String idColumn;
+  private static boolean maxScoreOnly;
+
+  private RunAdaptiveLogistic() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    mainToOutput(args, new PrintWriter(System.out));
+  }
+
+  static void mainToOutput(String[] args, PrintWriter output) throws IOException {
+    if (!parseArgs(args)) {
+      return;
+    }
+    AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+        .loadFromFile(new File(modelFile));
+
+    CsvRecordFactory csv = lmp.getCsvRecordFactory();
+    csv.setIdName(idColumn);
+
+    AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+
+    State<Wrapper, CrossFoldLearner> best = lr.getBest();
+    if (best == null) {
+      output.printf("%s\n",
+          "AdaptiveLogisticRegression has not be trained probably.");
+      return;
+    }
+    CrossFoldLearner learner = best.getPayload().getLearner();
+
+    BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
+    BufferedWriter out = new BufferedWriter(new FileWriter(outputFile));
+
+    out.write(idColumn + ",target,score");
+    out.newLine();
+
+    String line = in.readLine();
+    csv.firstLine(line);
+    line = in.readLine();
+    Map<String, Double> results = new HashMap<String, Double>();
+    int k = 0;
+    while (line != null) {
+      Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+      csv.processLine(line, v, false);
+      Vector scores = learner.classifyFull(v);
+      results.clear();
+      if (maxScoreOnly) {
+        results.put(csv.getTargetLabel(scores.maxValueIndex()),
+            scores.maxValue());
+      } else {
+        for (int i = 0; i < scores.size(); i++) {
+          results.put(csv.getTargetLabel(i), scores.get(i));
+        }
+      }
+
+      for (Map.Entry<String,Double> entry : results.entrySet()) {
+        out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
+        out.newLine();
+      }
+      k++;
+      if (k % 100 == 0) {
+        output.printf(Locale.ENGLISH, "%d records processed \n", k);
+      }
+      line = in.readLine();
+    }
+    out.flush();
+    out.close();
+    output.printf(Locale.ENGLISH, "%d records processed totally.\n", k);
+  }
+
+  private static boolean parseArgs(String[] args) {
+    DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help")
+      .withDescription("print this list").create();
+
+    Option quiet = builder.withLongName("quiet")
+      .withDescription("be extra quiet").create();
+
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option inputFileOption = builder
+      .withLongName("input")
+      .withRequired(true)
+      .withArgument(
+          argumentBuilder.withName("input").withMaximum(1)
+            .create())
+      .withDescription("where to get training data").create();
+
+    Option modelFileOption = builder
+      .withLongName("model")
+      .withRequired(true)
+      .withArgument(
+          argumentBuilder.withName("model").withMaximum(1)
+            .create())
+      .withDescription("where to get the trained model").create();
+    
+    Option outputFileOption = builder
+      .withLongName("output")
+      .withRequired(true)
+      .withDescription("the file path to output scores")
+      .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+      .create();
+    
+    Option idColumnOption = builder
+      .withLongName("idcolumn")
+      .withRequired(true)
+      .withDescription("the name of the id column for each record")
+      .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
+      .create();
+    
+    Option maxScoreOnlyOption = builder
+      .withLongName("maxscoreonly")
+      .withDescription("only output the target label with max scores")
+      .create();
+
+    Group normalArgs = new GroupBuilder()
+      .withOption(help).withOption(quiet)
+      .withOption(inputFileOption).withOption(modelFileOption)
+      .withOption(outputFileOption).withOption(idColumnOption)
+      .withOption(maxScoreOnlyOption)
+      .create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    inputFile = getStringArgument(cmdLine, inputFileOption);
+    modelFile = getStringArgument(cmdLine, modelFileOption);
+    outputFile = getStringArgument(cmdLine, outputFileOption);
+    idColumn = getStringArgument(cmdLine, idColumnOption);
+    maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);    
+    return true;
+  }
+
+  private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+    return cmdLine.hasOption(option);
+  }
+
+  private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+    return (String) cmdLine.getValue(inputFile);
+  }
+
+}

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java Wed Jul  6 20:17:58 2011
@@ -33,7 +33,7 @@ import org.apache.mahout.classifier.eval
 import java.io.BufferedReader;
 import java.io.File;
 import java.io.IOException;
-import java.io.PrintStream;
+import java.io.PrintWriter;
 import java.util.Locale;
 
 public final class RunLogistic {
@@ -43,12 +43,15 @@ public final class RunLogistic {
   private static boolean showAuc;
   private static boolean showScores;
   private static boolean showConfusion;
-  static PrintStream output = System.out;
 
   private RunLogistic() {
   }
 
   public static void main(String[] args) throws IOException {
+    mainToOutput(args, new PrintWriter(System.out));
+  }
+
+  static void mainToOutput(String[] args, PrintWriter output) throws IOException {
     if (parseArgs(args)) {
       if (!showAuc && !showConfusion && !showScores) {
         showAuc = true;

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java?rev=1143542&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java Wed Jul  6 20:17:58 2011
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.PrintWriter;
+import java.util.List;
+import java.util.Locale;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+
+public final class TrainAdaptiveLogistic {
+
+  private static String inputFile;
+  private static String outputFile;
+  private static AdaptiveLogisticModelParameters lmp;
+  private static int passes;
+  private static boolean showperf;
+  private static int skipperfnum = 99;
+  private static AdaptiveLogisticRegression model;
+
+  private TrainAdaptiveLogistic() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    mainToOutput(args, new PrintWriter(System.out));
+  }
+
+  static void mainToOutput(String[] args, PrintWriter output) throws IOException {
+    if (parseArgs(args)) {
+
+      CsvRecordFactory csv = lmp.getCsvRecordFactory();
+      model = lmp.createAdaptiveLogisticRegression();
+      State<Wrapper, CrossFoldLearner> best = null;
+      CrossFoldLearner learner = null;
+
+      int k = 0;
+      for (int pass = 0; pass < passes; pass++) {
+        BufferedReader in = open(inputFile);
+
+        // read variable names
+        csv.firstLine(in.readLine());
+
+        String line = in.readLine();
+
+        while (line != null) {
+          // for each new line, get target and predictors
+          Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+          int targetValue = csv.processLine(line, input);
+          // update model
+          model.train(targetValue, input);
+          k++;
+
+          if (showperf && (k % (skipperfnum + 1) == 0)) {
+
+            best = model.getBest();
+            if (best != null) {
+              learner = best.getPayload().getLearner();
+            }
+            if (learner != null) {
+              double averageCorrect = learner.percentCorrect();
+              double averageLL = learner.logLikelihood();
+              output.printf("%d\t%.3f\t%.2f\n",
+                            k, averageLL, averageCorrect * 100);
+            } else {
+              output.printf(Locale.ENGLISH,
+                            "%10d %2d %s\n", k, targetValue,
+                            "AdaptiveLogisticRegression has not found a good model ......");
+            }
+          }
+          line = in.readLine();
+        }
+        in.close();
+      }
+
+      best = model.getBest();
+      if (best != null) {
+        learner = best.getPayload().getLearner();
+      }
+      if (learner == null) {
+        output.printf(Locale.ENGLISH,
+                      "%s\n", "AdaptiveLogisticRegression has not successfully trained any model.");
+        return;
+      }
+
+
+      OutputStream modelOutput = new FileOutputStream(outputFile);
+      try {
+        lmp.saveTo(modelOutput);
+      } finally {
+        modelOutput.close();
+      }
+
+      OnlineLogisticRegression lr = learner.getModels().get(0);
+      output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
+      output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
+      String sep = "";
+      for (String v : csv.getTraceDictionary().keySet()) {
+        double weight = predictorWeight(lr, 0, csv, v);
+        if (weight != 0) {
+          output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
+          sep = " + ";
+        }
+      }
+      output.printf("\n");
+
+      for (int row = 0; row < lr.getBeta().numRows(); row++) {
+        for (String key : csv.getTraceDictionary().keySet()) {
+          double weight = predictorWeight(lr, row, csv, key);
+          if (weight != 0) {
+            output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
+          }
+        }
+        for (int column = 0; column < lr.getBeta().numCols(); column++) {
+          output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column));
+        }
+        output.println();
+      }
+    }
+
+  }
+
+  private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) {
+    double weight = 0;
+    for (Integer column : csv.getTraceDictionary().get(predictor)) {
+      weight += lr.getBeta().get(row, column);
+    }
+    return weight;
+  }
+
+  private static boolean parseArgs(String[] args) {
+    DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help")
+        .withDescription("print this list").create();
+
+    Option quiet = builder.withLongName("quiet")
+        .withDescription("be extra quiet").create();
+    
+   
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option showperf = builder
+      .withLongName("showperf")
+      .withDescription("output performance measures during training")
+      .create();
+
+    Option inputFile = builder
+        .withLongName("input")
+        .withRequired(true)
+        .withArgument(
+            argumentBuilder.withName("input").withMaximum(1)
+                .create())
+        .withDescription("where to get training data").create();
+
+    Option outputFile = builder
+        .withLongName("output")
+        .withRequired(true)
+        .withArgument(
+            argumentBuilder.withName("output").withMaximum(1)
+                .create())
+        .withDescription("where to write the model content").create();
+
+    Option threads = builder.withLongName("threads")
+        .withArgument(
+            argumentBuilder.withName("threads").withDefault("4").create())
+        .withDescription("the number of threads AdaptiveLogisticRegression uses")
+        .create();
+
+
+    Option predictors = builder.withLongName("predictors")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("predictors").create())
+        .withDescription("a list of predictor variables").create();
+
+    Option types = builder
+        .withLongName("types")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("types").create())
+        .withDescription(
+            "a list of predictor variable types (numeric, word, or text)")
+        .create();
+
+    Option target = builder
+        .withLongName("target")
+        .withDescription("the name of the target variable")    
+        .withRequired(true)    
+        .withArgument(
+            argumentBuilder.withName("target").withMaximum(1)
+                .create())
+         .create();
+    
+    Option targetCategories = builder
+    .withLongName("categories")
+    .withDescription("the number of target categories to be considered")
+    .withRequired(true)
+    .withArgument(
+        argumentBuilder.withName("categories").withMaximum(1)
+            .create())        
+    .create();
+    
+
+    Option features = builder
+        .withLongName("features")
+        .withDescription("the number of internal hashed features to use")
+        .withArgument(
+            argumentBuilder.withName("numFeatures")
+                .withDefault("1000").withMaximum(1).create())        
+        .create();
+
+    Option passes = builder
+        .withLongName("passes")
+        .withDescription("the number of times to pass over the input data")
+        .withArgument(
+            argumentBuilder.withName("passes").withDefault("2")
+                .withMaximum(1).create())        
+        .create();
+
+    Option interval = builder.withLongName("interval")
+        .withArgument(
+            argumentBuilder.withName("interval").withDefault("500").create())
+        .withDescription("the interval property of AdaptiveLogisticRegression")
+        .create();
+
+    Option window = builder.withLongName("window")
+        .withArgument(
+            argumentBuilder.withName("window").withDefault("800").create())
+        .withDescription("the average propery of AdaptiveLogisticRegression")
+        .create();
+
+    Option skipperfnum = builder.withLongName("skipperfnum")
+        .withArgument(
+            argumentBuilder.withName("skipperfnum").withDefault("99").create())
+        .withDescription("show performance measures every (skipperfnum + 1) rows")
+        .create();
+
+    Option prior = builder.withLongName("prior")
+        .withArgument(
+            argumentBuilder.withName("prior").withDefault("L1").create())
+        .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
+        .create();
+
+    Option priorOption = builder.withLongName("prioroption")
+        .withArgument(
+            argumentBuilder.withName("prioroption").create())
+        .withDescription("constructor parameter for ElasticBandPrior and TPrior")
+        .create();
+
+    Option auc = builder.withLongName("auc")
+        .withArgument(
+            argumentBuilder.withName("auc").withDefault("global").create())
+        .withDescription("the auc to use: global or grouped")
+        .create();
+
+    
+
+    Group normalArgs = new GroupBuilder().withOption(help)
+        .withOption(quiet).withOption(inputFile).withOption(outputFile)
+        .withOption(target).withOption(targetCategories)
+        .withOption(predictors).withOption(types).withOption(passes)
+        .withOption(interval).withOption(window).withOption(threads)
+        .withOption(prior).withOption(features).withOption(showperf)
+        .withOption(skipperfnum).withOption(priorOption).withOption(auc)
+        .create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+    TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
+                                                         outputFile);
+
+    List<String> typeList = Lists.newArrayList();
+    for (Object x : cmdLine.getValues(types)) {
+      typeList.add(x.toString());
+    }
+
+    List<String> predictorList = Lists.newArrayList();
+    for (Object x : cmdLine.getValues(predictors)) {
+      predictorList.add(x.toString());
+    }
+
+    lmp = new AdaptiveLogisticModelParameters();
+    lmp.setTargetVariable(getStringArgument(cmdLine, target));
+    lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+    lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+    lmp.setInterval(getIntegerArgument(cmdLine, interval));
+    lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
+    lmp.setThreads(getIntegerArgument(cmdLine, threads));
+    lmp.setAuc(getStringArgument(cmdLine, auc));
+    lmp.setPrior(getStringArgument(cmdLine, prior));
+    if (cmdLine.getValue(priorOption) != null) {
+      lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
+    }
+    lmp.setTypeMap(predictorList, typeList);
+    TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
+    TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum);
+    TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
+
+    lmp.checkParameters();
+
+    return true;
+  }
+
+  private static String getStringArgument(CommandLine cmdLine,
+                                          Option inputFile) {
+    return (String) cmdLine.getValue(inputFile);
+  }
+
+  private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+    return cmdLine.hasOption(option);
+  }
+
+  private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+    return Integer.parseInt((String) cmdLine.getValue(features));
+  }
+
+  private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+    return Double.parseDouble((String) cmdLine.getValue(op));
+  }
+
+  public static AdaptiveLogisticRegression getModel() {
+    return model;
+  }
+
+  public static LogisticModelParameters getParameters() {
+    return lmp;
+  }
+
+   static BufferedReader open(String inputFile) throws IOException {
+    InputStream in;
+    try {
+      in = Resources.getResource(inputFile).openStream();
+    } catch (IllegalArgumentException e) {
+      in = new FileInputStream(new File(inputFile));
+    }
+    return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
+  }
+   
+}

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java Wed Jul  6 20:17:58 2011
@@ -40,7 +40,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.io.OutputStream;
-import java.io.PrintStream;
+import java.io.PrintWriter;
 import java.util.List;
 import java.util.Locale;
 
@@ -53,16 +53,18 @@ public final class TrainLogistic {
   private static String inputFile;
   private static String outputFile;
   private static LogisticModelParameters lmp;
-
   private static int passes;
   private static boolean scores;
   private static OnlineLogisticRegression model;
-  static PrintStream output = System.out;
 
   private TrainLogistic() {
   }
 
   public static void main(String[] args) throws IOException {
+    mainToOutput(args, new PrintWriter(System.out));
+  }
+
+  static void mainToOutput(String[] args, PrintWriter output) throws IOException {
     if (parseArgs(args)) {
       double logPEstimate = 0;
       int samples = 0;

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java?rev=1143542&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java (added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java Wed Jul  6 20:17:58 2011
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Locale;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.ConfusionMatrix;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/*
+ * Auc and averageLikelihood are always shown if possible, if the number of target value is more than 2, 
+ * then Auc and entropy matirx are not shown regardless the value of showAuc and showEntropy
+ * the user passes, because the current implementation does not support them on two value targets.
+ * */
+public final class ValidateAdaptiveLogistic {
+
+  private static String inputFile;
+  private static String modelFile;
+  private static boolean showAuc;
+  private static boolean showScores;
+  private static boolean showConfusion;
+
+  private ValidateAdaptiveLogistic() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    mainToOutput(args, new PrintWriter(System.out));
+  }
+
+  static void mainToOutput(String[] args, PrintWriter output) throws IOException {
+    if (parseArgs(args)) {
+      if (!showAuc && !showConfusion && !showScores) {
+        showAuc = true;
+        showConfusion = true;
+      }
+
+      Auc collector = null;
+      AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+          .loadFromFile(new File(modelFile));
+      CsvRecordFactory csv = lmp.getCsvRecordFactory();
+      AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();      
+      
+      if (lmp.getTargetCategories().size() <=2 ) {
+        collector = new Auc();
+      }
+      
+      OnlineSummarizer slh = new OnlineSummarizer();
+      ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), "unknown");
+
+
+      State<Wrapper, CrossFoldLearner> best = lr.getBest();
+      if (best == null) {
+        output.printf("%s\n",
+            "AdaptiveLogisticRegression has not be trained probably.");
+        return;
+      }
+      CrossFoldLearner learner = best.getPayload().getLearner();
+
+      BufferedReader in = TrainLogistic.open(inputFile);
+      String line = in.readLine();
+      csv.firstLine(line);
+      line = in.readLine();
+      if (showScores) {
+        output.printf(Locale.ENGLISH, "\"%s\", \"%s\", \"%s\", \"%s\"\n",
+            "target", "model-output", "log-likelihood", "average-likelihood");
+      }
+      while (line != null) {
+        Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+        //TODO: How to avoid extra target values not shown in the training process.
+        int target = csv.processLine(line, v);
+        double likelihood = learner.logLikelihood(target, v);
+        double score = learner.classifyFull(v).maxValue();
+        
+        slh.add(likelihood);
+        cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target));        
+        
+        if (showScores) {
+          output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f\n", target,
+              score, learner.logLikelihood(target, v), slh.getMean());
+        }
+        if (collector != null) {
+          collector.add(target, score);
+        }
+        line = in.readLine();
+      }
+      
+      output.printf(Locale.ENGLISH,"\nLog-likelihood:");
+      output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f\n", 
+          slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian());
+
+      if (collector != null) {        
+        output.printf(Locale.ENGLISH, "\nAUC = %.2f\n", collector.auc());        
+      }
+      
+      if (showConfusion) {
+        output.printf(Locale.ENGLISH, "\n%s\n\n", cm.toString());
+        
+        if (collector != null){
+          Matrix m = collector.entropy();
+          output.printf(Locale.ENGLISH,
+              "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0),
+              m.get(1, 0), m.get(0, 1), m.get(1, 1));
+        }        
+      }
+      
+    }
+  }
+
+  private static boolean parseArgs(String[] args) {
+    DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help")
+        .withDescription("print this list").create();
+
+    Option quiet = builder.withLongName("quiet")
+        .withDescription("be extra quiet").create();
+
+    Option auc = builder.withLongName("auc").withDescription("print AUC")
+        .create();
+    Option confusion = builder.withLongName("confusion")
+        .withDescription("print confusion matrix").create();
+
+    Option scores = builder.withLongName("scores")
+        .withDescription("print scores").create();
+  
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option inputFileOption = builder
+        .withLongName("input")
+        .withRequired(true)
+        .withArgument(
+            argumentBuilder.withName("input").withMaximum(1)
+                .create())
+        .withDescription("where to get validate data").create();
+
+    Option modelFileOption = builder
+        .withLongName("model")
+        .withRequired(true)
+        .withArgument(
+            argumentBuilder.withName("model").withMaximum(1)
+                .create())
+        .withDescription("where to get the trained model").create();
+
+    Group normalArgs = new GroupBuilder().withOption(help)
+        .withOption(quiet).withOption(auc).withOption(scores)
+        .withOption(confusion).withOption(inputFileOption)
+        .withOption(modelFileOption).create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    inputFile = getStringArgument(cmdLine, inputFileOption);
+    modelFile = getStringArgument(cmdLine, modelFileOption);
+    showAuc = getBooleanArgument(cmdLine, auc);
+    showScores = getBooleanArgument(cmdLine, scores);
+    showConfusion = getBooleanArgument(cmdLine, confusion);
+
+    return true;
+  }
+
+  private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+    return cmdLine.hasOption(option);
+  }
+
+  private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
+    return (String) cmdLine.getValue(inputFile);
+  }
+
+}

Modified: mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java (original)
+++ mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java Wed Jul  6 20:17:58 2011
@@ -28,15 +28,11 @@ import org.apache.mahout.math.DenseVecto
 import org.apache.mahout.math.Vector;
 import org.junit.Test;
 
-import java.io.ByteArrayOutputStream;
 import java.io.File;
 import java.io.FileInputStream;
-import java.io.IOException;
 import java.io.InputStream;
-import java.io.PrintStream;
-import java.lang.reflect.Field;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
+import java.io.PrintWriter;
+import java.io.StringWriter;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -47,16 +43,19 @@ public class TrainLogisticTest extends M
   public void example13_1() throws Exception {
     String outputFile = getTestTempFile("model").getAbsolutePath();
 
-    String trainOut = runMain(TrainLogistic.class, new String[]{
-      "--input", "donut.csv",
-      "--output", outputFile,
-      "--target", "color", "--categories", "2",
-      "--predictors", "x", "y",
-      "--types", "numeric",
-      "--features", "20",
-      "--passes", "100",
-      "--rate", "50"
-    });
+    StringWriter sw = new StringWriter();
+    PrintWriter pw = new PrintWriter(sw);
+    TrainLogistic.mainToOutput(new String[]{
+        "--input", "donut.csv",
+        "--output", outputFile,
+        "--target", "color", "--categories", "2",
+        "--predictors", "x", "y",
+        "--types", "numeric",
+        "--features", "20",
+        "--passes", "100",
+        "--rate", "50"
+    }, pw);
+    String trainOut = sw.toString();
     assertTrue(trainOut.contains("x -0.7"));
     assertTrue(trainOut.contains("y -0.4"));
 
@@ -87,20 +86,26 @@ public class TrainLogisticTest extends M
       Closeables.closeQuietly(in);
     }
 
-    String output = runMain(RunLogistic.class, new String[]{
+    sw = new StringWriter();
+    pw = new PrintWriter(sw);
+    RunLogistic.mainToOutput(new String[]{
         "--input", "donut.csv",
         "--model", outputFile,
         "--auc",
         "--confusion"
-    });
-    assertTrue(output.contains("AUC = 0.57"));
-    assertTrue(output.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
+    }, pw);
+    trainOut = sw.toString();
+    assertTrue(trainOut.contains("AUC = 0.57"));
+    assertTrue(trainOut.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
   }
 
   @Test
   public void example13_2() throws Exception {
     String outputFile = getTestTempFile("model").getAbsolutePath();
-    String trainOut = runMain(TrainLogistic.class, new String[]{
+
+    StringWriter sw = new StringWriter();
+    PrintWriter pw = new PrintWriter(sw);
+    TrainLogistic.mainToOutput(new String[]{
         "--input", "donut.csv",
         "--output", outputFile,
         "--target", "color",
@@ -110,59 +115,34 @@ public class TrainLogisticTest extends M
         "--features", "20",
         "--passes", "100",
         "--rate", "50"
-    });
+    }, pw);
 
+    String trainOut = sw.toString();
     assertTrue(trainOut.contains("a 0."));
     assertTrue(trainOut.contains("b -1."));
     assertTrue(trainOut.contains("c -25."));
 
-    String output = runMain(RunLogistic.class, new String[]{
+    sw = new StringWriter();
+    pw = new PrintWriter(sw);
+    RunLogistic.mainToOutput(new String[]{
         "--input", "donut.csv",
         "--model", outputFile,
         "--auc",
         "--confusion"
-    });
-    assertTrue(output.contains("AUC = 1.00"));
-
-    String heldout = runMain(RunLogistic.class, new String[]{
+    }, pw);
+    trainOut = sw.toString();
+    assertTrue(trainOut.contains("AUC = 1.00"));
+
+    sw = new StringWriter();
+    pw = new PrintWriter(sw);
+    RunLogistic.mainToOutput(new String[]{
         "--input", "donut-test.csv",
         "--model", outputFile,
         "--auc",
         "--confusion"
-    });
-    assertTrue(heldout.contains("AUC = 0.9"));
-  }
-
-  /**
-   * Runs a class with a public static void main method.  We assume that there is an accessible
-   * field named "output" that we can change to redirect output.
-   *
-   *
-   * @param clazz   contains the main method.
-   * @param args    contains the command line arguments
-   * @return The contents to standard out as a string.
-   * @throws IOException                   Not possible, but must be declared.
-   * @throws NoSuchFieldException          If there isn't an output field.
-   * @throws IllegalAccessException        If the output field isn't accessible by us.
-   * @throws NoSuchMethodException         If there isn't a main method.
-   * @throws InvocationTargetException     If the main method throws an exception.
-   */
-  private static String runMain(Class<?> clazz, String[] args)
-    throws NoSuchFieldException, IllegalAccessException, NoSuchMethodException, InvocationTargetException {
-    ByteArrayOutputStream trainOutput = new ByteArrayOutputStream();
-    PrintStream printStream = new PrintStream(trainOutput);
-
-    try {
-      Field outputField = clazz.getDeclaredField("output");
-      Method main = clazz.getMethod("main", args.getClass());
-
-      outputField.set(null, printStream);
-      Object[] argList = {args};
-      main.invoke(null, argList);
-      return new String(trainOutput.toByteArray(), Charsets.UTF_8);
-    } finally {
-      Closeables.closeQuietly(printStream);
-    }
+    }, pw);
+    trainOut = sw.toString();
+    assertTrue(trainOut.contains("AUC = 0.9"));
   }
 
   private static void verifyModel(LogisticModelParameters lmp,

Modified: mahout/trunk/src/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Wed Jul  6 20:17:58 2011
@@ -31,6 +31,9 @@ org.apache.mahout.cf.taste.hadoop.item.R
 org.apache.mahout.classifier.sgd.TrainLogistic = trainlogistic : Train a logistic regression using stochastic gradient descent
 org.apache.mahout.classifier.sgd.RunLogistic = runlogistic : Run a logistic regression model against CSV data
 org.apache.mahout.classifier.sgd.PrintResourceOrFile = cat : Print a file or resource as the logistic regression models would see it
+org.apache.mahout.classifier.sgd.TrainAdaptiveLogistic = trainAdaptiveLogistic : Train an AdaptivelogisticRegression model
+org.apache.mahout.classifier.sgd.ValidateAdaptiveLogistic = validateAdaptiveLogistic : Validate an AdaptivelogisticRegression model against hold-out data set
+org.apache.mahout.classifier.sgd.RunAdaptiveLogistic = runAdaptiveLogistic : Score new production data using a probably trained and validated AdaptivelogisticRegression model
 org.apache.mahout.classifier.bayes.WikipediaXmlSplitter = wikipediaXMLSplitter : Reads wikipedia data and creates ch  
 org.apache.mahout.classifier.bayes.WikipediaDatasetCreatorDriver = wikipediaDataSetCreator : Splits data set of wikipedia wrt feature like country
 org.apache.mahout.math.hadoop.stochasticsvd.SSVDCli = ssvd : Stochastic SVD



Mime
View raw message