mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1001974 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/sgd/ test/java/org/apache/mahout/classifier/sgd/
Date Tue, 28 Sep 2010 00:11:51 GMT
Author: tdunning
Date: Tue Sep 28 00:11:51 2010
New Revision: 1001974

URL: http://svn.apache.org/viewvc?rev=1001974&view=rev
Log:
Added variable evolutionary step size to improve annealing evaluation in each step.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1001974&r1=1001973&r2=1001974&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Tue Sep 28 00:11:51 2010
@@ -63,7 +63,11 @@ public class AdaptiveLogisticRegression 
   private static final int SURVIVORS = 2;
 
   private int record;
-  private int evaluationInterval = 1000;
+  private int cutoff = 1000;
+  private int minInterval = 1000;
+  private int maxInterval = 1000;
+  private int currentStep = 1000;
+  private int bufferSize = 1000;
 
   // transient here is a signal to GSON not to serialize pending records
   private transient List<TrainingExample> buffer = Lists.newArrayList();
@@ -105,7 +109,7 @@ public class AdaptiveLogisticRegression 
     record++;
 
     buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
-    if (buffer.size() > evaluationInterval) {
+    if (buffer.size() > bufferSize) {
       trainWithBufferedExamples();
     }
   }
@@ -134,19 +138,52 @@ public class AdaptiveLogisticRegression 
     } catch (ExecutionException e) {
       throw new IllegalStateException(e);
     }
