hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1397550 - in /hama/trunk: examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Date Fri, 12 Oct 2012 13:02:26 GMT
Author: tommaso
Date: Fri Oct 12 13:02:26 2012
New Revision: 1397550

URL: http://svn.apache.org/viewvc?rev=1397550&view=rev
Log:
[HAMA-651] - added iterations threshold

Modified:
    hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Modified: hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java?rev=1397550&r1=1397549&r2=1397550&view=diff
==============================================================================
--- hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
(original)
+++ hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
Fri Oct 12 13:02:26 2012
@@ -44,7 +44,8 @@ public class GradientDescentExample {
     // BSP job configuration
     HamaConfiguration conf = new HamaConfiguration();
     conf.setFloat(GradientDescentBSP.ALPHA, 0.002f);
-    conf.setFloat(GradientDescentBSP.THRESHOLD, 0.2f);
+    conf.setFloat(GradientDescentBSP.COST_THRESHOLD, 0.5f);
+    conf.setInt(GradientDescentBSP.ITERATIONS_THRESHOLD, 300);
 
     BSPJob bsp = new BSPJob(conf, GradientDescentExample.class);
     // Set the job name

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1397550&r1=1397549&r2=1397550&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Fri
Oct 12 13:02:26 2012
@@ -38,21 +38,24 @@ public class GradientDescentBSP extends 
   private static final Logger log = LoggerFactory.getLogger(GradientDescentBSP.class);
   public static final String INITIAL_THETA_VALUES = "gd.initial.theta";
   public static final String ALPHA = "gd.alpha";
-  public static final String THRESHOLD = "gd.threshold";
+  public static final String COST_THRESHOLD = "gd.cost.threshold";
+  public static final String ITERATIONS_THRESHOLD = "gd.iterations.threshold";
   public static final String REGRESSION_MODEL_CLASS = "gd.regression.model";
 
   private boolean master;
   private DoubleVector theta;
   private double cost;
-  private double threshold;
+  private double costThreshold;
   private float alpha;
   private RegressionModel regressionModel;
+  private int iterationsThreshold;
 
   @Override
   public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
     master = peer.getPeerIndex() == peer.getNumPeers() / 2;
     cost = Integer.MAX_VALUE;
-    threshold = peer.getConfiguration().getFloat(THRESHOLD, 0.1f);
+    costThreshold = peer.getConfiguration().getFloat(COST_THRESHOLD, 0.1f);
+    iterationsThreshold = peer.getConfiguration().getInt(ITERATIONS_THRESHOLD, 10000);
     alpha = peer.getConfiguration().getFloat(ALPHA, 0.003f);
     try {
       regressionModel = ((Class<? extends RegressionModel>) peer.getConfiguration().getClass(REGRESSION_MODEL_CLASS,
LinearRegressionModel.class)).newInstance();
@@ -63,7 +66,7 @@ public class GradientDescentBSP extends 
 
   @Override
   public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
-
+    int iterations = 0;
     while (true) {
 
       getTheta(peer);
@@ -109,7 +112,7 @@ public class GradientDescentBSP extends 
       if (cost - totalCost < 0) {
         throw new RuntimeException(new StringBuilder("gradient descent failed to converge
with alpha ").
                 append(alpha).toString());
-      } else if (totalCost == 0 || totalCost < threshold) {
+      } else if (totalCost == 0 || totalCost < costThreshold || iterations >= iterationsThreshold)
{
         cost = totalCost;
         break;
       } else {
@@ -168,8 +171,8 @@ public class GradientDescentBSP extends 
       peer.reopenInput();
       peer.sync();
 
+      iterations++;
     }
-
   }
 
   @Override



Mime
View raw message