mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1003753 - in /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd: AbstractOnlineLogisticRegression.java DefaultGradient.java Gradient.java MixedGradient.java RankingGradient.java
Date Sat, 02 Oct 2010 08:18:36 GMT
Author: tdunning
Date: Sat Oct  2 08:18:35 2010
New Revision: 1003753

URL: http://svn.apache.org/viewvc?rev=1003753&view=rev
Log:
Added ranking gradient implement

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=1003753&r1=1003752&r2=1003753&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Sat Oct  2 08:18:35 2010
@@ -163,11 +163,8 @@ public abstract class AbstractOnlineLogi
     // push coefficients back to zero based on the prior
     regularize(instance);
 
-    // what does the current model say?
-    Vector v = classify(instance);
-
     // update each row of coefficients according to result
-    Vector gradient = this.gradient.apply(groupKey, actual, v);
+    Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
     for (int i = 0; i < numCategories - 1; i++) {
       double gradientBase = gradient.get(i);
 
@@ -177,7 +174,7 @@ public abstract class AbstractOnlineLogi
         Vector.Element updateLocation = nonZeros.next();
         int j = updateLocation.index();
 
-        double newValue = beta.getQuick(i, j) + learningRate * gradientBase * instance.get(j)
* perTermLearningRate(j);
+        double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j)
* instance.get(j);
         beta.setQuick(i, j, newValue);
       }
     }
@@ -324,24 +321,4 @@ public abstract class AbstractOnlineLogi
     return k < 1;
   }
 
-  public static class DefaultGradient implements Gradient {
-    /**
-     * Provides a default gradient computation useful for logistic regression.  This
-     * can be over-ridden to incorporate AUC driven learning.
-     * <p>
-     * See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf
-     * @param groupKey     A grouping key to allow per-something AUC loss to be used for
training.
-     *@param actual       The target variable value.
-     * @param v            The current score vector.   @return
-     */
-    @Override
-    public final Vector apply(String groupKey, int actual, Vector v) {
-      Vector r = v.like();
-      if (actual != 0) {
-        r.setQuick(actual - 1, 1);
-      }
-      r.assign(v, Functions.MINUS);
-      return r;
-    }
-  }
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java?rev=1003753&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
(added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
Sat Oct  2 08:18:35 2010
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implements the basic logistic training law.
+ */
+public class DefaultGradient implements Gradient {
+  /**
+   * Provides a default gradient computation useful for logistic regression.  
+   *
+   * @param groupKey     A grouping key to allow per-something AUC loss to be used for training.
+   * @param actual       The target variable value.
+   * @param instance     The current feature vector to use for gradient computation
+   * @param classifier   The classifier that can compute scores
+   * @return  The gradient to be applied to beta
+   */
+  @Override
+  public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier
classifier) {
+    // what does the current model say?
+    Vector v = classifier.classify(instance);
+
+    Vector r = v.like();
+    if (actual != 0) {
+      r.setQuick(actual - 1, 1);
+    }
+    r.assign(v, Functions.MINUS);
+    return r;
+  }
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java?rev=1003753&r1=1003752&r2=1003753&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java Sat Oct
 2 08:18:35 2010
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.classifier.sgd;
 
+import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.math.Vector;
 
 /**
@@ -25,5 +26,5 @@ import org.apache.mahout.math.Vector;
  * normal loss function.
  */
 public interface Gradient {
-  Vector apply(String groupKey, int actual, Vector v);
+  Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier);
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java?rev=1003753&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java Sat
Oct  2 08:18:35 2010
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * Provides a stochastic mixture of ranking updates and normal logistic updates. This uses
a
+ * combination of AUC driven learning to improve ranking performance and traditional log-loss
driven
+ * learning to improve log-likelihood.
+ * <p/>
+ * See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf
+ */
+public class MixedGradient implements Gradient {
+  private double alpha;
+
+  private RankingGradient rank;
+  private Gradient basic;
+
+  Random random = RandomUtils.getRandom();
+
+  public MixedGradient(double alpha, int window) {
+    this.alpha = alpha;
+    this.rank = new RankingGradient(window);
+    this.basic = this.rank.getBaseGradient();
+  }
+
+  @Override
+  public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier
classifier) {
+    if (random.nextDouble() < alpha) {
+      // one option is to apply a ranking update relative to our recent history
+      return rank.apply(groupKey, actual, instance, classifier);
+    } else {
+      // the other option is a normal update, but we have to update our history on the way
+      rank.addToHistory(actual, instance);
+      return basic.apply(groupKey, actual, instance, classifier);
+    }
+  }
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java?rev=1003753&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
(added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
Sat Oct  2 08:18:35 2010
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * Uses the difference between this instance and recent history to get a
+ * gradient that optimizes ranking performance.  Essentially this is the
+ * same as directly optimizing AUC.  It isn't expected that this would
+ * be used alone, but rather that a MixedGradient would use it and a
+ * DefaultGradient together to combine both ranking and log-likelihood
+ * goals.
+ */
+public class RankingGradient implements Gradient {
+  private static final Gradient basic = new DefaultGradient();
+
+  private int window = 10;
+
+  private List<Deque<Vector>> history = Lists.newArrayList();
+
+  public RankingGradient(int window) {
+    this.window = window;
+  }
+
+  @Override
+  public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier
classifier) {
+    addToHistory(actual, instance);
+
+    // now compute average gradient versus saved vectors from the other side
+    Deque<Vector> otherSide = history.get(1 - actual);
+    int n = otherSide.size();
+
+    Vector r = null;
+    for (Vector other : otherSide) {
+      Vector g = basic.apply(groupKey, actual, instance.minus(other), classifier);
+
+      if (r == null) {
+        r = g;
+      } else {
+        r.assign(g, Functions.plusMult(1.0 / n));
+      }
+    }
+    return r;
+  }
+
+  public void addToHistory(int actual, Vector instance) {
+    // save this instance
+    Deque<Vector> ourSide = history.get(actual);
+    ourSide.add(instance);
+    if (ourSide.size() >= window) {
+      ourSide.pollFirst();
+    }
+  }
+
+  public Gradient getBaseGradient() {
+    return basic;
+  }
+}



Mime
View raw message