mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From adene...@apache.org
Subject svn commit: r1187953 - in /mahout/trunk: core/src/main/java/org/apache/mahout/df/builder/ core/src/main/java/org/apache/mahout/df/data/ core/src/main/java/org/apache/mahout/df/mapreduce/ core/src/main/java/org/apache/mahout/df/node/ core/src/main/java/...
Date Sun, 23 Oct 2011 19:26:20 GMT
Author: adeneche
Date: Sun Oct 23 19:26:19 2011
New Revision: 1187953

URL: http://svn.apache.org/viewvc?rev=1187953&view=rev
Log:
MAHOUT-840 target attribute can now be numerical, although regression is still not supported

Removed:
    mahout/trunk/core/src/main/java/org/apache/mahout/df/node/MockLeaf.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Sun Oct 23 19:26:19 2011
@@ -69,6 +69,7 @@ public class DefaultTreeBuilder implemen
     
     if (selected == null) {
       selected = new boolean[data.getDataset().nbAttributes()];
+      selected[data.getDataset().getLabelId()] = true; // never select the label
     }
     
     if (data.isEmpty()) {
@@ -78,7 +79,7 @@ public class DefaultTreeBuilder implemen
       return new Leaf(data.majorityLabel(rng));
     }
     if (data.identicalLabel()) {
-      return new Leaf(data.get(0).getLabel());
+      return new Leaf(data.getDataset().getLabel(data.get(0)));
     }
     
     int[] attributes = randomAttributes(rng, selected, m);

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Data.java Sun Oct 23 19:26:19 2011
@@ -215,9 +215,9 @@ public class Data implements Cloneable {
       return true;
     }
     