+    buffer.clear();
+
+    if (record > cutoff) {
+      cutoff = nextStep(record);
 
-    // evolve based on new fitness
-    ep.mutatePopulation(SURVIVORS);
+      // evolve based on new fitness
+      ep.mutatePopulation(SURVIVORS);
 
-    if (freezeSurvivors) {
-      // now grossly hack the top survivors so they stick around.  Set their
-      // mutation rates small and also hack their learning rate to be small
-      // as well.
-      for (State<Wrapper> state : ep.getPopulation().subList(0, SURVIVORS)) {
-        state.getPayload().freeze(state);
+      if (freezeSurvivors) {
+        // now grossly hack the top survivors so they stick around.  Set their
+        // mutation rates small and also hack their learning rate to be small
+        // as well.
+        for (State<Wrapper> state : ep.getPopulation().subList(0, SURVIVORS)) {
+          state.getPayload().freeze(state);
+        }
       }
     }
-    buffer.clear();
+
+  }
+
+  public int nextStep(int recordNumber) {
+    int stepSize = stepSize(recordNumber, 2.6);
+    if (stepSize < minInterval) {
+      stepSize = minInterval;
+    }
+
+    if (stepSize > maxInterval) {
+      stepSize = maxInterval;
+    }
+
+    int newCutoff = stepSize * (recordNumber / stepSize + 1);
+    if (newCutoff < cutoff + currentStep) {
+      newCutoff = cutoff + currentStep;
+    } else {
+      this.currentStep = stepSize;
+    }
+    return newCutoff;
+  }
+
+  public static int stepSize(int recordNumber, double multiplier) {
+    final int[] bumps = new int[]{1, 2, 5};
+    double log = Math.floor(multiplier * Math.log10(recordNumber));
+    int bump = bumps[(int) log % bumps.length];
+    int scale = (int) Math.pow(10, Math.floor(log / bumps.length));
+
+    return bump * scale;
   }
 
   @Override
@@ -173,7 +210,23 @@ public class AdaptiveLogisticRegression 
    * @param interval  Number of training examples to use in each epoch of optimization.
    */
   public void setInterval(int interval) {
-    this.evaluationInterval = interval;
+    this.minInterval = interval;
+    this.maxInterval = interval;
+    this.cutoff = interval * (record / interval + 1);
+  }
+
+  /**
+   * Starts optimization using the shorter interval and progresses to the longer using the
specified
+   * number of steps per decade.  Note that values < 200 are not accepted.  Values even
that small
+   * are unlikely to be useful.
+   *
+   * @param minInterval  The minimum epoch length for the evolutionary optimization
+   * @param maxInterval  The maximum epoch length
+   */
+  public void setInterval(int minInterval, int maxInterval) {
+    this.minInterval = Math.max(200, minInterval);
+    this.maxInterval = Math.max(200, maxInterval);
+    this.cutoff = minInterval * (record / minInterval + 1);
   }
 
   public void setPoolSize(int poolSize) {
@@ -234,8 +287,12 @@ public class AdaptiveLogisticRegression 
     this.record = record;
   }
 
-  public int getEvaluationInterval() {
-    return evaluationInterval;
+  public int getMinInterval() {
+    return minInterval;
+  }
+
+  public int getMaxInterval() {
+    return maxInterval;
   }
 
   public int getNumCategories() {
@@ -246,10 +303,6 @@ public class AdaptiveLogisticRegression 
     return seed.getPayload().getLearner().getPrior();
   }
 
-  public void setEvaluationInterval(int evaluationInterval) {
-    this.evaluationInterval = evaluationInterval;
-  }
-
   public void setBuffer(List<TrainingExample> buffer) {
     this.buffer = buffer;
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1001974&r1=1001973&r2=1001974&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
Tue Sep 28 00:11:51 2010
@@ -340,7 +340,11 @@ public final class ModelSerializer {
                                          x.get("numFeatures").getAsInt(),
                                          jdc.<PriorFunction>deserialize(x.get("prior"),
PriorFunction.class));
       Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
-      r.setEvaluationInterval(x.get("evaluationInterval").getAsInt());
+      if (x.get("evaluationInterval")!=null) {
+        r.setInterval(x.get("evaluationInterval").getAsInt());
+      } else {
+        r.setInterval(x.get("minInterval").getAsInt(), x.get("minInterval").getAsInt());
+      }
       r.setRecord(x.get("record").getAsInt());
 
       Type epType = new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
@@ -360,7 +364,8 @@ public final class ModelSerializer {
           new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType()));
       r.add("buffer", jsc.serialize(x.getBuffer(),
           new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {}.getType()));
-      r.add("evaluationInterval", jsc.serialize(x.getEvaluationInterval()));
+      r.add("minInterval", jsc.serialize(x.getMinInterval()));
+      r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
       Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
       r.add("best", jsc.serialize(x.getBest(), stateType));
       r.add("numFeatures", jsc.serialize(x.getNumFeatures()));

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=1001974&r1=1001973&r2=1001974&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
Tue Sep 28 00:11:51 2010
@@ -138,4 +138,44 @@ public final class AdaptiveLogisticRegre
     // make sure that the copy didn't lose anything
     assertEquals(auc1, w.getLearner().auc(), 0);
   }
+
+  @Test
+  public void stepSize() {
+    assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2));
+    assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6));
+    assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6));
+    assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3));
+  }
+
+  @Test
+  public void constantStep() {
+    AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+    lr.setInterval(5000);
+    assertEquals(20000, lr.nextStep(15000));
+    assertEquals(20000, lr.nextStep(15001));
+    assertEquals(20000, lr.nextStep(16500));
+    assertEquals(20000, lr.nextStep(19999));
+  }
+    
+
+  @Test
+  public void growingStep() {
+    AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+    lr.setInterval(2000, 10000);
+
+    // start with minimum step size
+    for (int i = 2000; i < 20000;i+=2000) {
+      assertEquals(i + 2000, lr.nextStep(i));
+    }
+
+    // then level up a bit
+    for (int i = 20000; i < 50000; i += 5000) {
+      assertEquals(i + 5000, lr.nextStep(i));
+    }
+
+    // and more, but we top out with this step size
+    for (int i = 50000; i < 500000; i += 10000) {
+      assertEquals(i + 10000, lr.nextStep(i));
+    }
+  }
 }



Mime
View raw message