hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1397520 - /hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Date Fri, 12 Oct 2012 11:46:34 GMT
Author: tommaso
Date: Fri Oct 12 11:46:34 2012
New Revision: 1397520

URL: http://svn.apache.org/viewvc?rev=1397520&view=rev
Log:
[HAMA-651] - adjusting defaults a bit, plus fixing input reopenings, theta initialization
and derivatives aggregation

Modified:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1397520&r1=1397519&r2=1397520&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Fri
Oct 12 11:46:34 2012
@@ -52,8 +52,8 @@ public class GradientDescentBSP extends 
   public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
     master = peer.getPeerIndex() == peer.getNumPeers() / 2;
     cost = Integer.MAX_VALUE;
-    threshold = peer.getConfiguration().getFloat(THRESHOLD, 0.01f);
-    alpha = peer.getConfiguration().getFloat(ALPHA, 0.3f);
+    threshold = peer.getConfiguration().getFloat(THRESHOLD, 0.1f);
+    alpha = peer.getConfiguration().getFloat(ALPHA, 0.003f);
     try {
       regressionModel = ((Class<? extends RegressionModel>) peer.getConfiguration().getClass(REGRESSION_MODEL_CLASS,
LinearRegressionModel.class)).newInstance();
     } catch (Exception e) {
@@ -77,7 +77,6 @@ public class GradientDescentBSP extends 
       // read an input
       KeyValuePair<VectorWritable, DoubleWritable> kvp;
       while ((kvp = peer.readNext()) != null) {
-
         // calculate cost for given input
         double y = kvp.getValue().get();
         DoubleVector x = kvp.getKey().getVector();
@@ -89,15 +88,15 @@ public class GradientDescentBSP extends 
       }
 
       // cost is sent and aggregated by each
-      double totalCost = localCost;
-
       for (String peerName : peer.getAllPeerNames()) {
-        peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{localCost,
numRead})));
+        if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
+          peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{localCost,
numRead})));
+        }
       }
       peer.sync();
 
       // second superstep : aggregate cost calculation
-
+      double  totalCost = localCost;
       VectorWritable costResult;
       while ((costResult = peer.getCurrentMessage()) != null) {
         totalCost += costResult.getVector().get(0);
@@ -106,27 +105,23 @@ public class GradientDescentBSP extends 
 
       totalCost /= numRead; // TODO : remove this and incorporate the 1/m element in RegressionModel#calculateCostForItem
 
+      // cost check
       if (cost - totalCost < 0) {
         throw new RuntimeException("gradient descent failed to converge with alpha " + alpha);
-      } else if (totalCost == 0 || cost - totalCost < threshold) {
+      } else if (totalCost == 0 || totalCost < threshold) {
+        log.info(peer.getPeerName()+": finishing!");
         cost = totalCost;
         break;
       } else {
         cost = totalCost;
+        if (log.isInfoEnabled()) {
+          log.info(peer.getPeerName()+": cost is " + cost);
+        }
       }
 
-
-      if (log.isInfoEnabled()) {
-        log.info("cost is " + cost);
-      }
-
-
+      peer.reopenInput();
       peer.sync();
 
-      if (master) { // TODO : check if this has to be done only by the master
-        peer.reopenInput();
-      }
-
       double[] thetaDelta = new double[theta.getLength()];
 
       // third superstep : calculate partial derivatives' deltas in parallel
@@ -148,8 +143,8 @@ public class GradientDescentBSP extends 
 
       // fourth superstep : aggregate partial derivatives
       VectorWritable thetaDeltaSlice;
+      double[] newTheta = thetaDelta;
       while ((thetaDeltaSlice = peer.getCurrentMessage()) != null) {
-        double[] newTheta = new double[theta.getLength()];
 
         for (int j = 0; j < theta.getLength(); j++) {
           newTheta[j] += thetaDeltaSlice.getVector().get(j);
@@ -158,17 +153,18 @@ public class GradientDescentBSP extends 
         for (int j = 0; j < theta.getLength(); j++) {
           newTheta[j] = theta.get(j) - newTheta[j] * alpha;
         }
+      }
+      theta = new DenseDoubleVector(newTheta);
 
-        theta = new DenseDoubleVector(newTheta);
-
-        if (log.isInfoEnabled()) {
-          log.info("new theta for cost " + cost + " is " + theta.toArray().toString());
-        }
-        // master writes down the output
-        if (master) {
-          peer.write(new VectorWritable(theta), new DoubleWritable(cost));
-        }
+      if (log.isInfoEnabled()) {
+        log.info(peer.getPeerName()+": new theta for cost " + cost + " is " + theta.toString());
+      }
+      // master writes down the output
+      if (master) {
+        peer.write(new VectorWritable(theta), new DoubleWritable(cost));
       }
+
+      peer.reopenInput();
       peer.sync();
 
     }
@@ -178,7 +174,7 @@ public class GradientDescentBSP extends 
   @Override
   public void cleanup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException {
     if (log.isInfoEnabled()) {
-        log.info("computation finished with cost " + cost + " for theta " + theta);
+        log.info(peer.getPeerName()+":computation finished with cost " + cost + " for theta
" + theta);
     }
     // master writes down the final output
     if (master) {
@@ -187,23 +183,28 @@ public class GradientDescentBSP extends 
   }
 
   public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
-    if (master && theta == null) {
-      int size = getXSize(peer);
-      theta = new DenseDoubleVector(size, peer.getConfiguration().getInt(INITIAL_THETA_VALUES,
10));
-      for (String peerName : peer.getAllPeerNames()) {
-        peer.send(peerName, new VectorWritable(theta));
-      }
-      peer.sync();
-    } else {
-      peer.sync();
-      VectorWritable vectorWritable = peer.getCurrentMessage();
-      theta = vectorWritable.getVector();
+    if (theta == null) {
+        if (master) {
+            int size = getXSize(peer);
+            theta = new DenseDoubleVector(size, peer.getConfiguration().getInt(INITIAL_THETA_VALUES,
10));
+            for (String peerName : peer.getAllPeerNames()) {
+                peer.send(peerName, new VectorWritable(theta));
+            }
+            log.info(peer.getPeerName() + ": sending theta");
+            peer.sync();
+        } else {
+            log.info(peer.getPeerName() + ": getting theta");
+            peer.sync();
+            VectorWritable vectorWritable = peer.getCurrentMessage();
+            theta = vectorWritable.getVector();
+        }
     }
   }
 
   private int getXSize(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException {
-    VectorWritable key = null;
-    peer.readNext(key, null);
+    VectorWritable key = new VectorWritable();
+    DoubleWritable value = new DoubleWritable();
+    peer.readNext(key, value);
     peer.reopenInput(); // reset input to start
     if (key == null) {
       throw new IOException("cannot read input vector size");



Mime
View raw message