mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1070622 - in /mahout/trunk/core/src: main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Date Mon, 14 Feb 2011 19:40:15 GMT
Author: ssc
Date: Mon Feb 14 19:40:15 2011
New Revision: 1070622

URL: http://svn.apache.org/viewvc?rev=1070622&view=rev
Log:
MAHOUT-606 Parallelize non-distributed ALSWRFactorizer

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java?rev=1070622&r1=1070621&r2=1070622&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
Mon Feb 14 19:40:15 2011
@@ -34,6 +34,9 @@ import org.slf4j.LoggerFactory;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 
 /**
  * factorizes the rating matrix using "Alternating-Least-Squares with Weighted-λ-Regularization"
as described in
@@ -61,75 +64,146 @@ public class ALSWRFactorizer extends Abs
     this.numIterations = numIterations;
   }
 
+  static class Features {
+
+    private final DataModel dataModel;
+    private final int numFeatures;
+
+    private double[][] M;
+    private double[][] U;
+
+    Features(ALSWRFactorizer factorizer) throws TasteException {
+      this.dataModel = factorizer.dataModel;
+      this.numFeatures = factorizer.numFeatures;
+      Random random = RandomUtils.getRandom();
+      M = new double[this.dataModel.getNumItems()][this.numFeatures];
+      LongPrimitiveIterator itemIDsIterator = this.dataModel.getItemIDs();
+      while (itemIDsIterator.hasNext()) {
+        long itemID = itemIDsIterator.nextLong();
+        int itemIDIndex = factorizer.itemIndex(itemID);
+        M[itemIDIndex][0] = averateRating(itemID);
+        for (int feature = 1; feature < this.numFeatures; feature++) {
+          M[itemIDIndex][feature] = random.nextDouble() * 0.1;
+        }
+      }
+      U = new double[this.dataModel.getNumUsers()][this.numFeatures];
+    }
+
+    double[][] getM() {
+      return M;
+    }
+
+    double[][] getU() {
+      return U;
+    }
+
+    DenseVector getUserFeatureColumn(int index) {
+      return new DenseVector(U[index]);
+    }
+
+    DenseVector getItemFeatureColumn(int index) {
+      return new DenseVector(M[index]);
+    }
+
+    void setFeatureColumnInU(int idIndex, Vector vector) {
+      setFeatureColumn(U, idIndex, vector);
+    }
+
+    void setFeatureColumnInM(int idIndex, Vector vector) {
+      setFeatureColumn(M, idIndex, vector);
+    }
+
+    protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) {
+      for (int feature = 0; feature < numFeatures; feature++) {
+        matrix[idIndex][feature] = vector.get(feature);
+      }
+    }
+
+    protected double averateRating(long itemID) throws TasteException {
+      PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+      RunningAverage avg = new FullRunningAverage();
+      for (Preference pref : prefs) {
+        avg.addDatum(pref.getValue());
+      }
+      return avg.getAverage();
+    }
+  }
+
   @Override
   public Factorization factorize() throws TasteException {
     log.info("starting to compute the factorization...");
-    AlternateLeastSquaresSolver solver = new AlternateLeastSquaresSolver();
-
-    double[][] M = initializeM();
-    double[][] U = null;
+    final AlternateLeastSquaresSolver solver = new AlternateLeastSquaresSolver();
+    final Features features = new Features(this);
 
     for (int iteration = 0; iteration < numIterations; iteration++) {
       log.info("iteration {}", iteration);
 
       /* fix M - compute U */
-      U = new double[dataModel.getNumUsers()][numFeatures];
-
+      ExecutorService queue = createQueue();
       LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
-      while (userIDsIterator.hasNext()) {
-        long userID = userIDsIterator.nextLong();
-        List<Vector> featureVectors = new ArrayList<Vector>();
-        LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator();
-        while (itemIDsFromUser.hasNext()) {
-          long itemID = itemIDsFromUser.nextLong();
-          featureVectors.add(new DenseVector(M[itemIndex(itemID)]));
+      try {
+        while (userIDsIterator.hasNext()) {
+          final long userID = userIDsIterator.nextLong();
+          final LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator();
+          final PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
+          queue.execute(new Runnable() {
+            @Override
+            public void run() {
+              List<Vector> featureVectors = new ArrayList<Vector>();
+              while (itemIDsFromUser.hasNext()) {
+                long itemID = itemIDsFromUser.nextLong();
+                featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID)));
+              }
+              Vector userFeatures = solver.solve(featureVectors, ratingVector(userPrefs),
lambda, numFeatures);
+              features.setFeatureColumnInU(userIndex(userID), userFeatures);
+            }
+          });
+        }
+      } finally {
+        queue.shutdown();
+        try {
+          queue.awaitTermination(dataModel.getNumUsers(), TimeUnit.SECONDS);
+        } catch (InterruptedException e) {
+          throw new IllegalStateException("Error when computing user features", e);
         }
-        PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
-        Vector userFeatures = solver.solve(featureVectors, ratingVector(userPrefs), lambda,
numFeatures);
-        setFeatureColumn(U, userIndex(userID), userFeatures);
       }
 
       /* fix U - compute M */
