mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1002863 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/sgd/ examples/src/main/java/org/apache/mahout/classifier/sgd/
Date Wed, 29 Sep 2010 21:52:21 GMT
Author: tdunning
Date: Wed Sep 29 21:52:20 2010
New Revision: 1002863

URL: http://svn.apache.org/viewvc?rev=1002863&view=rev
Log:
Adjusted class structure and API for classifiers to allow a group key to passed down into
the training.
Also broke out the gradient computation to allow variant training objectives such as AUC in
addition to logistic loss.
Added the gradient to serialized content
Made a new PolymorphicTypeAdapter to limit amount of repeated code

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java Wed Sep
29 21:52:20 2010
@@ -38,8 +38,30 @@ public interface OnlineLearner {
    * the original data record in a data file.
    *
    * @param trackingKey The tracking key for this training example.
+   * @param groupKey     An optional value that allows examples to be grouped in the computation
of
+   * the update to the model.
    * @param actual   The value of the target variable.  This value should be in the half-open
- *                 interval [0..n) where n is the number of target categories.
+   *                 interval [0..n) where n is the number of target categories.
+   * @param instance The feature vector for this example.
+   */
+  void train(long trackingKey, String groupKey, int actual, Vector instance);
+
+  /**
+   * Updates the model using a particular target variable value and a feature vector.
+   * <p/>
+   * There may an assumption that if multiple passes through the training data are necessary
that
+   * the tracking key for a record will be the same for each pass and that there will be
a
+   * relatively large number of distinct tracking keys and that the low-order bits of the
tracking
+   * keys will not correlate with any of the input variables.  This tracking key is used
to assign
+   * training examples to different test/training splits.
+   * <p/>
+   * Examples of useful tracking keys include id-numbers for the training records derived
from
+   * a database id for the base table from the which the record is derived, or the offset
of
+   * the original data record in a data file.
+   *
+   * @param trackingKey The tracking key for this training example.
+   * @param actual   The value of the target variable.  This value should be in the half-open
+   *                 interval [0..n) where n is the number of target categories.
    * @param instance The feature vector for this example.
    */
   void train(long trackingKey, int actual, Vector instance);

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Wed Sep 29 21:52:20 2010
@@ -60,6 +60,9 @@ public abstract class AbstractOnlineLogi
   // can we ignore any further regularization when doing classification?
   private boolean sealed = false;
 
+  // by default we don't do any fancy training
+  private Gradient gradient = new DefaultGradient();
+
   /**
    * Chainable configuration option.
    *
@@ -149,11 +152,7 @@ public abstract class AbstractOnlineLogi
   }
 
   @Override
-  public void train(long trackingKey, int actual, Vector instance) {
-    train(actual, instance);
-  }
-
-  public void train(int actual, Vector instance) {
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
     unseal();
 
     double learningRate = currentLearningRate();
@@ -165,12 +164,9 @@ public abstract class AbstractOnlineLogi
     Vector v = classify(instance);
 
     // update each row of coefficients according to result
+    Vector gradient = this.gradient.apply(groupKey, actual, v);
     for (int i = 0; i < numCategories - 1; i++) {
-      double gradientBase = -v.getQuick(i);
-      // the use of i+1 instead of i here is what makes the 0-th category be the one without
coefficients
-      if ((i + 1) == actual) {
-        gradientBase += 1;
-      }
+      double gradientBase = gradient.get(i);
 
       // then we apply the gradientBase to the resulting element.
       Iterator<Vector.Element> nonZeros = instance.iterateNonZero();
@@ -195,6 +191,16 @@ public abstract class AbstractOnlineLogi
 
   }
 
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
   public void regularize(Vector instance) {
     if (updateSteps == null || isSealed()) {
       return;
@@ -230,6 +236,10 @@ public abstract class AbstractOnlineLogi
     this.prior = prior;
   }
 
+  public void setGradient(Gradient gradient) {
+    this.gradient = gradient;
+  }
+
   public PriorFunction getPrior() {
     return prior;
   }
@@ -308,4 +318,25 @@ public abstract class AbstractOnlineLogi
     });
     return k < 1;
   }
+
+  public static class DefaultGradient implements Gradient {
+    /**
+     * Provides a default gradient computation useful for logistic regression.  This
+     * can be over-ridden to incorporate AUC driven learning.
+     * <p>
+     * See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf
+     * @param groupKey     A grouping key to allow per-something AUC loss to be used for
training.
+     *@param actual       The target variable value.
+     * @param v            The current score vector.   @return
+     */
+    @Override
+    public final Vector apply(String groupKey, int actual, Vector v) {
+      Vector r = v.like();
+      if (actual != 0) {
+        r.setQuick(actual - 1, 1);
+      }
+      r.assign(v, Functions.MINUS);
+      return r;
+    }
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Wed Sep 29 21:52:20 2010
@@ -69,19 +69,18 @@ public class AdaptiveLogisticRegression 
   private int currentStep = 1000;
   private int bufferSize = 1000;
 
-  // transient here is a signal to GSON not to serialize pending records
-  private transient List<TrainingExample> buffer = Lists.newArrayList();
+  private List<TrainingExample> buffer = Lists.newArrayList();
   private EvolutionaryProcess<Wrapper> ep;
   private State<Wrapper> best;
   private int threadCount = 20;
   private int poolSize = 20;
   private State<Wrapper> seed;
   private int numFeatures;
-  //private double averagingWindow;
 
   private boolean freezeSurvivors = true;
 
   // for GSON
+  @SuppressWarnings({"UnusedDeclaration"})
   private AdaptiveLogisticRegression() {
   }
 
@@ -96,7 +95,7 @@ public class AdaptiveLogisticRegression 
 
   @Override
   public void train(int actual, Vector instance) {
-    train(record, actual, instance);
+    train(record, null, actual, instance);
   }
 
   @Override
@@ -104,7 +103,7 @@ public class AdaptiveLogisticRegression 
     train(trackingKey, null, actual, instance);
   }
 
