mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From p..@apache.org
Subject [32/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy into mahout-hdfs and mahout-mr, closes apache/mahout#86
Date Wed, 01 Apr 2015 18:08:03 GMT
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
new file mode 100644
index 0000000..0b2c41b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Generic definition of a 1 of n logistic regression classifier that returns probabilities in
+ * response to a feature vector.  This classifier uses 1 of n-1 coding where the 0-th category
+ * is not stored explicitly.
+ * <p/>
+ * Provides the SGD based algorithm for learning a logistic regression, but omits all
+ * annealing of learning rates.  Any extension of this abstract class must define the overall
+ * and per-term annealing for themselves.
+ */
+public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner {
+  // coefficients for the classification.  This is a dense matrix
+  // that is (numCategories-1) x numFeatures
+  protected Matrix beta;
+
+  // number of categories we are classifying.  This should the number of rows of beta plus one.
+  protected int numCategories;
+
+  protected int step;
+
+  // information about how long since coefficient rows were updated.  This allows lazy regularization.
+  protected Vector updateSteps;
+
+  // information about how many updates we have had on a location.  This allows per-term
+  // annealing a la confidence weighted learning.
+  protected Vector updateCounts;
+
+  // weight of the prior on beta
+  private double lambda = 1.0e-5;
+  protected PriorFunction prior;
+
+  // can we ignore any further regularization when doing classification?
+  private boolean sealed;
+
+  // by default we don't do any fancy training
+  private Gradient gradient = new DefaultGradient();
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param lambda New value of lambda, the weighting factor for the prior distribution.
+   * @return This, so other configurations can be chained.
+   */
+  public AbstractOnlineLogisticRegression lambda(double lambda) {
+    this.lambda = lambda;
+    return this;
+  }
+
+  /**
+   * Computes the inverse link function, by default the logistic link function.
+   *
+   * @param v The output of the linear combination in a GLM.  Note that the value
+   *          of v is disturbed.
+   * @return A version of v with the link function applied.
+   */
+  public static Vector link(Vector v) {
+    double max = v.maxValue();
+    if (max >= 40) {
+      // if max > 40, we subtract the large offset first
+      // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+      v.assign(Functions.minus(max)).assign(Functions.EXP);
+      return v.divide(v.norm(1));
+    } else {
+      v.assign(Functions.EXP);
+      return v.divide(1 + v.norm(1));
+    }
+  }
+
+  /**
+   * Computes the binomial logistic inverse link function.
+   *
+   * @param r The value to transform.
+   * @return The logit of r.
+   */
+  public static double link(double r) {
+    if (r < 0.0) {
+      double s = Math.exp(r);
+      return s / (1.0 + s);
+    } else {
+      double s = Math.exp(-r);
+      return 1.0 / (1.0 + s);
+    }
+  }
+
+  @Override
+  public Vector classifyNoLink(Vector instance) {
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+    return beta.times(instance);
+  }
+
+  public double classifyScalarNoLink(Vector instance) {
+    return beta.viewRow(0).dot(instance);
+  }
+
+  /**
+   * Returns n-1 probabilities, one for each category but the 0-th.  The probability of the 0-th
+   * category is 1 - sum(this result).
+   *
+   * @param instance A vector of features to be classified.
+   * @return A vector of probabilities, one for each of the first n-1 categories.
+   */
+  @Override
+  public Vector classify(Vector instance) {
+    return link(classifyNoLink(instance));
+  }
+
+  /**
+   * Returns a single scalar probability in the case where we have two categories.  Using this
+   * method avoids an extra vector allocation as opposed to calling classify() or an extra two
+   * vector allocations relative to classifyFull().
+   *
+   * @param instance The vector of features to be classified.
+   * @return The probability of the first of two categories.
+   * @throws IllegalArgumentException If the classifier doesn't have two categories.
+   */
+  @Override
+  public double classifyScalar(Vector instance) {
+    Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+
+    // result is a vector with one element so we can just use dot product
+    return link(classifyScalarNoLink(instance));
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    unseal();
+
+    double learningRate = currentLearningRate();
+
+    // push coefficients back to zero based on the prior
+    regularize(instance);
+
+    // update each row of coefficients according to result
+    Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
+    for (int i = 0; i < numCategories - 1; i++) {
+      double gradientBase = gradient.get(i);
+
+      // then we apply the gradientBase to the resulting element.
+      for (Element updateLocation : instance.nonZeroes()) {
+        int j = updateLocation.index();
+
+        double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);
+        beta.setQuick(i, j, newValue);
+      }
+    }
+
+    // remember that these elements got updated
+    for (Element element : instance.nonZeroes()) {
+      int j = element.index();
+      updateSteps.setQuick(j, getStep());
+      updateCounts.incrementQuick(j, 1);
+    }
+    nextStep();
+
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
+  public void regularize(Vector instance) {
+    if (updateSteps == null || isSealed()) {
+      return;
+    }
+
+    // anneal learning rate
+    double learningRate = currentLearningRate();
+
+    // here we lazily apply the prior to make up for our neglect
+    for (int i = 0; i < numCategories - 1; i++) {
+      for (Element updateLocation : instance.nonZeroes()) {
+        int j = updateLocation.index();
+        double missingUpdates = getStep() - updateSteps.get(j);
+        if (missingUpdates > 0) {
+          double rate = getLambda() * learningRate * perTermLearningRate(j);
+          double newValue = prior.age(beta.get(i, j), missingUpdates, rate);
+          beta.set(i, j, newValue);
+          updateSteps.set(j, getStep());
+        }
+      }
+    }
+  }
+
+  // these two abstract methods are how extensions can modify the basic learning behavior of this object.
+
+  public abstract double perTermLearningRate(int j);
+
+  public abstract double currentLearningRate();
+
+  public void setPrior(PriorFunction prior) {
+    this.prior = prior;
+  }
+
+  public void setGradient(Gradient gradient) {
+    this.gradient = gradient;
+  }
+
+  public PriorFunction getPrior() {
+    return prior;
+  }
+
+  public Matrix getBeta() {
+    close();
+    return beta;
+  }
+
+  public void setBeta(int i, int j, double betaIJ) {
+    beta.set(i, j, betaIJ);
+  }
+
+  @Override
+  public int numCategories() {
+    return numCategories;
+  }
+
+  public int numFeatures() {
+    return beta.numCols();
+  }
+
+  public double getLambda() {
+    return lambda;
+  }
+
+  public int getStep() {
+    return step;
+  }
+
+  protected void nextStep() {
+    step++;
+  }
+
+  public boolean isSealed() {
+    return sealed;
+  }
+
+  protected void unseal() {
+    sealed = false;
+  }
+
+  private void regularizeAll() {
+    Vector all = new DenseVector(beta.numCols());
+    all.assign(1);
+    regularize(all);
+  }
+
+  @Override
+  public void close() {
+    if (!sealed) {
+      step++;
+      regularizeAll();
+      sealed = true;
+    }
+  }
+
+  public void copyFrom(AbstractOnlineLogisticRegression other) {
+    // number of categories we are classifying.  This should the number of rows of beta plus one.
+    Preconditions.checkArgument(numCategories == other.numCategories,
+            "Can't copy unless number of target categories is the same");
+
+    beta.assign(other.beta);
+
+    step = other.step;
+
+    updateSteps.assign(other.updateSteps);
+    updateCounts.assign(other.updateCounts);
+  }
+
+  public boolean validModel() {
+    double k = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+      @Override
+      public double apply(double v) {
+        return Double.isNaN(v) || Double.isInfinite(v) ? 1 : 0;
+      }
+    });
+    return k < 1;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
new file mode 100644
index 0000000..d00b021
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
@@ -0,0 +1,586 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.EvolutionaryProcess;
+import org.apache.mahout.ep.Mapping;
+import org.apache.mahout.ep.Payload;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ExecutionException;
+
+/**
+ * This is a meta-learner that maintains a pool of ordinary
+ * {@link org.apache.mahout.classifier.sgd.OnlineLogisticRegression} learners. Each
+ * member of the pool has different learning rates.  Whichever of the learners in the pool falls
+ * behind in terms of average log-likelihood will be tossed out and replaced with variants of the
+ * survivors.  This will let us automatically derive an annealing schedule that optimizes learning
+ * speed.  Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might
+ * seem that it would to maintain multiple learners in memory.  Doing this adaptation on-line as we
+ * learn also decreases the number of learning rate parameters required and replaces the normal
+ * hyper-parameter search.
+ * <p/>
+ * One wrinkle is that the pool of learners that we maintain is actually a pool of
+ * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which themselves contain several OnlineLogisticRegression
+ * objects.  These pools allow estimation
+ * of performance on the fly even if we make many passes through the data.  This does, however,
+ * increase the cost of training since if we are using 5-fold cross-validation, each vector is used
+ * 4 times for training and once for classification.  If this becomes a problem, then we should
+ * probably use a 2-way unbalanced train/test split rather than full cross validation.  With the
+ * current default settings, we have 100 learners running.  This is better than the alternative of
+ * running hundreds of training passes to find good hyper-parameters because we only have to parse
+ * and feature-ize our inputs once.  If you already have good hyper-parameters, then you might
+ * prefer to just run one CrossFoldLearner with those settings.
+ * <p/>
+ * The fitness used here is AUC.  Another alternative would be to try log-likelihood, but it is much
+ * easier to get bogus values of log-likelihood than with AUC and the results seem to accord pretty
+ * well.  It would be nice to allow the fitness function to be pluggable. This use of AUC means that
+ * AdaptiveLogisticRegression is mostly suited for binary target variables. This will be fixed
+ * before long by extending OnlineAuc to handle non-binary cases or by using a different fitness
+ * value in non-binary cases.
+ */
+public class AdaptiveLogisticRegression implements OnlineLearner, Writable {
+  public static final int DEFAULT_THREAD_COUNT = 20;
+  public static final int DEFAULT_POOL_SIZE = 20;
+  private static final int SURVIVORS = 2;
+
+  private int record;
+  private int cutoff = 1000;
+  private int minInterval = 1000;
+  private int maxInterval = 1000;
+  private int currentStep = 1000;
+  private int bufferSize = 1000;
+
+  private List<TrainingExample> buffer = Lists.newArrayList();
+  private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep;
+  private State<Wrapper, CrossFoldLearner> best;
+  private int threadCount = DEFAULT_THREAD_COUNT;
+  private int poolSize = DEFAULT_POOL_SIZE;
+  private State<Wrapper, CrossFoldLearner> seed;
+  private int numFeatures;
+
+  private boolean freezeSurvivors = true;
+
+  private static final Logger log = LoggerFactory.getLogger(AdaptiveLogisticRegression.class);
+
+  public AdaptiveLogisticRegression() {}
+
+  /**
+   * Uses {@link #DEFAULT_THREAD_COUNT} and {@link #DEFAULT_POOL_SIZE}
+   * @param numCategories The number of categories (labels) to train on
+   * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+   * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+   *
+   * @see #AdaptiveLogisticRegression(int, int, org.apache.mahout.classifier.sgd.PriorFunction, int, int)
+   */
+  public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+    this(numCategories, numFeatures, prior, DEFAULT_THREAD_COUNT, DEFAULT_POOL_SIZE);
+  }
+
+  /**
+   *
+   * @param numCategories The number of categories (labels) to train on
+   * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+   * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+   * @param threadCount The number of threads to use for training
+   * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use.
+   */
+  public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount,
+      int poolSize) {
+    this.numFeatures = numFeatures;
+    this.threadCount = threadCount;
+    this.poolSize = poolSize;
+    seed = new State<Wrapper, CrossFoldLearner>(new double[2], 10);
+    Wrapper w = new Wrapper(numCategories, numFeatures, prior);
+    seed.setPayload(w);
+
+    Wrapper.setMappings(seed);
+    seed.setPayload(w);
+    setPoolSize(this.poolSize);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(record, null, actual, instance);
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    record++;
+
+    buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
+    //don't train until we have enough examples
+    if (buffer.size() > bufferSize) {
+      trainWithBufferedExamples();
+    }
+  }
+
+  private void trainWithBufferedExamples() {
+    try {
+      this.best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+        @Override
+        public double apply(Payload<CrossFoldLearner> z, double[] params) {
+          Wrapper x = (Wrapper) z;
+          for (TrainingExample example : buffer) {
+            x.train(example);
+          }
+          if (x.getLearner().validModel()) {
+            if (x.getLearner().numCategories() == 2) {
+              return x.wrapped.auc();
+            } else {
+              return x.wrapped.logLikelihood();
+            }
+          } else {
+            return Double.NaN;
+          }
+        }
+      });
+    } catch (InterruptedException e) {
+      // ignore ... shouldn't happen
+      log.warn("Ignoring exception", e);
+    } catch (ExecutionException e) {
+      throw new IllegalStateException(e.getCause());
+    }
+    buffer.clear();
+
+    if (record > cutoff) {
+      cutoff = nextStep(record);
+
+      // evolve based on new fitness
+      ep.mutatePopulation(SURVIVORS);
+
+      if (freezeSurvivors) {
+        // now grossly hack the top survivors so they stick around.  Set their
+        // mutation rates small and also hack their learning rate to be small
+        // as well.
+        for (State<Wrapper, CrossFoldLearner> state : ep.getPopulation().subList(0, SURVIVORS)) {
+          Wrapper.freeze(state);
+        }
+      }
+    }
+
+  }
+
+  public int nextStep(int recordNumber) {
+    int stepSize = stepSize(recordNumber, 2.6);
+    if (stepSize < minInterval) {
+      stepSize = minInterval;
+    }
+
+    if (stepSize > maxInterval) {
+      stepSize = maxInterval;
+    }
+
+    int newCutoff = stepSize * (recordNumber / stepSize + 1);
+    if (newCutoff < cutoff + currentStep) {
+      newCutoff = cutoff + currentStep;
+    } else {
+      this.currentStep = stepSize;
+    }
+    return newCutoff;
+  }
+
+  public static int stepSize(int recordNumber, double multiplier) {
+    int[] bumps = {1, 2, 5};
+    double log = Math.floor(multiplier * Math.log10(recordNumber));
+    int bump = bumps[(int) log % bumps.length];
+    int scale = (int) Math.pow(10, Math.floor(log / bumps.length));
+
+    return bump * scale;
+  }
+
+  @Override
+  public void close() {
+    trainWithBufferedExamples();
+    try {
+      ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+        @Override
+        public double apply(Payload<CrossFoldLearner> payload, double[] params) {
+          CrossFoldLearner learner = ((Wrapper) payload).getLearner();
+          learner.close();
+          return learner.logLikelihood();
+        }
+      });
+    } catch (InterruptedException e) {
+      log.warn("Ignoring exception", e);
+    } catch (ExecutionException e) {
+      throw new IllegalStateException(e);
+    } finally {
+      ep.close();
+    }
+  }
+
+  /**
+   * How often should the evolutionary optimization of learning parameters occur?
+   *
+   * @param interval Number of training examples to use in each epoch of optimization.
+   */
+  public void setInterval(int interval) {
+    setInterval(interval, interval);
+  }
+
+  /**
+   * Starts optimization using the shorter interval and progresses to the longer using the specified
+   * number of steps per decade.  Note that values < 200 are not accepted.  Values even that small
+   * are unlikely to be useful.
+   *
+   * @param minInterval The minimum epoch length for the evolutionary optimization
+   * @param maxInterval The maximum epoch length
+   */
+  public void setInterval(int minInterval, int maxInterval) {
+    this.minInterval = Math.max(200, minInterval);
+    this.maxInterval = Math.max(200, maxInterval);
+    this.cutoff = minInterval * (record / minInterval + 1);
+    this.currentStep = minInterval;
+    bufferSize = Math.min(minInterval, bufferSize);
+  }
+
+  public final void setPoolSize(int poolSize) {
+    this.poolSize = poolSize;
+    setupOptimizer(poolSize);
+  }
+
+  public void setThreadCount(int threadCount) {
+    this.threadCount = threadCount;
+    setupOptimizer(poolSize);
+  }
+
+  public void setAucEvaluator(OnlineAuc auc) {
+    seed.getPayload().setAucEvaluator(auc);
+    setupOptimizer(poolSize);
+  }
+
+  private void setupOptimizer(int poolSize) {
+    ep = new EvolutionaryProcess<Wrapper, CrossFoldLearner>(threadCount, poolSize, seed);
+  }
+
+  /**
+   * Returns the size of the internal feature vector.  Note that this is not the same as the number
+   * of distinct features, especially if feature hashing is being used.
+   *
+   * @return The internal feature vector size.
+   */
+  public int numFeatures() {
+    return numFeatures;
+  }
+
+  /**
+   * What is the AUC for the current best member of the population.  If no member is best, usually
+   * because we haven't done any training yet, then the result is set to NaN.
+   *
+   * @return The AUC of the best member of the population or NaN if we can't figure that out.
+   */
+  public double auc() {
+    if (best == null) {
+      return Double.NaN;
+    } else {
+      Wrapper payload = best.getPayload();
+      return payload.getLearner().auc();
+    }
+  }
+
+  public State<Wrapper, CrossFoldLearner> getBest() {
+    return best;
+  }
+
+  public void setBest(State<Wrapper, CrossFoldLearner> best) {
+    this.best = best;
+  }
+
+  public int getRecord() {
+    return record;
+  }
+
+  public void setRecord(int record) {
+    this.record = record;
+  }
+
+  public int getMinInterval() {
+    return minInterval;
+  }
+
+  public int getMaxInterval() {
+    return maxInterval;
+  }
+
+  public int getNumCategories() {
+    return seed.getPayload().getLearner().numCategories();
+  }
+
+  public PriorFunction getPrior() {
+    return seed.getPayload().getLearner().getPrior();
+  }
+
+  public void setBuffer(List<TrainingExample> buffer) {
+    this.buffer = buffer;
+  }
+
+  public List<TrainingExample> getBuffer() {
+    return buffer;
+  }
+
+  public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() {
+    return ep;
+  }
+
+  public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> ep) {
+    this.ep = ep;
+  }
+
+  public State<Wrapper, CrossFoldLearner> getSeed() {
+    return seed;
+  }
+
+  public void setSeed(State<Wrapper, CrossFoldLearner> seed) {
+    this.seed = seed;
+  }
+
+  public int getNumFeatures() {
+    return numFeatures;
+  }
+
+  public void setAveragingWindow(int averagingWindow) {
+    seed.getPayload().getLearner().setWindowSize(averagingWindow);
+    setupOptimizer(poolSize);
+  }
+
+  public void setFreezeSurvivors(boolean freezeSurvivors) {
+    this.freezeSurvivors = freezeSurvivors;
+  }
+
+  /**
+   * Provides a shim between the EP optimization stuff and the CrossFoldLearner.  The most important
+   * interface has to do with the parameters of the optimization.  These are taken from the double[]
+   * params in the following order <ul> <li> regularization constant lambda <li> learningRate </ul>.
+   * All other parameters are set in such a way so as to defeat annealing to the extent possible.
+   * This lets the evolutionary algorithm handle the annealing.
+   * <p/>
+   * Note that per coefficient annealing is still done and no optimization of the per coefficient
+   * offset is done.
+   */
+  public static class Wrapper implements Payload<CrossFoldLearner> {
+    private CrossFoldLearner wrapped;
+
+    public Wrapper() {
+    }
+
+    public Wrapper(int numCategories, int numFeatures, PriorFunction prior) {
+      wrapped = new CrossFoldLearner(5, numCategories, numFeatures, prior);
+    }
+
+    @Override
+    public Wrapper copy() {
+      Wrapper r = new Wrapper();
+      r.wrapped = wrapped.copy();
+      return r;
+    }
+
+    @Override
+    public void update(double[] params) {
+      int i = 0;
+      wrapped.lambda(params[i++]);
+      wrapped.learningRate(params[i]);
+
+      wrapped.stepOffset(1);
+      wrapped.alpha(1);
+      wrapped.decayExponent(0);
+    }
+
+    public static void freeze(State<Wrapper, CrossFoldLearner> s) {
+      // radically decrease learning rate
+      double[] params = s.getParams();
+      params[1] -= 10;
+
+      // and cause evolution to hold (almost)
+      s.setOmni(s.getOmni() / 20);
+      double[] step = s.getStep();
+      for (int i = 0; i < step.length; i++) {
+        step[i] /= 20;
+      }
+    }
+
+    public static void setMappings(State<Wrapper, CrossFoldLearner> x) {
+      int i = 0;
+      // set the range for regularization (lambda)
+      x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
+      // set the range for learning rate (mu)
+      x.setMap(i, Mapping.logLimit(1.0e-8, 1));
+    }
+
+    public void train(TrainingExample example) {
+      wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance());
+    }
+
+    public CrossFoldLearner getLearner() {
+      return wrapped;
+    }
+
+    @Override
+    public String toString() {
+      return String.format(Locale.ENGLISH, "auc=%.2f", wrapped.auc());
+    }
+
+    public void setAucEvaluator(OnlineAuc auc) {
+      wrapped.setAucEvaluator(auc);
+    }
+
+    @Override
+    public void write(DataOutput out) throws IOException {
+      wrapped.write(out);
+    }
+
+    @Override
+    public void readFields(DataInput input) throws IOException {
+      wrapped = new CrossFoldLearner();
+      wrapped.readFields(input);
+    }
+  }
+
+  public static class TrainingExample implements Writable {
+    private long key;
+    private String groupKey;
+    private int actual;
+    private Vector instance;
+
+    private TrainingExample() {
+    }
+
+    public TrainingExample(long key, String groupKey, int actual, Vector instance) {
+      this.key = key;
+      this.groupKey = groupKey;
+      this.actual = actual;
+      this.instance = instance;
+    }
+
+    public long getKey() {
+      return key;
+    }
+
+    public int getActual() {
+      return actual;
+    }
+
+    public Vector getInstance() {
+      return instance;
+    }
+
+    public String getGroupKey() {
+      return groupKey;
+    }
+
+    @Override
+    public void write(DataOutput out) throws IOException {
+      out.writeLong(key);
+      if (groupKey != null) {
+        out.writeBoolean(true);
+        out.writeUTF(groupKey);
+      } else {
+        out.writeBoolean(false);
+      }
+      out.writeInt(actual);
+      VectorWritable.writeVector(out, instance, true);
+    }
+
+    @Override
+    public void readFields(DataInput in) throws IOException {
+      key = in.readLong();
+      if (in.readBoolean()) {
+        groupKey = in.readUTF();
+      }
+      actual = in.readInt();
+      instance = VectorWritable.readVector(in);
+    }
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(record);
+    out.writeInt(cutoff);
+    out.writeInt(minInterval);
+    out.writeInt(maxInterval);
+    out.writeInt(currentStep);
+    out.writeInt(bufferSize);
+
+    out.writeInt(buffer.size());
+    for (TrainingExample example : buffer) {
+      example.write(out);
+    }
+
+    ep.write(out);
+
+    best.write(out);
+
+    out.writeInt(threadCount);
+    out.writeInt(poolSize);
+    seed.write(out);
+    out.writeInt(numFeatures);
+
+    out.writeBoolean(freezeSurvivors);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    record = in.readInt();
+    cutoff = in.readInt();
+    minInterval = in.readInt();
+    maxInterval = in.readInt();
+    currentStep = in.readInt();
+    bufferSize = in.readInt();
+
+    int n = in.readInt();
+    buffer = Lists.newArrayList();
+    for (int i = 0; i < n; i++) {
+      TrainingExample example = new TrainingExample();
+      example.readFields(in);
+      buffer.add(example);
+    }
+
+    ep = new EvolutionaryProcess<Wrapper, CrossFoldLearner>();
+    ep.readFields(in);
+
+    best = new State<Wrapper, CrossFoldLearner>();
+    best.readFields(in);
+
+    threadCount = in.readInt();
+    poolSize = in.readInt();
+    seed = new State<Wrapper, CrossFoldLearner>();
+    seed.readFields(in);
+
+    numFeatures = in.readInt();
+    freezeSurvivors = in.readBoolean();
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
new file mode 100644
index 0000000..36bcae0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
@@ -0,0 +1,334 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Does cross-fold validation of log-likelihood and AUC on several online logistic regression
+ * models. Each record is passed to all but one of the models for training and to the remaining
+ * model for evaluation.  In order to maintain proper segregation between the different folds across
+ * training data iterations, data should either be passed to this learner in the same order each
+ * time the training data is traversed or a tracking key such as the file offset of the training
+ * record should be passed with each training example.
+ */
+public class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Writable {
+  private int record;
+  // minimum score to be used for computing log likelihood
+  private static final double MIN_SCORE = 1.0e-50;
+  private OnlineAuc auc = new GlobalOnlineAuc();
+  private double logLikelihood;
+  private final List<OnlineLogisticRegression> models = Lists.newArrayList();
+
+  // lambda, learningRate, perTermOffset, perTermExponent
+  private double[] parameters = new double[4];
+  private int numFeatures;
+  private PriorFunction prior;
+  private double percentCorrect;
+
+  private int windowSize = Integer.MAX_VALUE;
+
+  public CrossFoldLearner() {
+  }
+
+  public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) {
+    this.numFeatures = numFeatures;
+    this.prior = prior;
+    for (int i = 0; i < folds; i++) {
+      OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior);
+      model.alpha(1).stepOffset(0).decayExponent(0);
+      models.add(model);
+    }
+  }
+
+  // -------- builder-like configuration methods
+
+  public CrossFoldLearner lambda(double v) {
+    for (OnlineLogisticRegression model : models) {
+      model.lambda(v);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner learningRate(double x) {
+    for (OnlineLogisticRegression model : models) {
+      model.learningRate(x);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner stepOffset(int x) {
+    for (OnlineLogisticRegression model : models) {
+      model.stepOffset(x);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner decayExponent(double x) {
+    for (OnlineLogisticRegression model : models) {
+      model.decayExponent(x);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner alpha(double alpha) {
+    for (OnlineLogisticRegression model : models) {
+      model.alpha(alpha);
+    }
+    return this;
+  }
+
+  // -------- training methods
+  @Override
+  public void train(int actual, Vector instance) {
+    train(record, null, actual, instance);
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    record++;
+    int k = 0;
+    for (OnlineLogisticRegression model : models) {
+      if (k == mod(trackingKey, models.size())) {
+        Vector v = model.classifyFull(instance);
+        double score = Math.max(v.get(actual), MIN_SCORE);
+        logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);
+
+        int correct = v.maxValueIndex() == actual ? 1 : 0;
+        percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize);
+        if (numCategories() == 2) {
+          auc.addSample(actual, groupKey, v.get(1));
+        }
+      } else {
+        model.train(trackingKey, groupKey, actual, instance);
+      }
+      k++;
+    }
+  }
+
+  private static long mod(long x, int y) {
+    long r = x % y;
+    return r < 0 ? r + y : r;
+  }
+
+  @Override
+  public void close() {
+    for (OnlineLogisticRegression m : models) {
+      m.close();
+    }
+  }
+
+  public void resetLineCounter() {
+    record = 0;
+  }
+
+  public boolean validModel() {
+    boolean r = true;
+    for (OnlineLogisticRegression model : models) {
+      r &= model.validModel();
+    }
+    return r;
+  }
+
+  // -------- classification methods
+
+  @Override
+  public Vector classify(Vector instance) {
+    Vector r = new DenseVector(numCategories() - 1);
+    DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
+    for (OnlineLogisticRegression model : models) {
+      r.assign(model.classify(instance), scale);
+    }
+    return r;
+  }
+
+  @Override
+  public Vector classifyNoLink(Vector instance) {
+    Vector r = new DenseVector(numCategories() - 1);
+    DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
+    for (OnlineLogisticRegression model : models) {
+      r.assign(model.classifyNoLink(instance), scale);
+    }
+    return r;
+  }
+
+  @Override
+  public double classifyScalar(Vector instance) {
+    double r = 0;
+    int n = 0;
+    for (OnlineLogisticRegression model : models) {
+      n++;
+      r += model.classifyScalar(instance);
+    }
+    return r / n;
+  }
+
+  // -------- status reporting methods
+  
+  @Override
+  public int numCategories() {
+    return models.get(0).numCategories();
+  }
+
+  public double auc() {
+    return auc.auc();
+  }
+
+  public double logLikelihood() {
+    return logLikelihood;
+  }
+
+  public double percentCorrect() {
+    return percentCorrect;
+  }
+
+  // -------- evolutionary optimization
+
+  public CrossFoldLearner copy() {
+    CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior);
+    r.models.clear();
+    for (OnlineLogisticRegression model : models) {
+      model.close();
+      OnlineLogisticRegression newModel =
+          new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior);
+      newModel.copyFrom(model);
+      r.models.add(newModel);
+    }
+    return r;
+  }
+
+  public int getRecord() {
+    return record;
+  }
+
+  public void setRecord(int record) {
+    this.record = record;
+  }
+
+  public OnlineAuc getAucEvaluator() {
+    return auc;
+  }
+
+  public void setAucEvaluator(OnlineAuc auc) {
+    this.auc = auc;
+  }
+
+  public double getLogLikelihood() {
+    return logLikelihood;
+  }
+
+  public void setLogLikelihood(double logLikelihood) {
+    this.logLikelihood = logLikelihood;
+  }
+
+  public List<OnlineLogisticRegression> getModels() {
+    return models;
+  }
+
+  public void addModel(OnlineLogisticRegression model) {
+    models.add(model);
+  }
+
+  public double[] getParameters() {
+    return parameters;
+  }
+
+  public void setParameters(double[] parameters) {
+    this.parameters = parameters;
+  }
+
+  public int getNumFeatures() {
+    return numFeatures;
+  }
+
+  public void setNumFeatures(int numFeatures) {
+    this.numFeatures = numFeatures;
+  }
+
+  public void setWindowSize(int windowSize) {
+    this.windowSize = windowSize;
+    auc.setWindowSize(windowSize);
+  }
+
+  public PriorFunction getPrior() {
+    return prior;
+  }
+
+  public void setPrior(PriorFunction prior) {
+    this.prior = prior;
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(record);
+    PolymorphicWritable.write(out, auc);
+    out.writeDouble(logLikelihood);
+    out.writeInt(models.size());
+    for (OnlineLogisticRegression model : models) {
+      model.write(out);
+    }
+
+    for (double x : parameters) {
+      out.writeDouble(x);
+    }
+    out.writeInt(numFeatures);
+    PolymorphicWritable.write(out, prior);
+    out.writeDouble(percentCorrect);
+    out.writeInt(windowSize);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    record = in.readInt();
+    auc = PolymorphicWritable.read(in, OnlineAuc.class);
+    logLikelihood = in.readDouble();
+    int n = in.readInt();
+    for (int i = 0; i < n; i++) {
+      OnlineLogisticRegression olr = new OnlineLogisticRegression();
+      olr.readFields(in);
+      models.add(olr);
+    }
+    parameters = new double[4];
+    for (int i = 0; i < 4; i++) {
+      parameters[i] = in.readDouble();
+    }
+    numFeatures = in.readInt();
+    prior = PolymorphicWritable.read(in, PriorFunction.class);
+    percentCorrect = in.readDouble();
+    windowSize = in.readInt();
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
new file mode 100644
index 0000000..b21860f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import org.apache.commons.csv.CSVUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Converts CSV data lines to vectors.
+ *
+ * Use of this class proceeds in a few steps.
+ * <ul>
+ * <li> At construction time, you tell the class about the target variable and provide
+ * a dictionary of the types of the predictor values.  At this point,
+ * the class yet cannot decode inputs because it doesn't know the fields that are in the
+ * data records, nor their order.
+ * <li> Optionally, you tell the parser object about the possible values of the target
+ * variable.  If you don't do this then you probably should set the number of distinct
+ * values so that the target variable values will be taken from a restricted range.
+ * <li> Later, when you get a list of the fields, typically from the first line of a CSV
+ * file, you tell the factory about these fields and it builds internal data structures
+ * that allow it to decode inputs.  The most important internal state is the field numbers
+ * for various fields.  After this point, you can use the factory for decoding data.
+ * <li> To encode data as a vector, you present a line of input to the factory and it
+ * mutates a vector that you provide.  The factory also retains trace information so
+ * that it can approximately reverse engineer vectors later.
+ * <li> After converting data, you can ask for an explanation of the data in terms of
+ * terms and weights.  In order to explain a vector accurately, the factory needs to
+ * have seen the particular values of categorical fields (typically during encoding vectors)
+ * and needs to have a reasonably small number of collisions in the vector encoding.
+ * </ul>
+ */
+public class CsvRecordFactory implements RecordFactory {
+  private static final String INTERCEPT_TERM = "Intercept Term";
+
+  private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY =
+          ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder()
+                  .put("continuous", ContinuousValueEncoder.class)
+                  .put("numeric", ContinuousValueEncoder.class)
+                  .put("n", ContinuousValueEncoder.class)
+                  .put("word", StaticWordValueEncoder.class)
+                  .put("w", StaticWordValueEncoder.class)
+                  .put("text", TextValueEncoder.class)
+                  .put("t", TextValueEncoder.class)
+                  .build();
+
+  private final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
+
+  private int target;
+  private final Dictionary targetDictionary;
+  
+  //Which column is  used for identify a CSV file line 
+  private String idName;
+  private int id = -1;
+
+  private List<Integer> predictors;
+  private Map<Integer, FeatureVectorEncoder> predictorEncoders;
+  private int maxTargetValue = Integer.MAX_VALUE;
+  private final String targetName;
+  private final Map<String, String> typeMap;
+  private List<String> variableNames;
+  private boolean includeBiasTerm;
+  private static final String CANNOT_CONSTRUCT_CONVERTER =
+      "Unable to construct type converter... shouldn't be possible";
+
+  /**
+   * Parse a single line of CSV-formatted text.
+   *
+   * Separated to make changing this functionality for the entire class easier
+   * in the future.
+   * @param line - CSV formatted text
+   * @return List<String>
+   */
+  private List<String> parseCsvLine(String line) {
+    try {
+      return Arrays.asList(CSVUtils.parseLine(line));
+	   }
+	   catch (IOException e) {
+      List<String> list = Lists.newArrayList();
+      list.add(line);
+      return list;
+   	}
+  }
+
+  private List<String> parseCsvLine(CharSequence line) {
+    return parseCsvLine(line.toString());
+  }
+
+  /**
+   * Construct a parser for CSV lines that encodes the parsed data in vector form.
+   * @param targetName            The name of the target variable.
+   * @param typeMap               A map describing the types of the predictor variables.
+   */
+  public CsvRecordFactory(String targetName, Map<String, String> typeMap) {
+    this.targetName = targetName;
+    this.typeMap = typeMap;
+    targetDictionary = new Dictionary();
+  }
+
+  public CsvRecordFactory(String targetName, String idName, Map<String, String> typeMap) {
+    this(targetName, typeMap);
+    this.idName = idName;
+  }
+
+  /**
+   * Defines the values and thus the encoding of values of the target variables.  Note
+   * that any values of the target variable not present in this list will be given the
+   * value of the last member of the list.
+   * @param values  The values the target variable can have.
+   */
+  @Override
+  public void defineTargetCategories(List<String> values) {
+    Preconditions.checkArgument(
+        values.size() <= maxTargetValue,
+        "Must have less than or equal to " + maxTargetValue + " categories for target variable, but found "
+            + values.size());
+    if (maxTargetValue == Integer.MAX_VALUE) {
+      maxTargetValue = values.size();
+    }
+
+    for (String value : values) {
+      targetDictionary.intern(value);
+    }
+  }
+
+  /**
+   * Defines the number of target variable categories, but allows this parser to
+   * pick encodings for them as they appear.
+   * @param max  The number of categories that will be expected.  Once this many have been
+   * seen, all others will get the encoding max-1.
+   */
+  @Override
+  public CsvRecordFactory maxTargetValue(int max) {
+    maxTargetValue = max;
+    return this;
+  }
+
+  @Override
+  public boolean usesFirstLineAsSchema() {
+    return true;
+  }
+
+  /**
+   * Processes the first line of a file (which should contain the variable names). The target and
+   * predictor column numbers are set from the names on this line.
+   *
+   * @param line       Header line for the file.
+   */
+  @Override
+  public void firstLine(String line) {
+    // read variable names, build map of name -> column
+    final Map<String, Integer> vars = Maps.newHashMap();
+    variableNames = parseCsvLine(line);
+    int column = 0;
+    for (String var : variableNames) {
+      vars.put(var, column++);
+    }
+
+    // record target column and establish dictionary for decoding target
+    target = vars.get(targetName);
+    
+    // record id column
+    if (idName != null) {
+      id = vars.get(idName);
+    }
+
+    // create list of predictor column numbers
+    predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {
+      @Override
+      public Integer apply(String from) {
+        Integer r = vars.get(from);
+        Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars);
+        return r;
+      }
+    }));
+
+    if (includeBiasTerm) {
+      predictors.add(-1);
+    }
+    Collections.sort(predictors);
+
+    // and map from column number to type encoder for each column that is a predictor
+    predictorEncoders = Maps.newHashMap();
+    for (Integer predictor : predictors) {
+      String name;
+      Class<? extends FeatureVectorEncoder> c;
+      if (predictor == -1) {
+        name = INTERCEPT_TERM;
+        c = ConstantValueEncoder.class;
+      } else {
+        name = variableNames.get(predictor);
+        c = TYPE_DICTIONARY.get(typeMap.get(name));
+      }
+      try {
+        Preconditions.checkArgument(c != null, "Invalid type of variable %s,  wanted one of %s",
+          typeMap.get(name), TYPE_DICTIONARY.keySet());
+        Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class);
+        Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s", typeMap.get(name));
+        FeatureVectorEncoder encoder = constructor.newInstance(name);
+        predictorEncoders.put(predictor, encoder);
+        encoder.setTraceDictionary(traceDictionary);
+      } catch (InstantiationException e) {
+        throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+      } catch (IllegalAccessException e) {
+        throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+      } catch (InvocationTargetException e) {
+        throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+      } catch (NoSuchMethodException e) {
+        throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+      }
+    }
+  }
+
+
+  /**
+   * Decodes a single line of CSV data and records the target and predictor variables in a record.
+   * As a side effect, features are added into the featureVector.  Returns the value of the target
+   * variable.
+   *
+   * @param line          The raw data.
+   * @param featureVector Where to fill in the features.  Should be zeroed before calling
+   *                      processLine.
+   * @return The value of the target variable.
+   */
+  @Override
+  public int processLine(String line, Vector featureVector) {
+    List<String> values = parseCsvLine(line);
+
+    int targetValue = targetDictionary.intern(values.get(target));
+    if (targetValue >= maxTargetValue) {
+      targetValue = maxTargetValue - 1;
+    }
+
+    for (Integer predictor : predictors) {
+      String value;
+      if (predictor >= 0) {
+        value = values.get(predictor);
+      } else {
+        value = null;
+      }
+      predictorEncoders.get(predictor).addToVector(value, featureVector);
+    }
+    return targetValue;
+  }
+  
+  /***
+   * Decodes a single line of CSV data and records the target(if retrunTarget is true)
+   * and predictor variables in a record. As a side effect, features are added into the featureVector.
+   * Returns the value of the target variable. When used during classify against production data without
+   * target value, the method will be called with returnTarget = false. 
+   * @param line The raw data.
+   * @param featureVector Where to fill in the features.  Should be zeroed before calling
+   *                      processLine.
+   * @param returnTarget whether process and return target value, -1 will be returned if false.
+   * @return The value of the target variable.
+   */
+  public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) {
+    List<String> values = parseCsvLine(line);
+    int targetValue = -1;
+    if (returnTarget) {
+      targetValue = targetDictionary.intern(values.get(target));
+      if (targetValue >= maxTargetValue) {
+        targetValue = maxTargetValue - 1;
+      }
+    }
+
+    for (Integer predictor : predictors) {
+      String value = predictor >= 0 ? values.get(predictor) : null;
+      predictorEncoders.get(predictor).addToVector(value, featureVector);
+    }
+    return targetValue;
+  }
+  
+  /***
+   * Extract the raw target string from a line read from a CSV file.
+   * @param line the line of content read from CSV file
+   * @return the raw target value in the corresponding column of CSV line 
+   */
+  public String getTargetString(CharSequence line) {
+    List<String> values = parseCsvLine(line);
+    return values.get(target);
+
+  }
+
+  /***
+   * Extract the corresponding raw target label according to a code 
+   * @param code the integer code encoded during training process
+   * @return the raw target label
+   */  
+  public String getTargetLabel(int code) {
+    for (String key : targetDictionary.values()) {
+      if (targetDictionary.intern(key) == code) {
+        return key;
+      }
+    }
+    return null;
+  }
+  
+  /***
+   * Extract the id column value from the CSV record
+   * @param line the line of content read from CSV file
+   * @return the id value of the CSV record
+   */
+  public String getIdString(CharSequence line) {
+    List<String> values = parseCsvLine(line);
+    return values.get(id);
+  }
+
+  /**
+   * Returns a list of the names of the predictor variables.
+   *
+   * @return A list of variable names.
+   */
+  @Override
+  public Iterable<String> getPredictors() {
+    return Lists.transform(predictors, new Function<Integer, String>() {
+      @Override
+      public String apply(Integer v) {
+        if (v >= 0) {
+          return variableNames.get(v);
+        } else {
+          return INTERCEPT_TERM;
+        }
+      }
+    });
+  }
+
+  @Override
+  public Map<String, Set<Integer>> getTraceDictionary() {
+    return traceDictionary;
+  }
+
+  @Override
+  public CsvRecordFactory includeBiasTerm(boolean useBias) {
+    includeBiasTerm = useBias;
+    return this;
+  }
+
+  @Override
+  public List<String> getTargetCategories() {
+    List<String> r = targetDictionary.values();
+    if (r.size() > maxTargetValue) {
+      r.subList(maxTargetValue, r.size()).clear();
+    }
+    return r;
+  }
+
+  public String getIdName() {
+    return idName;
+  }
+
+  public void setIdName(String idName) {
+    this.idName = idName;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
new file mode 100644
index 0000000..f81d8ce
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implements the basic logistic training law.
+ */
+public class DefaultGradient implements Gradient {
+  /**
+   * Provides a default gradient computation useful for logistic regression.  
+   *
+   * @param groupKey     A grouping key to allow per-something AUC loss to be used for training.
+   * @param actual       The target variable value.
+   * @param instance     The current feature vector to use for gradient computation
+   * @param classifier   The classifier that can compute scores
+   * @return  The gradient to be applied to beta
+   */
+  @Override
+  public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+    // what does the current model say?
+    Vector v = classifier.classify(instance);
+
+    Vector r = v.like();
+    if (actual != 0) {
+      r.setQuick(actual - 1, 1);
+    }
+    r.assign(v, Functions.MINUS);
+    return r;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
new file mode 100644
index 0000000..8128370
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements a linear combination of L1 and L2 priors.  This can give an
+ * interesting mixture of sparsity and load-sharing between redundant predictors.
+ */
+public class ElasticBandPrior implements PriorFunction {
+  private double alphaByLambda;
+  private L1 l1;
+  private L2 l2;
+
+  // Exists for Writable
+  public ElasticBandPrior() {
+    this(0.0);
+  }
+
+  public ElasticBandPrior(double alphaByLambda) {
+    this.alphaByLambda = alphaByLambda;
+    l1 = new L1();
+    l2 = new L2(1);
+  }
+
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    oldValue *= Math.pow(1 - alphaByLambda * learningRate, generations);
+    double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+    if (newValue * oldValue < 0.0) {
+      // don't allow the value to change sign
+      return 0.0;
+    } else {
+      return newValue;
+    }
+  }
+
+  @Override
+  public double logP(double betaIJ) {
+    return l1.logP(betaIJ) + alphaByLambda * l2.logP(betaIJ);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeDouble(alphaByLambda);
+    l1.write(out);
+    l2.write(out);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    alphaByLambda = in.readDouble();
+    l1 = new L1();
+    l1.readFields(in);
+    l2 = new L2();
+    l2.readFields(in);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
new file mode 100644
index 0000000..524fc06
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Provides the ability to inject a gradient into the SGD logistic regresion.
+ * Typical uses of this are to use a ranking score such as AUC instead of a
+ * normal loss function.
+ */
+public interface Gradient {
+  Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier);
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
new file mode 100644
index 0000000..d158f4d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Sets;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Random;
+
+/**
+ * Online gradient machine learner that tries to minimize the label ranking hinge loss.
+ * Implements a gradient machine with one sigmpod hidden layer.
+ * It tries to minimize the ranking loss of some given set of labels,
+ * so this can be used for multi-class, multi-label
+ * or auto-encoding of sparse data (e.g. text).
+ */
+public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+  public static final int WRITABLE_VERSION = 1;
+
+  // the learning rate of the algorithm
+  private double learningRate = 0.1;
+
+  // the regularization term, a positive number that controls the size of the weight vector
+  private double regularization = 0.1;
+
+  // the sparsity term, a positive number that controls the sparsity of the hidden layer. (0 - 1)
+  private double sparsity = 0.1;
+
+  // the sparsity learning rate.
+  private double sparsityLearningRate = 0.1;
+
+  // the number of features
+  private int numFeatures = 10;
+  // the number of hidden nodes
+  private int numHidden = 100;
+  // the number of output nodes
+  private int numOutput = 2;
+
+  // coefficients for the input to hidden layer.
+  // There are numHidden Vectors of dimension numFeatures.
+  private Vector[] hiddenWeights;
+
+  // coefficients for the hidden to output layer.
+  // There are numOuput Vectors of dimension numHidden.
+  private Vector[] outputWeights;
+
+  // hidden unit bias
+  private Vector hiddenBias;
+
+  // output unit bias
+  private Vector outputBias;
+
+  private final Random rnd;
+
+  public GradientMachine(int numFeatures, int numHidden, int numOutput) {
+    this.numFeatures = numFeatures;
+    this.numHidden = numHidden;
+    this.numOutput = numOutput;
+    hiddenWeights = new DenseVector[numHidden];
+    for (int i = 0; i < numHidden; i++) {
+      hiddenWeights[i] = new DenseVector(numFeatures);
+      hiddenWeights[i].assign(0);
+    }
+    hiddenBias = new DenseVector(numHidden);
+    hiddenBias.assign(0);
+    outputWeights = new DenseVector[numOutput];
+    for (int i = 0; i < numOutput; i++) {
+      outputWeights[i] = new DenseVector(numHidden);
+      outputWeights[i].assign(0);
+    }
+    outputBias = new DenseVector(numOutput);
+    outputBias.assign(0);
+    rnd = RandomUtils.getRandom();
+  }
+
+  /**
+   * Initialize weights.
+   *
+   * @param gen random number generator.
+   */
+  public void initWeights(Random gen) {
+    double hiddenFanIn = 1.0 / Math.sqrt(numFeatures);
+    for (int i = 0; i < numHidden; i++) {
+      for (int j = 0; j < numFeatures; j++) {
+        double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn;
+        hiddenWeights[i].setQuick(j, val);
+      }
+    }
+    double outputFanIn = 1.0 / Math.sqrt(numHidden);
+    for (int i = 0; i < numOutput; i++) {
+      for (int j = 0; j < numHidden; j++) {
+        double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn;
+        outputWeights[i].setQuick(j, val);
+      }
+    }
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param learningRate New value of initial learning rate.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine learningRate(double learningRate) {
+    this.learningRate = learningRate;
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param regularization A positive value that controls the weight vector size.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine regularization(double regularization) {
+    this.regularization = regularization;
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param sparsity A value between zero and one that controls the fraction of hidden units
+   *                 that are activated on average.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine sparsity(double sparsity) {
+    this.sparsity = sparsity;
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param sparsityLearningRate New value of initial learning rate for sparsity.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine sparsityLearningRate(double sparsityLearningRate) {
+    this.sparsityLearningRate = sparsityLearningRate;
+    return this;
+  }
+
+  public void copyFrom(GradientMachine other) {
+    numFeatures = other.numFeatures;
+    numHidden = other.numHidden;
+    numOutput = other.numOutput;
+    learningRate = other.learningRate;
+    regularization = other.regularization;
+    sparsity = other.sparsity;
+    sparsityLearningRate = other.sparsityLearningRate;
+    hiddenWeights = new DenseVector[numHidden];
+    for (int i = 0; i < numHidden; i++) {
+      hiddenWeights[i] = other.hiddenWeights[i].clone();
+    }
+    hiddenBias = other.hiddenBias.clone();
+    outputWeights = new DenseVector[numOutput];
+    for (int i = 0; i < numOutput; i++) {
+      outputWeights[i] = other.outputWeights[i].clone();
+    }
+    outputBias = other.outputBias.clone();
+  }
+
+  @Override
+  public int numCategories() {
+    return numOutput;
+  }
+
+  public int numFeatures() {
+    return numFeatures;
+  }
+
+  public int numHidden() {
+    return numHidden;
+  }
+
+  /**
+   * Feeds forward from input to hidden unit..
+   *
+   * @return Hidden unit activations.
+   */
+  public DenseVector inputToHidden(Vector input) {
+    DenseVector activations = new DenseVector(numHidden);
+    for (int i = 0; i < numHidden; i++) {
+      activations.setQuick(i, hiddenWeights[i].dot(input));
+    }
+    activations.assign(hiddenBias, Functions.PLUS);
+    activations.assign(Functions.min(40.0)).assign(Functions.max(-40));
+    activations.assign(Functions.SIGMOID);
+    return activations;
+  }
+
+  /**
+   * Feeds forward from hidden to output
+   *
+   * @return Output unit activations.
+   */
+  public DenseVector hiddenToOutput(Vector hiddenActivation) {
+    DenseVector activations = new DenseVector(numOutput);
+    for (int i = 0; i < numOutput; i++) {
+      activations.setQuick(i, outputWeights[i].dot(hiddenActivation));
+    }
+    activations.assign(outputBias, Functions.PLUS);
+    return activations;
+  }
+
+  /**
+   * Updates using ranking loss.
+   *
+   * @param hiddenActivation the hidden unit's activation
+   * @param goodLabels       the labels you want ranked above others.
+   * @param numTrials        how many times you want to search for the highest scoring bad label.
+   * @param gen              Random number generator.
+   */
+  public void updateRanking(Vector hiddenActivation,
+                            Collection<Integer> goodLabels,
+                            int numTrials,
+                            Random gen) {
+    // All the labels are good, do nothing.
+    if (goodLabels.size() >= numOutput) {
+      return;
+    }
+    for (Integer good : goodLabels) {
+      double goodScore = outputWeights[good].dot(hiddenActivation);
+      int highestBad = -1;
+      double highestBadScore = Double.NEGATIVE_INFINITY;
+      for (int i = 0; i < numTrials; i++) {
+        int bad = gen.nextInt(numOutput);
+        while (goodLabels.contains(bad)) {
+          bad = gen.nextInt(numOutput);
+        }
+        double badScore = outputWeights[bad].dot(hiddenActivation);
+        if (badScore > highestBadScore) {
+          highestBadScore = badScore;
+          highestBad = bad;
+        }
+      }
+      int bad = highestBad;
+      double loss = 1.0 - goodScore + highestBadScore;
+      if (loss < 0.0) {
+        continue;
+      }
+      // Note from the loss above the gradient dloss/dy , y being the label is -1 for good
+      // and +1 for bad.
+      // dy / dw is just w since  y = x' * w + b.
+      // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w.
+      // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w.
+      // dy / db = 1.
+      Vector gradGood = outputWeights[good].clone();
+      gradGood.assign(Functions.NEGATE);
+      Vector propHidden = gradGood.clone();
+      Vector gradBad = outputWeights[bad].clone();
+      propHidden.assign(gradBad, Functions.PLUS);
+      gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization)));
+      outputWeights[good].assign(gradGood, Functions.PLUS);
+      gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization)));
+      outputWeights[bad].assign(gradBad, Functions.PLUS);
+      outputBias.setQuick(good, outputBias.get(good) + learningRate);
+      outputBias.setQuick(bad, outputBias.get(bad) - learningRate);
+      // Gradient of sigmoid is s * (1 -s).
+      Vector gradSig = hiddenActivation.clone();
+      gradSig.assign(Functions.SIGMOIDGRADIENT);
+      // Multiply by the change caused by the ranking loss.
+      for (int i = 0; i < numHidden; i++) {
+        gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i));
+      }
+      for (int i = 0; i < numHidden; i++) {
+        for (int j = 0; j < numFeatures; j++) {
+          double v = hiddenWeights[i].get(j);
+          v -= learningRate * (gradSig.get(i) + regularization * v);
+          hiddenWeights[i].setQuick(j, v);
+        }
+      }
+    }
+  }
+
+  @Override
+  public Vector classify(Vector instance) {
+    Vector result = classifyNoLink(instance);
+    // Find the max value's index.
+    int max = result.maxValueIndex();
+    result.assign(0);
+    result.setQuick(max, 1.0);
+    return result.viewPart(1, result.size() - 1);
+  }
+
+  @Override
+  public Vector classifyNoLink(Vector instance) {
+    DenseVector hidden = inputToHidden(instance);
+    return hiddenToOutput(hidden);
+  }
+
+  @Override
+  public double classifyScalar(Vector instance) {
+    Vector output = classifyNoLink(instance);
+    if (output.get(0) > output.get(1)) {
+      return 0;
+    }
+    return 1;
+  }
+
+  public GradientMachine copy() {
+    close();
+    GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories());
+    r.copyFrom(this);
+    return r;
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(WRITABLE_VERSION);
+    out.writeDouble(learningRate);
+    out.writeDouble(regularization);
+    out.writeDouble(sparsity);
+    out.writeDouble(sparsityLearningRate);
+    out.writeInt(numFeatures);
+    out.writeInt(numHidden);
+    out.writeInt(numOutput);
+    VectorWritable.writeVector(out, hiddenBias);
+    for (int i = 0; i < numHidden; i++) {
+      VectorWritable.writeVector(out, hiddenWeights[i]);
+    }
+    VectorWritable.writeVector(out, outputBias);
+    for (int i = 0; i < numOutput; i++) {
+      VectorWritable.writeVector(out, outputWeights[i]);
+    }
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    int version = in.readInt();
+    if (version == WRITABLE_VERSION) {
+      learningRate = in.readDouble();
+      regularization = in.readDouble();
+      sparsity = in.readDouble();
+      sparsityLearningRate = in.readDouble();
+      numFeatures = in.readInt();
+      numHidden = in.readInt();
+      numOutput = in.readInt();
+      hiddenWeights = new DenseVector[numHidden];
+      hiddenBias = VectorWritable.readVector(in);
+      for (int i = 0; i < numHidden; i++) {
+        hiddenWeights[i] = VectorWritable.readVector(in);
+      }
+      outputWeights = new DenseVector[numOutput];
+      outputBias = VectorWritable.readVector(in);
+      for (int i = 0; i < numOutput; i++) {
+        outputWeights[i] = VectorWritable.readVector(in);
+      }
+    } else {
+      throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+    }
+  }
+
+  @Override
+  public void close() {
+    // This is an online classifier, nothing to do.
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    Vector hiddenActivation = inputToHidden(instance);
+    hiddenToOutput(hiddenActivation);
+    Collection<Integer> goodLabels = Sets.newHashSet();
+    goodLabels.add(actual);
+    updateRanking(hiddenActivation, goodLabels, 2, rnd);
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
new file mode 100644
index 0000000..28a05f2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements the Laplacian or bi-exponential prior.  This prior has a strong tendency to set coefficients to zero
+ * and thus is useful as an alternative to variable selection.  This version implements truncation which prevents
+ * a coefficient from changing sign.  If a correction would change the sign, the coefficient is truncated to zero.
+ *
+ * Note that it doesn't matter to have a scale for this distribution because after taking the derivative of the logP,
+ * the lambda coefficient used to combine the prior with the observations has the same effect.  If we had a scale here,
+ * then it would be the same effect as just changing lambda.
+ */
+public class L1 implements PriorFunction {
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+    if (newValue * oldValue < 0) {
+      // don't allow the value to change sign
+      return 0;
+    } else {
+      return newValue;
+    }
+  }
+
+  @Override
+  public double logP(double betaIJ) {
+    return -Math.abs(betaIJ);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    // stateless class has nothing to serialize
+  }
+
+  @Override
+  public void readFields(DataInput dataInput) throws IOException {
+    // stateless class has nothing to serialize
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
new file mode 100644
index 0000000..3dfb9fc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements the Gaussian prior.  This prior has a tendency to decrease large coefficients toward zero, but
+ * doesn't tend to set them to exactly zero.
+ */
+public class L2 implements PriorFunction {
+
+  private static final double HALF_LOG_2PI = Math.log(2.0 * Math.PI) / 2.0;
+
+  private double s2;
+  private double s;
+
+  public L2(double scale) {
+    s = scale;
+    s2 = scale * scale;
+  }
+
+  public L2() {
+    s = 1.0;
+    s2 = 1.0;
+  }
+
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    return oldValue * Math.pow(1.0 - learningRate / s2, generations);
+  }
+
+  @Override
+  public double logP(double betaIJ) {
+    return -betaIJ * betaIJ / s2 / 2.0 - Math.log(s) - HALF_LOG_2PI;
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeDouble(s2);
+    out.writeDouble(s);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    s2 = in.readDouble();
+    s = in.readDouble();
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
new file mode 100644
index 0000000..a290b22
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * <p>Provides a stochastic mixture of ranking updates and normal logistic updates. This uses a
+ * combination of AUC driven learning to improve ranking performance and traditional log-loss driven
+ * learning to improve log-likelihood.</p>
+ *
+ * <p>See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf</p>
+ *
+ * <p>This implementation only makes sense for the binomial case.</p>
+ */
+public class MixedGradient implements Gradient {
+
+  private final double alpha;
+  private final RankingGradient rank;
+  private final Gradient basic;
+  private final Random random = RandomUtils.getRandom();
+  private boolean hasZero;
+  private boolean hasOne;
+
+  public MixedGradient(double alpha, int window) {
+    this.alpha = alpha;
+    this.rank = new RankingGradient(window);
+    this.basic = this.rank.getBaseGradient();
+  }
+
+  @Override
+  public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+    if (random.nextDouble() < alpha) {
+      // one option is to apply a ranking update relative to our recent history
+      if (!hasZero || !hasOne) {
+        throw new IllegalStateException();
+      }
+      return rank.apply(groupKey, actual, instance, classifier);
+    } else {
+      hasZero |= actual == 0;
+      hasOne |= actual == 1;
+      // the other option is a normal update, but we have to update our history on the way
+      rank.addToHistory(actual, instance);
+      return basic.apply(groupKey, actual, instance, classifier);
+    }
+  }
+}


Mime
View raw message