horn-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zjaf...@apache.org
Subject [3/3] incubator-horn git commit: HORN-8: Implement asynchronous parameter server
Date Tue, 02 Feb 2016 23:59:05 GMT
HORN-8: Implement asynchronous parameter server


Project: http://git-wip-us.apache.org/repos/asf/incubator-horn/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-horn/commit/91c0c796
Tree: http://git-wip-us.apache.org/repos/asf/incubator-horn/tree/91c0c796
Diff: http://git-wip-us.apache.org/repos/asf/incubator-horn/diff/91c0c796

Branch: refs/heads/master
Commit: 91c0c796e76303a0e3cf27606fbc10a03d05ed0e
Parents: 8f412c6
Author: Lee Dongjin <dongjin.lee.kr@gmail.com>
Authored: Tue Feb 2 00:15:22 2016 +0900
Committer: Lee Dongjin <dongjin.lee.kr@gmail.com>
Committed: Tue Feb 2 00:16:29 2016 +0900

----------------------------------------------------------------------
 .../org/apache/horn/bsp/ParameterMerger.java    |  10 ++
 .../apache/horn/bsp/ParameterMergerServer.java  |  97 +++++++++++
 .../bsp/SmallLayeredNeuralNetworkTrainer.java   | 173 ++++++-------------
 3 files changed, 162 insertions(+), 118 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/91c0c796/src/main/java/org/apache/horn/bsp/ParameterMerger.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMerger.java b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