-      M = new double[dataModel.getNumItems()][numFeatures];
-
+      queue = createQueue();
       LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
-      while (itemIDsIterator.hasNext()) {
-        long itemID = itemIDsIterator.nextLong();
-        List<Vector> featureVectors = new ArrayList<Vector>();
-        for (Preference pref : dataModel.getPreferencesForItem(itemID)) {
-          long userID = pref.getUserID();
-          featureVectors.add(new DenseVector(U[userIndex(userID)]));
+      try {
+        while (itemIDsIterator.hasNext()) {
+          final long itemID = itemIDsIterator.nextLong();
+          final PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
+          queue.execute(new Runnable() {
+            @Override
+            public void run() {
+              List<Vector> featureVectors = new ArrayList<Vector>();
+              for (Preference pref : itemPrefs) {
+                long userID = pref.getUserID();
+                featureVectors.add(features.getUserFeatureColumn(userIndex(userID)));
+              }
+              Vector itemFeatures = solver.solve(featureVectors, ratingVector(itemPrefs),
lambda, numFeatures);
+              features.setFeatureColumnInM(itemIndex(itemID), itemFeatures);
+            }
+          });
+        }
+      } finally {
+        queue.shutdown();
+        try {
+          queue.awaitTermination(dataModel.getNumItems(), TimeUnit.SECONDS);
+        } catch (InterruptedException e) {
+          throw new IllegalStateException("Error when computing item features", e);
         }
-        PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
-        Vector itemFeatures = solver.solve(featureVectors, ratingVector(itemPrefs), lambda,
numFeatures);
-        setFeatureColumn(M, itemIndex(itemID), itemFeatures);
       }
     }
 
     log.info("finished computation of the factorization...");
-    return createFactorization(U, M);
+    return createFactorization(features.getU(), features.getM());
   }
 
-  protected double[][] initializeM() throws TasteException {
-    Random random = RandomUtils.getRandom();
-    double[][] M = new double[dataModel.getNumItems()][numFeatures];
-
-    LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
-    while (itemIDsIterator.hasNext()) {
-      long itemID = itemIDsIterator.nextLong();
-      int itemIDIndex = itemIndex(itemID);
-      M[itemIDIndex][0] = averateRating(itemID);
-      for (int feature = 1; feature < numFeatures; feature++) {
-        M[itemIDIndex][feature] = random.nextDouble() * 0.1;
-      }
-    }
-    return M;
-  }
-
-  protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) {
-    for (int feature = 0; feature < numFeatures; feature++) {
-      matrix[idIndex][feature] = vector.get(feature);
-    }
+  protected ExecutorService createQueue() {
+    return Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
   }
 
   protected Vector ratingVector(PreferenceArray prefs) {
@@ -139,14 +213,4 @@ public class ALSWRFactorizer extends Abs
     }
     return new DenseVector(ratings);
   }
-
-  protected double averateRating(long itemID) throws TasteException {
-    PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
-    RunningAverage avg = new FullRunningAverage();
-    for (Preference pref : prefs) {
-      avg.addDatum(pref.getValue());
-    }
-    return avg.getAverage();
-  }
-
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java?rev=1070622&r1=1070621&r2=1070622&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
(original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Mon Feb 14 19:40:15 2011
@@ -79,11 +79,12 @@ public class ALSWRFactorizerTest extends
 
   @Test
   public void setFeatureColumn() throws Exception {
-    double[][] matrix = new double[3][3];
+    ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(factorizer);
     Vector vector = new DenseVector(new double[] { 0.5, 2.0, 1.5 });
     int index = 1;
 
-    factorizer.setFeatureColumn(matrix, index, vector);
+    features.setFeatureColumnInM(index, vector);
+    double[][] matrix = features.getM();
 
     assertEquals(vector.get(0), matrix[index][0], EPSILON);
     assertEquals(vector.get(1), matrix[index][1], EPSILON);
@@ -104,12 +105,14 @@ public class ALSWRFactorizerTest extends
 
   @Test
   public void averageRating() throws Exception {
-    assertEquals(2.5, factorizer.averateRating(3l), EPSILON);
+    ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(factorizer);
+    assertEquals(2.5, features.averateRating(3l), EPSILON);
   }
 
   @Test
   public void initializeM() throws Exception {
-    double[][] M = factorizer.initializeM();
+    ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(factorizer);
+    double[][] M = features.getM();
 
     assertEquals(3.333333333, M[0][0], EPSILON);
     assertEquals(5, M[1][0], EPSILON);



Mime
View raw message