Return-Path: Delivered-To: apmail-mahout-commits-archive@www.apache.org Received: (qmail 95398 invoked from network); 16 Aug 2010 16:58:17 -0000 Received: from unknown (HELO mail.apache.org) (140.211.11.3) by 140.211.11.9 with SMTP; 16 Aug 2010 16:58:17 -0000 Received: (qmail 70145 invoked by uid 500); 16 Aug 2010 16:58:17 -0000 Delivered-To: apmail-mahout-commits-archive@mahout.apache.org Received: (qmail 70069 invoked by uid 500); 16 Aug 2010 16:58:17 -0000 Mailing-List: contact commits-help@mahout.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@mahout.apache.org Delivered-To: mailing list commits@mahout.apache.org Received: (qmail 70062 invoked by uid 99); 16 Aug 2010 16:58:16 -0000 Received: from nike.apache.org (HELO nike.apache.org) (192.87.106.230) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 16 Aug 2010 16:58:16 +0000 X-ASF-Spam-Status: No, hits=-2000.0 required=10.0 tests=ALL_TRUSTED X-Spam-Check-By: apache.org Received: from [140.211.11.4] (HELO eris.apache.org) (140.211.11.4) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 16 Aug 2010 16:58:08 +0000 Received: by eris.apache.org (Postfix, from userid 65534) id 24E51238890A; Mon, 16 Aug 2010 16:56:48 +0000 (UTC) Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit Subject: svn commit: r986045 [1/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org/apache/mahout/classifier/ core/src/test/java/org/apache/mahout/classifier/sgd/ ma... Date: Mon, 16 Aug 2010 16:56:47 -0000 To: commits@mahout.apache.org From: tdunning@apache.org X-Mailer: svnmailer-1.0.8 Message-Id: <20100816165648.24E51238890A@eris.apache.org> X-Virus-Checked: Checked by ClamAV on apache.org Author: tdunning Date: Mon Aug 16 16:56:46 2010 New Revision: 986045 URL: http://svn.apache.org/viewvc?rev=986045&view=rev Log: MAHOUT-228 - Initial SGD classifier release. Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/MurmurHash.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * This is a very fast, non-cryptographic hash suitable for general hash-based + * lookup. See http://murmurhash.googlepages.com/ for more details. + *

+ *

The C version of MurmurHash 2.0 found at that site was ported + * to Java by Andrzej Bialecki (ab at getopt org).

+ */ +public class MurmurHash { + /** + * Hashes bytes in an array. + * @param data The bytes to hash. + * @param seed The seed for the hash. + * @return The 32 bit hash of the bytes in question. + */ + public static int hash(byte[] data, int seed) { + return hash(ByteBuffer.wrap(data), seed); + } + + /** + * Hashes bytes in part of an array. + * @param data The data to hash. + * @param offset Where to start munging. + * @param length How many bytes to process. + * @param seed The seed to start with. + * @return The 32-bit hash of the data in question. + */ + public static int hash(byte[] data, int offset, int length, int seed) { + return hash(ByteBuffer.wrap(data, offset, length), seed); + } + + /** + * Hashes the bytes in a buffer from the current position to the limit. + * @param buf The bytes to hash. + * @param seed The seed for the hash. + * @return The 32 bit murmur hash of the bytes in the buffer. + */ + public static int hash(ByteBuffer buf, int seed) { + // save byte order for later restoration + ByteOrder byteOrder = buf.order(); + buf.order(ByteOrder.LITTLE_ENDIAN); + + int m = 0x5bd1e995; + int r = 24; + + int h = seed ^ buf.remaining(); + + int k; + while (buf.remaining() >= 4) { + k = buf.getInt(); + + k *= m; + k ^= k >>> r; + k *= m; + + h *= m; + h ^= k; + } + + if (buf.remaining() > 0) { + ByteBuffer finish = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + // for big-endian version, use this first: + // finish.position(4-buf.remaining()); + finish.put(buf).rewind(); + h ^= finish.getInt(); + h *= m; + } + + h ^= h >>> 13; + h *= m; + h ^= h >>> 15; + + buf.order(byteOrder); + return h; + } + + + public static long hash64A(byte[] data, int seed) { + return hash64A(ByteBuffer.wrap(data), seed); + } + + public static long hash64A(byte[] data, int offset, int length, int seed) { + return hash64A(ByteBuffer.wrap(data, offset, length), seed); + } + + public static long hash64A(ByteBuffer buf, int seed) { + ByteOrder byteOrder = buf.order(); + buf.order(ByteOrder.LITTLE_ENDIAN); + + long m = 0xc6a4a7935bd1e995L; + int r = 47; + + long h = seed ^ (buf.remaining() * m); + + long k; + while (buf.remaining() >= 8) { + k = buf.getLong(); + + k *= m; + k ^= k >>> r; + k *= m; + + h ^= k; + h *= m; + } + + if (buf.remaining() > 0) { + ByteBuffer finish = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN); + // for big-endian version, do this first: + // finish.position(8-buf.remaining()); + finish.put(buf).rewind(); + h ^= finish.getLong(); + h *= m; + } + + h ^= h >>> r; + h *= m; + h ^= h >>> r; + + buf.order(byteOrder); + return h; + } + + @Deprecated + public static long hashLong(byte[] bytes, int seed) { + return (((long) hash(bytes, seed ^ 120705477)) << 32) + hash(bytes, seed ^ 226137830); + } + + @Deprecated + public static int hash_original(byte[] data, int seed) { + int m = 0x5bd1e995; + int r = 24; + + int h = seed ^ data.length; + + int len = data.length; + int len_4 = len >> 2; + + int k; + for (int i = 0; i < len_4; i++) { + int i_4 = i << 2; + k = data[i_4]; + k |= data[i_4 + 1] << 8; + k |= data[i_4 + 2] << 16; + k |= data[i_4 + 3] << 24; + + k *= m; + k ^= k >>> r; + k *= m; + + h *= m; + h ^= k; + } + + int len_m = len_4 << 2; + int left = len - len_m; + + switch (left) { + case 3: + h ^= (int) data[len_m + 2] << 16; + case 2: + h ^= (int) data[len_m + 1] << 8; + case 1: + h ^= (int) data[len_m]; + h *= m; + default: + } + + h ^= h >>> 13; + h *= m; + h ^= h >>> 15; + + return h; + } + + @Deprecated + public static long hash64A_original(byte[] data, int seed) { + long m = 0xc6a4a7935bd1e995L; + int r = 47; + + int len = data.length; + int len_8 = len >> 3; + + long h = seed ^ (len * m); + + long k; + for (int i = 0; i < len_8; i++) { + int i_8 = i << 3; + k = (data[i_8 + 7] & 0xffL) << 56; + k |= (data[i_8 + 6] & 0xffL) << 48; + k |= (data[i_8 + 5] & 0xffL) << 40; + k |= (data[i_8 + 4] & 0xffL) << 32; + k |= (data[i_8 + 3] & 0xffL) << 24; + k |= (data[i_8 + 2] & 0xffL) << 16; + k |= (data[i_8 + 1] & 0xffL) << 8; + k |= (data[i_8] & 0xffL); + + k *= m; + k ^= k >>> r; + k *= m; + + h ^= k; + h *= m; + } + + int len_m = len_8 << 3; + int left = len - len_m; + + switch (left) { + case 7: + h ^= (data[len_m + 6] & 0xffL) << 48; + case 6: + h ^= (data[len_m + 5] & 0xffL) << 40; + case 5: + h ^= (data[len_m + 4] & 0xffL) << 32; + case 4: + h ^= (data[len_m + 3] & 0xffL) << 24; + case 3: + h ^= (data[len_m + 2] & 0xffL) << 16; + case 2: + h ^= (data[len_m + 1] & 0xffL) << 8; + case 1: + h ^= data[len_m] & 0xffL; + h *= m; + default: + } + + h ^= h >>> r; + h *= m; + h ^= h >>> r; + + return h; + } + + + /* Testing ... + static int NUM = 1000; + + public static void main(String[] args) { + byte[] bytes = new byte[4]; + for (int i = 0; i < NUM; i++) { + bytes[0] = (byte)(i & 0xff); + bytes[1] = (byte)((i & 0xff00) >> 8); + bytes[2] = (byte)((i & 0xff0000) >> 16); + bytes[3] = (byte)((i & 0xff000000) >> 24); + System.out.println(Integer.toHexString(i) + " " + Integer.toHexString(hash(bytes, 1))); + } + } */ +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +import java.util.Iterator; + +/** + * Generic definition of a 1 of n logistic regression classifier that returns probabilities in + * response to a feature vector. This classifier uses 1 of n-1 coding where the 0-th category + * is not stored explicitly. + *

+ * TODO: implement symbolic input with string, overall cooccurrence and n-gram hash encoding + * TODO: implement reporter system to monitor progress + * + * Provides the based SGD based algorithm for learning a logistic regression, but omits all + * annealing of learning rates. Any extension of this abstract class must define the overall + * and per-term annealing for themselves. + */ +public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner { + // coefficients for the classification. This is a dense matrix + // that is (numCategories-1) x numFeatures + protected Matrix beta; + + // number of categories we are classifying. This should the number of rows of beta plus one. + protected int numCategories; + + private int step = 0; + + // information about how long since coefficient rows were updated. This allows lazy regularization. + protected transient Vector updateSteps; + + // information about how many updates we have had on a location. This allows per-term + // annealing a la confidence weighted learning. + protected transient Vector updateCounts; + + // weight of the prior on beta + private double lambda = 1e-5; + protected transient PriorFunction prior; + + // can we ignore any further regularization when doing classification? + private boolean sealed = false; + + /** + * Chainable configuration option. + * + * @param lambda New value of lambda, the weighting factor for the prior distribution. + * @return This, so other configurations can be chained. + */ + public AbstractOnlineLogisticRegression lambda(double lambda) { + this.lambda = lambda; + return this; + } + + private Vector logisticLink(Vector v) { + double max = v.maxValue(); + if (max < 40) { + v.assign(Functions.exp); + double sum = 1 + v.norm(1); + return v.divide(sum); + } else { + v.assign(Functions.minus(max)).assign(Functions.exp); + return v; + } + } + + /** + * Returns n-1 probabilities, one for each category but the 0-th. The probability of the 0-th + * category is 1 - sum(this result). + * + * @param instance A vector of features to be classified. + * @return A vector of probabilities, one for each of the first n-1 categories. + */ + public Vector classify(Vector instance) { + // apply pending regularization to whichever coefficients matter + regularize(instance); + + Vector v = beta.times(instance); + return logisticLink(v); + } + + /** + * Returns a single scalar probability in the case where we have two categories. Using this + * method avoids an extra vector allocation as opposed to calling classify() or an extra two + * vector allocations relative to classifyFull(). + * + * @param instance The vector of features to be classified. + * @return The probability of the first of two categories. + * @throws IllegalArgumentException If the classifier doesn't have two categories. + */ + public double classifyScalar(Vector instance) { + if (numCategories() != 2) { + throw new IllegalArgumentException("Can only call classifyScalar with two categories"); + } + + // apply pending regularization to whichever coefficients matter + regularize(instance); + + // result is a vector with one element so we can just use dot product + double r = Math.exp(beta.getRow(0).dot(instance)); + return r / (1 + r); + } + + public void train(int actual, Vector instance) { + unseal(); + + double learningRate = currentLearningRate(); + + // push coefficients back to zero based on the prior + regularize(instance); + + // what does the current model say? + Vector v = classify(instance); + + // update each row of coefficients according to result + for (int i = 0; i < numCategories - 1; i++) { + double gradientBase = -v.getQuick(i); + // the use of i+1 instead of i here is what makes the 0-th category be the one without coefficients + if ((i + 1) == actual) { + gradientBase += 1; + } + + // then we apply the gradientBase to the resulting element. + Iterator nonZeros = instance.iterateNonZero(); + while (nonZeros.hasNext()) { + Vector.Element updateLocation = nonZeros.next(); + int j = updateLocation.index(); + double newValue = beta.get(i, j) + learningRate * gradientBase * instance.get(j) * perTermLearningRate(j); + beta.set(i, j, newValue); + } + } + + // remember that these elements got updated + Iterator i = instance.iterateNonZero(); + while (i.hasNext()) { + Vector.Element element = i.next(); + int j = element.index(); + updateSteps.setQuick(j, getStep()); + updateCounts.setQuick(j, updateCounts.getQuick(j) + 1); + } + nextStep(); + + } + + public void regularize(Vector instance) { + if (updateSteps == null || isSealed()) { + return; + } + + // anneal learning rate + double learningRate = currentLearningRate(); + + // here we lazily apply the prior to make up for our neglect + for (int i = 0; i < numCategories - 1; i++) { + Iterator nonZeros = instance.iterateNonZero(); + while (nonZeros.hasNext()) { + Vector.Element updateLocation = nonZeros.next(); + int j = updateLocation.index(); + double missingUpdates = getStep() - updateSteps.get(j); + if (missingUpdates > 0) { + double newValue = prior.age(beta.get(i, j), missingUpdates, getLambda() * learningRate * perTermLearningRate(j)); + beta.set(i, j, newValue); + } + } + } + } + + // these two abstract methods are how extensions can modify the basic learning behavior of this object. + + public abstract double perTermLearningRate(int j); + + public abstract double currentLearningRate(); + + public void setPrior(PriorFunction prior) { + this.prior = prior; + } + + public PriorFunction getPrior() { + return prior; + } + + public Matrix getBeta() { + close(); + return beta; + } + + public void setBeta(int i, int j, double beta_ij) { + beta.set(i, j, beta_ij); + } + + public int numCategories() { + return numCategories; + } + + public int numFeatures() { + return beta.numCols(); + } + + public double getLambda() { + return lambda; + } + + public int getStep() { + return step; + } + + protected void nextStep() { + step++; + } + + public boolean isSealed() { + return sealed; + } + + protected void unseal() { + sealed = false; + } + + private void regularizeAll() { + Vector all = new DenseVector(beta.numCols()); + all.assign(1); + regularize(all); + } + + public void close() { + if (!sealed) { + step++; + regularizeAll(); + sealed = true; + } + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,58 @@ +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Lists; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.jet.random.NegativeBinomial; +import org.apache.mahout.math.jet.random.engine.MersenneTwister; + +import java.util.Collections; +import java.util.List; + +/** + * This is a meta-learner that maintains a pool of ordinary OnlineLogisticRegression learners. Each + * member of the pool has different learning rates. Whichever of the learners in the pool falls + * behind in terms of average log-likelihood will be tossed out and replaced with variants of the + * survivors. This will let us automatically derive an annealing schedule that optimizes learning + * speed. Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might + * seem that it would to maintain multiple learners in memory. Doing this adaptation on-line as we + * learn also decreases the number of learning rate parameters required and replaces the normal + * hyper-parameter search. + */ +public class AdaptiveAnnealedLogisticRegression implements OnlineLearner { + private int record = 0; + private List pool = Lists.newArrayList(); + private int evaluationInterval = 1000; + + public AdaptiveAnnealedLogisticRegression(int poolSize, int numCategories, int numFeatures, PriorFunction prior) { + for (int i = 0; i < poolSize; i++) { + CrossFoldLearner model = new CrossFoldLearner(5, numCategories, numFeatures, prior); + pool.add(model); + } + NegativeBinomial nb = new NegativeBinomial(10, 0.1, new MersenneTwister()); + } + + @Override + public void train(int actual, Vector instance) { + for (CrossFoldLearner learner : pool) { + learner.train(actual, instance); + } + record++; + if (record % evaluationInterval == 0) { + Collections.sort(pool); + // pick a parent from the top half of the pool weighted toward the top few + + } + } + + @Override + public void train(int trackingKey, int actual, Vector instance) { + train(actual, instance); + } + + @Override + public void close() { + //To change body of implemented methods use File | Settings | File Templates. + } + +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveWordValueEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; +import org.apache.mahout.math.Vector; + +/** + * Encodes words into vectors much as does WordValueEncoder while maintaining + * an adaptive dictionary of values seen so far. This allows weighting of terms + * without a pre-scan of all of the data. + */ +public class AdaptiveWordValueEncoder extends WordValueEncoder { + private Multiset dictionary; + + public AdaptiveWordValueEncoder(String name) { + super(name); + dictionary = HashMultiset.create(); + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(String originalForm, double weight, Vector data) { + dictionary.add(originalForm); + super.addToVector(originalForm, weight, data); + } + + @Override + protected double weight(String originalForm) { + // the counts here are adjusted so that every observed value has an extra 0.5 count + // as does a hypothetical unobserved value. This smooths our estimates a bit and + // allows the first word seen to have a non-zero weight of -log(1.5 / 2) + double thisWord = dictionary.count(originalForm) + 0.5; + double allWords = dictionary.size() + dictionary.elementSet().size() * 0.5 + 0.5; + return -Math.log(thisWord / allWords); + } + + public Multiset getDictionary() { + return dictionary; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ConstantValueEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; + +/** + * An encoder that does the standard thing for a virtual bias term. + */ +public class ConstantValueEncoder extends FeatureVectorEncoder { + public ConstantValueEncoder(String name) { + super(name); + } + + @Override + public void addToVector(String originalForm, double weight, Vector data) { + for (int i = 0; i < probes; i++) { + int n = hash(name, i, data.size()); + trace(name, null, n); + data.set(n, data.get(n) + weight); + } + } + + @Override + public String asString(String originalForm) { + return name; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; + +/** + * Continuous values are stored in fixed randomized location in the feature vector. + */ +public class ContinuousValueEncoder extends FeatureVectorEncoder { + public ContinuousValueEncoder(String name) { + super(name); + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(String originalForm, double weight, Vector data) { + for (int i = 0; i < probes; i++) { + int n = hash(name, CONTINUOUS_VALUE_HASH_SEED + i, data.size()); + trace(name, null, n); + data.set(n, data.get(n) + weight * Double.parseDouble(originalForm)); + } + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + return name + ":" + originalForm; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,137 @@ +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Lists; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.stats.OnlineAuc; + +import java.util.List; + +/** + * Does cross-fold validation of log-likelihood and AUC on several online logistic regression + * models. Each record is passed to all but one of the models for training and to the remaining + * model for evaluation. In order to maintain proper segregation between the different folds across + * training data iterations, data should either be passed to this learner in the same order each + * time the training data is traversed or a tracking key such as the file offset of the training + * record should be passed with each training example. + */ +class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Comparable { + int record = 0; + OnlineAuc auc = new OnlineAuc(); + double logLikelihood = 0; + List models = Lists.newArrayList(); + + // lambda, learningRate, perTermOffset, perTermExponent + double[] parameters = new double[4]; + + CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) { + for (int i = 0; i < folds; i++) { + OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior); + model.alpha(1).stepOffset(0).decayExponent(0); + models.add(model); + } + } + + @Override + public void train(int actual, Vector instance) { + train(record, actual, instance); + } + + @Override + public void train(int trackingKey, int actual, Vector instance) { + record++; + int k = 0; + for (OnlineLogisticRegression model : models) { + if (k == trackingKey % models.size()) { + Vector v = model.classifyFull(instance); + double score = v.get(actual); + logLikelihood += (Math.log(score) - logLikelihood) / record; + auc.addSample(actual, v.get(1)); + } else { + model.train(actual, instance); + } + k = (k + 1) % models.size(); + } + } + + @Override + public void close() { + for (OnlineLogisticRegression m : models) { + m.close(); + } + } + + public void resetLineCounter() { + record = 0; + } + + @Override + public int compareTo(CrossFoldLearner other) { + return Double.compare(this.logLikelihood, other.logLikelihood); + } + + public CrossFoldLearner lambda(double v) { + for (OnlineLogisticRegression model : models) { + model.lambda(v); + } + return this; + } + + public CrossFoldLearner learningRate(double x) { + for (OnlineLogisticRegression model : models) { + model.learningRate(x); + } + return this; + } + + public CrossFoldLearner stepOffset(int x) { + for (OnlineLogisticRegression model : models) { + model.stepOffset(x); + } + return this; + } + + public CrossFoldLearner decayExponent(double x) { + for (OnlineLogisticRegression model : models) { + model.decayExponent(x); + } + return this; + } + + @Override + public int numCategories() { + return models.get(0).numCategories(); + } + + @Override + public Vector classify(Vector instance) { + Vector r = new DenseVector(numCategories() - 1); + double scale = 1.0 / models.size(); + for (OnlineLogisticRegression model : models) { + r.assign(model.classify(instance), Functions.plusMult(scale)); + } + return r; + } + + @Override + public double classifyScalar(Vector instance) { + double r = 0; + int n = 0; + for (OnlineLogisticRegression model : models) { + n++; + r += model.classifyScalar(instance); + } + return r / n; + } + + public double auc() { + return auc.auc(); + } + + public double logLikelihood() { + return logLikelihood; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Function; +import com.google.common.base.Splitter; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.mahout.math.Vector; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Converts csv data lines to vectors. + * + * Use of this class proceeds in a few steps. + *

    + *
  • At construction time, you tell the class about the target variable and provide + * a dictionary of the types of the predictor values. At this point, + * the class yet cannot decode inputs because it doesn't know the fields that are in the + * data records, nor their order. + *
  • Optionally, you tell the parser object about the possible values of the target + * variable. If you don't do this then you probably should set the number of distinct + * values so that the target variable values will be taken from a restricted range. + *
  • Later, when you get a list of the fields, typically from the first line of a CSV + * file, you tell the factory about these fields and it builds internal data structures + * that allow it to decode inputs. The most important internal state is the field numbers + * for various fields. After this point, you can use the factory for decoding data. + *
  • To encode data as a vector, you present a line of input to the factory and it + * mutates a vector that you provide. The factory also retains trace information so + * that it can approximately reverse engineer vectors later. + *
  • After converting data, you can ask for an explanation of the data in terms of + * terms and weights. In order to explain a vector accurately, the factory needs to + * have seen the particular values of categorical fields (typically during encoding vectors) + * and needs to have a reasonably small number of collisions in the vector encoding. + *
+ */ +public class CsvRecordFactory implements RecordFactory { + private static final String INTERCEPT_TERM = "Intercept Term"; + + // crude CSV value splitter. This will fail if any double quoted strings have + // commas inside. Also, escaped quotes will not be unescaped. Good enough for now. + private Splitter onComma = Splitter.on(",").trimResults(CharMatcher.is('"')); + + private static final Map> typeDictionary = + ImmutableMap.>builder() + .put("continuous", ContinuousValueEncoder.class) + .put("numeric", ContinuousValueEncoder.class) + .put("n", ContinuousValueEncoder.class) + .put("word", StaticWordValueEncoder.class) + .put("w", StaticWordValueEncoder.class) + .put("text", TextValueEncoder.class) + .put("t", TextValueEncoder.class) + .build(); + + private Map> traceDictionary = Maps.newTreeMap(); + + private int target; + private Dictionary targetDictionary; + + private List predictors; + private Map predictorEncoders; + private int maxTargetValue = Integer.MAX_VALUE; + private String targetName; + private Map typeMap; + private List variableNames; + private boolean includeBiasTerm; + + /** + * Construct a parser for CSV lines that encodes the parsed data in vector form. + * @param targetName The name of the target variable. + * @param typeMap A map describing the types of the predictor variables. + */ + public CsvRecordFactory(String targetName, final Map typeMap) { + this.targetName = targetName; + this.typeMap = typeMap; + targetDictionary = new Dictionary(); + } + + /** + * Defines the values and thus the encoding of values of the target variables. Note + * that any values of the target variable not present in this list will be given the + * value of the last member of the list. + * @param values The values the target variable can have. + */ + @Override + public void defineTargetCategories(List values) { + if (values.size() > maxTargetValue) { + throw new IllegalArgumentException("Must have less than or equal to " + maxTargetValue + " categories for target variable, but found " + values.size()); + } + + if (maxTargetValue == Integer.MAX_VALUE) { + maxTargetValue = values.size(); + } + + for (String value : values) { + targetDictionary.intern(value); + } + } + + /** + * Defines the number of target variable categories, but allows this parser to + * pick encodings for them as they appear. + * @param max The number of categories that will be excpeted. Once this many have been + * seen, all others will get the encoding max-1. + */ + @Override + public CsvRecordFactory maxTargetValue(int max) { + maxTargetValue = max; + return this; + } + + @Override + public boolean usesFirstLineAsSchema() { + return true; + } + + /** + * Processes the first line of a file (which should contain the variable names). The target and + * predictor column numbers are set from the names on this line. + * + * @param line Header line for the file. + */ + @Override + public void firstLine(String line) { + // read variable names, build map of name -> column + final Map vars = Maps.newHashMap(); + int column = 0; + variableNames = Lists.newArrayList(onComma.split(line)); + for (String var : variableNames) { + vars.put(var, column++); + } + + // record target column and establish dictionary for decoding target + target = vars.get(targetName); + + // create list of predictor column numbers + predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function() { + @Override + public Integer apply(String from) { + Integer r = vars.get(from); + if (r == null) { + throw new IllegalArgumentException("Can't find variable " + from + ", only know about " + vars); + } + return r; + } + })); + + if (includeBiasTerm) { + predictors.add(-1); + } + Collections.sort(predictors); + + // and map from column number to type encoder for each column that is a predictor + predictorEncoders = Maps.newHashMap(); + for (Integer predictor : predictors) { + String name; + Class c; + if (predictor != -1) { + name = variableNames.get(predictor); + c = typeDictionary.get(typeMap.get(name)); + } else { + name = INTERCEPT_TERM; + c = ConstantValueEncoder.class; + } + try { + if (c == null) { + throw new IllegalArgumentException("Invalid type of variable " + typeMap.get(name) + " wanted on of " + typeDictionary.keySet()); + } + Constructor constructor = c.getConstructor(String.class); + if (constructor == null) { + throw new IllegalArgumentException("Can't find correct constructor for " + typeMap.get(name)); + } + FeatureVectorEncoder encoder = constructor.newInstance(name); + predictorEncoders.put(predictor, encoder); + encoder.setTraceDictionary(traceDictionary); + } catch (InstantiationException e) { + throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e); + } catch (IllegalAccessException e) { + throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e); + } catch (InvocationTargetException e) { + throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e); + } catch (NoSuchMethodException e) { + throw new ImpossibleException("Unable to construct type converter... shouldn't be possible", e); + } + } + } + + + /** + * Decodes a single line of csv data and records the target and predictor variables in a record. + * As a side effect, features are added into the featureVector. Returns the value of the target + * variable. + * + * @param line The raw data. + * @param featureVector Where to fill in the features. Should be zeroed before calling + * processLine. + * @return The value of the target variable. + */ + @Override + public int processLine(String line, Vector featureVector) { + List values = Lists.newArrayList(onComma.split(line)); + + int targetValue = targetDictionary.intern(values.get(target)); + if (targetValue >= maxTargetValue) { + targetValue = maxTargetValue - 1; + } + + for (Integer predictor : predictors) { + String value; + if (predictor >= 0) { + value = values.get(predictor); + } else { + value = null; + } + predictorEncoders.get(predictor).addToVector(value, featureVector); + } + return targetValue; + } + + /** + * Returns a list of the names of the predictor variables. + * + * @return A list of variable names. + */ + @Override + public Iterable getPredictors() { + return Lists.transform(predictors, new Function() { + @Override + public String apply(Integer v) { + if (v >= 0) { + return variableNames.get(v); + } else { + return INTERCEPT_TERM; + } + } + }); + } + + @Override + public Map> getTraceDictionary() { + return traceDictionary; + } + + @Override + public CsvRecordFactory includeBiasTerm(boolean useBias) { + includeBiasTerm = useBias; + return this; + } + + @Override + public List getTargetCategories() { + List r = targetDictionary.values(); + if (r.size() > maxTargetValue) { + r.subList(maxTargetValue, r.size()).clear(); + } + return r; + } + + private static class ImpossibleException extends RuntimeException { + private ImpossibleException(String message, Throwable cause) { + super(message, cause); + } + } + +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Dictionary.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Maps; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** +* Assigns integer codes to strings as they appear. +*/ +public class Dictionary { + private Map dict = Maps.newLinkedHashMap(); + + public int intern(String s) { + if (!dict.containsKey(s)) { + dict.put(s, dict.size()); + } + return dict.get(s); + } + + public List values() { + // order of keySet is guaranteed to be insertion order + return new ArrayList(dict.keySet()); + } + + public static Dictionary fromList(List values) { + Dictionary dict = new Dictionary(); + for (String value : values) { + dict.intern(value); + } + return dict; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +public class ElasticBandPrior extends PriorFunction { + private double alphaByLambda; + private L1 l1; + private L2 l2; + + public ElasticBandPrior(double alphaByLambda) { + this.alphaByLambda = alphaByLambda; + l1 = new L1(); + l2 = new L2(1); + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + oldValue = oldValue * Math.pow(1 - alphaByLambda * learningRate , generations); + double newValue = oldValue - Math.signum(oldValue) * learningRate * generations; + if (newValue * oldValue < 0) { + // don't allow the value to change sign + return 0; + } else { + return newValue; + } + } + + @Override + public double logP(double beta_ij) { + return l1.logP(beta_ij) + alphaByLambda * l2.logP(beta_ij); + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/FeatureVectorEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Sets; +import org.apache.mahout.classifier.MurmurHash; +import org.apache.mahout.math.Vector; + +import java.nio.charset.Charset; +import java.util.Map; +import java.util.Set; + +/** + * General interface for objects that record features into a feature vector. + *

+ * By convention, sub-classes should provide a constructor that accepts just a field name as well as + * setters to customize properties of the conversion such as adding tokenizers or a weight + * dictionary. + */ +public abstract class FeatureVectorEncoder { + protected static final int CONTINUOUS_VALUE_HASH_SEED = 1; + protected static final int WORD_LIKE_VALUE_HASH_SEED = 100; + + protected String name; + protected int probes = 1; + + private Map> traceDictionary = null; + + public FeatureVectorEncoder(String name) { + this.name = name; + } + + /** + * Adds a value expressed in string form to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + public void addToVector(String originalForm, Vector data) { + addToVector(originalForm, 1.0, data); + } + + /** + * Adds a weighted value expressed in string form to a vector. In some cases it is convenient to + * use this method to encode continuous values using the weight as the value. In such cases, the + * string value should typically be set to null. + * + * @param originalForm The original form of the value as a string. + * @param weight The weight to be applied to this feature. + * @param data The vector to which the value should be added. + */ + public abstract void addToVector(String originalForm, double weight, Vector data); + + // ******* Utility functions used by most implementations + + /** + * Hash a string and an integer into the range [0..numFeatures-1]. + * + * @param term The string. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected int hash(String term, int probe, int numFeatures) { + long r = MurmurHash.hash64A(term.getBytes(Charset.forName("UTF-8")), probe) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Hash two strings and an integer into the range [0..numFeatures-1]. + * + * @param term1 The first string. + * @param term2 The second string. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected int hash(String term1, String term2, int probe, int numFeatures) { + long r = MurmurHash.hash64A(term1.getBytes(Charset.forName("UTF-8")), probe); + r = MurmurHash.hash64A(term2.getBytes(Charset.forName("UTF-8")), (int) r) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found + * with associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + public abstract String asString(String originalForm); + + /** + * Sets the number of locations in the feature vector that a value should be in. + * + * @param probes Number of locations to increment. + */ + public void setProbes(int probes) { + this.probes = probes; + } + + public String getName() { + return name; + } + + protected void trace(String name, String subName, int n) { + if (traceDictionary != null) { + String key = name; + if (subName != null) { + key = name + "=" + subName; + } + Set trace = traceDictionary.get(key); + if (trace == null) { + trace = Sets.newHashSet(n); + traceDictionary.put(key, trace); + } else { + trace.add(n); + } + } + } + + public void setTraceDictionary(Map> traceDictionary) { + this.traceDictionary = traceDictionary; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L1.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +/** + * Implements the Laplacian or bi-exponential prior. This prior has a strong tendency to set coefficients to zero + * and thus is useful as an alternative to variable selection. This version implements truncation which prevents + * a coefficient from changing sign. If a correction would change the sign, the coefficient is truncated to zero. + * + * Note that it doesn't matter to have a scale for this distribution because after taking the derivative of the logP, + * the lambda coefficient used to combine the prior with the observations has the same effect. If we had a scale here, + * then it would be the same effect as just changing lambda. + */ +public class L1 extends PriorFunction { + @Override + public double age(double oldValue, double generations, double learningRate) { + double newValue = oldValue - Math.signum(oldValue) * learningRate * generations; + if (newValue * oldValue < 0) { + // don't allow the value to change sign + return 0; + } else { + return newValue; + } + } + + @Override + public double logP(double beta_ij) { + return - Math.abs(beta_ij) ; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/L2.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import static java.lang.Math.log; + +/** + * Implements the Gaussian prior. This prior has a tendency to decrease large coefficients toward zero, but + * doesn't tend to set them to exactly zero. + */ +public class L2 extends PriorFunction { + private double s2; + private double s; + + public L2(double scale) { + this.s = scale; + this.s2 = scale * scale; + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + return oldValue * Math.pow(1 - learningRate / s2, generations); + } + + @Override + public double logP(double beta_ij) { + return -beta_ij * beta_ij / s2 / 2 - log(s) - log(2 * Math.PI) / 2; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; + +/** + * Extends the basic on-line logistic regression learner with a specific set of learning + * rate annealing schedules. + */ +public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression { + // these next two control decayFactor^steps exponential type of annealing + // learning rate and decay factor + private double mu_0 = 1; + private double decayFactor = 1 - 1e-3; + + + // these next two control 1/steps^forget type annealing + private int stepOffset = 10; + // -1 equals even weighting of all examples, 0 means only use exponential annealing + private double forgettingExponent = -0.5; + + private int perTermAnnealingOffset = 20; + + private OnlineLogisticRegression() { + // private constructor available for Gson, but not normal use + } + + public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) { + this.numCategories = numCategories; + this.prior = prior; + + updateSteps = new DenseVector(numFeatures); + updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset); + beta = new DenseMatrix(numCategories - 1, numFeatures); + } + + /** + * Chainable configuration option. + * + * @param alpha New value of decayFactor, the exponential decay rate for the learning rate. + * @return This, so other configurations can be chained. + */ + public OnlineLogisticRegression alpha(double alpha) { + this.decayFactor = alpha; + return this; + } + + public OnlineLogisticRegression lambda(double lambda) { + // we only over-ride this to provide a more restrictive return type + super.lambda(lambda); + return this; + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public OnlineLogisticRegression learningRate(double learningRate) { + this.mu_0 = learningRate; + return this; + } + + public OnlineLogisticRegression stepOffset(int stepOffset) { + this.stepOffset = stepOffset; + return this; + } + + public OnlineLogisticRegression decayExponent(double decayExponent) { + if (decayExponent > 0) { + decayExponent = -decayExponent; + } + this.forgettingExponent = decayExponent; + return this; + } + + + @Override + public double perTermLearningRate(int j) { + return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j)); + } + + @Override + public double currentLearningRate() { + return mu_0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent); + } + + @Override + public void train(int trackingKey, int actual, Vector instance) { + train(actual, instance); + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +/** + * A prior is used to regularize the learning algorithm. + */ +public abstract class PriorFunction { + /** + * Applies the regularization to a coefficient. + * @param oldValue The previous value. + * @param generations The number of generations. + * @param learningRate The learning rate with lambda baked in. + * @return The new coefficient value after regularization. + */ + public abstract double age(double oldValue, double generations, double learningRate); + + /** + * Returns the log of the probability of a particular coefficient value according to the prior. + * @param beta_ij The coefficient. + * @return The log probability. + */ + public abstract double logP(double beta_ij); +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Created by IntelliJ IDEA. User: tdunning Date: Jun 14, 2010 Time: 8:35:09 AM To change this + * template use File | Settings | File Templates. + */ +public interface RecordFactory { + void defineTargetCategories(List values); + + RecordFactory maxTargetValue(int max); + + boolean usesFirstLineAsSchema(); + + int processLine(String line, Vector featureVector); + + Iterable getPredictors(); + + Map> getTraceDictionary(); + + RecordFactory includeBiasTerm(boolean useBias); + + List getTargetCategories(); + + void firstLine(String line); +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/StaticWordValueEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.util.Collections; +import java.util.Map; + +/** + * Encodes a categorical values with an unbounded vocabulary. Values are encoding by incrementing + * a few locations in the output vector with a weight that is either defaulted to 1 or that is + * looked up in a weight dictionary. By default, only one probe is used which should be fine + * but could cause a decrease in the speed of learning because more features will be non-zero. + * If a large feature vector is used so that the probability of feature collisions is suitably + * small, then this can be decreased to 1. If a very small feature vector is used, the number + * of probes should probably be increased to 3. + */ +public class StaticWordValueEncoder extends WordValueEncoder { + private Map dictionary; + private double missingValueWeight = 1; + + public StaticWordValueEncoder(String name) { + super(name); + } + + /** + * Sets the weighting dictionary to be used by this encoder. Also sets + * the missing value weight to be half the smallest weight in the dictionary. + * @param dictionary The dictionary to use to look up weights. + */ + public void setDictionary(Map dictionary) { + this.dictionary = dictionary; + missingValueWeight = Collections.min(dictionary.values()) / 2; + } + + /** + * Sets the weight that is to be used for values that do not appear in the dictionary. + * @param missingValueWeight The default weight for missing values. + */ + public void setMissingValueWeight(double missingValueWeight) { + this.missingValueWeight = missingValueWeight; + } + + @Override + protected double weight(String originalForm) { + double weight = missingValueWeight; + if (dictionary != null && dictionary.containsKey(originalForm)) { + weight = dictionary.get(originalForm); + } + return weight; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import static java.lang.Math.log; +import static org.apache.commons.math.special.Gamma.logGamma; + +/** + * Provides a t-distribution as a prior. + */ +public class TPrior extends PriorFunction { + private double df; + + public TPrior(double df) { + this.df = df; + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + for (int i = 0; i < generations; i++) { + oldValue = oldValue - learningRate * oldValue * (df + 1) / (df + oldValue * oldValue); + } + return oldValue; + } + + @Override + public double logP(double beta_ij) { + return logGamma((df + 1) / 2) - log(df * Math.PI) - logGamma(df / 2) - (df + 1) / 2 * log(1 + beta_ij * beta_ij); + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/TextValueEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Splitter; +import org.apache.mahout.math.Vector; + +import java.util.regex.Pattern; + +/** + * Encodes text that is tokenized on non-alphanum separators. Each word is encoded using a + * settable encoder which is by default an StaticWordValueEncoder which gives all + * words the same weight. + */ +public class TextValueEncoder extends FeatureVectorEncoder { + Splitter onNonWord = Splitter.on(Pattern.compile("\\W+")).omitEmptyStrings(); + private FeatureVectorEncoder wordEncoder; + + public TextValueEncoder(String name) { + super(name); + wordEncoder = new StaticWordValueEncoder(name); + probes = 2; + } + + /** + * Adds a value to a vector after tokenizing it by splitting on non-alphanum characters. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(String originalForm, double weight, Vector data) { + for (String word : tokenize(originalForm)) { + wordEncoder.addToVector(word, weight, data); + } + } + + private Iterable tokenize(String originalForm) { + return onNonWord.split(originalForm); + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + StringBuilder r = new StringBuilder("["); + String sep = ""; + for (String word : tokenize(originalForm)) { + r.append(sep); + r.append(wordEncoder.asString(word)); + sep = ", "; + } + r.append("]"); + return r.toString(); + } + + public void setWordEncoder(FeatureVectorEncoder wordEncoder) { + this.wordEncoder = wordEncoder; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +/** + * A uniform prior. This is an improper prior that corresponds to no regularization at all. + */ +public class UniformPrior extends PriorFunction { + @Override + public double age(double oldValue, double generations, double learningRate) { + return oldValue; + } + + @Override + public double logP(double beta_ij) { + return 0; + } +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/WordValueEncoder.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; + +/** + * Encodes words as sparse vector updates to a Vector. Weighting is defined by a + * sub-class. + */ +public abstract class WordValueEncoder extends FeatureVectorEncoder { + + public WordValueEncoder(String name) { + super(name); + probes = 2; + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(String originalForm, double w, Vector data) { + double weight = w * weight(originalForm); + for (int i = 0; i < probes; i++) { + int n = hash(name, originalForm, WORD_LIKE_VALUE_HASH_SEED + i, data.size()); + trace(name, originalForm, n); + data.set(n, data.get(n) + weight); + } + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + return String.format("%s:%s:%.4f", name, originalForm, weight(originalForm)); + } + + protected abstract double weight(String originalForm); +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/package.html Mon Aug 16 16:56:46 2010 @@ -0,0 +1,20 @@ + + +SGD stands for Stochastic Gradient Descent and refers to a class of learning algorithms +that make it relatively easy to build high speed on-line learning algorithms for a variety +of problems, notably including supervised learning for classification. + +The primary class of interest in the this package is CrossFoldLearner which contains a +number (typically 5) of sub-learners, each of which is given a different portion of the +training data. Each of these sub-learners can then be evaluated on the data it was not +trained on. This allows fully incremental learning while still getting cross-validated +performance estimates. + +The CrossFoldLearner implements OnlineLearner and thus expects to be fed input in the form +of a target variable and a feature vector. The target variable is simply an integer in the +half-open interval [0..numFeatures) where numFeatures is defined when the CrossFoldLearner +is constructed. The creation of feature vectors is facilitated by the classes that inherit +from FeatureVectorEncoder. These classes currently implement a form of feature hashing with +multiple probes to limit feature ambiguity. + + \ No newline at end of file Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/MurmurHashTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import junit.framework.TestCase; +import org.junit.Test; + +import java.io.UnsupportedEncodingException; + +@SuppressWarnings({"deprecation"}) +public class MurmurHashTest extends TestCase { + @Test + public void testForLotsOfChange() throws UnsupportedEncodingException { + long h1 = MurmurHash.hashLong("abc".getBytes("UTF-8"), 0); + long h2 = MurmurHash.hashLong("abc ".getBytes("UTF-8"), 0); + int flipCount = Long.bitCount(h1 ^ h2); + assertTrue("Small changes should result in lots of bit flips, only found " + flipCount, flipCount > 25); + } + + @Test + public void testHash64() throws UnsupportedEncodingException { + // test data generated by running MurmurHash2_64.cpp + assertEquals(0x9cc9c33498a95efbL, MurmurHash.hash64A("abc".getBytes("UTF-8"), 0)); + assertEquals(0xd2c8c9b470122bddL, MurmurHash.hash64A("abc def ghi jkl ".getBytes("UTF-8"), 0)); + assertEquals(0xcd37895736a81cbcL, MurmurHash.hash64A("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0)); + } + + @SuppressWarnings({"deprecation"}) + @Test + public void testHash64original() throws UnsupportedEncodingException { + // test data generated by running MurmurHash2_64.cpp + assertEquals(0x9cc9c33498a95efbL, MurmurHash.hash64A_original("abc".getBytes("UTF-8"), 0)); + assertEquals(0xd2c8c9b470122bddL, MurmurHash.hash64A_original("abc def ghi jkl ".getBytes("UTF-8"), 0)); + assertEquals(0xcd37895736a81cbcL, MurmurHash.hash64A_original("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0)); + } + + @Test + public void testHash() throws UnsupportedEncodingException { + // test data generated by running MurmurHashNeutral2.cpp + assertEquals(0x13577c9b, MurmurHash.hash("abc".getBytes("UTF-8"), 0)); + assertEquals(0x6fec441b, MurmurHash.hash("abc def ghi jkl ".getBytes("UTF-8"), 0)); + assertEquals(0x7e953277, MurmurHash.hash("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0)); + } + + @Test + public void testHashNio() throws UnsupportedEncodingException { + // test data generated by running MurmurHashNeutral2.cpp + assertEquals(0x13577c9b, MurmurHash.hash_original("abc".getBytes("UTF-8"), 0)); + assertEquals(0x6fec441b, MurmurHash.hash_original("abc def ghi jkl ".getBytes("UTF-8"), 0)); + assertEquals(0x7e953277, MurmurHash.hash_original("abc def ghi jkl moreGoo".getBytes("UTF-8"), 0)); + } +}