mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1000280 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/main/java/org/apache/mahout/math/stats/ core/src/test/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org/apache/mahout/math/stats/ math/sr...
Date Thu, 23 Sep 2010 00:12:09 GMT
Author: tdunning
Date: Thu Sep 23 00:12:09 2010
New Revision: 1000280

URL: http://svn.apache.org/viewvc?rev=1000280&view=rev
Log:
Added grouping key to ALR and CFL.  Added AUC classes to support grouping.
Made pending records transient to save space in serialized models.
Extended learning rate lower limit.
Fixed rare corner case issue in OnlineSummarizer

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java
      - copied, changed from r1000189, mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Thu Sep 23 00:12:09 2010
@@ -24,6 +24,7 @@ import org.apache.mahout.ep.Evolutionary
 import org.apache.mahout.ep.Mapping;
 import org.apache.mahout.ep.State;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.OnlineAuc;
 
 import java.util.List;
 import java.util.Locale;
@@ -63,7 +64,8 @@ public class AdaptiveLogisticRegression 
   private int record = 0;
   private int evaluationInterval = 1000;
 
-  private List<TrainingExample> buffer = Lists.newArrayList();
+  // transient here is a signal to GSON not to serialize pending records
+  private transient List<TrainingExample> buffer = Lists.newArrayList();
   private EvolutionaryProcess<Wrapper> ep;
   private State<Wrapper> best;
   private int threadCount = 20;
@@ -94,9 +96,14 @@ public class AdaptiveLogisticRegression 
 
   @Override
   public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
     record++;
 
-    buffer.add(new TrainingExample(trackingKey, actual, instance));
+    buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
     if (buffer.size() > evaluationInterval) {
       trainWithBufferedExamples();
     }
@@ -178,6 +185,11 @@ public class AdaptiveLogisticRegression 
     setupOptimizer(poolSize);
   }
 
+  public void setAucEvaluator(OnlineAuc auc) {
+    seed.getPayload().setAucEvaluator(auc);
+    setupOptimizer(poolSize);
+  }
+
   private void setupOptimizer(int poolSize) {
     ep = new EvolutionaryProcess<Wrapper>(threadCount, poolSize, seed);
   }
@@ -265,12 +277,9 @@ public class AdaptiveLogisticRegression 
     return numFeatures;
   }
 
-  public void setNumFeatures(int numFeatures) {
-    this.numFeatures = numFeatures;
-  }
-
   public void setAveragingWindow(int averagingWindow) {
     seed.getPayload().getLearner().setWindowSize(averagingWindow);
+    setupOptimizer(poolSize);
   }
 
   public void setFreezeSurvivors(boolean freezeSurvivors) {
@@ -333,11 +342,11 @@ public class AdaptiveLogisticRegression 
       // set the range for regularization (lambda)
       x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
       // set the range for learning rate (mu)
-      x.setMap(i, Mapping.softLimit(0.001, 10));
+      x.setMap(i, Mapping.logLimit(1e-8, 1));
     }
 
     public void train(TrainingExample example) {
-      wrapped.train(example.getKey(), example.getActual(), example.getInstance());
+      wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance());
     }
 
     public CrossFoldLearner getLearner() {
@@ -348,10 +357,15 @@ public class AdaptiveLogisticRegression 
     public String toString() {
       return String.format(Locale.ENGLISH, "auc=%.2f", wrapped.auc());
     }
+
+    public void setAucEvaluator(OnlineAuc auc) {
+      wrapped.setAucEvaluator(auc);
+    }
   }
 
   public static class TrainingExample {
     private long key;
+    private String groupKey;
     private int actual;
     private Vector instance;
 
@@ -359,8 +373,9 @@ public class AdaptiveLogisticRegression 
     private TrainingExample() {
     }
 
-    public TrainingExample(long key, int actual, Vector instance) {
+    public TrainingExample(long key, String groupKey, int actual, Vector instance) {
       this.key = key;
+      this.groupKey = groupKey;
       this.actual = actual;
       this.instance = instance;
     }
@@ -376,6 +391,10 @@ public class AdaptiveLogisticRegression 
     public Vector getInstance() {
       return instance;
     }
+
+    public String getGroupKey() {
+      return groupKey;
+    }
   }
 }
 

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Thu Sep 23 00:12:09 2010
@@ -8,6 +8,7 @@ import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.BinaryFunction;
 import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.stats.OnlineAuc;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
 
 import java.util.List;
 
