mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r998242 - /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Date Fri, 17 Sep 2010 18:41:22 GMT
Author: tdunning
Date: Fri Sep 17 18:41:21 2010
New Revision: 998242

URL: http://svn.apache.org/viewvc?rev=998242&view=rev
Log:
Separated classifier from link function.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=998242&r1=998241&r2=998242&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Fri Sep 17 18:41:21 2010
@@ -74,15 +74,38 @@ public abstract class AbstractOnlineLogi
     return this;
   }
 
-  private Vector logisticLink(Vector v) {
+  /**
+   * Computes the inverse link function, by default the logistic link function.
+   *
+   * @param v  The output of the linear combination in a GLM.  Note that the value
+   * of v is disturbed.
+   * @return A version of v with the link function applied.
+   */
+  public Vector link(Vector v) {
     double max = v.maxValue();
-    if (max < 40) {
+    if (max >= 40) {
+      // if max > 40, we subtract the large offset first
+      // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+      v.assign(Functions.minus(max)).assign(Functions.EXP);
+      return v.divide(v.norm(1));
+    } else {
       v.assign(Functions.EXP);
-      double sum = 1 + v.norm(1);
-      return v.divide(sum);
+      return v.divide(1 + v.norm(1));
+    }
+  }
+
+  /**
+   * Computes the binomial logistic inverse link function.
+   * @param r  The value to transform.
+   * @return   The logit of r.
+   */
+  public double link(double r){
+    if (r < 0) {
+      double s = Math.exp(r);
+      return s / (1 + s);
     } else {
-      v.assign(Functions.minus(max)).assign(Functions.EXP);
-      return v;
+      double s = Math.exp(-r);
+      return 1 / (1 + s);
     }
   }
 
@@ -92,6 +115,10 @@ public abstract class AbstractOnlineLogi
     return beta.times(instance);
   }
 
+  public double classifyScalarNoLink(Vector instance) {
+    return beta.getRow(0).dot(instance);
+  }
+
   /**
    * Returns n-1 probabilities, one for each category but the 0-th.  The probability of the
0-th
    * category is 1 - sum(this result).
@@ -100,7 +127,7 @@ public abstract class AbstractOnlineLogi
    * @return A vector of probabilities, one for each of the first n-1 categories.
    */
   public Vector classify(Vector instance) {
-    return logisticLink(classifyNoLink(instance));
+    return link(classifyNoLink(instance));
   }
 
   /**
@@ -121,8 +148,7 @@ public abstract class AbstractOnlineLogi
     regularize(instance);
 
     // result is a vector with one element so we can just use dot product
-    double r = Math.exp(beta.getRow(0).dot(instance));
-    return r / (1 + r);
+    return link(classifyScalarNoLink(instance));
   }
 
   @Override



Mime
View raw message