Return-Path: Delivered-To: apmail-mahout-commits-archive@www.apache.org Received: (qmail 67202 invoked from network); 7 Apr 2011 21:11:00 -0000 Received: from hermes.apache.org (HELO mail.apache.org) (140.211.11.3) by minotaur.apache.org with SMTP; 7 Apr 2011 21:11:00 -0000 Received: (qmail 96740 invoked by uid 500); 7 Apr 2011 21:10:59 -0000 Delivered-To: apmail-mahout-commits-archive@mahout.apache.org Received: (qmail 96702 invoked by uid 500); 7 Apr 2011 21:10:59 -0000 Mailing-List: contact commits-help@mahout.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@mahout.apache.org Delivered-To: mailing list commits@mahout.apache.org Received: (qmail 96695 invoked by uid 99); 7 Apr 2011 21:10:59 -0000 Received: from athena.apache.org (HELO athena.apache.org) (140.211.11.136) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 07 Apr 2011 21:10:59 +0000 X-ASF-Spam-Status: No, hits=-2000.0 required=5.0 tests=ALL_TRUSTED X-Spam-Check-By: apache.org Received: from [140.211.11.4] (HELO eris.apache.org) (140.211.11.4) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 07 Apr 2011 21:10:49 +0000 Received: by eris.apache.org (Postfix, from userid 65534) id 316FD2388980; Thu, 7 Apr 2011 21:10:29 +0000 (UTC) Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit Subject: svn commit: r1090013 - in /mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1: ./ svd/ Date: Thu, 07 Apr 2011 21:10:29 -0000 To: commits@mahout.apache.org From: ssc@apache.org X-Mailer: svnmailer-1.0.8 Message-Id: <20110407211029.316FD2388980@eris.apache.org> Author: ssc Date: Thu Apr 7 21:10:28 2011 New Revision: 1090013 URL: http://svn.apache.org/viewvc?rev=1090013&view=rev Log: MAHOUT-657 Sample code to apply SVD to the KDD data Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java?rev=1090013&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java Thu Apr 7 21:10:28 2011 @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EstimateConverter { + + private static final Logger log = LoggerFactory.getLogger(EstimateConverter.class); + + private EstimateConverter() {} + + public static byte convert(double estimate, long userID, long itemID) { + if (Double.isNaN(estimate)) { + log.warn("Unable to compute estimate for user {}, item {}", userID, itemID); + return 0x7F; + } else { + int scaledEstimate = (int) (estimate * 2.55); + if (scaledEstimate > 255) { + scaledEstimate = 255; + } else if (scaledEstimate < 0) { + scaledEstimate = 0; + } + return (byte) scaledEstimate; + } + } +} Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java?rev=1090013&r1=1090012&r2=1090013&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java Thu Apr 7 21:10:28 2011 @@ -54,19 +54,7 @@ final class Track1Callable implements Ca log.warn("Unknown item {}; OK unless this is the real contest data", itemID); continue; } - - if (Double.isNaN(estimate)) { - log.warn("Unable to compute estimate for user {}, item {}", userID, itemID); - result[i] = 0x7F; - } else { - int scaledEstimate = (int) (estimate * 2.55); - if (scaledEstimate > 255) { - scaledEstimate = 255; - } else if (scaledEstimate < 0) { - scaledEstimate = 0; - } - result[i] = (byte) scaledEstimate; - } + result[i] = EstimateConverter.convert(estimate, userID, itemID); } if (COUNT.incrementAndGet() % 10000 == 0) { Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java?rev=1090013&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java Thu Apr 7 21:10:28 2011 @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; + +import java.util.ArrayList; +import java.util.List; + +/** + * can be used to drop {@link DataModel}s into {@link ParallelArraysSGDFactorizer} + */ +public class DataModelFactorizablePreferences implements FactorizablePreferences { + + private final FastIDSet userIDs; + private final FastIDSet itemIDs; + + private final List preferences; + + private final float minPreference; + private final float maxPreference; + + public DataModelFactorizablePreferences(DataModel dataModel) { + + minPreference = dataModel.getMinPreference(); + maxPreference = dataModel.getMaxPreference(); + + try { + userIDs = new FastIDSet(dataModel.getNumUsers()); + itemIDs = new FastIDSet(dataModel.getNumItems()); + preferences = new ArrayList(); + + LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs(); + while (userIDsIterator.hasNext()) { + long userID = userIDsIterator.nextLong(); + userIDs.add(userID); + for (Preference preference : dataModel.getPreferencesFromUser(userID)) { + itemIDs.add(preference.getItemID()); + preferences.add(new GenericPreference(userID, preference.getItemID(), preference.getValue())); + } + } + } catch (Exception e) { + throw new IllegalStateException("Unable to create factorizable preferences!", e); + } + } + + @Override + public LongPrimitiveIterator getUserIDs() { + return userIDs.iterator(); + } + + @Override + public LongPrimitiveIterator getItemIDs() { + return itemIDs.iterator(); + } + + @Override + public Iterable getPreferences() { + return preferences; + } + + @Override + public float getMinPreference() { + return minPreference; + } + + @Override + public float getMaxPreference() { + return maxPreference; + } + + @Override + public int numUsers() { + return userIDs.size(); + } + + @Override + public int numItems() { + return itemIDs.size(); + } + + @Override + public int numPreferences() { + return preferences.size(); + } +} + Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java?rev=1090013&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java Thu Apr 7 21:10:28 2011 @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.Preference; + +/** + * models the necessary input for {@link ParallelArraysSGDFactorizer} + */ +public interface FactorizablePreferences { + + LongPrimitiveIterator getUserIDs(); + + LongPrimitiveIterator getItemIDs(); + + Iterable getPreferences(); + + float getMinPreference(); + + float getMaxPreference(); + + int numUsers(); + + int numItems(); + + int numPreferences(); + +} Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java?rev=1090013&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java Thu Apr 7 21:10:28 2011 @@ -0,0 +1,159 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import org.apache.mahout.cf.taste.example.kddcup.DataFileIterator; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.Preference; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; + +public class KDDCupFactorizablePreferences implements FactorizablePreferences { + + private final File dataFile; + + public KDDCupFactorizablePreferences(File dataFile) { + this.dataFile = dataFile; + } + + @Override + public LongPrimitiveIterator getUserIDs() { + return new FixedSizeLongIterator(numUsers()); + } + + @Override + public LongPrimitiveIterator getItemIDs() { + return new FixedSizeLongIterator(numItems()); + } + + @Override + public Iterable getPreferences() { + return new Iterable() { + @Override + public Iterator iterator() { + try { + return new DataFilePreferencesIterator(new DataFileIterator(dataFile)); + } catch (IOException e) { + throw new IllegalStateException("Cannot iterate over datafile!", e); + } + } + }; + } + + @Override + public float getMinPreference() { + return 0; + } + + @Override + public float getMaxPreference() { + return 100; + } + + @Override + public int numUsers() { + return 1000990; + } + + @Override + public int numItems() { + return 624961; + } + + @Override + public int numPreferences() { + return 252800275; + } + + static class DataFilePreferencesIterator implements Iterator { + + private final DataFileIterator dataFileIterator; + + Iterator currentUserPrefsIterator; + + public DataFilePreferencesIterator(DataFileIterator dataFileIterator) { + this.dataFileIterator = dataFileIterator; + } + + @Override + public boolean hasNext() { + if (currentUserPrefsIterator != null && currentUserPrefsIterator.hasNext()) { + return true; + } else { + return dataFileIterator.hasNext(); + } + } + + @Override + public Preference next() { + if (currentUserPrefsIterator == null || !currentUserPrefsIterator.hasNext()) { + currentUserPrefsIterator = dataFileIterator.next().getFirst().iterator(); + } + return currentUserPrefsIterator.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + + static class FixedSizeLongIterator implements LongPrimitiveIterator { + + private long currentValue; + private final long maximum; + + public FixedSizeLongIterator(long maximum) { + this.maximum = maximum; + currentValue = 0; + } + + @Override + public long nextLong() { + return currentValue++; + } + + @Override + public long peek() { + return currentValue; + } + + @Override + public void skip(int n) { + currentValue += n; + } + + @Override + public boolean hasNext() { + return currentValue < maximum; + } + + @Override + public Long next() { + return ++currentValue; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + +} Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java?rev=1090013&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java Thu Apr 7 21:10:28 2011 @@ -0,0 +1,257 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Random; + +/** + * {@link org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer} based on Simon Funk's famous article "Netflix Update: Try this at home" + * {@see http://sifter.org/~simon/journal/20061211.html}. + * + * Attempts to be as memory efficient as possible, only iterating once through the {@link FactorizablePreferences} or {@link DataModel} while + * copying everything to primitive arrays. Learning works in place on these datastructures after that. + * + */ +public class ParallelArraysSGDFactorizer implements Factorizer { + + public static final double DEFAULT_LEARNING_RATE = 0.005; + public static final double DEFAULT_PREVENT_OVERFITTING = 0.02; + public static final double DEFAULT_RANDOM_NOISE = 0.005; + + private final int numFeatures; + private final int numIterations; + private final float minPreference; + private final float maxPreference; + + private final Random random; + private final double learningRate; + private final double preventOverfitting; + + private final FastByIDMap userIDMapping; + private final FastByIDMap itemIDMapping; + + private final double[][] userFeatures; + private final double[][] itemFeatures; + + private final int[] userIndexes; + private final int[] itemIndexes; + private final float[] values; + + private final double defaultValue; + private final double interval; + private final double[] cachedEstimates; + + + private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class); + + public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) { + this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE, + DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE); + } + + public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate, + double preventOverfitting, double randomNoise) { + this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting, + randomNoise); + } + + public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) { + this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING, + DEFAULT_RANDOM_NOISE); + } + + public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures, + int numIterations, double learningRate, double preventOverfitting, double randomNoise) { + + this.numFeatures = numFeatures; + this.numIterations = numIterations; + minPreference = factorizablePreferences.getMinPreference(); + maxPreference = factorizablePreferences.getMaxPreference(); + + this.random = RandomUtils.getRandom(); + this.learningRate = learningRate; + this.preventOverfitting = preventOverfitting; + + int numUsers = factorizablePreferences.numUsers(); + int numItems = factorizablePreferences.numItems(); + int numPrefs = factorizablePreferences.numPreferences(); + + log.info("Mapping {} users...", numUsers); + userIDMapping = new FastByIDMap(numUsers); + int index = 0; + LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs(); + while (userIterator.hasNext()) { + userIDMapping.put(userIterator.nextLong(), index++); + } + + log.info("Mapping {} items", numItems); + itemIDMapping = new FastByIDMap(numItems); + index = 0; + LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs(); + while (itemIterator.hasNext()) { + itemIDMapping.put(itemIterator.nextLong(), index++); + } + + this.userIndexes = new int[numPrefs]; + this.itemIndexes = new int[numPrefs]; + this.values = new float[numPrefs]; + this.cachedEstimates = new double[numPrefs]; + + index = 0; + log.info("Loading {} preferences into memory", numPrefs); + RunningAverage average = new FullRunningAverage(); + for (Preference preference : factorizablePreferences.getPreferences()) { + userIndexes[index] = userIDMapping.get(preference.getUserID()); + itemIndexes[index] = itemIDMapping.get(preference.getItemID()); + values[index] = preference.getValue(); + cachedEstimates[index] = 0; + + average.addDatum(preference.getValue()); + + index++; + if (index % 1000000 == 0) { + log.info("Processed {} preferences", index); + } + } + log.info("Processed {} preferences, done.", index); + + double averagePreference = average.getAverage(); + log.info("Average preference value is {}", averagePreference); + + double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference(); + defaultValue = Math.sqrt((averagePreference - (prefInterval * 0.1)) / numFeatures); + interval = (prefInterval * 0.1) / numFeatures; + + userFeatures = new double[numUsers][numFeatures]; + itemFeatures = new double[numItems][numFeatures]; + + log.info("Initializing feature vectors..."); + for (int feature = 0; feature < numFeatures; feature++) { + for (int userIndex = 0; userIndex < numUsers; userIndex++) { + userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise; + } + for (int itemIndex = 0; itemIndex < numItems; itemIndex++) { + itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise; + } + } + } + + @Override + public Factorization factorize() throws TasteException { + for (int feature = 0; feature < numFeatures; feature++) { + log.info("Shuffling preferences..."); + shufflePreferences(); + log.info("Starting training of feature {} ...", feature); + for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) { + if (currentIteration != (numIterations - 1)) { + trainingIteration(feature); + } else { + double rmse = trainingIterationWithRmse(feature); + log.info("Finished training feature {} with RMSE {}", feature, rmse); + } + } + if (feature < numFeatures - 1) { + log.info("Updating cache..."); + for (int index = 0; index < userIndexes.length; index++) { + cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index], + false); + } + } + } + log.info("Factorization done"); + return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); + } + + private void trainingIteration(int feature) { + for (int index = 0; index < userIndexes.length; index++) { + train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]); + } + } + + private double trainingIterationWithRmse(int feature) { + double rmse = 0; + for (int index = 0; index < userIndexes.length; index++) { + double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]); + rmse += (error * error); + } + return Math.sqrt(rmse / (double) userIndexes.length); + } + + private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) { + double sum = cachedEstimate; + sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature]; + if (trailing) { + sum += (numFeatures - feature - 1) * ((defaultValue + interval) * (defaultValue + interval)); + if (sum > maxPreference) { + sum = maxPreference; + } else if (sum < minPreference) { + sum = minPreference; + } + } + return sum; + } + + public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) { + double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true); + double[] userVector = userFeatures[userIndex]; + double[] itemVector = itemFeatures[itemIndex]; + + userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]); + itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]); + + return error; + } + + protected void shufflePreferences() { + /* Durstenfeld shuffle */ + for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) { + int swapPos = random.nextInt(currentPos + 1); + swapPreferences(currentPos, swapPos); + } + } + + private void swapPreferences(int posA, int posB) { + int tmpUserIndex = userIndexes[posA]; + int tmpItemIndex = itemIndexes[posA]; + float tmpValue = values[posA]; + double tmpEstimate = cachedEstimates[posA]; + + userIndexes[posA] = userIndexes[posB]; + itemIndexes[posA] = itemIndexes[posB]; + values[posA] = values[posB]; + cachedEstimates[posA] = cachedEstimates[posB]; + + userIndexes[posB] = tmpUserIndex; + itemIndexes[posB] = tmpItemIndex; + values[posB] = tmpValue; + cachedEstimates[posB] = tmpEstimate; + } +} Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java?rev=1090013&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java Thu Apr 7 21:10:28 2011 @@ -0,0 +1,121 @@ +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable; +import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; +import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStream; + +/** + * run an SVD factorization of the KDD track1 data. + * + * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M + * + */ +public class Track1SVDRunner { + + private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class); + + public static void main(String[] args) throws Exception { + + if (args.length != 2) { + System.err.println("Necessary arguments: "); + System.exit(-1); + } + + File dataFileDirectory = new File(args[0]); + if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) { + throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory); + } + + File resultFile = new File(args[1]); + + /* the knobs to turn */ + int numFeatures = 20; + int numIterations = 5; + double learningRate = 0.0001; + double preventOverfitting = 0.002; + double randomNoise = 0.0001; + + + KDDCupFactorizablePreferences factorizablePreferences = + new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory)); + + Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations, + learningRate, preventOverfitting, randomNoise); + + Factorization factorization = sgdFactorizer.factorize(); + + log.info("Estimating validation preferences..."); + int prefsProcessed = 0; + RunningAverage average = new FullRunningAverage(); + DataFileIterable validations = new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory)); + for (Pair validationPair : validations) { + for (Preference validationPref : validationPair.getFirst()) { + double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(), + factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference()); + double error = validationPref.getValue() - estimate; + average.addDatum(error * error); + prefsProcessed++; + if (prefsProcessed % 100000 == 0) { + log.info("Computed {} estimations", prefsProcessed); + } + } + } + log.info("Computed {} estimations, done.", prefsProcessed); + + double rmse = Math.sqrt(average.getAverage()); + log.info("RMSE {}", rmse); + + log.info("Estimating test preferences..."); + OutputStream out = null; + try { + out = new BufferedOutputStream(new FileOutputStream(resultFile)); + + DataFileIterable tests = new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory)); + for (Pair testPair : tests) { + for (Preference testPref : testPair.getFirst()) { + double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(), + factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference()); + byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID()); + out.write(result); + } + } + } finally { + out.flush(); + out.close(); + } + log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath()); + } + + static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference, + float maxPreference) throws NoSuchUserException, NoSuchItemException { + double[] userFeatures = factorization.getUserFeatures(userID); + double[] itemFeatures = factorization.getItemFeatures(itemID); + double estimate = 0; + for (int feature = 0; feature < userFeatures.length; feature++) { + estimate += userFeatures[feature] * itemFeatures[feature]; + } + if (estimate < minPreference) { + estimate = minPreference; + } else if (estimate > maxPreference) { + estimate = maxPreference; + } + return estimate; + } + +}