mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1054567 - in /mahout/trunk: core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ math/src/main/java/org/apache/mahout/math/als/ math/src/test/java/org/apach...
Date Mon, 03 Jan 2011 09:18:48 GMT
Author: ssc
Date: Mon Jan  3 09:18:46 2011
New Revision: 1054567

URL: http://svn.apache.org/viewvc?rev=1054567&view=rev
Log:
MAHOUT-572 Non-distributed implementation of ALS-WR matrix factorization

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/als/
    mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/als/
    mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
Removed:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVD.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * factorizes the rating matrix using "Alternating-Least-Squares with Weighted-λ-Regularization" as described in
+ * the paper "Large-scale Collaborative Filtering for the Netflix Prize" available at
+ * {@see http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf}
+ */
+public class ALSWRFactorizer extends AbstractFactorizer {
+
+  private final DataModel dataModel;
+
+  /** number of features used to compute this factorization */
+  private final int numFeatures;
+  /** parameter to control the regularization */
+  private final double lambda;
+  /** number of iterations */
+  private final int numIterations;
+
+  private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizer.class);
+
+  public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations) throws TasteException {
+    super(dataModel);
+    this.dataModel = dataModel;
+    this.numFeatures = numFeatures;
+    this.lambda = lambda;
+    this.numIterations = numIterations;
+  }
+
+  @Override
+  public Factorization factorize() throws TasteException {
+    log.info("starting to compute the factorization...");
+    AlternateLeastSquaresSolver solver = new AlternateLeastSquaresSolver();
+
+    double[][] M = initializeM();
+    double[][] U = null;
+
+    for (int iteration = 0; iteration < numIterations; iteration++) {
+      log.info("iteration {}", iteration);
+
+      /* fix M - compute U */
+      U = new double[dataModel.getNumUsers()][numFeatures];
+
+      LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+      while (userIDsIterator.hasNext()) {
+        long userID = userIDsIterator.nextLong();
+        List<Vector> featureVectors = new ArrayList<Vector>();
+        LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator();
+        while (itemIDsFromUser.hasNext()) {
+          long itemID = itemIDsFromUser.nextLong();
+          featureVectors.add(new DenseVector(M[itemIndex(itemID)]));
+        }
+        PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
+        Vector userFeatures = solver.solve(featureVectors, ratingVector(userPrefs), lambda, numFeatures);
+        setFeatureColumn(U, userIndex(userID), userFeatures);
+      }
+
+      /* fix U - compute M */
+      M = new double[dataModel.getNumItems()][numFeatures];
+
+      LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+      while (itemIDsIterator.hasNext()) {
+        long itemID = itemIDsIterator.nextLong();
+        List<Vector> featureVectors = new ArrayList<Vector>();
+        for (Preference pref : dataModel.getPreferencesForItem(itemID)) {
+          long userID = pref.getUserID();
+          featureVectors.add(new DenseVector(U[userIndex(userID)]));
+        }
+        PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
+        Vector itemFeatures = solver.solve(featureVectors, ratingVector(itemPrefs), lambda, numFeatures);
+        setFeatureColumn(M, itemIndex(itemID), itemFeatures);
+      }
+    }
+
+    log.info("finished computation of the factorization...");
+    return createFactorization(U, M);
+  }
+
+  protected double[][] initializeM() throws TasteException {
+    Random random = RandomUtils.getRandom();
+    double[][] M = new double[dataModel.getNumItems()][numFeatures];
+
+    LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+    while (itemIDsIterator.hasNext()) {
+      long itemID = itemIDsIterator.nextLong();
+      int itemIDIndex = itemIndex(itemID);
+      M[itemIDIndex][0] = averateRating(itemID);
+      for (int feature = 1; feature < numFeatures; feature++) {
+        M[itemIDIndex][feature] = random.nextDouble() * 0.1;
+      }
+    }
+    return M;
+  }
+
+  protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) {
+    for (int feature = 0; feature < numFeatures; feature++) {
+      matrix[idIndex][feature] = vector.get(feature);
+    }
+  }
+
+  protected Vector ratingVector(PreferenceArray prefs) {
+    double[] ratings = new double[prefs.length()];
+    for (int n = 0; n < prefs.length(); n++) {
+      ratings[n] = prefs.get(n).getValue();
+    }
+    return new DenseVector(ratings);
+  }
+
+  protected double averateRating(long itemID) throws TasteException {
+    PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+    RunningAverage avg = new FullRunningAverage();
+    for (Preference pref : prefs) {
+      avg.addDatum(pref.getValue());
+    }
+    return avg.getAverage();
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * base class for {@link Factorizer}s, provides ID to index mapping
+ */
+public abstract class AbstractFactorizer implements Factorizer {
+
+  private final FastByIDMap<Integer> userIDMapping;
+  private final FastByIDMap<Integer> itemIDMapping;
+
+  protected AbstractFactorizer(DataModel dataModel) throws TasteException {
+    userIDMapping = createIDMapping(dataModel.getNumUsers(), dataModel.getUserIDs());
+    itemIDMapping = createIDMapping(dataModel.getNumItems(), dataModel.getItemIDs());
+  }
+
+  protected Factorization createFactorization(double[][] userFeatures, double[][] itemFeatures) {
+    return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+  }
+
+  protected Integer userIndex(long userID) {
+    return userIDMapping.get(userID);
+  }
+
+  protected Integer itemIndex(long itemID) {
+    return itemIDMapping.get(itemID);
+  }
+
+  private FastByIDMap<Integer> createIDMapping(int size, LongPrimitiveIterator idIterator) {
+    FastByIDMap<Integer> mapping = new FastByIDMap<Integer>(size);
+    int index = 0;
+    while (idIterator.hasNext()) {
+      mapping.put(idIterator.nextLong(), index++);
+    }
+    return mapping;
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Uses Single Value Decomposition to find the main features of the data set. Thanks to Simon Funk for the hints
+ * in the implementation, {@see http://sifter.org/~simon/journal/20061211.html}.
+ */
+public class ExpectationMaximizationSVDFactorizer extends AbstractFactorizer {
+
+  private final Random random;
+
+  private final double learningRate;
+  /** Parameter used to prevent overfitting. 0.02 is a good value. */
+  private final double preventOverfitting;
+
+  /** number of features used to compute this factorization */
+  private final int numFeatures;
+  /** number of iterations */
+  private final int numIterations;
+
+  /** user singular vectors */
+  private final double[][] leftVectors;
+
+  /** item singular vectors */
+  private final double[][] rightVectors;
+
+  private final DataModel dataModel;
+  private final List<Preference> cachedPreferences;
+
+  private static final Logger log = LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
+
+  public ExpectationMaximizationSVDFactorizer(DataModel dataModel, int numFeatures, int numIterations)
+      throws TasteException {
+    /* use the default parameters from the old SVDRecommender implementation */
+    this(dataModel, numFeatures, 0.005, 0.02, 0.005, numIterations);
+  }
+
+  public ExpectationMaximizationSVDFactorizer(DataModel dataModel, int numFeatures, double learningRate,
+      double preventOverfitting, double randomNoise, int numIterations) throws TasteException {
+    super(dataModel);
+    random = RandomUtils.getRandom();
+    this.dataModel = dataModel;
+    this.numFeatures = numFeatures;
+    this.numIterations = numIterations;
+
+    this.learningRate = learningRate;
+    this.preventOverfitting = preventOverfitting;
+
+    leftVectors = new double[dataModel.getNumUsers()][numFeatures];
+    rightVectors = new double[dataModel.getNumItems()][numFeatures];
+
+    double average = getAveragePreference();
+    double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
+
+    for (int feature = 0; feature < numFeatures; feature++) {
+      for (int userIndex = 0; userIndex < dataModel.getNumUsers(); userIndex++) {
+        leftVectors[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * randomNoise;
+      }
+      for (int itemIndex = 0; itemIndex < dataModel.getNumItems(); itemIndex++) {
+        rightVectors[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * randomNoise;
+      }
+    }
+
+    cachedPreferences = new ArrayList<Preference>(dataModel.getNumUsers());
+  }
+
+  @Override
+  public Factorization factorize() throws TasteException {
+    log.info("starting to compute the factorization...");
+
+    cachePreferences();
+    for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
+      log.info("iteration {}", currentIteration);
+      nextTrainStep();
+    }
+
+    log.info("finished computation of the factorization...");
+    return createFactorization(leftVectors, rightVectors);
+  }
+
+  void cachePreferences() throws TasteException {
+    cachedPreferences.clear();
+    LongPrimitiveIterator it = dataModel.getUserIDs();
+    while (it.hasNext()) {
+      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+        cachedPreferences.add(pref);
+      }
+    }
+  }
+
+  double getAveragePreference() throws TasteException {
+    RunningAverage average = new FullRunningAverage();
+    LongPrimitiveIterator it = dataModel.getUserIDs();
+    while (it.hasNext()) {
+      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+        average.addDatum(pref.getValue());
+      }
+    }
+    return average.getAverage();
+  }
+
+  void nextTrainStep() {
+    Collections.shuffle(cachedPreferences, random);
+    for (int feature = 0; feature < numFeatures; feature++) {
+      for (Preference pref : cachedPreferences) {
+        train(userIndex(pref.getUserID()), itemIndex(pref.getItemID()), feature, pref.getValue());
+      }
+    }
+  }
+
+  double getDotProduct(int userIndex, int itemIndex) {
+    double result = 1.0;
+    for (int feature = 0; feature < this.numFeatures; feature++) {
+      result += leftVectors[userIndex][feature] * rightVectors[itemIndex][feature];
+    }
+    return result;
+  }
+
+  void train(int userIndex, int itemIndex, int currentFeature, double value) {
+    double err = value - getDotProduct(userIndex, itemIndex);
+    double[] leftVector = leftVectors[userIndex];
+    double[] rightVector = rightVectors[itemIndex];
+    leftVector[currentFeature] += learningRate *
+        (err * rightVector[currentFeature] - preventOverfitting * leftVector[currentFeature]);
+    rightVector[currentFeature] += learningRate *
+        (err * leftVector[currentFeature] - preventOverfitting * rightVector[currentFeature]);
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+
+/**
+ * a factorization of the rating matrix
+ */
+public class Factorization {
+
+  /** used to find the rows in the user features matrix by userID */
+  private final FastByIDMap<Integer> userIDMapping;
+  /** used to find the rows in the item features matrix by itemID */
+  private final FastByIDMap<Integer> itemIDMapping;
+
+  /** user features matrix */
+  private final double[][] userFeatures;
+  /** item features matrix */
+  private final double[][] itemFeatures;
+
+  public Factorization(FastByIDMap<Integer> userIDMapping, FastByIDMap<Integer> itemIDMapping, double[][] userFeatures,
+      double[][] itemFeatures) {
+    this.userIDMapping = userIDMapping;
+    this.itemIDMapping = itemIDMapping;
+    this.userFeatures = userFeatures;
+    this.itemFeatures = itemFeatures;
+  }
+
+  public double[] getUserFeatures(long userID) throws NoSuchUserException {
+    Integer index = userIDMapping.get(userID);
+    if (index == null) {
+      throw new NoSuchUserException();
+    }
+    return userFeatures[index];
+  }
+
+  public double[] getItemFeatures(long itemID) throws NoSuchItemException {
+    Integer index = itemIDMapping.get(itemID);
+    if (index == null) {
+      throw new NoSuchItemException();
+    }
+    return itemFeatures[index];
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,29 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * Implementation must be able to create a factorization of a rating matrix
+ */
+public interface Factorizer {
+
+  Factorization factorize() throws TasteException;
+
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java?rev=1054567&r1=1054566&r2=1054567&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java Mon Jan  3 09:18:46 2011
@@ -17,171 +17,53 @@
 
 package org.apache.mahout.cf.taste.impl.recommender.svd;
 
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.List;
-import java.util.Random;
 import java.util.concurrent.Callable;
 
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import com.google.common.base.Preconditions;
 import org.apache.mahout.cf.taste.common.Refreshable;
 import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
 import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
 import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
 import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender;
 import org.apache.mahout.cf.taste.impl.recommender.TopItems;
 import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.Preference;
 import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
 import org.apache.mahout.cf.taste.recommender.IDRescorer;
 import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.common.RandomUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.base.Preconditions;
-
 /**
- * <p>
- * A {@link org.apache.mahout.cf.taste.recommender.Recommender} which uses Single Value Decomposition
- * to find the main features of the data set. Thanks to Simon Funk for the hints in the implementation.
+ * A {@link org.apache.mahout.cf.taste.recommender.Recommender} that uses matrix factorization (a projection of users
+ * and items onto a feature space)
  */
 public final class SVDRecommender extends AbstractRecommender {
-  
-  private static final Logger log = LoggerFactory.getLogger(SVDRecommender.class);
-  private static final Random random = RandomUtils.getRandom();
-  
+
+  private Factorization factorization;
   private final RefreshHelper refreshHelper;
-  
-  /** Number of features */
-  private final int numFeatures;
-  
-  private final FastByIDMap<Integer> userMap;
-  private final FastByIDMap<Integer> itemMap;
-  private final ExpectationMaximizationSVD emSvd;
-  private final List<Preference> cachedPreferences;
-  
-  /**
-   * @param numFeatures
-   *          the number of features
-   * @param initialSteps
-   *          number of initial training steps
-   */
-  public SVDRecommender(DataModel dataModel,
-                        CandidateItemsStrategy candidateItemsStrategy,
-                        int numFeatures,
-                        int initialSteps) throws TasteException {
+
+  private static final Logger log = LoggerFactory.getLogger(SVDRecommender.class);
+
+  public SVDRecommender(DataModel dataModel, Factorizer factorizer) throws TasteException {
+    this(dataModel, factorizer, getDefaultCandidateItemsStrategy());
+  }
+
+  public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy)
+      throws TasteException {
     super(dataModel, candidateItemsStrategy);
-    
-    this.numFeatures = numFeatures;
-    
-    int numUsers = dataModel.getNumUsers();
-    userMap = new FastByIDMap<Integer>(numUsers);
-    
-    int idx = 0;
-    LongPrimitiveIterator userIterator = dataModel.getUserIDs();
-    while (userIterator.hasNext()) {
-      userMap.put(userIterator.nextLong(), idx++);
-    }
-    
-    int numItems = dataModel.getNumItems();
-    itemMap = new FastByIDMap<Integer>(numItems);
-    
-    idx = 0;
-    LongPrimitiveIterator itemIterator = dataModel.getItemIDs();
-    while (itemIterator.hasNext()) {
-      itemMap.put(itemIterator.nextLong(), idx++);
-    }
-    
-    double average = getAveragePreference();
-    double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
-    
-    emSvd = new ExpectationMaximizationSVD(numUsers, numItems, numFeatures, defaultValue);
-    cachedPreferences = new ArrayList<Preference>(numUsers);
-    recachePreferences();
-    
+    factorization = factorizer.factorize();
     refreshHelper = new RefreshHelper(new Callable<Object>() {
       @Override
       public Object call() throws TasteException {
-        recachePreferences();
         // TODO: train again
         return null;
       }
     });
-    refreshHelper.addDependency(dataModel);
-    
-    train(initialSteps);
-  }
-
-  public SVDRecommender(DataModel dataModel,
-                        int numFeatures,
-                        int initialSteps) throws TasteException {
-    this(dataModel, getDefaultCandidateItemsStrategy(), numFeatures, initialSteps);
-  }
-  
-  private void recachePreferences() throws TasteException {
-    cachedPreferences.clear();
-    DataModel dataModel = getDataModel();
-    LongPrimitiveIterator it = dataModel.getUserIDs();
-    while (it.hasNext()) {
-      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
-        cachedPreferences.add(pref);
-      }
-    }
-  }
-  
-  private double getAveragePreference() throws TasteException {
-    RunningAverage average = new FullRunningAverage();
-    DataModel dataModel = getDataModel();
-    LongPrimitiveIterator it = dataModel.getUserIDs();
-    while (it.hasNext()) {
-      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
-        average.addDatum(pref.getValue());
-      }
-    }
-    return average.getAverage();
-  }
-  
-  public void train(int steps) {
-    for (int i = 0; i < steps; i++) {
-      nextTrainStep();
-    }
-  }
-  
-  private void nextTrainStep() {
-    Collections.shuffle(cachedPreferences, random);
-    for (int i = 0; i < numFeatures; i++) {
-      for (Preference pref : cachedPreferences) {
-        int useridx = userMap.get(pref.getUserID());
-        int itemidx = itemMap.get(pref.getItemID());
-        emSvd.train(useridx, itemidx, i, pref.getValue());
-      }
-    }
-  }
-  
-  private float predictRating(int user, int item) {
-    return (float) emSvd.getDotProduct(user, item);
-  }
-  
-  @Override
-  public float estimatePreference(long userID, long itemID) throws TasteException {
-    Integer useridx = userMap.get(userID);
-    if (useridx == null) {
-      throw new NoSuchUserException();
-    }
-    Integer itemidx = itemMap.get(itemID);
-    if (itemidx == null) {
-      throw new NoSuchItemException();
-    }
-    return predictRating(useridx, itemidx);
+    refreshHelper.addDependency(getDataModel());
   }
-  
+
   @Override
   public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
     Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
@@ -189,37 +71,43 @@ public final class SVDRecommender extend
 
     FastIDSet possibleItemIDs = getAllOtherItems(userID);
 
-    TopItems.Estimator<Long> estimator = new Estimator(userID);
-
     List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
-      estimator);
-
+        new Estimator(userID));
     log.debug("Recommendations are: {}", topItems);
+
     return topItems;
   }
-  
-  @Override
-  public void refresh(Collection<Refreshable> alreadyRefreshed) {
-    refreshHelper.refresh(alreadyRefreshed);
-  }
-  
+
+  /**
+   * a preference is estimated by computing the dot-product of the user and item feature vectors
+   */
   @Override
-  public String toString() {
-    return "SVDRecommender[numFeatures:" + numFeatures + ']';
+  public float estimatePreference(long userID, long itemID) throws TasteException {
+    double[] userFeatures = factorization.getUserFeatures(userID);
+    double[] itemFeatures = factorization.getItemFeatures(itemID);
+    double estimate = 0;
+    for (int feature = 0; feature < userFeatures.length; feature++) {
+      estimate += userFeatures[feature] * itemFeatures[feature];
+    }
+    return (float) estimate;
   }
-  
+
   private final class Estimator implements TopItems.Estimator<Long> {
-    
+
     private final long theUserID;
-    
+
     private Estimator(long theUserID) {
       this.theUserID = theUserID;
     }
-    
+
     @Override
     public double estimate(Long itemID) throws TasteException {
       return estimatePreference(theUserID, itemID);
     }
   }
-  
+
+  @Override
+  public void refresh(Collection<Refreshable> alreadyRefreshed) {
+    refreshHelper.refresh(alreadyRefreshed);
+  }
 }

Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,149 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class ALSWRFactorizerTest extends TasteTestCase {
+
+  ALSWRFactorizer factorizer;
+  DataModel dataModel;
+
+  /**
+   *  rating-matrix
+   *
+   *          burger  hotdog  berries  icecream
+   *  dog       5       5        2        -
+   *  rabbit    2       -        3        5
+   *  cow       -       5        -        3
+   *  donkey    3       -        -        5
+   */
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+
+    userData.put(1l, new GenericUserPreferenceArray(Arrays.asList(new Preference[] {
+        new GenericPreference(1l, 1l, 5f),
+        new GenericPreference(1l, 2l, 5f),
+        new GenericPreference(1l, 3l, 2f) })));
+
+    userData.put(2l, new GenericUserPreferenceArray(Arrays.asList(new Preference[] {
+        new GenericPreference(2l, 1l, 2f),
+        new GenericPreference(2l, 3l, 3f),
+        new GenericPreference(2l, 4l, 5f) })));
+
+    userData.put(3l, new GenericUserPreferenceArray(Arrays.asList(new Preference[] {
+        new GenericPreference(3l, 2l, 5f),
+        new GenericPreference(3l, 4l, 3f) })));
+
+    userData.put(4l, new GenericUserPreferenceArray(Arrays.asList(new Preference[] {
+        new GenericPreference(4l, 1l, 3f),
+        new GenericPreference(4l, 4l, 5f) })));
+
+    dataModel = new GenericDataModel(userData);
+    factorizer = new ALSWRFactorizer(dataModel, 3, 0.065, 10);
+  }
+
+  @Test
+  public void setFeatureColumn() throws Exception {
+    double[][] matrix = new double[3][3];
+    Vector vector = new DenseVector(new double[] { 0.5, 2.0, 1.5 });
+    int index = 1;
+
+    factorizer.setFeatureColumn(matrix, index, vector);
+
+    assertEquals(vector.get(0), matrix[index][0], EPSILON);
+    assertEquals(vector.get(1), matrix[index][1], EPSILON);
+    assertEquals(vector.get(2), matrix[index][2], EPSILON);
+  }
+
+  @Test
+  public void ratingVector() throws Exception {
+    PreferenceArray prefs = dataModel.getPreferencesFromUser(1);
+
+    Vector ratingVector = factorizer.ratingVector(prefs);
+
+    assertEquals(prefs.length(), ratingVector.getNumNondefaultElements());
+    assertEquals(prefs.get(0).getValue(), ratingVector.get(0), EPSILON);
+    assertEquals(prefs.get(1).getValue(), ratingVector.get(1), EPSILON);
+    assertEquals(prefs.get(2).getValue(), ratingVector.get(2), EPSILON);
+  }
+
+  @Test
+  public void averageRating() throws Exception {
+    assertEquals(2.5, factorizer.averateRating(3l), EPSILON);
+  }
+
+  @Test
+  public void initializeM() throws Exception {
+    double[][] M = factorizer.initializeM();
+
+    assertEquals(3.333333333, M[0][0], EPSILON);
+    assertEquals(5, M[1][0], EPSILON);
+    assertEquals(2.5, M[2][0], EPSILON);
+    assertEquals(4.333333333, M[3][0], EPSILON);
+
+    for (int itemIndex = 0; itemIndex < dataModel.getNumItems(); itemIndex++) {
+      for (int feature = 1; feature < 3; feature++ ) {
+        assertTrue(M[itemIndex][feature] >= 0);
+        assertTrue(M[itemIndex][feature] <= 0.1);
+      }
+    }
+  }
+
+  @Test
+  public void toyExample() throws Exception {
+
+    SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+   /* a hold out test would be better, but this is just a toy example so we only check that the
+    * factorization is close to the original matrix */
+    RunningAverage avg = new FullRunningAverage();
+    LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+    while (userIDs.hasNext()) {
+      long userID = userIDs.nextLong();
+      for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+        double rating = pref.getValue();
+        double estimate = svdRecommender.estimatePreference(userID, pref.getItemID());
+        double err = rating - estimate;
+        avg.addDatum(err * err);
+      }
+    }
+
+    double rmse = Math.sqrt(avg.getAverage());
+    assertTrue(rmse < 0.2d);
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.easymock.classextension.EasyMock;
+import org.junit.Test;
+
+import java.util.List;
+
+public class SVDRecommenderTest extends TasteTestCase {
+
+  @Test
+  public void estimatePreference() throws Exception {
+    DataModel dataModel = EasyMock.createMock(DataModel.class);
+    Factorizer factorizer = EasyMock.createMock(Factorizer.class);
+    Factorization factorization = EasyMock.createMock(Factorization.class);
+
+    EasyMock.expect(factorizer.factorize()).andReturn(factorization);
+    EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[] { 0.4, 2 });
+    EasyMock.expect(factorization.getItemFeatures(5L)).andReturn(new double[] { 1, 0.3 });
+    EasyMock.replay(dataModel, factorizer, factorization);
+
+    SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+    float estimate = svdRecommender.estimatePreference(1L, 5L);
+    assertEquals(1, estimate, EPSILON);
+
+    EasyMock.verify(dataModel, factorizer, factorization);
+  }
+
+  @Test
+  public void recommend() throws Exception {
+    DataModel dataModel = EasyMock.createMock(DataModel.class);
+    CandidateItemsStrategy candidateItemsStrategy = EasyMock.createMock(CandidateItemsStrategy.class);
+    Factorizer factorizer = EasyMock.createMock(Factorizer.class);
+    Factorization factorization = EasyMock.createMock(Factorization.class);
+
+    FastIDSet candidateItems = new FastIDSet();
+    candidateItems.add(5L);
+    candidateItems.add(3L);
+
+    EasyMock.expect(factorizer.factorize()).andReturn(factorization);
+    EasyMock.expect(candidateItemsStrategy.getCandidateItems(1L, dataModel)).andReturn(candidateItems);
+    EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[] { 0.4, 2 });
+    EasyMock.expect(factorization.getItemFeatures(5L)).andReturn(new double[] { 1, 0.3 });
+    EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[] { 0.4, 2 });
+    EasyMock.expect(factorization.getItemFeatures(3L)).andReturn(new double[] { 2, 0.6 });
+
+    EasyMock.replay(dataModel, candidateItemsStrategy, factorizer, factorization);
+
+    SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer, candidateItemsStrategy);
+
+    List<RecommendedItem> recommendedItems = svdRecommender.recommend(1L, 5);
+    assertEquals(2, recommendedItems.size());
+    assertEquals(3L, recommendedItems.get(0).getItemID());
+    assertEquals(2f, recommendedItems.get(0).getValue(), EPSILON);
+    assertEquals(5L, recommendedItems.get(1).getItemID());
+    assertEquals(1f, recommendedItems.get(1).getValue(), EPSILON);
+
+    EasyMock.verify(dataModel, candidateItemsStrategy, factorizer, factorization);
+  }
+}

Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * {@see http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf}
+ */
+public class AlternateLeastSquaresSolver {
+
+  public Vector solve(List<Vector> featureVectors, Vector ratingVector, double lambda, int numFeatures) {
+
+    Preconditions.checkNotNull(featureVectors);
+    Preconditions.checkArgument(!featureVectors.isEmpty());
+    Preconditions.checkNotNull(ratingVector);    
+    Preconditions.checkArgument(featureVectors.size() == ratingVector.getNumNondefaultElements());
+
+    int nui = ratingVector.getNumNondefaultElements();
+
+    Matrix MiIi = createMiIi(featureVectors, numFeatures);
+    Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);
+
+    /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
+    Matrix Ai = addLambdaTimesNuiTimesE(MiIi.times(MiIi.transpose()), lambda, nui);
+    /* compute Vi = MIi * t(R(i,Ii)) */
+    Matrix Vi = MiIi.times(RiIiMaybeTransposed);
+    /* compute ui = inverse(Ai) * Vi */
+    return solve(Ai, Vi);
+  }
+
+  Vector solve(Matrix Ai, Matrix Vi) {
+    return new QRDecomposition(Ai).solve(Vi).getColumn(0);
+  }
+
+  protected Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) {
+    Preconditions.checkArgument(matrix.numCols() == matrix.numRows());
+    double lambdaTimesNui = lambda * nui;
+    for (int n = 0; n < matrix.numCols(); n++) {
+      matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui);
+    }
+    return matrix;
+  }
+
+  protected Matrix createMiIi(List<Vector> featureVectors, int numFeatures) {
+    Matrix MiIi = new DenseMatrix(numFeatures, featureVectors.size());
+    for (int n = 0; n < featureVectors.size(); n++) {
+      Vector featureVector = featureVectors.get(n);
+      for (int m = 0; m < numFeatures; m++) {
+        MiIi.setQuick(m, n, featureVector.get(m));
+      }
+    }
+    return MiIi;
+  }
+
+  protected Matrix createRiIiMaybeTransposed(Vector ratingVector) {
+    Preconditions.checkArgument(ratingVector.isSequentialAccess());
+    Matrix RiIiMaybeTransposed = new DenseMatrix(ratingVector.getNumNondefaultElements(), 1);
+    Iterator<Vector.Element> ratingsIterator = ratingVector.iterateNonZero();
+    int index = 0;
+    while (ratingsIterator.hasNext()) {
+      Vector.Element elem = ratingsIterator.next();
+      RiIiMaybeTransposed.setQuick(index++, 0, elem.get());
+    }
+    return RiIiMaybeTransposed;
+  }
+}

Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java?rev=1054567&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java Mon Jan  3 09:18:46 2011
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class AlternateLeastSquaresSolverTest extends MahoutTestCase {
+
+  AlternateLeastSquaresSolver solver;
+
+  @Before
+  public void setup() {
+    solver = new AlternateLeastSquaresSolver();
+  }
+
+  @Test
+  public void addLambdaTimesNuiTimesE() {
+    int nui = 5;
+    double lambda = 0.2;
+    Matrix matrix = new SparseMatrix(new int[] { 5, 5 });
+
+    solver.addLambdaTimesNuiTimesE(matrix, lambda, nui);
+
+    for (int n = 0; n < 5; n++) {
+      assertEquals(1.0, matrix.getQuick(n, n), EPSILON);
+    }
+  }
+
+  @Test
+  public void createMiIi() {
+    Vector f1 = new DenseVector(new double[] { 1, 2, 3 });
+    Vector f2 = new DenseVector(new double[] { 4, 5, 6 });
+
+    Matrix miIi = solver.createMiIi(Arrays.asList(f1, f2), 3);
+
+    assertEquals(1.0, miIi.getQuick(0, 0), EPSILON);
+    assertEquals(2.0, miIi.getQuick(1, 0), EPSILON);
+    assertEquals(3.0, miIi.getQuick(2, 0), EPSILON);
+    assertEquals(4.0, miIi.getQuick(0, 1), EPSILON);
+    assertEquals(5.0, miIi.getQuick(1, 1), EPSILON);
+    assertEquals(6.0, miIi.getQuick(2, 1), EPSILON);
+  }
+
+  @Test
+  public void createRiIiMaybeTransposed() {
+    Vector ratings = new SequentialAccessSparseVector(3);
+    ratings.setQuick(1, 1.0);
+    ratings.setQuick(3, 3.0);
+    ratings.setQuick(5, 5.0);
+
+    Matrix riIiMaybeTransposed = solver.createRiIiMaybeTransposed(ratings);
+    assertEquals(1, riIiMaybeTransposed.numCols(), 1);
+    assertEquals(3, riIiMaybeTransposed.numRows(), 3);
+
+    assertEquals(1.0, riIiMaybeTransposed.getQuick(0, 0), EPSILON);
+    assertEquals(3.0, riIiMaybeTransposed.getQuick(1, 0), EPSILON);
+    assertEquals(5.0, riIiMaybeTransposed.getQuick(2, 0), EPSILON);
+  }
+
+  @Test
+  public void createRiIiMaybeTransposedExceptionOnNonSequentialVector() {
+    Vector ratings = new RandomAccessSparseVector(3);
+    ratings.setQuick(1, 1.0);
+    ratings.setQuick(3, 3.0);
+    ratings.setQuick(5, 5.0);
+
+    try {
+      solver.createRiIiMaybeTransposed(ratings);
+      fail();
+    } catch (IllegalArgumentException e) {}
+  }
+
+}



Mime
View raw message