new file mode 100644
index 0000000..709331b
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
@@ -0,0 +1,10 @@
+package org.apache.horn.bsp;
+
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.ipc.VersionedProtocol;
+
+public interface ParameterMerger extends VersionedProtocol {
+	long versionID = 1L;
+
+	SmallLayeredNeuralNetworkMessage merge(double trainingError,  DoubleMatrix[] weightUpdates,
DoubleMatrix[] prevWeightUpdates);
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/91c0c796/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
new file mode 100644
index 0000000..54caf2b
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
@@ -0,0 +1,97 @@
+package org.apache.horn.bsp;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.mortbay.log.Log;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class ParameterMergerServer implements ParameterMerger {
+	/* The parameter merge base. */
+	protected SmallLayeredNeuralNetwork inMemoryModel;
+
+	/* To terminate or not to terminate. */
+	protected AtomicBoolean isConverge;
+
+	/* The number of slave works that request commits. */
+	protected int SlaveCount;
+
+	/* After mergeLimit, terminate whether the result is converging or not. */
+	protected int mergeLimit;
+
+	/* last n training errors. converging is decided based on the average value of these errors.
*/
+	protected double[] trainingErrors;
+
+	/* If the average of last n training errors is smaller than this value, it is converging.
*/
+	protected double prevAvgTrainingError = Double.MAX_VALUE;
+
+	/* current index for trainingErrors. */
+	protected int curTrainingError = 0;
+
+	/* how many merges have been conducted? */
+	protected int mergeCount = 0;
+
+	public ParameterMergerServer(SmallLayeredNeuralNetwork inMemoryModel, AtomicBoolean isConverge,
+	                             int slaveCount, int mergeLimit, int convergenceCheckInterval)
{
+		this.inMemoryModel = inMemoryModel;
+		this.isConverge = isConverge;
+		this.SlaveCount = slaveCount;
+		this.mergeLimit = mergeLimit;
+		this.trainingErrors = new double[convergenceCheckInterval];
+	}
+
+	@Override
+	public long getProtocolVersion(String s, long l) throws IOException {
+		return ParameterMerger.versionID;
+	}
+
+	@Override
+	public SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates,
+	                                              DoubleMatrix[] prevWeightUpdates) {
+		Preconditions.checkArgument(weightUpdates.length == prevWeightUpdates.length);
+
+		Log.info(String.format("Start merging: %d.\n", this.mergeCount));
+
+		if (!this.isConverge.get()) {
+			for (int i = 0; i < weightUpdates.length; ++i) {
+				weightUpdates[i] = weightUpdates[i].divide(this.SlaveCount);
+				prevWeightUpdates[i] = prevWeightUpdates[i].divide(this.SlaveCount);
+			}
+
+			synchronized (inMemoryModel) {
+				this.inMemoryModel.updateWeightMatrices(weightUpdates);
+				this.inMemoryModel.setPrevWeightMatrices(prevWeightUpdates);
+
+				// add trainingError to trainingErrors
+				this.trainingErrors[this.curTrainingError++] = trainingError;
+
+				// check convergence
+				if (this.trainingErrors.length == this.curTrainingError) {
+					double curAvgTrainingError = 0.0;
+					for (int i = 0; i < this.curTrainingError; ++i) {
+						curAvgTrainingError += this.trainingErrors[i];
+					}
+					curAvgTrainingError /= this.trainingErrors.length;
+
+					if (prevAvgTrainingError < curAvgTrainingError) {
+						this.isConverge.set(true);
+					} else {
+						// update
+						prevAvgTrainingError = curAvgTrainingError;
+						this.curTrainingError = 0;
+					}
+				}
+
+				if (++this.mergeCount == this.mergeLimit) {
+					this.isConverge.set(true);
+				}
+			}
+		}
+
+		return new SmallLayeredNeuralNetworkMessage(
+				0, this.isConverge.get(), this.inMemoryModel.getWeightMatrices(),
+				this.inMemoryModel.getPrevMatricesUpdates());
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/91c0c796/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
index b4657f0..9e3d02f 100644
--- a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
@@ -29,9 +29,11 @@ import org.apache.hama.commons.io.VectorWritable;
 import org.apache.hama.commons.math.DenseDoubleMatrix;
 import org.apache.hama.commons.math.DoubleMatrix;
 import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.ipc.RPC;
 import org.mortbay.log.Log;
 
 import java.io.IOException;
+import java.net.InetSocketAddress;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
@@ -42,21 +44,26 @@ import java.util.concurrent.atomic.AtomicBoolean;
 public final class SmallLayeredNeuralNetworkTrainer
     extends
     BSP<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage>
{
-
+  /* When given peer is master worker: base of parameter merge */
+  /* When given peer is slave worker: neural network for training */
   private SmallLayeredNeuralNetwork inMemoryModel;
+
+  /* Job configuration */
   private Configuration conf;
+
   /* Default batch size */
   private int batchSize;
 
-  /* check the interval between intervals */
-  private double prevAvgTrainingError;
-  private double curAvgTrainingError;
-  private long convergenceCheckInterval;
-  private long iterations;
-  private long maxIterations;
+  /* whether it is converging or not */
   private AtomicBoolean isConverge;
 
-  private String modelPath;
+  /* When given peer is master worker: Asynchronous parameter merger */
+  /* When given peer is slave worker: null */
+  private RPC.Server merger;
+
+  /* When given peer is master worker: null */
+  /* When given peer is slave worker: proxy to Asynchronous parameter merger */
+  private ParameterMerger proxy;
 
   /**
    * Returns true if this worker is master worker.
@@ -77,20 +84,37 @@ public final class SmallLayeredNeuralNetworkTrainer
     // At least one master & slave worker exist.
     Preconditions.checkArgument(peer.getNumPeers() >= 2);
 
+    String modelPath = conf.get("modelPath");
+    this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
+    this.conf = peer.getConfiguration();
+    this.batchSize = conf.getInt("training.batch.size", 50);
+    this.isConverge = new AtomicBoolean(false);
+
+    int slaveCount = peer.getNumPeers() - 1;
+    int mergeLimit = conf.getInt("training.max.iterations", 100000);
+    int convergenceCheckInterval = peer.getNumPeers() * conf.getInt("convergence.check.interval",
+        2000);
+    String master = peer.getPeerName();
+    String masterAddr = master.substring(0, master.indexOf(':'));
+    int port = conf.getInt("sync.server.port", 40042);
+
     if (isMaster(peer)) {
+      try {
+        this.merger = RPC.getServer(new ParameterMergerServer(inMemoryModel, isConverge,
slaveCount,
+            mergeLimit, convergenceCheckInterval), masterAddr, port, conf);
+        merger.start();
+      } catch (IOException e) {
+        e.printStackTrace();
+      }
       Log.info("Begin to train");
+    } else {
+      InetSocketAddress addr = new InetSocketAddress(masterAddr, port);
+      try {
+        this.proxy = (ParameterMerger) RPC.getProxy(ParameterMerger.class, ParameterMerger.versionID,
addr, conf);
+      } catch (IOException e) {
+        e.printStackTrace();
+      }
     }
-    this.isConverge = new AtomicBoolean(false);
-    this.conf = peer.getConfiguration();
-    this.iterations = 0;
-    this.modelPath = conf.get("modelPath");
-    this.maxIterations = conf.getLong("training.max.iterations", 100000);
-    this.convergenceCheckInterval = conf.getLong("convergence.check.interval",
-        2000);
-    this.modelPath = conf.get("modelPath");
-    this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
-    this.prevAvgTrainingError = Integer.MAX_VALUE;
-    this.batchSize = conf.getInt("training.batch.size", 50);
   }
 
   @Override
@@ -102,8 +126,6 @@ public final class SmallLayeredNeuralNetworkTrainer
     // write model to modelPath
     if (isMaster(peer)) {
       try {
-        Log.info(String.format("End of training, number of iterations: %d.\n",
-            this.iterations));
         Log.info(String.format("Write model back to %s\n",
             inMemoryModel.getModelPath()));
         this.inMemoryModel.writeModelToFile();
@@ -117,21 +139,12 @@ public final class SmallLayeredNeuralNetworkTrainer
   public void bsp(
       BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage>
peer)
       throws IOException, SyncException, InterruptedException {
-    while (this.iterations++ < maxIterations) {
-      // each slave-worker calculate the matrices updates according to local data
-      if (!isMaster(peer)) {
+    if (!isMaster(peer)) {
+      while (!this.isConverge.get()) {
+        // each slave-worker calculate the matrices updates according to local data
+        // and merge them with master
         calculateUpdates(peer);
       }
-      peer.sync();
-
-      // master merge the updates model
-      if (isMaster(peer)) {
-        mergeUpdates(peer);
-      }
-      peer.sync();
-      if (this.isConverge.get()) {
-        break;
-      }
     }
   }
 
@@ -144,20 +157,6 @@ public final class SmallLayeredNeuralNetworkTrainer
   private void calculateUpdates(
       BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage>
peer)
       throws IOException {
-    // receive update information from master
-    if (peer.getNumCurrentMessages() != 0) {
-      SmallLayeredNeuralNetworkMessage inMessage = peer.getCurrentMessage();
-      DoubleMatrix[] newWeights = inMessage.getCurMatrices();
-      DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
-      this.inMemoryModel.setWeightMatrices(newWeights);
-      this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
-      this.isConverge.set(inMessage.isConverge());
-      // check converge
-      if (isConverge.get()) {
-        return;
-      }
-    }
-
     DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
         .size()];
     for (int i = 0; i < weightUpdates.length; ++i) {
@@ -187,76 +186,14 @@ public final class SmallLayeredNeuralNetworkTrainer
       weightUpdates[i] = weightUpdates[i].divide(batchSize);
     }
 
-    DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
-        .getPrevMatricesUpdates();
-    SmallLayeredNeuralNetworkMessage outMessage = new SmallLayeredNeuralNetworkMessage(
-        avgTrainingError, false, weightUpdates, prevWeightUpdates);
-    peer.send(peer.getPeerName(0), outMessage);
-  }
-
-  /**
-   * Merge the updates according to the updates of the grooms.
-   * 
-   * @param peer
-   * @throws IOException
-   */
-  private void mergeUpdates(
-      BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage>
peer)
-      throws IOException {
-    int numMessages = peer.getNumCurrentMessages();
-    boolean isConverge = false;
-    if (numMessages == 0) { // converges
-      this.isConverge.set(true);
-      return;
-    }
-
-    double avgTrainingError = 0;
-    DoubleMatrix[] matricesUpdates = null;
-    DoubleMatrix[] prevMatricesUpdates = null;
-
-    while (peer.getNumCurrentMessages() > 0) {
-      SmallLayeredNeuralNetworkMessage message = peer.getCurrentMessage();
-      if (matricesUpdates == null) {
-        matricesUpdates = message.getCurMatrices();
-        prevMatricesUpdates = message.getPrevMatrices();
-      } else {
-        SmallLayeredNeuralNetwork.matricesAdd(matricesUpdates,
-            message.getCurMatrices());
-        SmallLayeredNeuralNetwork.matricesAdd(prevMatricesUpdates,
-            message.getPrevMatrices());
-      }
-      avgTrainingError += message.getTrainingError();
-    }
-
-    if (numMessages != 1) {
-      avgTrainingError /= numMessages;
-      for (int i = 0; i < matricesUpdates.length; ++i) {
-        matricesUpdates[i] = matricesUpdates[i].divide(numMessages);
-        prevMatricesUpdates[i] = prevMatricesUpdates[i].divide(numMessages);
-      }
-    }
-    this.inMemoryModel.updateWeightMatrices(matricesUpdates);
-    this.inMemoryModel.setPrevWeightMatrices(prevMatricesUpdates);
-
-    // check convergence
-    if (iterations % convergenceCheckInterval == 0) {
-      if (prevAvgTrainingError < curAvgTrainingError) {
-        // error cannot decrease any more
-        isConverge = true;
-      }
-      // update
-      prevAvgTrainingError = curAvgTrainingError;
-      curAvgTrainingError = 0;
-    }
-    curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
-
-    // broadcast updated weight matrices
-    for (String peerName : peer.getAllPeerNames()) {
-      SmallLayeredNeuralNetworkMessage msg = new SmallLayeredNeuralNetworkMessage(
-          0, isConverge, this.inMemoryModel.getWeightMatrices(),
-          this.inMemoryModel.getPrevMatricesUpdates());
-      peer.send(peerName, msg);
-    }
+    // exchange parameter update with master
+    SmallLayeredNeuralNetworkMessage inMessage = proxy.merge(avgTrainingError, weightUpdates,
+        this.inMemoryModel.getWeightMatrices());
+    DoubleMatrix[] newWeights = inMessage.getCurMatrices();
+    DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
+    this.inMemoryModel.setWeightMatrices(newWeights);
+    this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
+    this.isConverge.set(inMessage.isConverge());
   }
 
 }


Mime
View raw message