@@ -23,7 +24,7 @@ public class CrossFoldLearner extends Ab
   private int record;
   // minimum score to be used for computing log likelihood
   private static final double MIN_SCORE = 1e-50;
-  private OnlineAuc auc = new OnlineAuc();
+  private OnlineAuc auc = new GlobalOnlineAuc();
   private double logLikelihood;
   private final List<OnlineLogisticRegression> models = Lists.newArrayList();
 
@@ -94,6 +95,10 @@ public class CrossFoldLearner extends Ab
 
   @Override
   public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
     record++;
     int k = 0;
     for (OnlineLogisticRegression model : models) {
@@ -105,7 +110,7 @@ public class CrossFoldLearner extends Ab
         int correct = v.maxValueIndex() == actual ? 1 : 0;
         percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize);
         if (numCategories() == 2) {
-          auc.addSample(actual, v.get(1));
+          auc.addSample(actual, groupKey, v.get(1));
         }
       } else {
         model.train(trackingKey, actual, instance);
@@ -206,11 +211,11 @@ public class CrossFoldLearner extends Ab
     this.record = record;
   }
 
-  public OnlineAuc getAuc() {
+  public OnlineAuc getAucEvaluator() {
     return auc;
   }
 
-  public void setAuc(OnlineAuc auc) {
+  public void setAucEvaluator(OnlineAuc auc) {
     this.auc = auc;
   }
 
@@ -248,6 +253,7 @@ public class CrossFoldLearner extends Ab
 
   public void setWindowSize(int windowSize) {
     this.windowSize = windowSize;
+    auc.setWindowSize(windowSize);
   }
 
   public PriorFunction getPrior() {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
Thu Sep 23 00:12:09 2010
@@ -36,7 +36,7 @@ import org.apache.mahout.math.DenseMatri
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.stats.OnlineAuc;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
 
 import java.io.FileWriter;
 import java.io.IOException;
@@ -154,7 +154,7 @@ public final class ModelSerializer {
       CrossFoldLearner r = new CrossFoldLearner();
       JsonObject x = jsonElement.getAsJsonObject();
       r.setRecord(x.get("record").getAsInt());
-      r.setAuc(jsonDeserializationContext.<OnlineAuc>deserialize(x.get("auc"), OnlineAuc.class));
+      r.setAucEvaluator(jsonDeserializationContext.<GlobalOnlineAuc>deserialize(x.get("auc"),
GlobalOnlineAuc.class));
       r.setLogLikelihood(x.get("logLikelihood").getAsDouble());
 
       JsonArray models = x.get("models").getAsJsonArray();

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java
(from r1000189, mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java&r1=1000189&r2=1000280&rev=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java Thu
Sep 23 00:12:09 2010
@@ -31,8 +31,9 @@ import java.util.Random;
  * batch fashion.  The probabilistic definition (the probability that a random element of
one set
  * has a higher score than a random element of another set) gives us a way to estimate this
  * on-line.
+ * @see GroupedOnlineAuc
  */
-public class OnlineAuc {
+public class GlobalOnlineAuc implements OnlineAuc {
   enum ReplacementPolicy {
     FIFO, FAIR, RANDOM
   }
@@ -43,7 +44,7 @@ public class OnlineAuc {
   public static final int HISTORY = 10;
 
   // defines the exponential averaging window for results
-  private int windowSize=Integer.MAX_VALUE;
+  private int windowSize = Integer.MAX_VALUE;
 
   // FIFO has distinctly the best properties as a policy.  See OnlineAucTest for details
   private ReplacementPolicy policy = ReplacementPolicy.FIFO;
@@ -52,7 +53,7 @@ public class OnlineAuc {
   private final Vector averages;
   private final Vector samples;
 
-  public OnlineAuc() {
+  public GlobalOnlineAuc() {
     int numCategories = 2;
     scores = new DenseMatrix(numCategories, HISTORY);
     scores.assign(Double.NaN);
@@ -61,6 +62,13 @@ public class OnlineAuc {
     samples = new DenseVector(numCategories);
   }
 
+  @Override
+  @SuppressWarnings({"UnusedDeclaration"})
+  public double addSample(int category, String groupKey, double score) {
+    return addSample(category, score);
+  }
+  
+  @Override
   public double addSample(int category, double score) {
     int n = (int) samples.get(category);
     if (n < HISTORY) {
@@ -109,15 +117,22 @@ public class OnlineAuc {
     return auc();
   }
 
+  @Override
   public double auc() {
     // return an unweighted average of all averages.
     return (1 - averages.get(0) + averages.get(1)) / 2;
   }
 
+  public double value() {
+    return auc();
+  }
+
+  @Override
   public void setPolicy(ReplacementPolicy policy) {
     this.policy = policy;
   }
 
+  @Override
   public void setWindowSize(int windowSize) {
     this.windowSize = windowSize;
   }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java?rev=1000280&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java Thu
Sep 23 00:12:09 2010
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import com.google.common.collect.Maps;
+
+import java.util.Map;
+
+/**
+ * Implements a variant on AUC where the result returned is an average of several AUC measurements
+ * made on sub-groups of the overall data.  Controlling for the grouping factor allows the
effects
+ * of the grouping factor on the model to be ignored.  This is useful, for instance, when
using a
+ * classifier as a click prediction engine.  In that case you want AUC to refer only to the
ranking
+ * of items for a particular user, not to the discrimination of users from each other.  Grouping
by
+ * user (or user cluster) helps avoid optimizing for the wrong quality.
+ */
+public class GroupedOnlineAuc implements OnlineAuc {
+  private Map<String, OnlineAuc> map = Maps.newHashMap();
+  private GlobalOnlineAuc.ReplacementPolicy policy;
+  private int windowSize;
+
+  @Override
+  public double addSample(int category, String groupKey, double score) {
+    if (groupKey == null) {
+      addSample(category, score);
+    }
+    
+    OnlineAuc group = map.get(groupKey);
+    if (group == null) {
+      group = new GlobalOnlineAuc();
+      if (policy != null) {
+        group.setPolicy(policy);
+      }
+      if (windowSize > 0) {
+        group.setWindowSize(windowSize);
+      }
+      map.put(groupKey, group);
+    }
+    return group.addSample(category, score);
+  }
+
+  @Override
+  public double addSample(int category, double score) {
+    throw new UnsupportedOperationException("Can't add to " + this.getClass() + " without
group key");
+  }
+
+  @Override
+  public double auc() {
+    double sum = 0;
+    for (OnlineAuc auc : map.values()) {
+      sum += auc.auc();
+    }
+    return sum / map.size();
+  }
+
+  @Override
+  public void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy) {
+    this.policy = policy;
+    for (OnlineAuc auc : map.values()) {
+      auc.setPolicy(policy);
+    }
+  }
+
+  @Override
+  public void setWindowSize(int windowSize) {
+    this.windowSize = windowSize;
+    for (OnlineAuc auc : map.values()) {
+      auc.setWindowSize(windowSize);
+    }
+  }
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Thu Sep 23
00:12:09 2010
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
  * this work for additional information regarding copyright ownership.
@@ -17,108 +17,21 @@
 
 package org.apache.mahout.math.stats;
 
-import org.apache.mahout.math.DenseMatrix;
-import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.Vector;
-
-import java.util.Random;
-
 /**
- * Computes a running estimate of AUC (see http://en.wikipedia.org/wiki/Receiver_operating_characteristic).
- * <p/>
- * Since AUC is normally a global property of labeled scores, it is almost always computed
in a
- * batch fashion.  The probabilistic definition (the probability that a random element of
one set
- * has a higher score than a random element of another set) gives us a way to estimate this
- * on-line.
+ * Describes the generic outline of how to compute AUC.  Currently there are two
+ * implementations of this, one for computing a global estimate of AUC and the other
+ * for computing average grouped AUC.  Grouped AUC is useful when misusing a classifier
+ * as a recommendation system.
  */
-public class OnlineAuc {
-  enum ReplacementPolicy {
-    FIFO, FAIR, RANDOM
-  }
-
-  // increasing this to 100 causes very small improvements in accuracy.  Decreasing it to
2
-  // causes substantial degradation for the FAIR and RANDOM policies, but almost no change
-  // for the FIFO policy
-  public static final int HISTORY = 10;
-
-  // defines the exponential averaging window for results
-  private int windowSize=Integer.MAX_VALUE;
-
-  // FIFO has distinctly the best properties as a policy.  See OnlineAucTest for details
-  private ReplacementPolicy policy = ReplacementPolicy.FIFO;
-  private transient Random random = org.apache.mahout.common.RandomUtils.getRandom();
-  private final Matrix scores;
-  private final Vector averages;
-  private final Vector samples;
-
-  public OnlineAuc() {
-    int numCategories = 2;
-    scores = new DenseMatrix(numCategories, HISTORY);
-    scores.assign(Double.NaN);
-    averages = new DenseVector(numCategories);
-    averages.assign(0.5);
-    samples = new DenseVector(numCategories);
-  }
-
-  public double addSample(int category, double score) {
-    int n = (int) samples.get(category);
-    if (n < HISTORY) {
-      scores.set(category, n, score);
-    } else {
-      switch (policy) {
-        case FIFO:
-          scores.set(category, n % HISTORY, score);
-          break;
-        case FAIR:
-          int j1 = random.nextInt(n + 1);
-          if (j1 < HISTORY) {
-            scores.set(category, j1, score);
-          }
-          break;
-        case RANDOM:
-          int j2 = random.nextInt(HISTORY);
-          scores.set(category, j2, score);
-          break;
-      }
-    }
-
-    samples.set(category, n + 1);
-
-    if (samples.minValue() >= 1) {
-      // compare to previous scores for other category
-      Vector row = scores.viewRow(1 - category);
-      double m = 0.0;
-      double count = 0.0;
-      for (Vector.Element element : row) {
-        double v = element.get();
-        if (Double.isNaN(v)) {
-          continue;
-        }
-        count++;
-        if (score > v) {
-          m++;
-        } else if (score < v) {
-          // m += 0
-        } else if (score == v) {
-          m += 0.5;
-        }
-      }
-      averages.set(category, averages.get(category) + (m / count - averages.get(category))
/ Math.min(windowSize, samples.get(category)));
-    }
-    return auc();
-  }
-
-  public double auc() {
-    // return an unweighted average of all averages.
-    return (1 - averages.get(0) + averages.get(1)) / 2;
-  }
-
-  public void setPolicy(ReplacementPolicy policy) {
-    this.policy = policy;
-  }
-
-  public void setWindowSize(int windowSize) {
-    this.windowSize = windowSize;
-  }
+public interface OnlineAuc {
+  @SuppressWarnings({"UnusedDeclaration"})
+  double addSample(int category, String groupKey, double score);
+
+  double addSample(int category, double score);
+
+  double auc();
+
+  void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy);
+
+  void setWindowSize(int windowSize);
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
Thu Sep 23 00:12:09 2010
@@ -84,7 +84,7 @@ public final class AdaptiveLogisticRegre
     if (gen.nextDouble() < p) {
       target = 1;
     }
-    return new AdaptiveLogisticRegression.TrainingExample(i, target, data);
+    return new AdaptiveLogisticRegression.TrainingExample(i, null, target, data);
   }
 
   @Test

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Thu Sep 23 00:12:09 2010
@@ -28,6 +28,7 @@ import org.apache.mahout.math.DenseVecto
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
 import org.apache.mahout.math.stats.OnlineAuc;
 import org.junit.Test;
 
@@ -62,7 +63,7 @@ public final class ModelSerializerTest e
   @Test
   public void onlineAucRoundtrip() {
     RandomUtils.useTestSeed();
-    OnlineAuc auc1 = new OnlineAuc();
+    OnlineAuc auc1 = new GlobalOnlineAuc();
     Random gen = new Random(2);
     for (int i = 0; i < 10000; i++) {
       auc1.addSample(0, gen.nextGaussian());
@@ -73,7 +74,7 @@ public final class ModelSerializerTest e
     Gson gson = ModelSerializer.gson();
     String s = gson.toJson(auc1);
 
-    OnlineAuc auc2 = gson.fromJson(s, OnlineAuc.class);
+    OnlineAuc auc2 = gson.fromJson(s, GlobalOnlineAuc.class);
 
     assertEquals(auc1.auc(), auc2.auc(), 0);
 
@@ -148,7 +149,7 @@ public final class ModelSerializerTest e
     List<AdaptiveLogisticRegression.TrainingExample> x1 = Lists.newArrayList();
     for (int i = 0; i < 10; i++) {
       AdaptiveLogisticRegression.TrainingExample t =
-          new AdaptiveLogisticRegression.TrainingExample(i, i % 3, randomVector(gen, 5));
+          new AdaptiveLogisticRegression.TrainingExample(i, null, i % 3, randomVector(gen,
5));
       x1.add(t);
     }
 

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Thu Sep
23 00:12:09 2010
@@ -24,7 +24,7 @@ import org.junit.Test;
 
 import java.util.Random;
 
-import static org.apache.mahout.math.stats.OnlineAuc.ReplacementPolicy.*;
+import static org.apache.mahout.math.stats.GlobalOnlineAuc.ReplacementPolicy.*;
 
 public final class OnlineAucTest extends MahoutTestCase {
 
@@ -37,14 +37,14 @@ public final class OnlineAucTest extends
       stats[i] = new OnlineSummarizer();
     }
 
-    for (int i = 0; i < 500; i++) {
-      OnlineAuc a1 = new OnlineAuc();
+    for (int i = 0; i < 100; i++) {
+      OnlineAuc a1 = new GlobalOnlineAuc();
       a1.setPolicy(FAIR);
 
-      OnlineAuc a2 = new OnlineAuc();
+      OnlineAuc a2 = new GlobalOnlineAuc();
       a2.setPolicy(FIFO);
 
-      OnlineAuc a3 = new OnlineAuc();
+      OnlineAuc a3 = new GlobalOnlineAuc();
       a3.setPolicy(RANDOM);
 
       Auc a4 = new Auc();
@@ -72,7 +72,7 @@ public final class OnlineAucTest extends
     }
     
     int i = 0;
-    for (OnlineAuc.ReplacementPolicy policy : new OnlineAuc.ReplacementPolicy[]{FAIR, FIFO,
RANDOM, null}) {
+    for (GlobalOnlineAuc.ReplacementPolicy policy : new GlobalOnlineAuc.ReplacementPolicy[]{FAIR,
FIFO, RANDOM, null}) {
       OnlineSummarizer summary = stats[i++];
       System.out.printf("%s,%.4f (min = %.4f, 25%%-ile=%.4f, 75%%-ile=%.4f, max=%.4f)\n",
policy, summary.getMean(),
         summary.getQuartile(0), summary.getQuartile(1), summary.getQuartile(2), summary.getQuartile(3));
@@ -94,4 +94,32 @@ public final class OnlineAucTest extends
     assertEquals(0.7603, stats[2].getQuartile(1), 0.006);
     assertEquals(0.7603, stats[2].getQuartile(1), 0.006);
   }
+
+  @Test(expected=UnsupportedOperationException.class)
+  public void mustNotOmitGroup() {
+    OnlineAuc x = new GroupedOnlineAuc();
+    x.addSample(0, 3.14);
+  }
+
+  @Test
+  public void groupedAuc() {
+    Random gen = RandomUtils.getRandom();
+    OnlineAuc x = new GroupedOnlineAuc();
+    OnlineAuc y = new GlobalOnlineAuc();
+
+    for (int i = 0; i < 10000; i++) {
+      x.addSample(0, "a", gen.nextGaussian());
+      x.addSample(1, "a", gen.nextGaussian() + 1);
+      x.addSample(0, "b", gen.nextGaussian() + 10);
+      x.addSample(1, "b", gen.nextGaussian() + 11);
+
+      y.addSample(0, "a", gen.nextGaussian());
+      y.addSample(1, "a", gen.nextGaussian() + 1);
+      y.addSample(0, "b", gen.nextGaussian() + 10);
+      y.addSample(1, "b", gen.nextGaussian() + 11);
+    }
+
+    assertEquals(0.7603, x.auc(), 0.01);
+    assertEquals((0.7603 + 0.5) / 2, y.auc(), 0.02);
+  }
 }

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java?rev=1000280&r1=1000279&r2=1000280&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java Thu
Sep 23 00:12:09 2010
@@ -66,15 +66,17 @@ public class OnlineSummarizer {
 
     if (n < 100) {
       starter.add(sample);
-    } else if (n == 100) {
+    } else if (n == 100 && starter != null) {
+      // when we first reach 100 elements, we switch to incremental operation
       starter.add(sample);
-      q[0] = getMin();
-      q[1] = getQuartile(1);
-      q[2] = getQuartile(2);
-      q[3] = getQuartile(3);
-      q[4] = getMax();
+      for (int i = 0; i <= 4; i++) {
+        q[i] = getQuartile(i);
+      }
+      // this signals any invocations of getQuartile at exactly 100 elements that we have
+      // already switched to incremental operation
       starter = null;
     } else {
+      // n >= 100 && starter == null
       q[0] = Math.min(sample, q[0]);
       q[4] = Math.max(sample, q[4]);
 
@@ -106,11 +108,7 @@ public class OnlineSummarizer {
   }
 
   public double getMin() {
-    sort();
-    if (n == 0) {
-      throw new IllegalArgumentException("Must have at least one sample to estimate minimum
value");
-    }
-    return n <= 100 ? starter.get(0) : q[0];
+    return getQuartile(0);
   }
 
   private void sort() {
@@ -121,35 +119,39 @@ public class OnlineSummarizer {
   }
 
   public double getMax() {
-    sort();
-    if (n == 0) {
-      throw new IllegalArgumentException("Must have at least one sample to estimate maximum
value");
-    }
-    return n <= 100 ? starter.get(99) : q[4];
+    return getQuartile(4);
   }
 
   public double getQuartile(int i) {
-    sort();
-    switch (i) {
-      case 0:
-        return getMin();
-      case 1:
-      case 2:
-      case 3:
-        if (n > 100) {
-          return q[i];
-        } else if (n < 2) {
-          throw new IllegalArgumentException("Must have at least two samples to estimate
quartiles");
-        } else {
-          double x = i * (n - 1) / 4.0;
-          int k = (int) Math.floor(x);
-          double u = x - k;
-          return starter.get(k) * (1 - u) + starter.get(k + 1) * u;
-        }
-      case 4:
-        return getMax();
-      default:
-        throw new IllegalArgumentException("Quartile number must be in the range [0..4] not
" + i);
+    if (n > 100 || starter == null) {
+      return q[i];
+    } else {
+      sort();
+      switch (i) {
+        case 0:
+          if (n == 0) {
+            throw new IllegalArgumentException("Must have at least one sample to estimate
minimum value");
+          }
+          return starter.get(0);
+        case 1:
+        case 2:
+        case 3:
+          if (n >= 2) {
+            double x = i * (n - 1) / 4.0;
+            int k = (int) Math.floor(x);
+            double u = x - k;
+            return starter.get(k) * (1 - u) + starter.get(k + 1) * u;
+          } else {
+            throw new IllegalArgumentException("Must have at least two samples to estimate
quartiles");
+          }
+        case 4:
+          if (n == 0) {
+            throw new IllegalArgumentException("Must have at least one sample to estimate
maximum value");
+          }
+          return starter.get(starter.size() - 1);
+        default:
+          throw new IllegalArgumentException("Quartile number must be in the range [0..4]
not " + i);
+      }
     }
   }
 



Mime
View raw message