mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1090013 - in /mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1: ./ svd/
Date Thu, 07 Apr 2011 21:10:29 GMT
Author: ssc
Date: Thu Apr  7 21:10:28 2011
New Revision: 1090013

URL: http://svn.apache.org/viewvc?rev=1090013&view=rev
Log:
MAHOUT-657 Sample code to apply SVD to the KDD data

Added:
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
Modified:
    mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java?rev=1090013&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
Thu Apr  7 21:10:28 2011
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class EstimateConverter {
+
+  private static final Logger log = LoggerFactory.getLogger(EstimateConverter.class);
+
+  private EstimateConverter() {}
+
+  public static byte convert(double estimate, long userID, long itemID) {
+    if (Double.isNaN(estimate)) {
+      log.warn("Unable to compute estimate for user {}, item {}", userID, itemID);
+      return 0x7F;
+    } else {
+      int scaledEstimate = (int) (estimate * 2.55);
+      if (scaledEstimate > 255) {
+        scaledEstimate = 255;
+      } else if (scaledEstimate < 0) {
+        scaledEstimate = 0;
+      }
+      return (byte) scaledEstimate;
+    }
+  }
+}

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java?rev=1090013&r1=1090012&r2=1090013&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
Thu Apr  7 21:10:28 2011
@@ -54,19 +54,7 @@ final class Track1Callable implements Ca
         log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
         continue;
       }
