horn-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From edwardy...@apache.org
Subject [2/6] incubator-horn git commit: Add DistBeliefModelTrainer
Date Tue, 17 Nov 2015 08:22:44 GMT
Add DistBeliefModelTrainer

Project: http://git-wip-us.apache.org/repos/asf/incubator-horn/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-horn/commit/9141cede
Tree: http://git-wip-us.apache.org/repos/asf/incubator-horn/tree/9141cede
Diff: http://git-wip-us.apache.org/repos/asf/incubator-horn/diff/9141cede

Branch: refs/heads/master
Commit: 9141cede6ce44fda98531f4ff4e3c280e7a89349
Parents: 1187156
Author: Edward J. Yoon <edwardyoon@apache.org>
Authored: Mon Nov 9 17:12:56 2015 +0900
Committer: Edward J. Yoon <edwardyoon@apache.org>
Committed: Mon Nov 9 17:12:56 2015 +0900

----------------------------------------------------------------------
 .../horn/distbelief/DistBeliefModelTrainer.java | 87 ++++++++++++++++++++
 .../java/org/apache/horn/distbelief/Neuron.java |  5 ++
 .../org/apache/horn/distbelief/PropMessage.java |  3 +
 .../distbelief/TestDistBeliefModelTrainer.java  |  5 ++
 .../org/apache/horn/distbelief/TestNeuron.java  |  9 +-
 5 files changed, 106 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/9141cede/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java b/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java
new file mode 100644
index 0000000..a19139c
--- /dev/null
+++ b/src/main/java/org/apache/horn/distbelief/DistBeliefModelTrainer.java
@@ -0,0 +1,87 @@
+package org.apache.horn.distbelief;
+
+import java.io.IOException;
+
+import org.apache.hama.bsp.BSP;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+
+/**
+ * This DistBeliefModelTrainer performs each SGD. 
+ */
+public class DistBeliefModelTrainer extends BSP {
+
+  private boolean isConverge = false;
+  private int iterations;
+  private int maxIterations;
+  
+  @Override
+  public final void setup(BSPPeer peer) {
+    // loads subset of neural network model replica into memory
+  }
+  
+  @Override
+  public void bsp(BSPPeer peer) throws IOException, SyncException,
+      InterruptedException {
+
+    // Iterate until reach max iteration or convergence
+    while (this.iterations++ < maxIterations) {
+      
+      // Fetch latest parameters
+      fetchParameters(peer);
+      
+      // Perform mini-batch
+      doMinibatch(peer);
+      
+      // Push parameters
+      pushParameters(peer);
+      
+      if (this.isConverge) {
+        break;
+      }
+    }
+    
+  }
+
+  /**
+   * Performs the mini-batch
+   * @param peer
+   */
+  private void doMinibatch(BSPPeer peer) {
+    double avgTrainingError = 0.0;
+    // 1. loads a set of mini-batch instances from assigned splits into memory
+    
+    // 2. train incrementally from a mini-batch of instances
+    /*
+    for (Instance trainingInstance : MiniBatchSet) {
+      
+      // 2.1 upward propagation (start from the input layer)
+      for (Neuron neuron : neurons) {  
+        neuron.upward(msg);
+        sync();
+      }
+        
+      // calculate total error
+      sync();
+      
+      // 2.3 downward propagation (start from the total error)
+      for (Neuron neuron : neurons) {  
+        neuron.downward(msg);
+        sync();
+      }
+    
+      // calculate the the average training error
+    }
+    */
+    
+  }
+  
+  private void fetchParameters(BSPPeer peer) {
+    // TODO fetch latest weights from the parameter server
+  }
+
+  private void pushParameters(BSPPeer peer) {
+    // TODO push updated weights     
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/9141cede/src/main/java/org/apache/horn/distbelief/Neuron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/distbelief/Neuron.java b/src/main/java/org/apache/horn/distbelief/Neuron.java
index ce67cf2..fadb522 100644
--- a/src/main/java/org/apache/horn/distbelief/Neuron.java
+++ b/src/main/java/org/apache/horn/distbelief/Neuron.java
@@ -23,6 +23,10 @@ public abstract class Neuron<M extends Writable> implements NeuronInterface<M>
{
   double output;
   double weight;
 
+  public void propagate(double gradient) {
+    // TODO Auto-generated method stub
+  }
+
   public void setOutput(double output) {
     this.output = output;
   }
@@ -32,6 +36,7 @@ public abstract class Neuron<M extends Writable> implements NeuronInterface<M>
{
   }
 
   public void push(double weight) {
+    // TODO Auto-generated method stub
     this.weight = weight;
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/9141cede/src/main/java/org/apache/horn/distbelief/PropMessage.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/distbelief/PropMessage.java b/src/main/java/org/apache/horn/distbelief/PropMessage.java
index dd6f2b1..029cd6a 100644
--- a/src/main/java/org/apache/horn/distbelief/PropMessage.java
+++ b/src/main/java/org/apache/horn/distbelief/PropMessage.java
@@ -37,6 +37,9 @@ public class PropMessage<M extends Writable, W extends Writable> implements
     this.weight = weight;
   }
 
+  /**
+   * @return the activation or error message
+   */
   public M getMessage() {
     return message;
   }

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/9141cede/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java b/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java
new file mode 100644
index 0000000..5bbd90c
--- /dev/null
+++ b/src/test/java/org/apache/horn/distbelief/TestDistBeliefModelTrainer.java
@@ -0,0 +1,5 @@
+package org.apache.horn.distbelief;
+
+public class TestDistBeliefModelTrainer {
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/9141cede/src/test/java/org/apache/horn/distbelief/TestNeuron.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/horn/distbelief/TestNeuron.java b/src/test/java/org/apache/horn/distbelief/TestNeuron.java
index 37e8fd6..9af1315 100644
--- a/src/test/java/org/apache/horn/distbelief/TestNeuron.java
+++ b/src/test/java/org/apache/horn/distbelief/TestNeuron.java
@@ -28,6 +28,8 @@ import org.apache.hama.commons.math.Sigmoid;
 
 public class TestNeuron extends TestCase {
   private static double learningRate = 0.1;
+  private static double bias = -1;
+  private static double theta = 0.8;
 
   public static class MyNeuron extends
       Neuron<PropMessage<DoubleWritable, DoubleWritable>> {
@@ -40,24 +42,24 @@ public class TestNeuron extends TestCase {
       for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
         sum += m.getMessage().get() * m.getWeight().get();
       }
-      sum += (-1 * 0.8);
+      sum += (bias * theta);
 
       double output = new Sigmoid().apply(sum);
       this.setOutput(output);
+      this.propagate(output);
     }
 
     @Override
     public void downward(
         Iterable<PropMessage<DoubleWritable, DoubleWritable>> messages)
         throws IOException {
-
       for (PropMessage<DoubleWritable, DoubleWritable> m : messages) {
         // Calculates error gradient for each neuron
         double gradient = this.getOutput() * (1 - this.getOutput())
             * m.getMessage().get() * m.getWeight().get();
 
         // Propagates to lower layer
-        System.out.println(gradient);
+        this.propagate(gradient);
 
         // Weight corrections
         double weight = learningRate * this.getOutput() * m.getMessage().get();
@@ -84,4 +86,5 @@ public class TestNeuron extends TestCase {
     n.downward(x);
     assertEquals(-0.006688234848481696, n.getUpdate());
   }
+  
 }


Mime
View raw message