hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1402944 - in /hama/trunk/ml/src/main/java/org/apache/hama/ml/regression: CostFunction.java GradientDescentBSP.java LinearRegressionModel.java LogisticRegressionModel.java RegressionModel.java
Date Sun, 28 Oct 2012 07:14:28 GMT
Author: tommaso
Date: Sun Oct 28 07:14:27 2012
New Revision: 1402944

URL: http://svn.apache.org/viewvc?rev=1402944&view=rev
Log:
[HAMA-660] - total number of items to read is pre calculated and passed to the CostFunction

Modified:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java Sun Oct 28
07:14:27 2012
@@ -30,10 +30,11 @@ public interface CostFunction {
    *
    * @param x          the input vector
    * @param y          the learned output for x
+   * @param m          the number of existing items
    * @param theta      the parameters vector theta
    * @param hypothesis the hypothesis function to model the problem
    * @return the calculated cost for input x and output y
    */
-  public double calculateCostForItem(DoubleVector x, double y, DoubleVector theta, HypothesisFunction
hypothesis);
+  public double calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta,
HypothesisFunction hypothesis);
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Sun
Oct 28 07:14:27 2012
@@ -49,6 +49,7 @@ public class GradientDescentBSP extends 
   private float alpha;
   private RegressionModel regressionModel;
   private int iterationsThreshold;
+  private int m;
 
   @Override
   public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
@@ -66,6 +67,30 @@ public class GradientDescentBSP extends 
 
   @Override
   public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
+    // 0 superstep : count items
+
+    int itemCount = 0;
+    while (peer.readNext() != null) {
+      // increment counter
+      itemCount++;
+    }
+    for (String peerName : peer.getAllPeerNames()) {
+      if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
+        peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{itemCount})));
+      }
+    }
+    peer.sync();
+
+    // aggregate number of items
+    VectorWritable itemsResult;
+    while ((itemsResult = peer.getCurrentMessage()) != null) {
+      itemCount += itemsResult.getVector().get(0);
+    }
+
+    m = itemCount;
+
+    peer.reopenInput();
+
     int iterations = 0;
     while (true) {
 
@@ -75,25 +100,22 @@ public class GradientDescentBSP extends 
 
       double localCost = 0d;
 
-      int numRead = 0;
-
-      // read an input
+      // read an item
       KeyValuePair<VectorWritable, DoubleWritable> kvp;
       while ((kvp = peer.readNext()) != null) {
         // calculate cost for given input
         double y = kvp.getValue().get();
         DoubleVector x = kvp.getKey().getVector();
-        double costForX = regressionModel.calculateCostForItem(x, y, theta);
+        double costForX = regressionModel.calculateCostForItem(x, y, m, theta);
 
         // adds to local cost
         localCost += costForX;
-        numRead++;
       }
 
       // cost is sent and aggregated by each
       for (String peerName : peer.getAllPeerNames()) {
         if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
-          peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{localCost,
numRead})));
+          peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{localCost})));
         }
       }
       peer.sync();
@@ -103,11 +125,8 @@ public class GradientDescentBSP extends 
       VectorWritable costResult;
       while ((costResult = peer.getCurrentMessage()) != null) {
         totalCost += costResult.getVector().get(0);
-        numRead += costResult.getVector().get(1);
       }
 
-      totalCost /= numRead; // TODO : remove this and incorporate the 1/m element in RegressionModel#calculateCostForItem
-
       // cost check
       if (cost - totalCost < 0) {
         throw new RuntimeException(new StringBuilder("gradient descent failed to converge
with alpha ").

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java Sun
Oct 28 07:14:27 2012
@@ -29,8 +29,8 @@ public class LinearRegressionModel imple
   public LinearRegressionModel() {
     costFunction = new CostFunction() {
       @Override
-      public double calculateCostForItem(DoubleVector x, double y, DoubleVector theta, HypothesisFunction
hypothesis) {
-        return y * Math.pow(applyHypothesis(theta, x) - y, 2) / 2;
+      public double calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta,
HypothesisFunction hypothesis) {
+        return y * Math.pow(applyHypothesis(theta, x) - y, 2) / (2 * m);
       }
     };
   }
@@ -41,7 +41,7 @@ public class LinearRegressionModel imple
   }
 
   @Override
-  public double calculateCostForItem(DoubleVector x, double y, DoubleVector theta) {
-    return costFunction.calculateCostForItem(x, y, theta, this);
+  public double calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta)
{
+    return costFunction.calculateCostForItem(x, y, m, theta, this);
   }
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Sun Oct 28 07:14:27 2012
@@ -29,8 +29,8 @@ public class LogisticRegressionModel imp
   public LogisticRegressionModel() {
     costFunction = new CostFunction() {
       @Override
-      public double calculateCostForItem(DoubleVector x, double y, DoubleVector theta, HypothesisFunction
hypothesis) {
-        return -1 * y * Math.log(applyHypothesis(theta, x)) + (1 - y) * Math.log(1 - applyHypothesis(theta,
x));
+      public double calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta,
HypothesisFunction hypothesis) {
+        return (-1 * y * Math.log(applyHypothesis(theta, x)) + (1 - y) * Math.log(1 - applyHypothesis(theta,
x))) / m;
       }
     };
   }
@@ -41,7 +41,7 @@ public class LogisticRegressionModel imp
   }
 
   @Override
-  public double calculateCostForItem(DoubleVector x, double y, DoubleVector theta) {
-    return costFunction.calculateCostForItem(x, y, theta, this);
+  public double calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta)
{
+    return costFunction.calculateCostForItem(x, y, m, theta, this);
   }
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java Sun Oct
28 07:14:27 2012
@@ -30,9 +30,10 @@ public interface RegressionModel extends
    *
    * @param x     the input vector
    * @param y     the learned output for x
+   * @param m     the total number of existing items
    * @param theta the parameters vector theta
    * @return the calculated cost for input x and output y
    */
-  public double calculateCostForItem(DoubleVector x, double y, DoubleVector theta);
+  public double calculateCostForItem(DoubleVector x, double y, int m, DoubleVector theta);
 
 }



Mime
View raw message