-
+  @Override
   public void train(long trackingKey, String groupKey, int actual, Vector instance) {
     record++;
 
@@ -424,6 +423,7 @@ public class AdaptiveLogisticRegression 
     private Vector instance;
 
     // for GSON
+    @SuppressWarnings({"UnusedDeclaration"})
     private TrainingExample() {
     }
 

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Wed Sep 29 21:52:20 2010
@@ -90,7 +90,7 @@ public class CrossFoldLearner extends Ab
   // -------- training methods
   @Override
   public void train(int actual, Vector instance) {
-    train(record, actual, instance);
+    train(record, null, actual, instance);
   }
 
   @Override
@@ -98,6 +98,7 @@ public class CrossFoldLearner extends Ab
     train(trackingKey, null, actual, instance);
   }
 
+  @Override
   public void train(long trackingKey, String groupKey, int actual, Vector instance) {
     record++;
     int k = 0;
@@ -113,7 +114,7 @@ public class CrossFoldLearner extends Ab
           auc.addSample(actual, groupKey, v.get(1));
         }
       } else {
-        model.train(trackingKey, actual, instance);
+        model.train(trackingKey, groupKey, actual, instance);
       }
       k++;
     }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java?rev=1002863&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java Wed Sep
29 21:52:20 2010
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Provides the ability to inject a gradient into the SGD logistic regresion.
+ * Typical uses of this are to use a ranking score such as AUC instead of a
+ * normal loss function.
+ */
+public interface Gradient {
+  Vector apply(String groupKey, int actual, Vector v);
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
Wed Sep 29 21:52:20 2010
@@ -55,9 +55,10 @@ public final class ModelSerializer {
   static {
     final GsonBuilder gb = new GsonBuilder();
     gb.registerTypeAdapter(AdaptiveLogisticRegression.class, new AdaptiveLogisticRegressionTypeAdapter());
-    gb.registerTypeAdapter(Mapping.class, new MappingTypeAdapter());
-    gb.registerTypeAdapter(PriorFunction.class, new PriorTypeAdapter());
-    gb.registerTypeAdapter(OnlineAuc.class, new AucTypeAdapter());
+    gb.registerTypeAdapter(Mapping.class, new PolymorphicTypeAdapter<Mapping>());
+    gb.registerTypeAdapter(PriorFunction.class, new PolymorphicTypeAdapter<PriorFunction>());
+    gb.registerTypeAdapter(OnlineAuc.class, new PolymorphicTypeAdapter<OnlineAuc>());
+    gb.registerTypeAdapter(Gradient.class, new PolymorphicTypeAdapter<Gradient>());
     gb.registerTypeAdapter(CrossFoldLearner.class, new CrossFoldLearnerTypeAdapter());
     gb.registerTypeAdapter(Vector.class, new VectorTypeAdapter());
     gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
@@ -92,69 +93,22 @@ public final class ModelSerializer {
    * Reads a model in JSON format.
    *
    * @param in Where to read the model from.
-   * @param clazz
+   * @param clazz The class of the object we expect to read.
    * @return The LogisticModelParameters object that we read.
    */
   public static AdaptiveLogisticRegression loadJsonFrom(Reader in, Class<AdaptiveLogisticRegression>
clazz) {
     return gson().fromJson(in, clazz);
   }
 
-  private static class MappingTypeAdapter implements JsonDeserializer<Mapping>, JsonSerializer<Mapping>
{
+  private static class PolymorphicTypeAdapter<T> implements JsonDeserializer<T>,
JsonSerializer<T> {
     @Override
-    public Mapping deserialize(JsonElement jsonElement,
-                               Type type,
-                               JsonDeserializationContext jsonDeserializationContext) {
-      JsonObject x = jsonElement.getAsJsonObject();
-      try {
-        return jsonDeserializationContext.deserialize(x.get("value"), Class.forName(x.get("class").getAsString()));
-      } catch (ClassNotFoundException e) {
-        throw new IllegalStateException("Can't understand serialized data, found bad type:
"
-            + x.get("class").getAsString());
-      }
-    }
-
-    @Override
-    public JsonElement serialize(Mapping mapping, Type type, JsonSerializationContext jsonSerializationContext)
{
-      JsonObject r = new JsonObject();
-      r.add("class", new JsonPrimitive(mapping.getClass().getName()));
-      r.add("value", jsonSerializationContext.serialize(mapping));
-      return r;
-    }
-  }
-
-  private static class PriorTypeAdapter implements JsonDeserializer<PriorFunction>,
JsonSerializer<PriorFunction> {
-    @Override
-    public PriorFunction deserialize(JsonElement jsonElement,
-                                     Type type,
-                                     JsonDeserializationContext jsonDeserializationContext)
{
-      JsonObject x = jsonElement.getAsJsonObject();
-      try {
-        return jsonDeserializationContext.deserialize(x.get("value"), Class.forName(x.get("class").getAsString()));
-      } catch (ClassNotFoundException e) {
-        throw new IllegalStateException("Can't understand serialized data, found bad type:
"
-            + x.get("class").getAsString());
-      }
-    }
-
-    @Override
-    public JsonElement serialize(PriorFunction priorFunction,
-                                 Type type,
-                                 JsonSerializationContext jsonSerializationContext) {
-      JsonObject r = new JsonObject();
-      r.add("class", new JsonPrimitive(priorFunction.getClass().getName()));
-      r.add("value", jsonSerializationContext.serialize(priorFunction));
-      return r;
-    }
-  }
-
-  private static class AucTypeAdapter implements JsonDeserializer<OnlineAuc>, JsonSerializer<OnlineAuc>
{
-    @Override
-    public OnlineAuc deserialize(JsonElement jsonElement,
+    public T deserialize(JsonElement jsonElement,
                                      Type type,
                                      JsonDeserializationContext jsonDeserializationContext)
{
       JsonObject x = jsonElement.getAsJsonObject();
       try {
-        return jsonDeserializationContext.deserialize(x.get("value"), Class.forName(x.get("class").getAsString()));
+        //noinspection RedundantTypeArguments
+        return jsonDeserializationContext.<T>deserialize(x.get("value"), Class.forName(x.get("class").getAsString()));
       } catch (ClassNotFoundException e) {
         throw new IllegalStateException("Can't understand serialized data, found bad type:
"
             + x.get("class").getAsString());
@@ -162,12 +116,12 @@ public final class ModelSerializer {
     }
 
     @Override
-    public JsonElement serialize(OnlineAuc auc,
+    public JsonElement serialize(T x,
                                  Type type,
                                  JsonSerializationContext jsonSerializationContext) {
       JsonObject r = new JsonObject();
-      r.add("class", new JsonPrimitive(auc.getClass().getName()));
-      r.add("value", jsonSerializationContext.serialize(auc));
+      r.add("class", new JsonPrimitive(x.getClass().getName()));
+      r.add("value", jsonSerializationContext.serialize(x));
       return r;
     }
   }
@@ -196,6 +150,59 @@ public final class ModelSerializer {
     }
   }
 
+  private static class AdaptiveLogisticRegressionTypeAdapter implements JsonSerializer<AdaptiveLogisticRegression>,
+    JsonDeserializer<AdaptiveLogisticRegression> {
+
+    @Override
+    public AdaptiveLogisticRegression deserialize(JsonElement element, Type type, JsonDeserializationContext
jdc) {
+      JsonObject x = element.getAsJsonObject();
+      AdaptiveLogisticRegression r =
+          new AdaptiveLogisticRegression(x.get("numCategories").getAsInt(),
+                                         x.get("numFeatures").getAsInt(),
+                                         jdc.<PriorFunction>deserialize(x.get("prior"),
PriorFunction.class));
+      Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
+      if (x.get("evaluationInterval")!=null) {
+        r.setInterval(x.get("evaluationInterval").getAsInt());
+      } else {
+        r.setInterval(x.get("minInterval").getAsInt(), x.get("minInterval").getAsInt());
+      }
+      r.setRecord(x.get("record").getAsInt());
+
+      Type epType = new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
+      r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"),
epType));
+      r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"),
stateType));
+      if (x.get("best") != null) {
+        r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
stateType));
+      }
+
+      if (x.get("buffer") != null) {
+        r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
+          new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+          }.getType()));
+      }
+      return r;
+    }
+
+    @Override
+    public JsonElement serialize(AdaptiveLogisticRegression x, Type type, JsonSerializationContext
jsc) {
+      JsonObject r = new JsonObject();
+      r.add("ep", jsc.serialize(x.getEp(),
+          new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType()));
+      r.add("minInterval", jsc.serialize(x.getMinInterval()));
+      r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
+      Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
+      r.add("best", jsc.serialize(x.getBest(), stateType));
+      r.add("numFeatures", jsc.serialize(x.getNumFeatures()));
+      r.add("numCategories", jsc.serialize(x.getNumCategories()));
+      PriorFunction prior = x.getPrior();
+      JsonElement pf = jsc.serialize(prior, PriorFunction.class);
+      r.add("prior", pf);
+      r.add("record", jsc.serialize(x.getRecord()));
+      r.add("seed", jsc.serialize(x.getSeed(), stateType));
+      return r;
+    }
+  }
+
   /**
    * Tells GSON how to (de)serialize a Mahout matrix.  We assume on deserialization that
the matrix
    * is dense.
@@ -329,59 +336,6 @@ public final class ModelSerializer {
     }
   }
 
-  private static class AdaptiveLogisticRegressionTypeAdapter implements JsonSerializer<AdaptiveLogisticRegression>,
-    JsonDeserializer<AdaptiveLogisticRegression> {
-
-    @Override
-    public AdaptiveLogisticRegression deserialize(JsonElement element, Type type, JsonDeserializationContext
jdc) {
-      JsonObject x = element.getAsJsonObject();
-      AdaptiveLogisticRegression r =
-          new AdaptiveLogisticRegression(x.get("numCategories").getAsInt(),
-                                         x.get("numFeatures").getAsInt(),
-                                         jdc.<PriorFunction>deserialize(x.get("prior"),
PriorFunction.class));
-      Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
-      if (x.get("evaluationInterval")!=null) {
-        r.setInterval(x.get("evaluationInterval").getAsInt());
-      } else {
-        r.setInterval(x.get("minInterval").getAsInt(), x.get("minInterval").getAsInt());
-      }
-      r.setRecord(x.get("record").getAsInt());
-
-      Type epType = new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
-      r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"),
epType));
-      r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"),
stateType));
-      if (x.get("best") != null) {
-        r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
stateType));
-      }
-
-      if (x.get("buffer") != null) {
-        r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
-          new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
-          }.getType()));
-      }
-      return r;
-    }
-
-    @Override
-    public JsonElement serialize(AdaptiveLogisticRegression x, Type type, JsonSerializationContext
jsc) {
-      JsonObject r = new JsonObject();
-      r.add("ep", jsc.serialize(x.getEp(),
-          new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType()));
-      r.add("minInterval", jsc.serialize(x.getMinInterval()));
-      r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
-      Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
-      r.add("best", jsc.serialize(x.getBest(), stateType));
-      r.add("numFeatures", jsc.serialize(x.getNumFeatures()));
-      r.add("numCategories", jsc.serialize(x.getNumCategories()));
-      PriorFunction prior = x.getPrior();
-      JsonElement pf = jsc.serialize(prior, PriorFunction.class);
-      r.add("prior", pf);
-      r.add("record", jsc.serialize(x.getRecord()));
-      r.add("seed", jsc.serialize(x.getSeed(), stateType));
-      return r;
-    }
-  }
-
   private static class EvolutionaryProcessTypeAdapter implements
     InstanceCreator<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>,
     JsonDeserializer<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>,
@@ -434,5 +388,4 @@ public final class ModelSerializer {
     }
     return params;
   }
-
 }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
Wed Sep 29 21:52:20 2010
@@ -66,7 +66,7 @@ public class LogisticModelParameters {
    * the model.  If the input isn't CSV, then calling setTargetCategories before calling
saveTo will
    * suffice.
    *
-   * @return
+   * @return The CsvRecordFactory.
    */
   public CsvRecordFactory getCsvRecordFactory() {
     if (csv == null) {
@@ -83,7 +83,7 @@ public class LogisticModelParameters {
   /**
    * Creates a logistic regression trainer using the parameters collected here.
    *
-   * @return
+   * @return The newly allocated OnlineLogisticRegression object
    */
   public OnlineLogisticRegression createRegression() {
     if (lr == null) {
@@ -113,7 +113,7 @@ public class LogisticModelParameters {
    * trainer and the dictionary for the target categories.
    *
    * @param out Where to write the model.
-   * @throws IOException
+   * @throws IOException If we can't write the model.
    */
   public void saveTo(Writer out) throws IOException {
     if (lr != null) {
@@ -180,7 +180,7 @@ public class LogisticModelParameters {
   /**
    * Sets the target variable.  If you don't use the CSV record factory, then this is irrelevant.
    *
-   * @param targetVariable
+   * @param targetVariable The name of the target variable.
    */
   public void setTargetVariable(String targetVariable) {
     this.targetVariable = targetVariable;
@@ -189,7 +189,7 @@ public class LogisticModelParameters {
   /**
    * Sets the number of target categories to be considered.
    *
-   * @param maxTargetCategories
+   * @param maxTargetCategories The number of target categories.
    */
   public void setMaxTargetCategories(int maxTargetCategories) {
     this.maxTargetCategories = maxTargetCategories;



Mime
View raw message