-
-      if (Double.isNaN(estimate)) {
-        log.warn("Unable to compute estimate for user {}, item {}", userID, itemID);
-        result[i] = 0x7F;
-      } else {
-        int scaledEstimate = (int) (estimate * 2.55);
-        if (scaledEstimate > 255) {
-          scaledEstimate = 255;
-        } else if (scaledEstimate < 0) {
-          scaledEstimate = 0;
-        }
-        result[i] = (byte) scaledEstimate;
-      }
+      result[i] = EstimateConverter.convert(estimate, userID, itemID);
     }
 
     if (COUNT.incrementAndGet() % 10000 == 0) {

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java?rev=1090013&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
Thu Apr  7 21:10:28 2011
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * can be used to drop {@link DataModel}s into {@link ParallelArraysSGDFactorizer}
+ */
+public class DataModelFactorizablePreferences implements FactorizablePreferences {
+
+  private final FastIDSet userIDs;
+  private final FastIDSet itemIDs;
+
+  private final List<Preference> preferences;
+
+  private final float minPreference;
+  private final float maxPreference;
+
+  public DataModelFactorizablePreferences(DataModel dataModel) {
+
+    minPreference = dataModel.getMinPreference();
+    maxPreference = dataModel.getMaxPreference();
+
+    try {
+      userIDs = new FastIDSet(dataModel.getNumUsers());
+      itemIDs = new FastIDSet(dataModel.getNumItems());
+      preferences = new ArrayList<Preference>();
+
+      LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+      while (userIDsIterator.hasNext()) {
+        long userID = userIDsIterator.nextLong();
+        userIDs.add(userID);
+        for (Preference preference : dataModel.getPreferencesFromUser(userID)) {
+          itemIDs.add(preference.getItemID());
+          preferences.add(new GenericPreference(userID, preference.getItemID(), preference.getValue()));
+        }
+      }
+    } catch (Exception e) {
+      throw new IllegalStateException("Unable to create factorizable preferences!", e);
+    }
+  }
+
+  @Override
+  public LongPrimitiveIterator getUserIDs() {
+    return userIDs.iterator();
+  }
+
+  @Override
+  public LongPrimitiveIterator getItemIDs() {
+    return itemIDs.iterator();
+  }
+
+  @Override
+  public Iterable<Preference> getPreferences() {
+    return preferences;
+  }
+
+  @Override
+  public float getMinPreference() {
+    return minPreference;
+  }
+
+  @Override
+  public float getMaxPreference() {
+    return maxPreference;
+  }
+
+  @Override
+  public int numUsers() {
+    return userIDs.size();
+  }
+
+  @Override
+  public int numItems() {
+    return itemIDs.size();
+  }
+
+  @Override
+  public int numPreferences() {
+    return preferences.size();
+  }
+}
+

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java?rev=1090013&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
Thu Apr  7 21:10:28 2011
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.Preference;
+
+/**
+ * models the necessary input for {@link ParallelArraysSGDFactorizer}
+ */
+public interface FactorizablePreferences {
+
+  LongPrimitiveIterator getUserIDs();
+
+  LongPrimitiveIterator getItemIDs();
+
+  Iterable<Preference> getPreferences();
+
+  float getMinPreference();
+
+  float getMaxPreference();
+
+  int numUsers();
+
+  int numItems();
+
+  int numPreferences();
+
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java?rev=1090013&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
Thu Apr  7 21:10:28 2011
@@ -0,0 +1,159 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.Preference;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Iterator;
+
+public class KDDCupFactorizablePreferences implements FactorizablePreferences {
+
+  private final File dataFile;
+
+  public KDDCupFactorizablePreferences(File dataFile) {
+    this.dataFile = dataFile;
+  }
+
+  @Override
+  public LongPrimitiveIterator getUserIDs() {
+    return new FixedSizeLongIterator(numUsers());
+  }
+
+  @Override
+  public LongPrimitiveIterator getItemIDs() {
+    return new FixedSizeLongIterator(numItems());
+  }
+
+  @Override
+  public Iterable<Preference> getPreferences() {
+    return new Iterable<Preference>() {
+      @Override
+      public Iterator<Preference> iterator() {
+        try {
+          return new DataFilePreferencesIterator(new DataFileIterator(dataFile));
+        } catch (IOException e) {
+          throw new IllegalStateException("Cannot iterate over datafile!", e);
+        }
+      }
+    };
+  }
+
+  @Override
+  public float getMinPreference() {
+    return 0;
+  }
+
+  @Override
+  public float getMaxPreference() {
+    return 100;
+  }
+
+  @Override
+  public int numUsers() {
+    return 1000990;
+  }
+
+  @Override
+  public int numItems() {
+    return 624961;
+  }
+
+  @Override
+  public int numPreferences() {
+    return 252800275;
+  }
+
+  static class DataFilePreferencesIterator implements Iterator<Preference> {
+
+    private final DataFileIterator dataFileIterator;
+
+    Iterator<Preference> currentUserPrefsIterator;
+
+    public DataFilePreferencesIterator(DataFileIterator dataFileIterator) {
+      this.dataFileIterator = dataFileIterator;
+    }
+
+    @Override
+    public boolean hasNext() {
+      if (currentUserPrefsIterator != null && currentUserPrefsIterator.hasNext())
{
+        return true;
+      } else {
+        return dataFileIterator.hasNext();
+      }
+    }
+
+    @Override
+    public Preference next() {
+      if (currentUserPrefsIterator == null || !currentUserPrefsIterator.hasNext()) {
+        currentUserPrefsIterator = dataFileIterator.next().getFirst().iterator();
+      }
+      return currentUserPrefsIterator.next();
+    }
+
+    @Override
+    public void remove() {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+  static class FixedSizeLongIterator implements LongPrimitiveIterator {
+
+    private long currentValue;
+    private final long maximum;
+
+    public FixedSizeLongIterator(long maximum) {
+      this.maximum = maximum;
+      currentValue = 0;
+    }
+
+    @Override
+    public long nextLong() {
+      return currentValue++;
+    }
+
+    @Override
+    public long peek() {
+      return currentValue;
+    }
+
+    @Override
+    public void skip(int n) {
+      currentValue += n;
+    }
+
+    @Override
+    public boolean hasNext() {
+      return currentValue < maximum;
+    }
+
+    @Override
+    public Long next() {
+      return ++currentValue;
+    }
+
+    @Override
+    public void remove() {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java?rev=1090013&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
Thu Apr  7 21:10:28 2011
@@ -0,0 +1,257 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/**
+ * {@link org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer} based on Simon Funk's
famous article "Netflix Update: Try this at home"
+ * {@see http://sifter.org/~simon/journal/20061211.html}.
+ *
+ * Attempts to be as memory efficient as possible, only iterating once through the {@link
FactorizablePreferences} or {@link DataModel} while
+ * copying everything to primitive arrays. Learning works in place on these datastructures
after that.
+ *
+ */
+public class ParallelArraysSGDFactorizer implements Factorizer {
+
+  public static final double DEFAULT_LEARNING_RATE = 0.005;
+  public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
+  public static final double DEFAULT_RANDOM_NOISE = 0.005;
+
+  private final int numFeatures;
+  private final int numIterations;
+  private final float minPreference;
+  private final float maxPreference;
+
+  private final Random random;
+  private final double learningRate;
+  private final double preventOverfitting;
+
+  private final FastByIDMap<Integer> userIDMapping;
+  private final FastByIDMap<Integer> itemIDMapping;
+
+  private final double[][] userFeatures;
+  private final double[][] itemFeatures;
+
+  private final int[] userIndexes;
+  private final int[] itemIndexes;
+  private final float[] values;
+
+  private final double defaultValue;
+  private final double interval;
+  private final double[] cachedEstimates;
+
+
+  private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);
+
+  public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations)
{
+    this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE,
+        DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE);
+  }
+
+  public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations,
double learningRate,
+                                     double preventOverfitting, double randomNoise) {
+    this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate,
preventOverfitting,
+        randomNoise);
+  }
+
+  public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures,
int numIterations) {
+    this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING,
+        DEFAULT_RANDOM_NOISE);
+  }
+
+  public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int
numFeatures,
+      int numIterations, double learningRate, double preventOverfitting, double randomNoise)
{
+
+    this.numFeatures = numFeatures;
+    this.numIterations = numIterations;
+    minPreference = factorizablePreferences.getMinPreference();
+    maxPreference = factorizablePreferences.getMaxPreference();
+
+    this.random = RandomUtils.getRandom();
+    this.learningRate = learningRate;
+    this.preventOverfitting = preventOverfitting;
+
+    int numUsers = factorizablePreferences.numUsers();
+    int numItems = factorizablePreferences.numItems();
+    int numPrefs = factorizablePreferences.numPreferences();
+
+    log.info("Mapping {} users...", numUsers);
+    userIDMapping = new FastByIDMap<Integer>(numUsers);
+    int index = 0;
+    LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
+    while (userIterator.hasNext()) {
+      userIDMapping.put(userIterator.nextLong(), index++);
+    }
+
+    log.info("Mapping {} items", numItems);
+    itemIDMapping = new FastByIDMap<Integer>(numItems);
+    index = 0;
+    LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
+    while (itemIterator.hasNext()) {
+      itemIDMapping.put(itemIterator.nextLong(), index++);
+    }
+
+    this.userIndexes = new int[numPrefs];
+    this.itemIndexes = new int[numPrefs];
+    this.values = new float[numPrefs];
+    this.cachedEstimates = new double[numPrefs];
+
+    index = 0;
+    log.info("Loading {} preferences into memory", numPrefs);
+    RunningAverage average = new FullRunningAverage();
+    for (Preference preference : factorizablePreferences.getPreferences()) {
+      userIndexes[index] = userIDMapping.get(preference.getUserID());
+      itemIndexes[index] = itemIDMapping.get(preference.getItemID());
+      values[index] = preference.getValue();
+      cachedEstimates[index] = 0;
+
+      average.addDatum(preference.getValue());
+
+      index++;
+      if (index % 1000000 == 0) {
+        log.info("Processed {} preferences", index);
+      }
+    }
+    log.info("Processed {} preferences, done.", index);
+
+    double averagePreference = average.getAverage();
+    log.info("Average preference value is {}", averagePreference);
+
+    double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
+    defaultValue = Math.sqrt((averagePreference - (prefInterval * 0.1)) / numFeatures);
+    interval = (prefInterval * 0.1) / numFeatures;
+
+    userFeatures = new double[numUsers][numFeatures];
+    itemFeatures = new double[numItems][numFeatures];
+
+    log.info("Initializing feature vectors...");
+    for (int feature = 0; feature < numFeatures; feature++) {
+      for (int userIndex = 0; userIndex < numUsers; userIndex++) {
+        userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval
* randomNoise;
+      }
+      for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
+        itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval
* randomNoise;
+      }
+    }
+  }
+
+  @Override
+  public Factorization factorize() throws TasteException {
+    for (int feature = 0; feature < numFeatures; feature++) {
+      log.info("Shuffling preferences...");
+      shufflePreferences();
+     log.info("Starting training of feature {} ...", feature);
+      for (int currentIteration = 0; currentIteration < numIterations; currentIteration++)
{
+        if (currentIteration != (numIterations - 1)) {
+          trainingIteration(feature);
+        } else {
+          double rmse = trainingIterationWithRmse(feature);
+          log.info("Finished training feature {} with RMSE {}", feature, rmse);
+        }
+      }
+      if (feature < numFeatures - 1) {
+        log.info("Updating cache...");
+        for (int index = 0; index < userIndexes.length; index++) {
+          cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature,
cachedEstimates[index],
+              false);
+        }
+      }
+    }
+    log.info("Factorization done");
+    return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+  }
+
+  private void trainingIteration(int feature) {
+    for (int index = 0; index < userIndexes.length; index++) {
+      train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
+    }
+  }
+
+  private double trainingIterationWithRmse(int feature) {
+    double rmse = 0;
+    for (int index = 0; index < userIndexes.length; index++) {
+      double error = train(userIndexes[index], itemIndexes[index], feature, values[index],
cachedEstimates[index]);
+      rmse += (error * error);
+    }
+    return Math.sqrt(rmse / (double) userIndexes.length);
+  }
+
+  private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate,
boolean trailing) {
+    double sum = cachedEstimate;
+    sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature];
+    if (trailing) {
+      sum += (numFeatures - feature - 1) * ((defaultValue + interval) * (defaultValue + interval));
+      if (sum > maxPreference) {
+        sum = maxPreference;
+      } else if (sum < minPreference) {
+        sum = minPreference;
+      }
+    }
+    return sum;
+  }
+
+  public double train(int userIndex, int itemIndex, int feature, double original, double
cachedEstimate) {
+    double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true);
+    double[] userVector = userFeatures[userIndex];
+    double[] itemVector = itemFeatures[itemIndex];
+
+    userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting
* userVector[feature]);
+    itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting
* itemVector[feature]);
+
+    return error;
+  }
+
+  protected void shufflePreferences() {
+    /* Durstenfeld shuffle */
+    for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) {
+      int swapPos = random.nextInt(currentPos + 1);
+      swapPreferences(currentPos, swapPos);
+    }
+  }
+
+  private void swapPreferences(int posA, int posB) {
+    int tmpUserIndex = userIndexes[posA];
+    int tmpItemIndex = itemIndexes[posA];
+    float tmpValue = values[posA];
+    double tmpEstimate = cachedEstimates[posA];
+
+    userIndexes[posA] = userIndexes[posB];
+    itemIndexes[posA] = itemIndexes[posB];
+    values[posA] = values[posB];
+    cachedEstimates[posA] = cachedEstimates[posB];
+
+    userIndexes[posB] = tmpUserIndex;
+    itemIndexes[posB] = tmpItemIndex;
+    values[posB] = tmpValue;
+    cachedEstimates[posB] = tmpEstimate;
+  }
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java?rev=1090013&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
Thu Apr  7 21:10:28 2011
@@ -0,0 +1,121 @@
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+
+/**
+ * run an SVD factorization of the KDD track1 data.
+ *
+ * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M
+ *
+ */
+public class Track1SVDRunner {
+
+  private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class);
+
+  public static void main(String[] args) throws Exception {
+
+    if (args.length != 2) {
+      System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>");
+      System.exit(-1);
+    }
+
+    File dataFileDirectory = new File(args[0]);
+    if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+      throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+    }
+
+    File resultFile = new File(args[1]);
+
+    /* the knobs to turn */
+    int numFeatures = 20;
+    int numIterations = 5;
+    double learningRate = 0.0001;
+    double preventOverfitting = 0.002;
+    double randomNoise = 0.0001;
+
+
+    KDDCupFactorizablePreferences factorizablePreferences =
+        new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+
+    Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures,
numIterations,
+        learningRate, preventOverfitting, randomNoise);
+
+    Factorization factorization = sgdFactorizer.factorize();
+
+    log.info("Estimating validation preferences...");
+    int prefsProcessed = 0;
+    RunningAverage average = new FullRunningAverage();
+    DataFileIterable validations = new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory));
+    for (Pair<PreferenceArray,long[]> validationPair : validations) {
+      for (Preference validationPref : validationPair.getFirst()) {
+        double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(),
+            factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
+        double error = validationPref.getValue() - estimate;
+        average.addDatum(error * error);
+        prefsProcessed++;
+        if (prefsProcessed % 100000 == 0) {
+          log.info("Computed {} estimations", prefsProcessed);
+        }
+      }
+    }
+    log.info("Computed {} estimations, done.", prefsProcessed);
+
+    double rmse = Math.sqrt(average.getAverage());
+    log.info("RMSE {}", rmse);
+
+    log.info("Estimating test preferences...");
+    OutputStream out = null;
+    try {
+      out = new BufferedOutputStream(new FileOutputStream(resultFile));
+
+      DataFileIterable tests = new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory));
+      for (Pair<PreferenceArray,long[]> testPair : tests) {
+        for (Preference testPref : testPair.getFirst()) {
+          double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(),
+              factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
+          byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID());
+          out.write(result);
+        }
+      }
+    } finally {
+      out.flush();
+      out.close();
+    }
+    log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath());
+  }
+
+  static double estimatePreference(Factorization factorization, long userID, long itemID,
float minPreference,
+      float maxPreference) throws NoSuchUserException, NoSuchItemException {
+    double[] userFeatures = factorization.getUserFeatures(userID);
+    double[] itemFeatures = factorization.getItemFeatures(itemID);
+    double estimate = 0;
+    for (int feature = 0; feature < userFeatures.length; feature++) {
+      estimate += userFeatures[feature] * itemFeatures[feature];
+    }
+    if (estimate < minPreference) {
+      estimate = minPreference;
+    } else if (estimate > maxPreference) {
+      estimate = maxPreference;
+    }
+    return estimate;
+  }
+
+}



Mime
View raw message