mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r986045 [1/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org/apache/mahout/classifier/ core/src/test/java/org/apache/mahout/classifier/sgd/ ma...
Date Mon, 16 Aug 2010 16:56:47 GMT
Author: tdunning
Date: Mon Aug 16 16:56:46 2010
New Revision: 986045

URL: http://svn.apache.org/viewvc?rev=986045&view=rev
Log:
MAHOUT-228 - Initial SGD classifier release.

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+/**
+ * This is a very fast, non-cryptographic hash suitable for general hash-based
+ * lookup.  See http://murmurhash.googlepages.com/ for more details.
+ * <p/>
+ * <p>The C version of MurmurHash 2.0 found at that site was ported
+ * to Java by Andrzej Bialecki (ab at getopt org).</p>
+ */
+public class MurmurHash {
+  /**
+   * Hashes bytes in an array.
+   * @param data The bytes to hash.
+   * @param seed The seed for the hash.
+   * @return The 32 bit hash of the bytes in question.
+   */
+  public static int hash(byte[] data, int seed) {
+    return hash(ByteBuffer.wrap(data), seed);
+  }
+
+  /**
+   * Hashes bytes in part of an array.
+   * @param data    The data to hash.
+   * @param offset  Where to start munging.
+   * @param length  How many bytes to process.
+   * @param seed    The seed to start with.
+   * @return        The 32-bit hash of the data in question.
+   */
+  public static int hash(byte[] data, int offset, int length, int seed) {
+    return hash(ByteBuffer.wrap(data, offset, length), seed);
+  }
+
+  /**
+   * Hashes the bytes in a buffer from the current position to the limit.
+   * @param buf    The bytes to hash.
+   * @param seed   The seed for the hash.
+   * @return       The 32 bit murmur hash of the bytes in the buffer.
+   */
+  public static int hash(ByteBuffer buf, int seed) {
+    // save byte order for later restoration
+    ByteOrder byteOrder = buf.order();
+    buf.order(ByteOrder.LITTLE_ENDIAN);
+
+    int m = 0x5bd1e995;
+    int r = 24;
+
+    int h = seed ^ buf.remaining();
+
+    int k;
+    while (buf.remaining() >= 4) {
+      k = buf.getInt();
+
+      k *= m;
+      k ^= k >>> r;
+      k *= m;
+
+      h *= m;
+      h ^= k;
+    }
+
+    if (buf.remaining() > 0) {
+      ByteBuffer finish = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
+      // for big-endian version, use this first:
+      // finish.position(4-buf.remaining());
+      finish.put(buf).rewind();
+      h ^= finish.getInt();
+      h *= m;
+    }
+
+    h ^= h >>> 13;
+    h *= m;
+    h ^= h >>> 15;
+
+    buf.order(byteOrder);
+    return h;
+  }
+
+
+  public static long hash64A(byte[] data, int seed) {
+    return hash64A(ByteBuffer.wrap(data), seed);
+  }
+
+  public static long hash64A(byte[] data, int offset, int length, int seed) {
+    return hash64A(ByteBuffer.wrap(data, offset, length), seed);
+  }
+
+  public static long hash64A(ByteBuffer buf, int seed) {
+    ByteOrder byteOrder = buf.order();
+    buf.order(ByteOrder.LITTLE_ENDIAN);
+
+    long m = 0xc6a4a7935bd1e995L;
+    int r = 47;
+
+    long h = seed ^ (buf.remaining() * m);
+
+    long k;
+    while (buf.remaining() >= 8) {
+      k = buf.getLong();
+
+      k *= m;
+      k ^= k >>> r;
+      k *= m;
+
+      h ^= k;
+      h *= m;
+    }
+
+    if (buf.remaining() > 0) {
+      ByteBuffer finish = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN);
+      // for big-endian version, do this first:
+      // finish.position(8-buf.remaining());
+      finish.put(buf).rewind();
+      h ^= finish.getLong();
+      h *= m;
+    }
+
+    h ^= h >>> r;
+    h *= m;
+    h ^= h >>> r;
+
+    buf.order(byteOrder);
+    return h;
+  }
+
+  @Deprecated
+  public static long hashLong(byte[] bytes, int seed) {
+    return (((long) hash(bytes, seed ^ 120705477)) << 32) + hash(bytes, seed ^ 226137830);
+  }
+
+  @Deprecated
+  public static int hash_original(byte[] data, int seed) {
+    int m = 0x5bd1e995;
+    int r = 24;
+
+    int h = seed ^ data.length;
+
+    int len = data.length;
+    int len_4 = len >> 2;
+
+    int k;
+    for (int i = 0; i < len_4; i++) {
+      int i_4 = i << 2;
+      k = data[i_4];
+      k |= data[i_4 + 1] << 8;
+      k |= data[i_4 + 2] << 16;
+      k |= data[i_4 + 3] << 24;
+
+      k *= m;
+      k ^= k >>> r;
+      k *= m;
+
+      h *= m;
+      h ^= k;
+    }
+
+    int len_m = len_4 << 2;
+    int left = len - len_m;
+
+    switch (left) {
+      case 3:
+        h ^= (int) data[len_m + 2] << 16;
+      case 2:
+        h ^= (int) data[len_m + 1] << 8;
+      case 1:
+        h ^= (int) data[len_m];
+        h *= m;
+      default:
+    }
+
+    h ^= h >>> 13;
+    h *= m;
+    h ^= h >>> 15;
+
+    return h;
+  }
+
+  @Deprecated
+  public static long hash64A_original(byte[] data, int seed) {
+    long m = 0xc6a4a7935bd1e995L;
+    int r = 47;
+
+    int len = data.length;
+    int len_8 = len >> 3;
+
+    long h = seed ^ (len * m);
+
+    long k;
+    for (int i = 0; i < len_8; i++) {
+      int i_8 = i << 3;
+      k = (data[i_8 + 7] & 0xffL) << 56;
+      k |= (data[i_8 + 6] & 0xffL) << 48;
+      k |= (data[i_8 + 5] & 0xffL) << 40;
+      k |= (data[i_8 + 4] & 0xffL) << 32;
+      k |= (data[i_8 + 3] & 0xffL) << 24;
+      k |= (data[i_8 + 2] & 0xffL) << 16;
+      k |= (data[i_8 + 1] & 0xffL) << 8;
+      k |= (data[i_8] & 0xffL);
+
+      k *= m;
+      k ^= k >>> r;
+      k *= m;
+
+      h ^= k;
+      h *= m;
+    }
+
+    int len_m = len_8 << 3;
+    int left = len - len_m;
+
+    switch (left) {
+      case 7:
+        h ^= (data[len_m + 6] & 0xffL) << 48;
+      case 6:
+        h ^= (data[len_m + 5] & 0xffL) << 40;
+      case 5:
+        h ^= (data[len_m + 4] & 0xffL) << 32;
+      case 4:
+        h ^= (data[len_m + 3] & 0xffL) << 24;
+      case 3:
+        h ^= (data[len_m + 2] & 0xffL) << 16;
+      case 2:
+        h ^= (data[len_m + 1] & 0xffL) << 8;
+      case 1:
+        h ^= data[len_m] & 0xffL;
+        h *= m;
+      default:
+    }
+
+    h ^= h >>> r;
+    h *= m;
+    h ^= h >>> r;
+
+    return h;
+  }
+
+
+  /* Testing ...
+ static int NUM = 1000;
+
+ public static void main(String[] args) {
+   byte[] bytes = new byte[4];
+   for (int i = 0; i < NUM; i++) {
+     bytes[0] = (byte)(i & 0xff);
+     bytes[1] = (byte)((i & 0xff00) >> 8);
+     bytes[2] = (byte)((i & 0xff0000) >> 16);
+     bytes[3] = (byte)((i & 0xff000000) >> 24);
+     System.out.println(Integer.toHexString(i) + " " + Integer.toHexString(hash(bytes, 1)));
+   }
+ } */
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.util.Iterator;
+
+/**
+ * Generic definition of a 1 of n logistic regression classifier that returns probabilities in
+ * response to a feature vector.  This classifier uses 1 of n-1 coding where the 0-th category
+ * is not stored explicitly.
+ * <p/>
+ * TODO: implement symbolic input with string, overall cooccurrence and n-gram hash encoding
+ * TODO: implement reporter system to monitor progress
+ *
+ * Provides the based SGD based algorithm for learning a logistic regression, but omits all
+ * annealing of learning rates.  Any extension of this abstract class must define the overall
+ * and per-term annealing for themselves.
+ */
+public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner {
+  // coefficients for the classification.  This is a dense matrix
+  // that is (numCategories-1) x numFeatures
+  protected Matrix beta;
+
+  // number of categories we are classifying.  This should the number of rows of beta plus one.
+  protected int numCategories;
+
+  private int step = 0;
+
+  // information about how long since coefficient rows were updated.  This allows lazy regularization.
+  protected transient Vector updateSteps;
+
+  // information about how many updates we have had on a location.  This allows per-term
+  // annealing a la confidence weighted learning.
+  protected transient Vector updateCounts;
+
+  // weight of the prior on beta
+  private double lambda = 1e-5;
+  protected transient PriorFunction prior;
+
+  // can we ignore any further regularization when doing classification?
+  private boolean sealed = false;
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param lambda New value of lambda, the weighting factor for the prior distribution.
+   * @return This, so other configurations can be chained.
+   */
+  public AbstractOnlineLogisticRegression lambda(double lambda) {
+    this.lambda = lambda;
+    return this;
+  }
+
+  private Vector logisticLink(Vector v) {
+    double max = v.maxValue();
+    if (max < 40) {
+      v.assign(Functions.exp);
+      double sum = 1 + v.norm(1);
+      return v.divide(sum);
+    } else {
+      v.assign(Functions.minus(max)).assign(Functions.exp);
+      return v;
+    }
+  }
+
+  /**
+   * Returns n-1 probabilities, one for each category but the 0-th.  The probability of the 0-th
+   * category is 1 - sum(this result).
+   *
+   * @param instance A vector of features to be classified.
+   * @return A vector of probabilities, one for each of the first n-1 categories.
+   */
+  public Vector classify(Vector instance) {
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+
+    Vector v = beta.times(instance);
+    return logisticLink(v);
+  }
+
+  /**
+   * Returns a single scalar probability in the case where we have two categories.  Using this
+   * method avoids an extra vector allocation as opposed to calling classify() or an extra two
+   * vector allocations relative to classifyFull().
+   *
+   * @param instance The vector of features to be classified.
+   * @return The probability of the first of two categories.
+   * @throws IllegalArgumentException If the classifier doesn't have two categories.
+   */
+  public double classifyScalar(Vector instance) {
+    if (numCategories() != 2) {
+      throw new IllegalArgumentException("Can only call classifyScalar with two categories");
+    }
+
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+
+    // result is a vector with one element so we can just use dot product
+    double r = Math.exp(beta.getRow(0).dot(instance));
+    return r / (1 + r);
+  }
+
+  public void train(int actual, Vector instance) {
+    unseal();
+
+    double learningRate = currentLearningRate();
+
+    // push coefficients back to zero based on the prior
+    regularize(instance);
+
+    // what does the current model say?
+    Vector v = classify(instance);
+
+    // update each row of coefficients according to result
+    for (int i = 0; i < numCategories - 1; i++) {
+      double gradientBase = -v.getQuick(i);
+      // the use of i+1 instead of i here is what makes the 0-th category be the one without coefficients
+      if ((i + 1) == actual) {
+        gradientBase += 1;
+      }
+
+      // then we apply the gradientBase to the resulting element.
+      Iterator<Vector.Element> nonZeros = instance.iterateNonZero();
+      while (nonZeros.hasNext()) {
+        Vector.Element updateLocation = nonZeros.next();
+        int j = updateLocation.index();
+        double newValue = beta.get(i, j) + learningRate * gradientBase * instance.get(j) * perTermLearningRate(j);
+        beta.set(i, j, newValue);
+      }
+    }
+
+    // remember that these elements got updated
+    Iterator<Vector.Element> i = instance.iterateNonZero();
+    while (i.hasNext()) {
+      Vector.Element element = i.next();
+      int j = element.index();
+      updateSteps.setQuick(j, getStep());
+      updateCounts.setQuick(j, updateCounts.getQuick(j) + 1);
+    }
+    nextStep();
+
+  }
+
+  public void regularize(Vector instance) {
+    if (updateSteps == null || isSealed()) {
+      return;
+    }
+
+    // anneal learning rate
+    double learningRate = currentLearningRate();
+
+    // here we lazily apply the prior to make up for our neglect
+    for (int i = 0; i < numCategories - 1; i++) {
+      Iterator<Vector.Element> nonZeros = instance.iterateNonZero();
+      while (nonZeros.hasNext()) {
+        Vector.Element updateLocation = nonZeros.next();
+        int j = updateLocation.index();
+        double missingUpdates = getStep() - updateSteps.get(j);
+        if (missingUpdates > 0) {
+          double newValue = prior.age(beta.get(i, j), missingUpdates, getLambda() * learningRate * perTermLearningRate(j));
+          beta.set(i, j, newValue);
+        }
+      }
+    }
+  }
+
+  // these two abstract methods are how extensions can modify the basic learning behavior of this object.
+
+  public abstract double perTermLearningRate(int j);
+
+  public abstract double currentLearningRate();
+
+  public void setPrior(PriorFunction prior) {
+    this.prior = prior;
+  }
+
+  public PriorFunction getPrior() {
+    return prior;
+  }
+
+  public Matrix getBeta() {
+    close();
+    return beta;
+  }
+
+  public void setBeta(int i, int j, double beta_ij) {
+    beta.set(i, j, beta_ij);
+  }
+
+  public int numCategories() {
+    return numCategories;
+  }
+
+  public int numFeatures() {
+    return beta.numCols();
+  }
+
+  public double getLambda() {
+    return lambda;
+  }
+
+  public int getStep() {
+    return step;
+  }
+
+  protected void nextStep() {
+    step++;
+  }
+
+  public boolean isSealed() {
+    return sealed;
+  }
+
+  protected void unseal() {
+    sealed = false;
+  }
+
+  private void regularizeAll() {
+    Vector all = new DenseVector(beta.numCols());
+    all.assign(1);
+    regularize(all);
+  }
+
+  public void close() {
+    if (!sealed) {
+      step++;
+      regularizeAll();
+      sealed = true;
+    }
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,58 @@
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.jet.random.NegativeBinomial;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * This is a meta-learner that maintains a pool of ordinary OnlineLogisticRegression learners. Each
+ * member of the pool has different learning rates.  Whichever of the learners in the pool falls
+ * behind in terms of average log-likelihood will be tossed out and replaced with variants of the
+ * survivors.  This will let us automatically derive an annealing schedule that optimizes learning
+ * speed.  Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might
+ * seem that it would to maintain multiple learners in memory.  Doing this adaptation on-line as we
+ * learn also decreases the number of learning rate parameters required and replaces the normal
+ * hyper-parameter search.
+ */
+public class AdaptiveAnnealedLogisticRegression   implements OnlineLearner {
+  private int record = 0;
+  private List<CrossFoldLearner> pool = Lists.newArrayList();
+  private int evaluationInterval = 1000;
+
+  public AdaptiveAnnealedLogisticRegression(int poolSize, int numCategories, int numFeatures, PriorFunction prior) {
+    for (int i = 0; i < poolSize; i++) {
+      CrossFoldLearner model = new CrossFoldLearner(5, numCategories, numFeatures, prior);
+      pool.add(model);
+    }
+    NegativeBinomial nb = new NegativeBinomial(10, 0.1, new MersenneTwister());
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    for (CrossFoldLearner learner : pool) {
+      learner.train(actual, instance);
+    }
+    record++;
+    if (record % evaluationInterval == 0) {
+      Collections.sort(pool);
+      // pick a parent from the top half of the pool weighted toward the top few
+
+    }
+  }
+
+  @Override
+  public void train(int trackingKey, int actual, Vector instance) {
+    train(actual, instance);
+  }
+
+  @Override
+  public void close() {
+    //To change body of implemented methods use File | Settings | File Templates.
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Encodes words into vectors much as does WordValueEncoder while maintaining
+ * an adaptive dictionary of values seen so far.  This allows weighting of terms
+ * without a pre-scan of all of the data.
+ */
+public class AdaptiveWordValueEncoder extends WordValueEncoder {
+  private Multiset<String> dictionary;
+
+  public AdaptiveWordValueEncoder(String name) {
+    super(name);
+    dictionary = HashMultiset.create();
+  }
+
+  /**
+   * Adds a value to a vector.
+   *
+   * @param originalForm The original form of the value as a string.
+   * @param data         The vector to which the value should be added.
+   */
+  @Override
+  public void addToVector(String originalForm, double weight, Vector data) {
+    dictionary.add(originalForm);
+    super.addToVector(originalForm, weight, data);
+  }
+
+  @Override
+  protected double weight(String originalForm) {
+    // the counts here are adjusted so that every observed value has an extra 0.5 count
+    // as does a hypothetical unobserved value.  This smooths our estimates a bit and
+    // allows the first word seen to have a non-zero weight of -log(1.5 / 2)
+    double thisWord = dictionary.count(originalForm) + 0.5;
+    double allWords = dictionary.size() + dictionary.elementSet().size() * 0.5 + 0.5;
+    return -Math.log(thisWord / allWords);
+  }
+
+  public Multiset<String> getDictionary() {
+    return dictionary;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * An encoder that does the standard thing for a virtual bias term.
+ */
+public class ConstantValueEncoder extends FeatureVectorEncoder {
+  public ConstantValueEncoder(String name) {
+    super(name);
+  }
+
+  @Override
+  public void addToVector(String originalForm, double weight, Vector data) {
+    for (int i = 0; i < probes; i++) {
+      int n = hash(name, i, data.size());
+      trace(name, null, n);
+      data.set(n, data.get(n) + weight);
+    }
+  }
+
+  @Override
+  public String asString(String originalForm) {
+    return name;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Continuous values are stored in fixed randomized location in the feature vector.
+ */
+public class ContinuousValueEncoder extends FeatureVectorEncoder {
+  public ContinuousValueEncoder(String name) {
+    super(name);
+  }
+
+  /**
+   * Adds a value to a vector.
+   *
+   * @param originalForm The original form of the value as a string.
+   * @param data         The vector to which the value should be added.
+   */
+  @Override
+  public void addToVector(String originalForm, double weight, Vector data) {
+    for (int i = 0; i < probes; i++) {
+      int n = hash(name, CONTINUOUS_VALUE_HASH_SEED + i, data.size());
+      trace(name, null, n);
+      data.set(n, data.get(n) + weight * Double.parseDouble(originalForm));
+    }
+  }
+
+  /**
+   * Converts a value into a form that would help a human understand the internals of how the value
+   * is being interpreted.  For text-like things, this is likely to be a list of the terms found with
+   * associated weights (if any).
+   *
+   * @param originalForm The original form of the value as a string.
+   * @return A string that a human can read.
+   */
+  @Override
+  public String asString(String originalForm) {
+    return name + ":" + originalForm;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,137 @@
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+import java.util.List;
+
+/**
+ * Does cross-fold validation of log-likelihood and AUC on several online logistic regression
+ * models. Each record is passed to all but one of the models for training and to the remaining
+ * model for evaluation.  In order to maintain proper segregation between the different folds across
+ * training data iterations, data should either be passed to this learner in the same order each
+ * time the training data is traversed or a tracking key such as the file offset of the training
+ * record should be passed with each training example.
+ */
+class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Comparable<CrossFoldLearner> {
+  int record = 0;
+  OnlineAuc auc = new OnlineAuc();
+  double logLikelihood = 0;
+  List<OnlineLogisticRegression> models = Lists.newArrayList();
+
+  // lambda, learningRate, perTermOffset, perTermExponent
+  double[] parameters = new double[4];
+
+  CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) {
+    for (int i = 0; i < folds; i++) {
+      OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior);
+      model.alpha(1).stepOffset(0).decayExponent(0);
+      models.add(model);
+    }
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(record, actual, instance);
+  }
+
+  @Override
+  public void train(int trackingKey, int actual, Vector instance) {
+    record++;
+    int k = 0;
+    for (OnlineLogisticRegression model : models) {
+      if (k == trackingKey % models.size()) {
+        Vector v = model.classifyFull(instance);
+        double score = v.get(actual);
+        logLikelihood += (Math.log(score) - logLikelihood) / record;
+        auc.addSample(actual, v.get(1));
+      } else {
+        model.train(actual, instance);
+      }
+      k = (k + 1) % models.size();
+    }
+  }
+
+  @Override
+  public void close() {
+    for (OnlineLogisticRegression m : models) {
+      m.close();
+    }
+  }
+
+  public void resetLineCounter() {
+    record = 0;
+  }
+
+  @Override
+  public int compareTo(CrossFoldLearner other) {
+    return Double.compare(this.logLikelihood, other.logLikelihood);
+  }
+
+  public CrossFoldLearner lambda(double v) {
+    for (OnlineLogisticRegression model : models) {
+      model.lambda(v);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner learningRate(double x) {
+    for (OnlineLogisticRegression model : models) {
+      model.learningRate(x);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner stepOffset(int x) {
+    for (OnlineLogisticRegression model : models) {
+      model.stepOffset(x);
+    }
+    return this;
+  }
+
+  public CrossFoldLearner decayExponent(double x) {
+    for (OnlineLogisticRegression model : models) {
+      model.decayExponent(x);
+    }
+    return this;
+  }
+
+  @Override
+  public int numCategories() {
+    return models.get(0).numCategories();
+  }
+
+  @Override
+  public Vector classify(Vector instance) {
+    Vector r = new DenseVector(numCategories() - 1);
+    double scale = 1.0 / models.size();
+    for (OnlineLogisticRegression model : models) {
+      r.assign(model.classify(instance), Functions.plusMult(scale));
+    }
+    return r;
+  }
+
+  @Override
+  public double classifyScalar(Vector instance) {
+    double r = 0;
+    int n = 0;
+    for (OnlineLogisticRegression model : models) {
+      n++;
+      r += model.classifyScalar(instance);
+    }
+    return r / n;
+  }
+
+  public double auc() {
+    return auc.auc();
+  }
+
+  public double logLikelihood() {
+    return logLikelihood;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Function;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.mahout.math.Vector;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Converts csv data lines to vectors.
+ *
+ * Use of this class proceeds in a few steps.
+ * <ul>
+ * <li> At construction time, you tell the class about the target variable and provide
+ * a dictionary of the types of the predictor values.  At this point,
+ * the class yet cannot decode inputs because it doesn't know the fields that are in the
+ * data records, nor their order.
+ * <li> Optionally, you tell the parser object about the possible values of the target
+ * variable.  If you don't do this then you probably should set the number of distinct
+ * values so that the target variable values will be taken from a restricted range.
+ * <li> Later, when you get a list of the fields, typically from the first line of a CSV
+ * file, you tell the factory about these fields and it builds internal data structures
+ * that allow it to decode inputs.  The most important internal state is the field numbers
+ * for various fields.  After this point, you can use the factory for decoding data.
+ * <li> To encode data as a vector, you present a line of input to the factory and it
+ * mutates a vector that you provide.  The factory also retains trace information so
+ * that it can approximately reverse engineer vectors later.
+ * <li> After converting data, you can ask for an explanation of the data in terms of
+ * terms and weights.  In order to explain a vector accurately, the factory needs to
+ * have seen the particular values of categorical fields (typically during encoding vectors)
+ * and needs to have a reasonably small number of collisions in the vector encoding.
+ * </ul>
+ */
+public class CsvRecordFactory implements RecordFactory {
+  private static final String INTERCEPT_TERM = "Intercept Term";
+
+  // crude CSV value splitter.  This will fail if any double quoted strings have
+  // commas inside.  Also, escaped quotes will not be unescaped.  Good enough for now.
+  private Splitter onComma = Splitter.on(",").trimResults(CharMatcher.is('"'));
+
+  private static final Map<String, Class<? extends FeatureVectorEncoder>> typeDictionary =
+          ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder()
+                  .put("continuous", ContinuousValueEncoder.class)
+                  .put("numeric", ContinuousValueEncoder.class)
+                  .put("n", ContinuousValueEncoder.class)
+                  .put("word", StaticWordValueEncoder.class)
+                  .put("w", StaticWordValueEncoder.class)
+                  .put("text", TextValueEncoder.class)
+                  .put("t", TextValueEncoder.class)
+                  .build();
+
+  private Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
+
+  private int target;
+  private Dictionary targetDictionary;
+
+  private List<Integer> predictors;
+  private Map<Integer, FeatureVectorEncoder> predictorEncoders;
+  private int maxTargetValue = Integer.MAX_VALUE;
+  private String targetName;
+  private Map<String, String> typeMap;
+  private List<String> variableNames;
+  private boolean includeBiasTerm;
+
+  /**
+   * Construct a parser for CSV lines that encodes the parsed data in vector form.
+   * @param targetName            The name of the target variable.
+   * @param typeMap               A map describing the types of the predictor variables.
+   */
+  public CsvRecordFactory(String targetName, final Map<String, String> typeMap) {
+    this.targetName = targetName;
+    this.typeMap = typeMap;
+    targetDictionary = new Dictionary();
+  }
+
+  /**
+   * Defines the values and thus the encoding of values of the target variables.  Note
+   * that any values of the target variable not present in this list will be given the
+   * value of the last member of the list.
+   * @param values  The values the target variable can have.
+   */
+  @Override
+  public void defineTargetCategories(List<String> values) {
+    if (values.size() > maxTargetValue) {
+      throw new IllegalArgumentException("Must have less than or equal to " + maxTargetValue + " categories for target variable, but found " + values.size());
+    }
+
+    if (maxTargetValue == Integer.MAX_VALUE) {
+      maxTargetValue = values.size();
+    }
+
+    for (String value : values) {
+      targetDictionary.intern(value);
+    }
+  }
+
+  /**
+   * Defines the number of target variable categories, but allows this parser to
+   * pick encodings for them as they appear.
+   * @param max  The number of categories that will be excpeted.  Once this many have been
+   * seen, all others will get the encoding max-1.
+   */
+  @Override
+  public CsvRecordFactory maxTargetValue(int max) {
+    maxTargetValue = max;
+    return this;
+  }
+
+  @Override
+  public boolean usesFirstLineAsSchema() {
+    return true;
+  }
+
+  /**
+   * Processes the first line of a file (which should contain the variable names). The target and
+   * predictor column numbers are set from the names on this line.
+   *
+   * @param line       Header line for the file.
+   */
+  @Override
+  public void firstLine(String line) {
+    // read variable names, build map of name -> column
+    final Map<String, Integer> vars = Maps.newHashMap();
+    int column = 0;
+    variableNames = Lists.newArrayList(onComma.split(line));
+    for (String var : variableNames) {
+      vars.put(var, column++);
+    }
+
+    // record target column and establish dictionary for decoding target
+    target = vars.get(targetName);
+
+    // create list of predictor column numbers
+    predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {
+      @Override
+      public Integer apply(String from) {
+        Integer r = vars.get(from);
+        if (r == null) {
+          throw new IllegalArgumentException("Can't find variable " + from + ", only know about " + vars);
+        }
+        return r;
+      }
+    }));
+
+    if (includeBiasTerm) {
+      predictors.add(-1);
+    }
+    Collections.sort(predictors);
+
+    // and map from column number to type encoder for each column that is a predictor
+    predictorEncoders = Maps.newHashMap();
+    for (Integer predictor : predictors) {
+      String name;
+      Class<? extends FeatureVectorEncoder> c;
+      if (predictor != -1) {
+        name = variableNames.get(predictor);
+        c = typeDictionary.get(typeMap.get(name));
+      } else {
+        name = INTERCEPT_TERM;
+        c = ConstantValueEncoder.class;
+      }
+      try {
+        if (c == null) {
+          throw new IllegalArgumentException("Invalid type of variable " + typeMap.get(name) + " wanted on of " + typeDictionary.keySet());
+        }
+        Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class);
+        if (constructor == null) {
+          throw new IllegalArgumentException("Can't find correct constructor for " + typeMap.get(name));
+        }
+        FeatureVectorEncoder encoder = constructor.newInstance(name);
+        predictorEncoders.put(predictor, encoder);
+        encoder.setTraceDictionary(traceDictionary);
+      } catch (InstantiationException e) {
+        throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e);
+      } catch (IllegalAccessException e) {
+        throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e);
+      } catch (InvocationTargetException e) {
+        throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e);
+      } catch (NoSuchMethodException e) {
+        throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e);
+      }
+    }
+  }
+
+
+  /**
+   * Decodes a single line of csv data and records the target and predictor variables in a record.
+   * As a side effect, features are added into the featureVector.  Returns the value of the target
+   * variable.
+   *
+   * @param line          The raw data.
+   * @param featureVector Where to fill in the features.  Should be zeroed before calling
+   *                      processLine.
+   * @return The value of the target variable.
+   */
+  @Override
+  public int processLine(String line, Vector featureVector) {
+    List<String> values = Lists.newArrayList(onComma.split(line));
+
+    int targetValue = targetDictionary.intern(values.get(target));
+    if (targetValue >= maxTargetValue) {
+      targetValue = maxTargetValue - 1;
+    }
+
+    for (Integer predictor : predictors) {
+      String value;
+      if (predictor >= 0) {
+        value = values.get(predictor);
+      } else {
+        value = null;
+      }
+      predictorEncoders.get(predictor).addToVector(value, featureVector);
+    }
+    return targetValue;
+  }
+
+  /**
+   * Returns a list of the names of the predictor variables.
+   *
+   * @return A list of variable names.
+   */
+  @Override
+  public Iterable<String> getPredictors() {
+    return Lists.transform(predictors, new Function<Integer, String>() {
+      @Override
+      public String apply(Integer v) {
+        if (v >= 0) {
+          return variableNames.get(v);
+        } else {
+          return INTERCEPT_TERM;
+        }
+      }
+    });
+  }
+
+  @Override
+  public Map<String, Set<Integer>> getTraceDictionary() {
+    return traceDictionary;
+  }
+
+  @Override
+  public CsvRecordFactory includeBiasTerm(boolean useBias) {
+    includeBiasTerm = useBias;
+    return this;
+  }
+
+  @Override
+  public List<String> getTargetCategories() {
+    List<String> r = targetDictionary.values();
+    if (r.size() > maxTargetValue) {
+      r.subList(maxTargetValue, r.size()).clear();
+    }
+    return r;
+  }
+
+  private static class ImpossibleException extends RuntimeException {
+    private ImpossibleException(String message, Throwable cause) {
+      super(message, cause);
+    }
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Maps;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+* Assigns integer codes to strings as they appear.
+*/
+public class Dictionary {
+  private Map<String, Integer> dict = Maps.newLinkedHashMap();
+
+  public int intern(String s) {
+    if (!dict.containsKey(s)) {
+      dict.put(s, dict.size());
+    }
+    return dict.get(s);
+  }
+
+  public List<String> values() {
+    // order of keySet is guaranteed to be insertion order
+    return new ArrayList<String>(dict.keySet());
+  }
+
+  public static Dictionary fromList(List<String> values) {
+    Dictionary dict = new Dictionary();
+    for (String value : values) {
+      dict.intern(value);
+    }
+    return dict;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+public class ElasticBandPrior extends PriorFunction {
+  private double alphaByLambda;
+  private L1 l1;
+  private L2 l2;
+
+  public ElasticBandPrior(double alphaByLambda) {
+    this.alphaByLambda = alphaByLambda;
+    l1 = new L1();
+    l2 = new L2(1);
+  }
+
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    oldValue = oldValue * Math.pow(1 - alphaByLambda * learningRate , generations);
+    double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+    if (newValue * oldValue < 0) {
+      // don't allow the value to change sign
+      return 0;
+    } else {
+      return newValue;
+    }
+  }
+
+  @Override
+  public double logP(double beta_ij) {
+    return l1.logP(beta_ij) + alphaByLambda * l2.logP(beta_ij);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Sets;
+import org.apache.mahout.classifier.MurmurHash;
+import org.apache.mahout.math.Vector;
+
+import java.nio.charset.Charset;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * General interface for objects that record features into a feature vector.
+ * <p/>
+ * By convention, sub-classes should provide a constructor that accepts just a field name as well as
+ * setters to customize properties of the conversion such as adding tokenizers or a weight
+ * dictionary.
+ */
+public abstract class FeatureVectorEncoder {
+  protected static final int CONTINUOUS_VALUE_HASH_SEED = 1;
+  protected static final int WORD_LIKE_VALUE_HASH_SEED = 100;
+
+  protected String name;
+  protected int probes = 1;
+
+  private Map<String, Set<Integer>> traceDictionary = null;
+
+  public FeatureVectorEncoder(String name) {
+    this.name = name;
+  }
+
+  /**
+   * Adds a value expressed in string form to a vector.
+   *
+   * @param originalForm The original form of the value as a string.
+   * @param data         The vector to which the value should be added.
+   */
+  public void addToVector(String originalForm, Vector data) {
+    addToVector(originalForm, 1.0, data);
+  }
+
+  /**
+   * Adds a weighted value expressed in string form to a vector.  In some cases it is convenient to
+   * use this method to encode continuous values using the weight as the value.  In such cases, the
+   * string value should typically be set to null.
+   *
+   * @param originalForm The original form of the value as a string.
+   * @param weight       The weight to be applied to this feature.
+   * @param data         The vector to which the value should be added.
+   */
+  public abstract void addToVector(String originalForm, double weight, Vector data);
+
+  // ******* Utility functions used by most implementations
+
+  /**
+   * Hash a string and an integer into the range [0..numFeatures-1].
+   *
+   * @param term        The string.
+   * @param probe       An integer that modifies the resulting hash.
+   * @param numFeatures The range into which the resulting hash must fit.
+   * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+   *         term and probe.
+   */
+  protected int hash(String term, int probe, int numFeatures) {
+    long r = MurmurHash.hash64A(term.getBytes(Charset.forName("UTF-8")), probe) % numFeatures;
+    if (r < 0) {
+      r += numFeatures;
+    }
+    return (int) r;
+  }
+
+  /**
+   * Hash two strings and an integer into the range [0..numFeatures-1].
+   *
+   * @param term1       The first string.
+   * @param term2       The second string.
+   * @param probe       An integer that modifies the resulting hash.
+   * @param numFeatures The range into which the resulting hash must fit.
+   * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+   *         term and probe.
+   */
+  protected int hash(String term1, String term2, int probe, int numFeatures) {
+    long r = MurmurHash.hash64A(term1.getBytes(Charset.forName("UTF-8")), probe);
+    r = MurmurHash.hash64A(term2.getBytes(Charset.forName("UTF-8")), (int) r) % numFeatures;
+    if (r < 0) {
+      r += numFeatures;
+    }
+    return (int) r;
+  }
+
+  /**
+   * Converts a value into a form that would help a human understand the internals of how the value
+   * is being interpreted.  For text-like things, this is likely to be a list of the terms found
+   * with associated weights (if any).
+   *
+   * @param originalForm The original form of the value as a string.
+   * @return A string that a human can read.
+   */
+  public abstract String asString(String originalForm);
+
+  /**
+   * Sets the number of locations in the feature vector that a value should be in.
+   *
+   * @param probes Number of locations to increment.
+   */
+  public void setProbes(int probes) {
+    this.probes = probes;
+  }
+
+  public String getName() {
+    return name;
+  }
+
+  protected void trace(String name, String subName, int n) {
+    if (traceDictionary != null) {
+      String key = name;
+      if (subName != null) {
+        key = name + "=" + subName;
+      }
+      Set<Integer> trace = traceDictionary.get(key);
+      if (trace == null) {
+        trace = Sets.newHashSet(n);
+        traceDictionary.put(key, trace);
+      } else {
+        trace.add(n);
+      }
+    }
+  }
+
+  public void setTraceDictionary(Map<String, Set<Integer>> traceDictionary) {
+    this.traceDictionary = traceDictionary;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+/**
+ * Implements the Laplacian or bi-exponential prior.  This prior has a strong tendency to set coefficients to zero
+ * and thus is useful as an alternative to variable selection.  This version implements truncation which prevents
+ * a coefficient from changing sign.  If a correction would change the sign, the coefficient is truncated to zero.
+ *
+ * Note that it doesn't matter to have a scale for this distribution because after taking the derivative of the logP,
+ * the lambda coefficient used to combine the prior with the observations has the same effect.  If we had a scale here,
+ * then it would be the same effect as just changing lambda.
+ */
+public class L1 extends PriorFunction {
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+    if (newValue * oldValue < 0) {
+      // don't allow the value to change sign
+      return 0;
+    } else {
+      return newValue;
+    }
+  }
+
+  @Override
+  public double logP(double beta_ij) {
+    return - Math.abs(beta_ij) ;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import static java.lang.Math.log;
+
+/**
+ * Implements the Gaussian prior.  This prior has a tendency to decrease large coefficients toward zero, but
+ * doesn't tend to set them to exactly zero.
+ */
+public class L2 extends PriorFunction {
+  private double s2;
+  private double s;
+
+  public L2(double scale) {
+    this.s = scale;
+    this.s2 = scale * scale;
+  }
+
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    return oldValue * Math.pow(1 - learningRate / s2, generations);
+  }
+
+  @Override
+  public double logP(double beta_ij) {
+    return -beta_ij * beta_ij / s2 / 2 - log(s) - log(2 * Math.PI) / 2;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Extends the basic on-line logistic regression learner with a specific set of learning
+ * rate annealing schedules.
+ */
+public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression {
+  // these next two control decayFactor^steps exponential type of annealing
+  // learning rate and decay factor
+  private double mu_0 = 1;
+  private double decayFactor = 1 - 1e-3;
+
+
+  // these next two control 1/steps^forget type annealing
+  private int stepOffset = 10;
+  // -1 equals even weighting of all examples, 0 means only use exponential annealing
+  private double forgettingExponent = -0.5;
+
+  private int perTermAnnealingOffset = 20;
+
+  private OnlineLogisticRegression() {
+    // private constructor available for Gson, but not normal use
+  }
+
+  public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+    this.numCategories = numCategories;
+    this.prior = prior;
+
+    updateSteps = new DenseVector(numFeatures);
+    updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset);
+    beta = new DenseMatrix(numCategories - 1, numFeatures);
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param alpha New value of decayFactor, the exponential decay rate for the learning rate.
+   * @return This, so other configurations can be chained.
+   */
+  public OnlineLogisticRegression alpha(double alpha) {
+    this.decayFactor = alpha;
+    return this;
+  }
+
+  public OnlineLogisticRegression lambda(double lambda) {
+    // we only over-ride this to provide a more restrictive return type
+    super.lambda(lambda);
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param learningRate New value of initial learning rate.
+   * @return This, so other configurations can be chained.
+   */
+  public OnlineLogisticRegression learningRate(double learningRate) {
+    this.mu_0 = learningRate;
+    return this;
+  }
+
+  public OnlineLogisticRegression stepOffset(int stepOffset) {
+    this.stepOffset = stepOffset;
+    return this;
+  }
+
+  public OnlineLogisticRegression decayExponent(double decayExponent) {
+    if (decayExponent > 0) {
+      decayExponent = -decayExponent;
+    }
+    this.forgettingExponent = decayExponent;
+    return this;
+  }
+
+
+  @Override
+  public double perTermLearningRate(int j) {
+    return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j));
+  }
+
+  @Override
+  public double currentLearningRate() {
+    return mu_0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent);
+  }
+
+  @Override
+  public void train(int trackingKey, int actual, Vector instance) {
+    train(actual, instance);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+/**
+ * A prior is used to regularize the learning algorithm.
+ */
+public abstract class PriorFunction {
+  /**
+   * Applies the regularization to a coefficient.
+   * @param oldValue        The previous value.
+   * @param generations     The number of generations.
+   * @param learningRate    The learning rate with lambda baked in.
+   * @return                The new coefficient value after regularization.
+   */
+  public abstract double age(double oldValue, double generations, double learningRate);
+
+  /**
+   * Returns the log of the probability of a particular coefficient value according to the prior.
+   * @param beta_ij         The coefficient.
+   * @return                The log probability.
+   */
+  public abstract double logP(double beta_ij);
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Jun 14, 2010 Time: 8:35:09 AM To change this
+ * template use File | Settings | File Templates.
+ */
+public interface RecordFactory {
+  void defineTargetCategories(List<String> values);
+
+  RecordFactory maxTargetValue(int max);
+
+  boolean usesFirstLineAsSchema();
+
+  int processLine(String line, Vector featureVector);
+
+  Iterable<String> getPredictors();
+
+  Map<String, Set<Integer>> getTraceDictionary();
+
+  RecordFactory includeBiasTerm(boolean useBias);
+
+  List<String> getTargetCategories();
+
+  void firstLine(String line);
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.util.Collections;
+import java.util.Map;
+
+/**
+ * Encodes a categorical values with an unbounded vocabulary.  Values are encoding by incrementing
+ * a few locations in the output vector with a weight that is either defaulted to 1 or that is
+ * looked up in a weight dictionary.  By default, only one probe is used which should be fine
+ * but could cause a decrease in the speed of learning because more features will be non-zero.
+ * If a large feature vector is used so that the probability of feature collisions is suitably
+ * small, then this can be decreased to 1.  If a very small feature vector is used, the number
+ * of probes should probably be increased to 3.
+ */
+public class StaticWordValueEncoder extends WordValueEncoder {
+  private Map<String, Double> dictionary;
+  private double missingValueWeight = 1;
+
+  public StaticWordValueEncoder(String name) {
+    super(name);
+  }
+
+  /**
+   * Sets the weighting dictionary to be used by this encoder.  Also sets
+   * the missing value weight to be half the smallest weight in the dictionary.
+   * @param dictionary  The dictionary to use to look up weights.
+   */
+  public void setDictionary(Map<String, Double> dictionary) {
+    this.dictionary = dictionary;
+    missingValueWeight = Collections.min(dictionary.values()) / 2;
+  }
+
+  /**
+   * Sets the weight that is to be used for values that do not appear in the dictionary.
+   * @param missingValueWeight  The default weight for missing values.
+   */
+  public void setMissingValueWeight(double missingValueWeight) {
+    this.missingValueWeight = missingValueWeight;
+  }
+
+  @Override
+  protected double weight(String originalForm) {
+    double weight = missingValueWeight;
+    if (dictionary != null && dictionary.containsKey(originalForm)) {
+      weight = dictionary.get(originalForm);
+    }
+    return weight;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import static java.lang.Math.log;
+import static org.apache.commons.math.special.Gamma.logGamma;
+
+/**
+ * Provides a t-distribution as a prior.
+ */
+public class TPrior extends PriorFunction {
+  private double df;
+
+  public TPrior(double df) {
+    this.df = df;
+  }
+
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    for (int i = 0; i < generations; i++) {
+      oldValue = oldValue - learningRate * oldValue * (df + 1) / (df + oldValue * oldValue);
+    }
+    return oldValue;
+  }
+
+  @Override
+  public double logP(double beta_ij) {
+    return logGamma((df + 1) / 2) - log(df * Math.PI) - logGamma(df / 2) - (df + 1) / 2 * log(1 + beta_ij * beta_ij);
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Splitter;
+import org.apache.mahout.math.Vector;
+
+import java.util.regex.Pattern;
+
+/**
+ * Encodes text that is tokenized on non-alphanum separators.  Each word is encoded using a
+ * settable encoder which is by default an StaticWordValueEncoder which gives all
+ * words the same weight.
+ */
+public class TextValueEncoder extends FeatureVectorEncoder {
+  Splitter onNonWord = Splitter.on(Pattern.compile("\\W+")).omitEmptyStrings();
+  private FeatureVectorEncoder wordEncoder;
+
+  public TextValueEncoder(String name) {
+    super(name);
+    wordEncoder = new StaticWordValueEncoder(name);
+    probes = 2;
+  }
+
+  /**
+   * Adds a value to a vector after tokenizing it by splitting on non-alphanum characters.
+   *
+   * @param originalForm The original form of the value as a string.
+   * @param data         The vector to which the value should be added.
+   */
+  @Override
+  public void addToVector(String originalForm, double weight, Vector data) {
+    for (String word : tokenize(originalForm)) {
+      wordEncoder.addToVector(word, weight, data);
+    }
+  }
+
+  private Iterable<String> tokenize(String originalForm) {
+    return onNonWord.split(originalForm);
+  }
+
+  /**
+   * Converts a value into a form that would help a human understand the internals of how the value
+   * is being interpreted.  For text-like things, this is likely to be a list of the terms found with
+   * associated weights (if any).
+   *
+   * @param originalForm The original form of the value as a string.
+   * @return A string that a human can read.
+   */
+  @Override
+  public String asString(String originalForm) {
+    StringBuilder r = new StringBuilder("[");
+    String sep = "";
+    for (String word : tokenize(originalForm)) {
+      r.append(sep);
+      r.append(wordEncoder.asString(word));
+      sep = ", ";
+    }
+    r.append("]");
+    return r.toString();
+  }
+
+  public void setWordEncoder(FeatureVectorEncoder wordEncoder) {
+    this.wordEncoder = wordEncoder;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+/**
+ * A uniform prior.  This is an improper prior that corresponds to no regularization at all.
+ */
+public class UniformPrior extends PriorFunction {
+  @Override
+  public double age(double oldValue, double generations, double learningRate) {
+    return oldValue;
+  }
+
+  @Override
+  public double logP(double beta_ij) {
+    return 0;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Encodes words as sparse vector updates to a Vector.  Weighting is defined by a
+ * sub-class.
+ */
+public abstract class WordValueEncoder extends FeatureVectorEncoder {
+
+  public WordValueEncoder(String name) {
+    super(name);
+    probes = 2;
+  }
+
+  /**
+   * Adds a value to a vector.
+   *
+   * @param originalForm The original form of the value as a string.
+   * @param data         The vector to which the value should be added.
+   */
+  @Override
+  public void addToVector(String originalForm, double w, Vector data) {
+    double weight = w * weight(originalForm);
+    for (int i = 0; i < probes; i++) {
+      int n = hash(name, originalForm, WORD_LIKE_VALUE_HASH_SEED + i, data.size());
+      trace(name, originalForm, n);
+      data.set(n, data.get(n) + weight);
+    }
+  }
+
+  /**
+   * Converts a value into a form that would help a human understand the internals of how the value
+   * is being interpreted.  For text-like things, this is likely to be a list of the terms found with
+   * associated weights (if any).
+   *
+   * @param originalForm The original form of the value as a string.
+   * @return A string that a human can read.
+   */
+  @Override
+  public String asString(String originalForm) {
+    return String.format("%s:%s:%.4f", name, originalForm, weight(originalForm));
+  }
+
+  protected abstract double weight(String originalForm);
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html Mon Aug 16 16:56:46 2010
@@ -0,0 +1,20 @@
+<html>
+<body>
+SGD stands for Stochastic Gradient Descent and refers to a class of learning algorithms
+that make it relatively easy to build high speed on-line learning algorithms for a variety
+of problems, notably including supervised learning for classification.
+
+The primary class of interest in the this package is CrossFoldLearner which contains a
+number (typically 5) of sub-learners, each of which is given a different portion of the
+training data.  Each of these sub-learners can then be evaluated on the data it was not
+trained on.  This allows fully incremental learning while still getting cross-validated
+performance estimates.
+
+The CrossFoldLearner implements OnlineLearner and thus expects to be fed input in the form
+of a target variable and a feature vector.  The target variable is simply an integer in the
+half-open interval [0..numFeatures) where numFeatures is defined when the CrossFoldLearner
+is constructed.  The creation of feature vectors is facilitated by the classes that inherit
+from FeatureVectorEncoder.  These classes currently implement a form of feature hashing with
+multiple probes to limit feature ambiguity.
+</body>
+</html>
\ No newline at end of file

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import junit.framework.TestCase;
+import org.junit.Test;
+
+import java.io.UnsupportedEncodingException;
+
+@SuppressWarnings({"deprecation"})
+public class MurmurHashTest extends TestCase {
+    @Test
+    public void testForLotsOfChange() throws UnsupportedEncodingException {
+        long h1 = MurmurHash.hashLong("abc".getBytes("UTF-8"), 0);
+        long h2 = MurmurHash.hashLong("abc ".getBytes("UTF-8"), 0);
+        int flipCount = Long.bitCount(h1 ^ h2);
+        assertTrue("Small changes should result in lots of bit flips, only found " + flipCount, flipCount > 25);
+    }
+
+    @Test
+    public void testHash64() throws UnsupportedEncodingException {
+        // test data generated by running MurmurHash2_64.cpp
+        assertEquals(0x9cc9c33498a95efbL, MurmurHash.hash64A("abc".getBytes("UTF-8"), 0));
+        assertEquals(0xd2c8c9b470122bddL, MurmurHash.hash64A("abc def ghi jkl ".getBytes("UTF-8"), 0));
+        assertEquals(0xcd37895736a81cbcL, MurmurHash.hash64A("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0));
+    }
+
+    @SuppressWarnings({"deprecation"})
+    @Test
+    public void testHash64original() throws UnsupportedEncodingException {
+        // test data generated by running MurmurHash2_64.cpp
+        assertEquals(0x9cc9c33498a95efbL, MurmurHash.hash64A_original("abc".getBytes("UTF-8"), 0));
+        assertEquals(0xd2c8c9b470122bddL, MurmurHash.hash64A_original("abc def ghi jkl ".getBytes("UTF-8"), 0));
+        assertEquals(0xcd37895736a81cbcL, MurmurHash.hash64A_original("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0));
+    }
+
+    @Test
+    public void testHash() throws UnsupportedEncodingException {
+        // test data generated by running MurmurHashNeutral2.cpp
+        assertEquals(0x13577c9b, MurmurHash.hash("abc".getBytes("UTF-8"), 0));
+        assertEquals(0x6fec441b, MurmurHash.hash("abc def ghi jkl ".getBytes("UTF-8"), 0));
+        assertEquals(0x7e953277, MurmurHash.hash("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0));
+    }
+
+    @Test
+    public void testHashNio() throws UnsupportedEncodingException {
+        // test data generated by running MurmurHashNeutral2.cpp
+        assertEquals(0x13577c9b, MurmurHash.hash_original("abc".getBytes("UTF-8"), 0));
+        assertEquals(0x6fec441b, MurmurHash.hash_original("abc def ghi jkl ".getBytes("UTF-8"), 0));
+        assertEquals(0x7e953277, MurmurHash.hash_original("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0));
+    }
+}



Mime
View raw message