mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1403497 - /mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
Date Mon, 29 Oct 2012 20:50:45 GMT
Author: ssc
Date: Mon Oct 29 20:50:45 2012
New Revision: 1403497

URL: http://svn.apache.org/viewvc?rev=1403497&view=rev
Log:
MAHOUT-1089 SGD matrix factorization for rating prediction with user and item biases

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java?rev=1403497&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
(added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
Mon Oct 29 20:50:45 2012
@@ -0,0 +1,224 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/** Matrix factorization with user and item biases for rating prediction, trained with plain
vanilla SGD  */
+public class RatingSGDFactorizer extends AbstractFactorizer {
+
+  /** Multiplicative decay factor for learning_rate */
+  protected final double learningRateDecay;
+  /** Learning rate (step size) */
+  protected final double learningRate;
+  /** Parameter used to prevent overfitting. */
+  protected final double preventOverfitting;
+  /** Number of features used to compute this factorization */
+  protected final int numFeatures;
+  protected final int featureOffset = 3;
+  /** Number of iterations */
+  private final int numIterations;
+  /** Standard deviation for random initialization of features */
+  protected final double randomNoise;
+  /** User features */
+  protected double[][] userVectors;
+  /** Item features */
+  protected double[][] itemVectors;
+  protected final DataModel dataModel;
+  private long[] cachedUserIDs;
+  private long[] cachedItemIDs;
+
+  protected double biasLearningRate = 0.5;
+  protected double biasReg = 0.1;
+
+  /** place in user vector where the bias is stored */
+  protected static final int USER_BIAS_INDEX = 1;
+  /** place in item vector where the bias is stored */
+  protected static final int ITEM_BIAS_INDEX = 2;
+
+  public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws
TasteException {
+    this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
+  }
+
+  public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double
preventOverfitting,
+      double randomNoise, int numIterations, double learningRateDecay) throws TasteException
{
+    super(dataModel);
+    this.dataModel = dataModel;
+    this.numFeatures = numFeatures + featureOffset;
+    this.numIterations = numIterations;
+
+    this.learningRate = learningRate;
+    this.learningRateDecay = learningRateDecay;
+    this.preventOverfitting = preventOverfitting;
+    this.randomNoise = randomNoise;
+  }
+
+  protected void prepareTraining() throws TasteException {
+    Random random = RandomUtils.getRandom();
+    userVectors = new double[dataModel.getNumUsers()][numFeatures];
+    itemVectors = new double[dataModel.getNumItems()][numFeatures];
+
+    double globalAverage = getAveragePreference();
+    for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+      userVectors[userIndex][0] = globalAverage;
+      userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
+      userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains
item bias
+      for (int feature = featureOffset; feature < numFeatures; feature++) {
+        userVectors[userIndex][feature] = random.nextGaussian() * randomNoise;
+      }
+    }
+    for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
+      itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average
+      itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains
user bias
+      itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
+      for (int feature = featureOffset; feature < numFeatures; feature++) {
+        itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise;
+      }
+    }
+
+    cachePreferences();
+    shufflePreferences();
+  }
+
+  private int countPreferences() throws TasteException {
+    int numPreferences = 0;
+    LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+    while (userIDs.hasNext()) {
+      PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
+      numPreferences += preferencesFromUser.length();
+    }
+    return numPreferences;
+  }
+
+  private void cachePreferences() throws TasteException {
+    int numPreferences = countPreferences();
+    cachedUserIDs = new long[numPreferences];
+    cachedItemIDs = new long[numPreferences];
+
+    LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+    int index = 0;
+    while (userIDs.hasNext()) {
+      long userID = userIDs.nextLong();
+      PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
+      for (Preference preference : preferencesFromUser) {
+        cachedUserIDs[index] = userID;
+        cachedItemIDs[index] = preference.getItemID();
+        index++;
+      }
+    }
+  }
+
+  protected void shufflePreferences() {
+    Random random = RandomUtils.getRandom();
+    /* Durstenfeld shuffle */
+    for (int currentPos = cachedUserIDs.length - 1; currentPos > 0; currentPos--) {
+      int swapPos = random.nextInt(currentPos + 1);
+      swapCachedPreferences(currentPos, swapPos);
+    }
+  }
+
+  private void swapCachedPreferences(int posA, int posB) {
+    long tmpUserIndex = cachedUserIDs[posA];
+    long tmpItemIndex = cachedItemIDs[posA];
+
+    cachedUserIDs[posA] = cachedUserIDs[posB];
+    cachedItemIDs[posA] = cachedItemIDs[posB];
+
+    cachedUserIDs[posB] = tmpUserIndex;
+    cachedItemIDs[posB] = tmpItemIndex;
+  }
+
+  @Override
+  public Factorization factorize() throws TasteException {
+    prepareTraining();
+    double currentLearningRate = learningRate;
+
+
+    for (int it = 0; it < numIterations; it++) {
+      for (int index = 0; index < cachedUserIDs.length; index++) {
+        long userId = cachedUserIDs[index];
+        long itemId = cachedItemIDs[index];
+        float rating = dataModel.getPreferenceValue(userId, itemId);
+        updateParameters(userId, itemId, rating, currentLearningRate);
+      }
+      currentLearningRate *= learningRateDecay;
+    }
+    return createFactorization(userVectors, itemVectors);
+  }
+
+  double getAveragePreference() throws TasteException {
+    RunningAverage average = new FullRunningAverage();
+    LongPrimitiveIterator it = dataModel.getUserIDs();
+    while (it.hasNext()) {
+      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+        average.addDatum(pref.getValue());
+      }
+    }
+    return average.getAverage();
+  }
+
+  protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate)
+      throws TasteException {
+    int userIndex = userIndex(userID);
+    int itemIndex = itemIndex(itemID);
+
+    double[] userVector = userVectors[userIndex];
+    double[] itemVector = itemVectors[itemIndex];
+    double prediction = predictRating(userIndex, itemIndex);
+    double err = rating - prediction;
+
+    // adjust user bias
+    userVector[USER_BIAS_INDEX] +=
+        biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]);
+
+    // adjust item bias
+    itemVector[ITEM_BIAS_INDEX] +=
+        biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]);
+
+    // adjust features
+    for (int feature = featureOffset; feature < numFeatures; feature++) {
+      double userFeature = userVector[feature];
+      double itemFeature = itemVector[feature];
+
+      double deltaUserFeature = err * itemFeature - preventOverfitting * userFeature;
+      userVector[feature] += currentLearningRate * deltaUserFeature;
+
+      double deltaItemFeature = err * userFeature - preventOverfitting * itemFeature;
+      itemVector[feature] += currentLearningRate * deltaItemFeature;
+    }
+  }
+
+  private double predictRating(int userID, int itemID) {
+    double sum = 0;
+    for (int feature = 0; feature < numFeatures; feature++) {
+      sum += userVectors[userID][feature] * itemVectors[itemID][feature];
+    }
+    return sum;
+  }
+}
\ No newline at end of file



Mime
View raw message