-    int label = get(0).getLabel();
+    int label = dataset.getLabel(get(0));
     for (int index = 1; index < size(); index++) {
-      if (get(index).getLabel() != label) {
+      if (dataset.getLabel(get(index)) != label) {
         return false;
       }
     }
@@ -278,7 +278,7 @@ public class Data implements Cloneable {
     int[] labels = new int[size()];
     
     for (int index = 0; index < labels.length; index++) {
-      labels[index] = get(index).getLabel();
+      labels[index] = dataset.getLabel(get(index));
     }
     
     return labels;
@@ -300,10 +300,12 @@ public class Data implements Cloneable {
     int[] labels = new int[dataset.nbInstances()];
     DataConverter converter = new DataConverter(dataset);
 
+    int labelId = dataset.getLabelId();
+    
     try {
       int index = 0;
       while (iterator.hasNext()) {
-        labels[index++] = converter.convert(0, iterator.next()).getLabel();
+        labels[index++] = (int) converter.convert(0, iterator.next()).get(labelId);
       }
     } finally {
       Closeables.closeQuietly(iterator);
@@ -322,7 +324,7 @@ public class Data implements Cloneable {
     int[] counts = new int[dataset.nblabels()];
     
     for (int index = 0; index < size(); index++) {
-      counts[get(index).getLabel()]++;
+      counts[dataset.getLabel(get(index))]++;
     }
     
     // find the label values that appears the most
@@ -337,7 +339,7 @@ public class Data implements Cloneable {
    */
   public void countLabels(int[] counts) {
     for (int index = 0; index < size(); index++) {
-      counts[get(index).getLabel()]++;
+      counts[dataset.getLabel(get(index))]++;
     }
   }
   

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataConverter.java Sun Oct 23 19:26:19 2011
@@ -43,8 +43,8 @@ public class DataConverter {
   }
   
   public Instance convert(int id, CharSequence string) {
-    // all attributes (categorical, numerical), ignored, label
-    int nball = dataset.nbAttributes() + dataset.getIgnored().length + 1;
+    // all attributes (categorical, numerical, label), ignored
+    int nball = dataset.nbAttributes() + dataset.getIgnored().length;
     
     String[] tokens = COMMA_SPACE.split(string);
     Preconditions.checkArgument(tokens.length == nball, "Wrong number of attributes in the string");
@@ -55,26 +55,28 @@ public class DataConverter {
     int aId = 0;
     int label = -1;
     for (int attr = 0; attr < nball; attr++) {
-      String token = tokens[attr].trim();
-      
       if (ArrayUtils.contains(dataset.getIgnored(), attr)) {
         continue; // IGNORED
       }
+
+      String token = tokens[attr].trim();
       
       if ("?".equals(token)) {
         // missing value
         return null;
       }
       
-      if (attr == dataset.getLabelId()) {
+      if (aId == dataset.getLabelId()) {
         label = dataset.labelCode(token);
         if (label == -1) {
           log.error("label token: {} dataset.labels: {}", token, Arrays.toString(dataset.labels()));
           throw new IllegalStateException("Label value (" + token + ") not known");
         }
-      } else if (dataset.isNumerical(aId)) {
+      } 
+      
+      if (dataset.isNumerical(aId)) {
         vector.set(aId++, Double.parseDouble(token));
-      } else {
+      } else { // CATEGORICAL/LABEL
         vector.set(aId, dataset.valueOf(aId, token));
         aId++;
       }
@@ -85,6 +87,6 @@ public class DataConverter {
       throw new IllegalStateException("Label not found!");
     }
     
-    return new Instance(id, vector, label);
+    return new Instance(id, vector);
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/DataLoader.java Sun Oct 23 19:26:19 2011
@@ -112,7 +112,7 @@ public final class DataLoader {
       throw new IllegalStateException("Label not found!");
     }
     
-    return new Instance(id, vector, label);
+    return new Instance(id, vector);
   }
   
   /**
@@ -188,12 +188,14 @@ public final class DataLoader {
    * 
    * @param descriptor
    *          attributes description
+   * @param regression
+   * 					if true, the label is numerical
    * @param fs
    *          file system
    * @param path
    *          data path
    */
-  public static Dataset generateDataset(String descriptor, FileSystem fs, Path path) throws DescriptorException,
+  public static Dataset generateDataset(String descriptor, boolean regression, FileSystem fs, Path path) throws DescriptorException,
                                                                                     IOException {
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
     
@@ -217,7 +219,7 @@ public final class DataLoader {
     
     scanner.close();
     
-    return new Dataset(attrs, values, id);
+    return new Dataset(attrs, values, id, regression);
   }
   
   /**
@@ -226,7 +228,7 @@ public final class DataLoader {
    * @param descriptor
    *          attributes description
    */
-  public static Dataset generateDataset(String descriptor, String[] data) throws DescriptorException {
+  public static Dataset generateDataset(String descriptor, boolean regression, String[] data) throws DescriptorException {
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
     
     // used to convert CATEGORICAL and LABEL attributes to Integer
@@ -243,7 +245,7 @@ public final class DataLoader {
       }
     }
     
-    return new Dataset(attrs, values, id);
+    return new Dataset(attrs, values, id, regression);
   }
 
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java Sun Oct 23 19:26:19 2011
@@ -68,16 +68,13 @@ public class Dataset implements Writable
   
   private Attribute[] attributes;
   
-  /** all distinct labels */
-  private String[] labels;
-  
   /** list of ignored attributes */
   private int[] ignored;
   
   /** distinct values (CATEGORIAL attributes only) */
   private String[][] values;
   
-  /** index of the label attribute in the original data */
+  /** index of the label attribute in the loaded data (without ignored attributed) */
   private int labelId;
   
   /** number of instances in the dataset */
@@ -94,7 +91,8 @@ public class Dataset implements Writable
    *          distinct values for all CATEGORICAL attributes
    * @param nbInstances
    */
-  protected Dataset(Attribute[] attrs, List<String>[] values, int nbInstances) {
+  protected Dataset(Attribute[] attrs, List<String>[] values, int nbInstances, boolean regression) {
+  	Preconditions.checkArgument(regression == false, "Regression Problems not supported");
     validateValues(attrs, values);
 
     int nbattrs = countAttributes(attrs);
@@ -102,7 +100,7 @@ public class Dataset implements Writable
     // the label values are set apart
     attributes = new Attribute[nbattrs];
     this.values = new String[nbattrs][];
-    ignored = new int[attrs.length - (nbattrs + 1)]; // nbignored = total - (nbattrs + label)
+    ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs
 
     labelId = -1;
     int ignoredId = 0;
@@ -117,11 +115,10 @@ public class Dataset implements Writable
         if (labelId != -1) {
           throw new IllegalStateException("Label found more than once");
         }
-        labelId = attr;
-        continue;
+        labelId = ind;
       }
 
-      if (attrs[attr].isCategorical()) {
+      if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
         this.values[ind] = new String[values[attr].size()];
         values[attr].toArray(this.values[ind]);
       }
@@ -133,24 +130,25 @@ public class Dataset implements Writable
       throw new IllegalStateException("Label not found");
     }
 
-    labels = new String[values[labelId].size()];
-    values[labelId].toArray(labels);
-
     this.nbInstances = nbInstances;
   }
   
   public String[] labels() {
-    return Arrays.copyOf(labels, labels.length);
+    return Arrays.copyOf(values[labelId], nblabels());
   }
   
   public int nblabels() {
-    return labels.length;
+    return values[labelId].length;
   }
   
   public int getLabelId() {
     return labelId;
   }
   
+  public int getLabel(Instance instance) {
+  	return (int) instance.get(getLabelId());
+  }
+
   public int nbInstances() {
     return nbInstances;
   }
@@ -163,12 +161,15 @@ public class Dataset implements Writable
    * @return label's code
    */
   public int labelCode(String label) {
-    return ArrayUtils.indexOf(labels, label);
+    return ArrayUtils.indexOf(values[labelId], label);
   }
   
-  public String getLabel(int code) {
-    // TODO should handle the case (prediction == -1)
-    return labels[code];
+  public String getLabelString(int code) {
+    // handle the case (prediction == -1)
+  	if (code == -1) {
+  		return "unknown";
+  	}
+    return values[labelId][code];
   }
   
   /**
@@ -189,15 +190,13 @@ public class Dataset implements Writable
 
   
   /**
-   * Counts the number of attributes, except IGNORED and LABEL
-   * 
-   * @return number of attributes that are not IGNORED or LABEL
+   * @return number of attributes that are not IGNORED
    */
   protected static int countAttributes(Attribute[] attrs) {
     int nbattrs = 0;
     
-    for (Attribute attr1 : attrs) {
-      if (attr1.isNumerical() || attr1.isCategorical()) {
+    for (Attribute attr : attrs) {
+      if (!attr.isIgnored()) {
         nbattrs++;
       }
     }
@@ -208,7 +207,7 @@ public class Dataset implements Writable
   private static void validateValues(Attribute[] attrs, List<String>[] values) {
     Preconditions.checkArgument(attrs.length == values.length,  "attrs.length != values.length");
     for (int attr = 0; attr < attrs.length; attr++) {
-      Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null,
+      Preconditions.checkArgument(!(attrs[attr].isCategorical() || attrs[attr].isLabel()) || values[attr] != null,
           "values not found for attribute " + attr);
     }
   }
@@ -246,10 +245,6 @@ public class Dataset implements Writable
       return false;
     }
     
-    if (!Arrays.equals(labels, dataset.labels)) {
-      return false;
-    }
-    
     for (int attr = 0; attr < nbAttributes(); attr++) {
       if (!Arrays.equals(values[attr], dataset.values[attr])) {
         return false;
@@ -265,10 +260,8 @@ public class Dataset implements Writable
     for (Attribute attr : attributes) {
       hashCode = 31 * hashCode + attr.hashCode();
     }
-    for (String label : labels) {
-      hashCode = 31 * hashCode + label.hashCode();
-    }
     for (String[] valueRow : values) {
+    	if (valueRow == null) continue;
       for (String value : valueRow) {
         hashCode = 31 * hashCode + value.hashCode();
       }
@@ -305,14 +298,12 @@ public class Dataset implements Writable
       attributes[attr] = Attribute.valueOf(name);
     }
     
-    labels = WritableUtils.readStringArray(in);
-    
     ignored = DFUtils.readIntArray(in);
     
-    // only CATEGORICAL attributes have values
+    // only CATEGORICAL/LABEL attributes have values
     values = new String[nbAttributes][];
     for (int attr = 0; attr < nbAttributes; attr++) {
-      if (attributes[attr].isCategorical()) {
+      if (attributes[attr].isCategorical() || attributes[attr].isLabel()) {
         values[attr] = WritableUtils.readStringArray(in);
       }
     }
@@ -328,8 +319,6 @@ public class Dataset implements Writable
       WritableUtils.writeString(out, attr.name());
     }
     
-    WritableUtils.writeStringArray(out, labels);
-    
     DFUtils.writeArray(out, ignored);
     
     // only CATEGORICAL attributes have values

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Instance.java Sun Oct 23 19:26:19 2011
@@ -29,12 +29,9 @@ public class Instance {
   /** attributes, except LABEL and IGNORED */
   private final Vector attrs;
   
-  private final int label;
-  
-  public Instance(int id, Vector attrs, int label) {
+  public Instance(int id, Vector attrs) {
     this.id = id;
     this.attrs = attrs;
-    this.label = label;
   }
   
   /**
@@ -70,26 +67,17 @@ public class Instance {
     
     Instance instance = (Instance) obj;
     
-    return id == instance.id && label == instance.label && attrs.equals(instance.attrs);
+    return id == instance.id && attrs.equals(instance.attrs);
     
   }
   
   @Override
   public int hashCode() {
-    return id + label + attrs.hashCode();
+    return id + attrs.hashCode();
   }
 
   /** instance unique id */
   public int getId() {
     return id;
   }
-
-  /**
-   * instance label code.<br>
-   * use Dataset.labels to get the real label value
-   *
-   */
-  public int getLabel() {
-    return label;
-  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java Sun Oct 23 19:26:19 2011
@@ -173,8 +173,8 @@ public class Classifier {
             ofile.writeChar('\n');
 
             if (analyzer != null) {
-              analyzer.addInstance(dataset.getLabel(key),
-                                   new ClassifierResult(dataset.getLabel(Integer.parseInt(value)), 1.0));
+              analyzer.addInstance(dataset.getLabelString(key),
+                                   new ClassifierResult(dataset.getLabelString(Integer.parseInt(value)), 1.0));
             }
           }
         }
@@ -204,6 +204,7 @@ public class Classifier {
     private final Random rng = RandomUtils.getRandom();
     private boolean first = true;
     private final Text lvalue = new Text();
+    private Dataset dataset;
 
     @Override
     protected void setup(Context context) throws IOException, InterruptedException {
@@ -216,8 +217,8 @@ public class Classifier {
       if (files == null || files.length < 2) {
         throw new IOException("not enough paths in the DistributedCache");
       }
-
-      Dataset dataset = Dataset.load(conf, new Path(files[0].getPath()));
+      
+      dataset = Dataset.load(conf, new Path(files[0].getPath()));
 
       converter = new DataConverter(dataset);
 
@@ -242,7 +243,7 @@ public class Classifier {
       if (!line.isEmpty()) {
         Instance instance = converter.convert(0, line);
         int prediction = forest.classify(rng, instance);
-        key.set(instance.getLabel());
+        key.set(dataset.getLabel(instance));
         lvalue.set(Integer.toString(prediction));
         context.write(key, lvalue);
       }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/node/Node.java Sun Oct 23 19:26:19 2011
@@ -30,7 +30,6 @@ import org.apache.mahout.df.data.Instanc
 public abstract class Node implements Writable {
   
   protected enum Type {
-    MOCKLEAF,
     LEAF,
     NUMERICAL,
     CATEGORICAL
@@ -60,9 +59,6 @@ public abstract class Node implements Wr
     Node node;
     
     switch (type) {
-      case MOCKLEAF:
-        node = new MockLeaf();
-        break;
       case LEAF:
         node = new Leaf();
         break;

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/split/OptIgSplit.java Sun Oct 23 19:26:19 2011
@@ -22,6 +22,7 @@ import java.util.Arrays;
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.mahout.df.data.Data;
 import org.apache.mahout.df.data.DataUtils;
+import org.apache.mahout.df.data.Dataset;
 import org.apache.mahout.df.data.Instance;
 
 /**
@@ -52,11 +53,13 @@ public class OptIgSplit extends IgSplit 
     int[][] counts = new int[values.length][data.getDataset().nblabels()];
     int[] countAll = new int[data.getDataset().nblabels()];
     
+    Dataset dataset = data.getDataset();
+    
     // compute frequencies
     for (int index = 0; index < data.size(); index++) {
       Instance instance = data.get(index);
-      counts[ArrayUtils.indexOf(values, instance.get(attr))][instance.getLabel()]++;
-      countAll[instance.getLabel()]++;
+      counts[ArrayUtils.indexOf(values, instance.get(attr))][dataset.getLabel(instance)]++;
+      countAll[dataset.getLabel(instance)]++;
     }
     
     int size = data.size();
@@ -93,10 +96,12 @@ public class OptIgSplit extends IgSplit 
   }
   
   protected void computeFrequencies(Data data, int attr, double[] values) {
+  	Dataset dataset = data.getDataset();
+  	
     for (int index = 0; index < data.size(); index++) {
       Instance instance = data.get(index);
-      counts[ArrayUtils.indexOf(values, instance.get(attr))][instance.getLabel()]++;
-      countAll[instance.getLabel()]++;
+      counts[ArrayUtils.indexOf(values, instance.get(attr))][dataset.getLabel(instance)]++;
+      countAll[dataset.getLabel(instance)]++;
     }
   }
   

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java Sun Oct 23 19:26:19 2011
@@ -68,11 +68,14 @@ public final class Describe {
       abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription(
       "Path to generated descriptor file").create();
     
+    Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r")
+        .create();
+    
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
         .create();
     
     Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption(
-      descriptorOpt).withOption(helpOpt).create();
+      descriptorOpt).withOption(regOpt).withOption(helpOpt).create();
     
     try {
       Parser parser = new Parser();
@@ -87,19 +90,21 @@ public final class Describe {
       String dataPath = cmdLine.getValue(pathOpt).toString();
       String descPath = cmdLine.getValue(descPathOpt).toString();
       List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
+      boolean regression = cmdLine.hasOption(regOpt);
       
       log.debug("Data path : {}", dataPath);
       log.debug("Descriptor path : {}", descPath);
       log.debug("Descriptor : {}", descriptor);
+      log.debug("Regression : {}", regression);
       
-      runTool(dataPath, descriptor, descPath);
+      runTool(dataPath, descriptor, descPath, regression);
     } catch (OptionException e) {
       log.warn(e.toString());
       CommandLineUtil.printHelp(group);
     }
   }
   
-  private static void runTool(String dataPath, Iterable<String> description, String filePath)
+  private static void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression)
     throws DescriptorException, IOException {
     log.info("Generating the descriptor...");
     String descriptor = DescriptorUtils.generateDescriptor(description);
@@ -107,17 +112,17 @@ public final class Describe {
     Path fPath = validateOutput(filePath);
     
     log.info("generating the dataset...");
-    Dataset dataset = generateDataset(descriptor, dataPath);
+    Dataset dataset = generateDataset(descriptor, dataPath, regression);
     
     log.info("storing the dataset description");
     DFUtils.storeWritable(new Configuration(), fPath, dataset);
   }
   
-  private static Dataset generateDataset(String descriptor, String dataPath) throws IOException, DescriptorException {
+  private static Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException, DescriptorException {
     Path path = new Path(dataPath);
     FileSystem fs = path.getFileSystem(new Configuration());
     
-    return DataLoader.generateDataset(descriptor, fs, path);
+    return DataLoader.generateDataset(descriptor, regression, fs, path);
   }
   
   private static Path validateOutput(String filePath) throws IOException {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/FrequenciesJob.java Sun Oct 23 19:26:19 2011
@@ -163,12 +163,13 @@ public class FrequenciesJob {
     private LongWritable firstId;
     
     private DataConverter converter;
+    private Dataset dataset;
     
     @Override
     protected void setup(Context context) throws IOException, InterruptedException {
       Configuration conf = context.getConfiguration();
       
-      Dataset dataset = Builder.loadDataset(conf);
+      dataset = Builder.loadDataset(conf);
       setup(dataset);
     }
     
@@ -188,7 +189,7 @@ public class FrequenciesJob {
       
       Instance instance = converter.convert((int) key.get(), value.toString());
       
-      context.write(firstId, new IntWritable(instance.getLabel()));
+      context.write(firstId, new IntWritable(dataset.getLabel(instance)));
     }
     
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/UDistrib.java Sun Oct 23 19:26:19 2011
@@ -175,7 +175,7 @@ public final class UDistrib {
       
       // write the tuple in files[tuple.label]
       Instance instance = converter.convert(id++, line);
-      int label = instance.getLabel();
+      int label = dataset.getLabel(instance);
       files[currents[label]].writeBytes(line);
       files[currents[label]].writeChar('\n');
       

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java Sun Oct 23 19:26:19 2011
@@ -47,7 +47,7 @@ public final class InfiniteRecursionTest
 
     String[] source = Utils.double2String(dData);
     String descriptor = "N N N N N N N N L";
-    Dataset dataset = DataLoader.generateDataset(descriptor, source);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, source);
     Data data = DataLoader.loadData(dataset, source);
 
     builder.build(rng, data);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataConverterTest.java Sun Oct 23 19:26:19 2011
@@ -34,9 +34,9 @@ public final class DataConverterTest ext
     Random rng = RandomUtils.getRandom();
     
     String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
-    double[][] source = Utils.randomDoubles(rng, descriptor, INSTANCE_COUNT);
+    double[][] source = Utils.randomDoubles(rng, descriptor, false, INSTANCE_COUNT);
     String[] sData = Utils.double2String(source);
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     Data data = DataLoader.loadData(dataset, sData);
     
     DataConverter converter = new DataConverter(dataset);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataLoaderTest.java Sun Oct 23 19:26:19 2011
@@ -49,10 +49,10 @@ public final class DataLoaderTest extend
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
 
     // prepare the data
-    double[][] data = Utils.randomDoubles(rng, descriptor, datasize);
+    double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
     Collection<Integer> missings = Lists.newArrayList();
     String[] sData = prepareData(data, attrs, missings);
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     Data loaded = DataLoader.loadData(dataset, sData);
 
     testLoadedData(data, attrs, missings, loaded);
@@ -73,12 +73,12 @@ public final class DataLoaderTest extend
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
 
     // prepare the data
-    double[][] data = Utils.randomDoubles(rng, descriptor, datasize);
+    double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
     Collection<Integer> missings = Lists.newArrayList();
     String[] sData = prepareData(data, attrs, missings);
-    Dataset expected = DataLoader.generateDataset(descriptor, sData);
+    Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
 
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     
     assertEquals(expected, dataset);
   }
@@ -157,13 +157,13 @@ public final class DataLoaderTest extend
 
         if (attrs[attr].isNumerical()) {
           assertEquals(vector[attr], instance.get(aId++), EPSILON);
-        } else if (attrs[attr].isCategorical()) {
+        } else if (attrs[attr].isCategorical()||attrs[attr].isLabel()) {
           checkCategorical(data, missings, loaded, attr, aId, vector[attr],
               instance.get(aId));
           aId++;
-        } else if (attrs[attr].isLabel()) {
+        } /*else if (attrs[attr].isLabel()) {
           checkLabel(data, missings, loaded, attr, vector[attr]);
-        }
+        }*/
       }
       
       lind++;
@@ -192,7 +192,7 @@ public final class DataLoaderTest extend
 
       int aId = 0;
       for (int attr = 0; attr < nbAttributes; attr++) {
-        if (attrs[attr].isIgnored() || attrs[attr].isLabel()) {
+        if (attrs[attr].isIgnored()) {
           continue;
         }
 
@@ -220,10 +220,10 @@ public final class DataLoaderTest extend
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
 
     // prepare the data
-    double[][] source = Utils.randomDoubles(rng, descriptor, datasize);
+    double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
     Collection<Integer> missings = Lists.newArrayList();
     String[] sData = prepareData(source, attrs, missings);
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
 
     Path dataPath = Utils.writeDataToTestFile(sData);
     FileSystem fs = dataPath.getFileSystem(new Configuration());
@@ -246,15 +246,15 @@ public final class DataLoaderTest extend
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
 
     // prepare the data
-    double[][] source = Utils.randomDoubles(rng, descriptor, datasize);
+    double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
     Collection<Integer> missings = Lists.newArrayList();
     String[] sData = prepareData(source, attrs, missings);
-    Dataset expected = DataLoader.generateDataset(descriptor, sData);
+    Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
 
     Path path = Utils.writeDataToTestFile(sData);
     FileSystem fs = path.getFileSystem(new Configuration());
     
-    Dataset dataset = DataLoader.generateDataset(descriptor, fs, path);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path);
     
     assertEquals(expected, dataset);
   }
@@ -304,6 +304,8 @@ public final class DataLoaderTest extend
                          Data loaded,
                          int labelInd,
                          double value) {
+  	Dataset dataset = loaded.getDataset();
+  	
     // label's code that corresponds to the value
     int code = loaded.getDataset().labelCode(Double.toString(value));
 
@@ -315,9 +317,9 @@ public final class DataLoaderTest extend
       }
 
       if (source[index][labelInd] == value) {
-        assertEquals(code, loaded.get(lind).getLabel());
+        assertEquals(code, dataset.getLabel(loaded.get(lind)));
       } else {
-        assertFalse(code == loaded.get(lind).getLabel());
+        assertFalse(code == dataset.getLabel(loaded.get(lind)));
       }
 
       lind++;

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DataTest.java Sun Oct 23 19:26:19 2011
@@ -39,7 +39,7 @@ public class DataTest extends MahoutTest
   public void setUp() throws Exception {
     super.setUp();
     rng = RandomUtils.getRandom();
-    data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+    data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
   }
 
   /**
@@ -82,7 +82,7 @@ public class DataTest extends MahoutTest
 
   @Test
   public void testValues() throws Exception {
-    Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+    Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
 
     for (int attr = 0; attr < data.getDataset().nbAttributes(); attr++) {
       double[] values = data.values(attr);
@@ -108,14 +108,14 @@ public class DataTest extends MahoutTest
   @Test
   public void testIdenticalTrue() throws Exception {
     // generate a small data, only to get the dataset
-    Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, 1).getDataset();
+    Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
     
     // test empty data
     Data empty = new Data(dataset);
     assertTrue(empty.isIdentical());
 
     // test identical data, except for the labels
-    Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+    Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
     Instance model = identical.get(0);
     for (int index = 1; index < DATA_SIZE; index++) {
       for (int attr = 0; attr < identical.getDataset().nbAttributes(); attr++) {
@@ -131,7 +131,7 @@ public class DataTest extends MahoutTest
     int n = 10;
 
     for (int nloop = 0; nloop < n; nloop++) {
-      Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
+      Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
 
       // choose a random instance
       int index = rng.nextInt(DATA_SIZE);
@@ -148,7 +148,7 @@ public class DataTest extends MahoutTest
   @Test
   public void testIdenticalLabelTrue() throws Exception {
     // generate a small data, only to get a dataset
-    Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, 1).getDataset();
+    Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
     
     // test empty data
     Data empty = new Data(dataset);
@@ -156,11 +156,11 @@ public class DataTest extends MahoutTest
 
     // test identical labels
     String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
-    double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor,
+    double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
             DATA_SIZE, rng.nextInt());
     String[] sData = Utils.double2String(source);
     
-    dataset = DataLoader.generateDataset(descriptor, sData);
+    dataset = DataLoader.generateDataset(descriptor, false, sData);
     Data data = DataLoader.loadData(dataset, sData);
     
     assertTrue(data.identicalLabel());
@@ -173,7 +173,7 @@ public class DataTest extends MahoutTest
     for (int nloop = 0; nloop < n; nloop++) {
       String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
       int label = Utils.findLabel(descriptor);
-      double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor,
+      double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
               DATA_SIZE, rng.nextInt());
       // choose a random vector and change its label
       int index = rng.nextInt(DATA_SIZE);
@@ -181,7 +181,7 @@ public class DataTest extends MahoutTest
 
       String[] sData = Utils.double2String(source);
       
-      Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+      Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
       Data data = DataLoader.loadData(dataset, sData);
 
       assertFalse(data.identicalLabel());
@@ -237,8 +237,9 @@ public class DataTest extends MahoutTest
 
   @Test
   public void testCountLabel() throws Exception {
-    Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, DATA_SIZE);
-    int[] counts = new int[data.getDataset().nblabels()];
+    Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+    Dataset dataset = data.getDataset();
+    int[] counts = new int[dataset.nblabels()];
 
     int n = 10;
 
@@ -247,7 +248,7 @@ public class DataTest extends MahoutTest
       data.countLabels(counts);
       
       for (int index=0;index<data.size();index++) {
-        counts[data.get(index).getLabel()]--;
+        counts[dataset.getLabel(data.get(index))]--;
       }
       
       for (int label = 0; label < data.getDataset().nblabels(); label++) {
@@ -264,11 +265,11 @@ public class DataTest extends MahoutTest
     int label = Utils.findLabel(descriptor);
 
     int label1 = rng.nextInt();
-    double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, 100,
+    double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100,
         label1);
     String[] sData = Utils.double2String(source);
     
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     Data data = DataLoader.loadData(dataset, sData);
 
     int code1 = dataset.labelCode(Double.toString(label1));
@@ -286,7 +287,7 @@ public class DataTest extends MahoutTest
       }
     }
     sData = Utils.double2String(source);
-    dataset = DataLoader.generateDataset(descriptor, sData);
+    dataset = DataLoader.generateDataset(descriptor, false, sData);
     data = DataLoader.loadData(dataset, sData);
     int code2 = dataset.labelCode(Double.toString(label2));
 

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/DatasetTest.java Sun Oct 23 19:26:19 2011
@@ -51,7 +51,7 @@ public final class DatasetTest extends M
     for (int nloop = 0; nloop < n; nloop++) {
       byteOutStream.reset();
       
-      Dataset dataset = Utils.randomData(rng, NUM_ATTRIBUTES, 1).getDataset();
+      Dataset dataset = Utils.randomData(rng, NUM_ATTRIBUTES, false, 1).getDataset();
       
       dataset.write(out);
       

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/data/Utils.java Sun Oct 23 19:26:19 2011
@@ -106,16 +106,17 @@ public final class Utils {
    * 
    * @param rng Random number generator
    * @param nbAttributes number of attributes
+   * @param regression true is the label is numerical
    * @param number of data lines to generate
    */
-  public static double[][] randomDoubles(Random rng, int nbAttributes,int number) throws DescriptorException {
+  public static double[][] randomDoubles(Random rng, int nbAttributes, boolean regression, int number) throws DescriptorException {
     String descriptor = randomDescriptor(rng, nbAttributes);
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
 
     double[][] data = new double[number][];
 
     for (int index = 0; index < number; index++) {
-      data[index] = randomVector(rng, attrs);
+      data[index] = randomVector(rng, attrs, regression);
     }
 
     return data;
@@ -128,13 +129,13 @@ public final class Utils {
    * @param descriptor attributes description
    * @param number number of data lines to generate
    */
-  public static double[][] randomDoubles(Random rng, CharSequence descriptor, int number) throws DescriptorException {
+  public static double[][] randomDoubles(Random rng, CharSequence descriptor, boolean regression, int number) throws DescriptorException {
     Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
 
     double[][] data = new double[number][];
 
     for (int index = 0; index < number; index++) {
-      data[index] = randomVector(rng, attrs);
+      data[index] = randomVector(rng, attrs, regression);
     }
 
     return data;
@@ -145,13 +146,14 @@ public final class Utils {
    * 
    * @param rng Random number generator
    * @param nbAttributes number of attributes
+   * @param regression true is the label should be numerical
    * @param size data size
    */
-  public static Data randomData(Random rng, int nbAttributes, int size) throws DescriptorException {
+  public static Data randomData(Random rng, int nbAttributes, boolean regression, int size) throws DescriptorException {
     String descriptor = randomDescriptor(rng, nbAttributes);
-    double[][] source = randomDoubles(rng, descriptor, size);
+    double[][] source = randomDoubles(rng, descriptor, regression, size);
     String[] sData = double2String(source);
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, regression, sData);
     
     return DataLoader.loadData(dataset, sData);
   }
@@ -168,7 +170,7 @@ public final class Utils {
    * 
    * @param attrs attributes description
    */
-  private static double[] randomVector(Random rng, Attribute[] attrs) {
+  private static double[] randomVector(Random rng, Attribute[] attrs, boolean regression) {
     double[] vector = new double[attrs.length];
 
     for (int attr = 0; attr < attrs.length; attr++) {
@@ -176,9 +178,14 @@ public final class Utils {
         vector[attr] = Double.NaN;
       } else if (attrs[attr].isNumerical()) {
         vector[attr] = rng.nextDouble();
-      } else {
-        // CATEGORICAL or LABEL
+      } else if (attrs[attr].isCategorical()){
         vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+      } else { // LABEL
+      	if (regression) {
+          vector[attr] = rng.nextDouble();
+      	} else {
+          vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+      	}
       }
     }
 
@@ -222,14 +229,15 @@ public final class Utils {
    * 
    * @param rng
    * @param descriptor
+   * @param regression
    * @param number data size
    * @param value label value
    */
   public static double[][] randomDoublesWithSameLabel(Random rng,
-      String descriptor, int number, int value) throws DescriptorException {
+      String descriptor, boolean regression, int number, int value) throws DescriptorException {
     int label = findLabel(descriptor);
     
-    double[][] source = randomDoubles(rng, descriptor, number);
+    double[][] source = randomDoubles(rng, descriptor, regression, number);
     
     for (int index = 0; index < number; index++) {
       source[index][label] = value;

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java Sun Oct 23 19:26:19 2011
@@ -90,9 +90,9 @@ public final class Step1MapperTest exten
 
     // prepare the data
     String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
-    double[][] source = Utils.randomDoubles(rng, descriptor, NUM_INSTANCES);
+    double[][] source = Utils.randomDoubles(rng, descriptor, false, NUM_INSTANCES);
     String[] sData = Utils.double2String(source);
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     String[][] splits = Utils.splitData(sData, NUM_MAPPERS);
 
     MockTreeBuilder treeBuilder = new MockTreeBuilder();

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/split/DefaultIgSplitTest.java Sun Oct 23 19:26:19 2011
@@ -38,9 +38,9 @@ public final class DefaultIgSplitTest ex
     int label = Utils.findLabel(descriptor);
 
     // all the vectors have the same label (0)
-    double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, 100, 0);
+    double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, 0);
     String[] sData = Utils.double2String(temp);
-    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
+    Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
     Data data = DataLoader.loadData(dataset, sData);
     DefaultIgSplit iG = new DefaultIgSplit();
 
@@ -53,7 +53,7 @@ public final class DefaultIgSplitTest ex
       temp[index][label] = 1.0;
     }
     sData = Utils.double2String(temp);
-    dataset = DataLoader.generateDataset(descriptor, sData);
+    dataset = DataLoader.generateDataset(descriptor, false, sData);
     data = DataLoader.loadData(dataset, sData);
     iG = new DefaultIgSplit();
     
@@ -67,7 +67,7 @@ public final class DefaultIgSplitTest ex
       temp[index][label] = 2.0;
     }
     sData = Utils.double2String(temp);
-    dataset = DataLoader.generateDataset(descriptor, sData);
+    dataset = DataLoader.generateDataset(descriptor, false, sData);
     data = DataLoader.loadData(dataset, sData);
     iG = new DefaultIgSplit();
     

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/split/OptIgSplitTest.java Sun Oct 23 19:26:19 2011
@@ -37,7 +37,7 @@ public final class OptIgSplitTest extend
     IgSplit opt = new OptIgSplit();
 
     Random rng = RandomUtils.getRandom();
-    Data data = Utils.randomData(rng, NUM_ATTRIBUTES, NUM_INSTANCES);
+    Data data = Utils.randomData(rng, NUM_ATTRIBUTES, false, NUM_INSTANCES);
 
     for (int nloop = 0; nloop < 100; nloop++) {
       int attr = rng.nextInt(data.getDataset().nbAttributes());

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Sun Oct 23 19:26:19 2011
@@ -54,7 +54,9 @@ public class BreimanExample extends Conf
   private static final Logger log = LoggerFactory.getLogger(BreimanExample.class);
   
   /** sum test error */
-  private double sumTestErr;
+  private double sumTestErrM;
+
+  private double sumTestErrOne;
   
   /** mean time to build a forest with m=log2(M)+1 */
   private long sumTimeM;
@@ -113,9 +115,12 @@ public class BreimanExample extends Conf
     // compute the test set error (Selection Error), and mean tree error (One Tree Error),
     int[] testLabels = test.extractLabels();
     int[] predictions = new int[test.size()];
+    
     forestM.classify(test, predictions);
+    sumTestErrM += ErrorEstimate.errorRate(testLabels, predictions);
     
-    sumTestErr += ErrorEstimate.errorRate(testLabels, predictions);
+    forestOne.classify(test, predictions);
+    sumTestErrOne += ErrorEstimate.errorRate(testLabels, predictions);
   }
   
   public static void main(String[] args) throws Exception {
@@ -194,7 +199,8 @@ public class BreimanExample extends Conf
     }
     
     log.info("********************************************");
-    log.info("Selection error : {}", sumTestErr / nbIterations);
+    log.info("Random Input Test Error : {}", sumTestErrM / nbIterations);
+    log.info("Single Input Test Error : {}", sumTestErrOne / nbIterations);
     log.info("Mean Random Input Time : {}", DFUtils.elapsedTime(sumTimeM / nbIterations));
     log.info("Mean Single Input Time : {}", DFUtils.elapsedTime(sumTimeOne / nbIterations));
     log.info("Mean Random Input Num Nodes : {}", numNodesM / nbIterations);

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=1187953&r1=1187952&r2=1187953&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sun Oct 23 19:26:19 2011
@@ -261,8 +261,8 @@ public class TestForest extends Configur
         }
 
         if (analyzer != null) {
-          analyzer.addInstance(dataset.getLabel(instance.getLabel()),
-                               new ClassifierResult(dataset.getLabel(prediction), 1.0));
+          analyzer.addInstance(dataset.getLabelString(dataset.getLabel(instance)),
+                               new ClassifierResult(dataset.getLabelString(prediction), 1.0));
         }
       }
 



Mime
View raw message