mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From s..@apache.org
Subject svn commit: r1455231 - /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Date Mon, 11 Mar 2013 16:55:10 GMT
Author: ssc
Date: Mon Mar 11 16:55:09 2013
New Revision: 1455231

URL: http://svn.apache.org/r1455231
Log:
MAHOUT-1093 CrossFoldLearner trains in all folds if trackign key is negative

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=1455231&r1=1455230&r2=1455231&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Mon Mar 11 16:55:09 2013
@@ -123,7 +123,7 @@ public class CrossFoldLearner extends Ab
     record++;
     int k = 0;
     for (OnlineLogisticRegression model : models) {
-      if (k == trackingKey % models.size()) {
+      if (k == mod(trackingKey, models.size())) {
         Vector v = model.classifyFull(instance);
         double score = Math.max(v.get(actual), MIN_SCORE);
         logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);
@@ -140,6 +140,11 @@ public class CrossFoldLearner extends Ab
     }
   }
 
+  private long mod(long x, int y) {
+    long r = x % y;
+    return r < 0 ? r + y : r;
+  }
+
   @Override
   public void close() {
     for (OnlineLogisticRegression m : models) {



Mime
View raw message