mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1002728 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/sgd/ test/java/org/apache/mahout/classifier/sgd/
Date Wed, 29 Sep 2010 16:41:18 GMT
Author: tdunning
Date: Wed Sep 29 16:41:17 2010
New Revision: 1002728

URL: http://svn.apache.org/viewvc?rev=1002728&view=rev
Log:
Adjust bufferSize in ALR when setting custom step sizes
Got rid of buffer in saved model
Re-enabled adaptiveLogisticRegressionRoundTrip test.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1002728&r1=1002727&r2=1002728&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Wed Sep 29 16:41:17 2010
@@ -210,9 +210,7 @@ public class AdaptiveLogisticRegression 
    * @param interval  Number of training examples to use in each epoch of optimization.
    */
   public void setInterval(int interval) {
-    this.minInterval = interval;
-    this.maxInterval = interval;
-    this.cutoff = interval * (record / interval + 1);
+    setInterval(interval, interval);
   }
 
   /**
@@ -227,6 +225,8 @@ public class AdaptiveLogisticRegression 
     this.minInterval = Math.max(200, minInterval);
     this.maxInterval = Math.max(200, maxInterval);
     this.cutoff = minInterval * (record / minInterval + 1);
+    this.currentStep = minInterval;
+    bufferSize = Math.min(minInterval, bufferSize);
   }
 
   public void setPoolSize(int poolSize) {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1002728&r1=1002727&r2=1002728&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
Wed Sep 29 16:41:17 2010
@@ -350,10 +350,15 @@ public final class ModelSerializer {
       Type epType = new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
       r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"),
epType));
       r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"),
stateType));
-      r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
stateType));
+      if (x.get("best") != null) {
+        r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
stateType));
+      }
 
-      r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
-                  new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>()
{}.getType()));
+      if (x.get("buffer") != null) {
+        r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
+          new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+          }.getType()));
+      }
       return r;
     }
 
@@ -362,8 +367,6 @@ public final class ModelSerializer {
       JsonObject r = new JsonObject();
       r.add("ep", jsc.serialize(x.getEp(),
           new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType()));
-      r.add("buffer", jsc.serialize(x.getBuffer(),
-          new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {}.getType()));
       r.add("minInterval", jsc.serialize(x.getMinInterval()));
       r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
       Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=1002728&r1=1002727&r2=1002728&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Wed Sep 29 16:41:17 2010
@@ -140,7 +140,7 @@ public final class ModelSerializerTest e
     assertTrue(auc2 > auc1);
   }
 
-//  @Test
+  @Test
   public void adaptiveLogisticRegressionRoundTrip() {
     AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1());
     learner.setInterval(200);



Mime
View raw message