singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [08/13] incubator-singa git commit: SINGA-70 Refactor API of Layer, Worker, Server and Driver
Date Sun, 27 Sep 2015 14:34:31 GMT
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/neuron_layer.h b/include/neuralnet/neuron_layer.h
index 6c4647d..51ba304 100644
--- a/include/neuralnet/neuron_layer.h
+++ b/include/neuralnet/neuron_layer.h
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -38,9 +38,9 @@ class ConvolutionLayer : public NeuronLayer {
  public:
   ~ConvolutionLayer();
 
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
   const std::vector<Param*> GetParams() const override {
     std::vector<Param*> params{weight_, bias_};
     return params;
@@ -63,15 +63,15 @@ class ConvolutionLayer : public NeuronLayer {
  */
 class CConvolutionLayer : public ConvolutionLayer {
  public:
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 };
 
 class DropoutLayer : public NeuronLayer {
  public:
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
  protected:
   // drop probability
   float pdrop_;
@@ -90,9 +90,9 @@ class DropoutLayer : public NeuronLayer {
  * b_i, the neuron after normalization, N is the total num of kernels
  */
 class LRNLayer : public NeuronLayer {
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 
  protected:
   //! shape of the bottom layer feature
@@ -106,9 +106,9 @@ class LRNLayer : public NeuronLayer {
 
 class PoolingLayer : public NeuronLayer {
  public:
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 
  protected:
   int kernel_, pad_, stride_;
@@ -121,26 +121,26 @@ class PoolingLayer : public NeuronLayer {
  */
 class CPoolingLayer : public PoolingLayer {
  public:
-  void Setup(const LayerProto& proto, int npartitions);
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers);
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
  private:
   Blob<float> mask_;
 };
 
 class ReLULayer : public NeuronLayer {
  public:
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 };
 
 class InnerProductLayer : public NeuronLayer {
  public:
   ~InnerProductLayer();
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
   const std::vector<Param*> GetParams() const override {
     std::vector<Param*> params{weight_, bias_};
     return params;
@@ -159,9 +159,9 @@ class InnerProductLayer : public NeuronLayer {
  */
 class STanhLayer : public NeuronLayer {
  public:
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric *perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 };
 
 /**
@@ -174,19 +174,19 @@ class SigmoidLayer: public Layer {
   using Layer::ComputeFeature;
   using Layer::ComputeGradient;
 
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 };
 
 
 /**
  * Base layer for RBM models.
  */
-class RBMLayer: public Layer {
+class RBMLayer: virtual public Layer {
  public:
   virtual ~RBMLayer() {}
-  void Setup(const LayerProto& proto, int npartitions) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
   const Blob<float>& neg_data(const Layer* layer) {
     return neg_data_;
   }
@@ -218,12 +218,12 @@ class RBMLayer: public Layer {
 /**
  * RBM visible layer
  */
-class RBMVisLayer: public RBMLayer {
+class RBMVisLayer: public RBMLayer, public LossLayer {
  public:
   ~RBMVisLayer();
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 
  private:
   RBMLayer* hid_layer_;
@@ -235,9 +235,9 @@ class RBMVisLayer: public RBMLayer {
 class RBMHidLayer: public RBMLayer {
  public:
   ~RBMHidLayer();
-  void Setup(const LayerProto& proto, int npartitions) override;
-  void ComputeFeature(int flag, Metric* perf) override;
-  void ComputeGradient(int flag, Metric* perf) override;
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
 
  private:
   RBMLayer *vis_layer_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/server.h
----------------------------------------------------------------------
diff --git a/include/server.h b/include/server.h
new file mode 100644
index 0000000..4b75430
--- /dev/null
+++ b/include/server.h
@@ -0,0 +1,133 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_SERVER_H_
+#define SINGA_SERVER_H_
+
+#include <unordered_map>
+#include <vector>
+#include "comm/socket.h"
+#include "proto/job.pb.h"
+#include "utils/param.h"
+#include "utils/updater.h"
+
+namespace singa {
+
+ /* Repsond to worker's get/put/udpate request, and periodically syncing with
+  * other servers.
+  *
+  * Normally, the Server creates a response message for each request which
+  * will be sent back to the one who issued the request. However, if the request
+  * are not processed successfully, the original message will be returned. The
+  * sever does not know the returned message is a response or the original
+  * message. It just sends it to the router. The router will decided to
+  * re-send the request to the server or send it to the worker.
+  */
+class Server {
+ public:
+  ~Server();
+  Server(int group_id, int server_id,
+      const JobProto& job_conf,
+      const std::vector<int>& slice2group,
+      const std::vector<int>& slice2server);
+  void Run();
+  inline int grp_id() const { return grp_id_; }
+  inline int id() const { return id_; }
+
+ protected:
+  /**
+   * Process GET request.
+   *
+   * @return the orignal message or a response message which contains the values
+   * of the Param with the request version.
+   */
+  Msg* HandleGet(Msg** msg);
+  /**
+   * Process Update request.
+   *
+   * It waits until received the gradients from all workers from the same worker
+   * group. After updating, it responses to each sender with the new Param
+   * values. It may generate a sync message to the server group that maintains
+   * the global version of the updated Param (slice).
+   *
+   * Note: there is no counter for each worker group on the number of received
+   * update requests. Hence it is possible that the server would conduct the
+   * update when it receives x requests from group a and y requests from group
+   * b where x + y = group size. To avoid this problem, we can
+   * -# maintain request list for each group for each Param at the server side
+   * -# do not span a worker group among multiple nodes. then the updates from
+   * the same group would be locally aggregated on the worker node. And the
+   * server would conduct the update immediately after receiving the aggregated
+   * request.
+   * -# launch only one worker group.
+   *
+   * @return the orignal message or response message
+   */
+  const std::vector<Msg*> HandleUpdate(Msg **msg);
+  /**
+   * Process PUT request.
+   *
+   * @return the original message or response message. If we don't want to
+   * acknowledge the put request, then return nullptr.
+   */
+  Msg* HandlePut(Msg **msg);
+  /**
+   * Handle sync request from other server groups.
+   *
+   * It adds updates of Param (slice) from other server groups directly to
+   * local Param (slice). Currently, each Param (slice) has a master group,
+   * i.e., slice2group_[sliceid], which would receive such requests from all
+   * other server groups for the Param object.
+   *
+   * @param msg request msg containing the parameter updates
+   * @return response msg that contains the fresh parameter values.
+   */
+  Msg* HandleSyncRequest(Msg** msg);
+  /**
+   * Handle sync response.
+   *
+   * The response msg includes the latest values of a Param object from the
+   * server group that maintainers this Param object.
+   * The local Param values are replaced with the addition result of local
+   * udpates since the sync request was sent and the received Param values.
+   *
+   * @param response message
+   */
+  void HandleSyncResponse(Msg** msg);
+
+ protected:
+  int grp_id_ = -1;
+  int id_ = -1;
+  Updater* updater_ = nullptr;
+  //!< map from slice ID to slice and deleted in the destructor
+  std::unordered_map<int, ParamEntry*> shard_;
+  std::vector<int> slice2group_, slice2server_;
+  //!< num of updates from last sync with master server group for a param/slice
+  std::vector<int> n_updates_;
+  //!< num of sync requests that have not been responded
+  std::vector<int> n_pending_sync_;
+  std::vector<Blob<float>> last_sync_;
+  std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
+};
+
+}  // namespace singa
+
+#endif  // SINGA_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/singa.h
----------------------------------------------------------------------
diff --git a/include/singa.h b/include/singa.h
index d4ee557..6c801ab 100644
--- a/include/singa.h
+++ b/include/singa.h
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,16 +22,15 @@
 #ifndef SINGA_SINGA_H_
 #define SINGA_SINGA_H_
 
-#include "communication/socket.h"
+#include "comm/socket.h"
 #include "neuralnet/neuralnet.h"
 #include "neuralnet/layer.h"
 #include "proto/job.pb.h"
 #include "proto/singa.pb.h"
-#include "trainer/trainer.h"
 #include "utils/common.h"
 #include "utils/param.h"
 #include "utils/singleton.h"
 #include "utils/factory.h"
-#include "driver.h"
+#include "./driver.h"
 
 #endif  // SINGA_SINGA_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/stub.h
----------------------------------------------------------------------
diff --git a/include/stub.h b/include/stub.h
new file mode 100644
index 0000000..719f033
--- /dev/null
+++ b/include/stub.h
@@ -0,0 +1,109 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_STUB_H_
+#define SINGA_STUB_H_
+
+#include <queue>
+#include <unordered_map>
+#include <vector>
+#include <string>
+#include "comm/socket.h"
+#include "neuralnet/neuralnet.h"
+#include "proto/job.pb.h"
+#include "proto/singa.pb.h"
+#include "utils/factory.h"
+#include "utils/param.h"
+#include "utils/singleton.h"
+#include "./server.h"
+#include "./worker.h"
+
+namespace singa {
+
+class Stub {
+ public:
+  ~Stub();
+  /**
+   * Find an endpoint to bind.
+   */
+  void Setup();
+  /**
+   * The Stub instance runs this function in the main thread to handle (e.g.,
+   * forward) messages from workers and servers.
+   *
+   * @param[in] slice2server the k-th value is the ID of the server that is in
+   * charge of updating the Param slice with ID k. Large Param objects are
+   * sliced into subsets for load-balance. Different subsets are updated by
+   * different servers.
+   */
+  void Run(const vector<int>& slice2server,
+      const std::vector<Worker*>& workers,
+      const std::vector<Server*>& servers);
+
+  const std::string& endpoint() const {
+    return endpoint_;
+  }
+
+ protected:
+  /**
+   * Create a socket to send msg to the specified process
+   * @param dst_procs the dst process (logical) ID
+   * @return the newly created socket
+   */
+  Dealer* CreateInterProcsDealer(int dst_procs);
+  /**
+   * Generate a request message to Get the parameter object.
+   */
+  const std::vector<Msg*> HandleGetRequest(ParamEntry* entry, Msg** msg);
+  void HandleGetResponse(ParamEntry* entry, Msg** msg);
+  /**
+   * Generate a request message to Update the parameter object.
+   */
+  const std::vector<Msg*> HandleUpdateRequest(ParamEntry* entry, Msg** msg);
+  /**
+   * Handle response msg from servers for the update requests.
+   */
+  void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
+  /**
+   * Generate a request message to Put the parameter object.
+   */
+  const std::vector<Msg*> HandlePutRequest(ParamEntry* entry, Msg** msg);
+  /**
+   * Called by HandlePut, HandleUpdate and HandleGet functions
+   * @param type message type
+   * @param version param version
+   * @param entry
+   * @param msg
+   * @param ret generated messages
+   */
+  void GenMsgs(int type, int version, ParamEntry* entry,
+    Msg* msg, std::vector<Msg*> *ret);
+
+
+ protected:
+  Router *router_ = nullptr;
+  std::string endpoint_;
+  std::vector<int> slice2server_;
+};
+
+}  // namespace singa
+
+#endif  // SINGA_STUB_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
deleted file mode 100644
index 84b3a41..0000000
--- a/include/trainer/server.h
+++ /dev/null
@@ -1,132 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_TRAINER_SERVER_H_
-#define SINGA_TRAINER_SERVER_H_
-
-#include <unordered_map>
-#include <vector>
-#include "communication/socket.h"
-#include "proto/job.pb.h"
-#include "utils/param.h"
-#include "utils/updater.h"
-
-namespace singa {
-
- /* Repsond to worker's get/put/udpate request, and periodically syncing with
-  * other servers.
-  *
-  * Normally, the Server creates a response message for each request which
-  * will be sent back to the one who issued the request. However, if the request
-  * are not processed successfully, the original message will be returned. The
-  * sever does not know the returned message (response or the original message),
-  * it just sends it to the router. The router will decide to re-send the
-  * request to the server or send it to the worker.
-  */
-class Server {
- public:
-  Server(int group_id, int server_id);
-  ~Server();
-  void Setup(const UpdaterProto& proto, const std::vector<int>& slice2group,
-             const std::vector<int>& slice2server);
-  void Run();
-  inline int grp_id() const { return grp_id_; }
-  inline int id() const { return id_; }
-
- protected:
-  /**
-   * Process GET request.
-   *
-   * @return the orignal message or a response message which contains the values
-   * of the Param with the request version.
-   */
-  Msg* HandleGet(Msg** msg);
-  /**
-   * Process Update request.
-   *
-   * It waits until received the gradients from all workers from the same worker
-   * group. After updating, it responses to each sender with the new Param
-   * values. It may generate a sync message to the server group that maintains
-   * the global version of the updated Param (slice).
-   *
-   * Note: there is no counter for each worker group on the number of received
-   * update requests. Hence it is possible that the server would conduct the
-   * update when it receives x requests from group a and y requests from group
-   * b where x + y = group size. To avoid this problem, we can
-   * 1. maintain request list for each group for each Param at the server side
-   * 2. do not span a worker group among multiple nodes. then the updates from
-   * the same group would be locally aggregated on the worker node. And the
-   * server would conduct the update immediately after receiving the aggregated
-   * request.
-   * 3. launch only one worker group.
-   *
-   * @return the orignal message or response message
-   */
-  const std::vector<Msg*> HandleUpdate(Msg **msg);
-  /**
-   * Process PUT request.
-   *
-   * @return the original message or response message. If we don't want to
-   * acknowledge the put request, then return nullptr.
-   */
-  Msg* HandlePut(Msg **msg);
-  /**
-   * Handle sync request from other server groups.
-   *
-   * It adds updates of Param (slice) from other server groups directly to
-   * local Param (slice). Currently, each Param (slice) has a master group,
-   * i.e., slice2group_[sliceid], which would receive such requests from all
-   * other server groups for the Param object.
-   *
-   * @param msg request msg containing the parameter updates
-   * @return response msg that contains the fresh parameter values.
-   */
-  Msg* HandleSyncRequest(Msg** msg);
-  /**
-   * Handle sync response.
-   *
-   * The response msg includes the latest values of a Param object, for which
-   * this server sent the sync request to the master/maintainer group.
-   * The local Param values are replaced with the addition result of local
-   * udpates since the sync request was sent and the received Param values.
-   *
-   * @param response message
-   */
-  void HandleSyncResponse(Msg** msg);
-
- protected:
-  int grp_id_ = -1;
-  int id_ = -1;
-  Updater* updater_ = nullptr;
-  //!< map from slice ID to slice and deleted in the destructor
-  std::unordered_map<int, ParamEntry*> shard_;
-  std::vector<int> slice2group_, slice2server_;
-  //!< num of updates from last sync with master server group for a param/slice
-  std::vector<int> n_updates_;
-  //!< num of sync requests that have not been responded
-  std::vector<int> n_pending_sync_;
-  std::vector<Blob<float>> last_sync_;
-  std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_TRAINER_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
deleted file mode 100644
index 1c0e039..0000000
--- a/include/trainer/trainer.h
+++ /dev/null
@@ -1,163 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_TRAINER_TRAINER_H_
-#define SINGA_TRAINER_TRAINER_H_
-
-#include <queue>
-#include <unordered_map>
-#include <vector>
-#include "communication/socket.h"
-#include "neuralnet/neuralnet.h"
-#include "proto/job.pb.h"
-#include "proto/singa.pb.h"
-#include "trainer/server.h"
-#include "trainer/worker.h"
-#include "utils/factory.h"
-#include "utils/param.h"
-#include "utils/singleton.h"
-
-namespace singa {
-
-/**
- * Every running process has a training object which launches one or more
- * worker (and server) threads.
- *
- * The main thread runs a loop to forward messages between workers and servers.
- */
-class Trainer{
- public:
-  ~Trainer();
-  /**
-   * Entrance function which construct the workers and servers, and luanch
-   * one thread per worker/server.
-   *
-   * @param resume if true resume the training from the latest checkpoint files
-   * @param singaConf global singa configuration including zookeeper and
-   * @param jobConf job configuration, including cluster and model configuration
-   */
-  void Start(bool resume, const SingaProto& singaConf, JobProto* jobConf);
-
- protected:
-  /**
-   * Setting the checkpoint field of model configuration to resume training.
-   *
-   * The checkpoint folder will be searched to get the files for the latest
-   * checkpoint, which will be added into the checkpoint field. The workers
-   * would then load the values of params from the checkpoint files.
-   *
-   * @param jobConf job configuration
-   */
-  void Resume(JobProto* jobConf);
-  /**
-   * Create server instances.
-   * @param nthread total num of threads in current procs which is used to
-   * assign each thread a local thread ID. The number of workers is extracted
-   * from Cluster
-   * @param jobConf
-   * @return server instances
-   */
-  std::vector<Server*> CreateServers(const JobProto& jobConf);
-  /**
-   * Create workers instances.
-   * @param nthread total num of threads in current procs which is used to
-   * assign each thread a local thread ID. The number of workers is extracted
-   * from Cluster
-   * @param jobConf
-   * @return worker instances
-   */
-  std::vector<Worker*> CreateWorkers(const JobProto& jobConf);
-  /**
-   * Setup workers and servers.
-   *
-   * For each worker, create and assign a neuralnet to it.
-   * For each server, create and assign the param shard to it.
-   * Create the partition map from slice ID to server
-   * @param modelConf
-   * @param workers
-   * @param servers
-   */
-  void SetupWorkerServer(const JobProto& jobConf,
-                         const std::vector<Worker*>& workers,
-                         const std::vector<Server*>& servers);
-  void Run(const std::vector<Worker*>& workers,
-           const std::vector<Server*>& servers);
-  /**
-   * Display metrics to log (standard output)
-   */
-  void DisplayMetric(Msg** msg);
-  /**
-   * Create a socket to send msg to the specified process
-   * @param dst_procs the dst process (logical) ID
-   * @return the newly created socket
-   */
-  Dealer* CreateInterProcsDealer(int dst_procs);
-  /**
-   * Handle messages to local servers and local stub
-   */
-  void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg);
-  /**
-   * Generate a request message to Get the parameter object.
-   */
-  const std::vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
-  void HandleGetResponse(ParamEntry* entry, Msg** msg);
-  /**
-   * Generate a request message to Update the parameter object.
-   */
-  const std::vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
-  void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
-  /**
-   * Generate a request message to Put the parameter object.
-   */
-  const std::vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
-  /**
-   * Called by HandlePut, HandleUpdate and HandleGet functions
-   * @param type message type
-   * @param version param version
-   * @param entry
-   * @param msg
-   * @param ret generated messages
-   */
-  void GenMsgs(int type, int version, ParamEntry* entry,
-    Msg* msg, std::vector<Msg*> *ret);
-  /**
-   * Get a hash id for a Param object from a group.
-   *
-   * Simple multiple group_id with a large prime number 997 (assuming there are
-   * no more than 997 worker groups) and plus owner param id.
-   */
-  inline int Hash(int grp_id, int param_id) {
-    return grp_id * 997 + param_id;
-  }
-
- protected:
-  int procs_id_ = -1;
-  Router *router_ = nullptr;
-  std::unordered_map<int, ParamEntry*> worker_shard_;
-  //!< map from slice to the server that updates it
-  std::vector<int> slice2server_;
-  // a buffer of created nets, will destroy them all in destructor
-  std::vector<NeuralNet*> nets_;
-};
-
-}  // namespace singa
-
-#endif  // SINGA_TRAINER_TRAINER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
deleted file mode 100644
index 66439ec..0000000
--- a/include/trainer/worker.h
+++ /dev/null
@@ -1,258 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#ifndef SINGA_TRAINER_WORKER_H_
-#define SINGA_TRAINER_WORKER_H_
-
-#include <string>
-#include "communication/socket.h"
-#include "neuralnet/neuralnet.h"
-#include "proto/job.pb.h"
-
-namespace singa {
-
-//!< sleep 5 milliseconds if the Param is not updated to the expected version
-const int kCollectSleepTime = 5;
-/**
- * The Worker class which runs the training algorithm.
- * The first worker group will initialize parameters of the Net,
- * and put them into the distributed memory/table.
- * The virtual function TrainOneBatch and TestOneBatch implement the
- * training and test algorithm for one mini-batch data.
- *
- * Child workers override the two functions to implement their training
- * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT
- * algorithm respectively.
- */
-class Worker {
- public:
-  static Worker* Create(const JobProto& proto);
-  /**
-   * @param thread_id local thread index within the procs
-   * @param grp_id global worker group ID
-   * @param id worker ID within the group
-   */
-  virtual void Init(int grp_id, int id);
-  virtual ~Worker();
-  /**
-   * Setup members
-   */
-  void Setup(const JobProto& job, NeuralNet* train_net, NeuralNet* valid_net,
-             NeuralNet* test_net);
-  /**
-   * Init all local params (i.e., params from layers resident in this worker).
-   *
-   * If the param is owned by the worker, then init it and put it to servers.
-   * Otherwise call Get() to get the param. The Get may not send get request.
-   * Because the param's own is in the same procs. Once the owner initializes
-   * the param, its version is visiable to all shares.
-   * If the training starts from scrath, the params are initialzed using random
-   * distributions, e.g., Gaussian distribution. After that, the worker may
-   * train for a couple of steps to warmup the params before put
-   * them to servers (warmup of JobProto controls this).
-   *
-   * If the owner param is available from checkpoint file, then its
-   * values are parsed from the checkpoint file instead of randomly initialized.
-   * For params who do not have checkpoints, randomly init them.
-   */
-  void InitLocalParams();
-  /**
-    * Main function of Worker.
-    *
-    * Train the neuralnet step by step, test/validation is done periodically.
-    */
-  void Run();
-  /**
-   * Checkpoint all params owned by the worker from the first group onto disk.
-   * The serialization is done using BlobProtos which includes the name, version
-   * and values of each Param.
-   * Different worker would generate different checkpoint files. The file path
-   * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin
-   * @param step training step of this worker
-   * @param net the training net whose params will be dumped.
-   */
-  void Checkpoint(int step, NeuralNet* net);
-  /**
-    * Test the perforance of the learned model on validation or test dataset.
-    * Test is done by the first group.
-    * @param net, neural network
-    */
-  void Test(int nsteps, Phase phase, NeuralNet* net);
-  /**
-    * Train one mini-batch.
-    * Test/Validation is done before training.
-    */
-  virtual void TrainOneBatch(int step, Metric* perf) = 0;
-  /**
-   * Test/validate one mini-batch.
-   */
-  virtual void TestOneBatch(int step, Phase phase, NeuralNet* net,
-                            Metric* perf) = 0;
-  /**
-   * Report performance to the stub.
-   *
-   * @param prefix display prefix, e.g., 'Train', 'Test'
-   * @param perf
-   */
-  void Report(const std::string& prefix, const Metric & perf);
-  /**
-   * Put Param to server.
-   * @param param
-   * @param step used as current param version for the put request
-   */
-  int Put(Param* param, int step);
-  /**
-   * Get Param with specific version from server
-   * If the current version >= the requested version, then return.
-   * Otherwise send a get request to stub who would forwards it to servers.
-   * @param param
-   * @param step requested param version
-   */
-  int Get(Param* param, int step);
-  /**
-   * Update Param
-   * @param param
-   * @param step training step used for updating (e.g., deciding learning rate)
-   */
-  int Update(Param* param, int step);
-  /**
-   * Block until the param is updated since sending the update request
-   *
-   * @param param
-   * @param step not used
-   */
-  int Collect(Param* param, int step);
-  /**
-   * Call Collect for every param of net
-   */
-  int CollectAll(NeuralNet* net, int step);
-  /**
-   * Receive blobs from other workers due to model partitions.
-   */
-  void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
-  /**
-   * Send blobs to other workers due to model partitions.
-   */
-  void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
-  /**
-   * Check is it time to display training info, e.g., loss and precison.
-   */
-  inline bool DisplayNow(int step) const {
-    return job_conf_.disp_freq() > 0
-           && step >= job_conf_.disp_after()
-           && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0);
-  }
-  /**
-   * Check is it time to display training info, e.g., loss and precison.
-   */
-  inline bool DisplayDebugInfo(int step) const {
-    return DisplayNow(step) && job_conf_.debug() && grp_id_ == 0;
-  }
-  /**
-   * Check is it time to stop
-   */
-  inline bool StopNow(int step) const {
-    return step >= job_conf_.train_steps();
-  }
-  /**
-   * Check is it time to do checkpoint.
-   */
-  inline bool CheckpointNow(int step) const {
-    return grp_id_ == 0
-           && job_conf_.checkpoint_freq() > 0
-           && step >= job_conf_.checkpoint_after()
-           && ((step - job_conf_.checkpoint_after())
-              % job_conf_.checkpoint_freq() == 0);
-  }
-  /**
-   * Check is it time to do test.
-   * @param step the ::Train() has been called this num times.
-   */
-  inline bool TestNow(int step) const {
-    return grp_id_ == 0
-           && job_conf_.test_freq() > 0
-           && job_conf_.test_steps() > 0
-           && step >= job_conf_.test_after()
-           && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0);
-  }
-  /**
-   * Check is it time to do validation.
-   * @param step the ::Train() has been called step times.
-   */
-  inline bool ValidateNow(int step) const {
-    return grp_id_ == 0
-           && job_conf_.valid_freq() > 0
-           && job_conf_.valid_steps() > 0
-           && step >= job_conf_.valid_after()
-           && ((step - job_conf_.valid_after()) % job_conf_.valid_freq() == 0);
-  }
-  /**
-   * @return group ID
-   */
-  int grp_id() const { return grp_id_; }
-  /**
-   * @reutrn worker ID within the worker group.
-   */
-  int id() const { return id_; }
-
- protected:
-  int grp_id_ = -1, id_ = -1;
-  int step_ = 0;
-  JobProto job_conf_;
-  NeuralNet* train_net_ = nullptr;
-  NeuralNet* test_net_ = nullptr;
-  NeuralNet* validation_net_ = nullptr;
-  Dealer* layer_dealer_ = nullptr;
-  Dealer* dealer_ = nullptr;
-};
-
-class BPWorker: public Worker {
- public:
-  void TrainOneBatch(int step, Metric* perf) override;
-  void TestOneBatch(int step, Phase phase, NeuralNet* net, Metric* perf)
-      override;
-  void Forward(int step, Phase phase, NeuralNet* net, Metric* perf);
-  void Backward(int step, NeuralNet* net);
-};
-
-class CDWorker: public Worker {
- public:
-  void TrainOneBatch(int step, Metric* perf) override;
-  void TestOneBatch(int step, Phase phase, NeuralNet* net, Metric* perf)
-      override;
-};
-
-inline int BlobTrgt(int grp, int layer) {
-  return (grp << 16) | layer;
-}
-
-inline int BlobGrp(int blob_trgt) {
-  return blob_trgt >> 16;
-}
-
-inline int BlobLayer(int blob_trgt) {
-  static int mask = (1 << 16) -1;
-  return blob_trgt & mask;
-}
-
-}  // namespace singa
-
-#endif  // SINGA_TRAINER_WORKER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index e6c8c7c..f690438 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -25,12 +25,13 @@
 #include <memory>
 #include <string>
 #include <vector>
-#include "communication/msg.h"
+
+#include "comm/msg.h"
 #include "proto/job.pb.h"
 #include "utils/blob.h"
 
 namespace singa {
-
+using std::vector;
 /**
  * Base parameter generator which intializes parameter values.
  */
@@ -92,7 +93,34 @@ class UniformSqrtFanInOutGen : public UniformGen {
  */
 class Param {
  public:
-  static Param* Create(const ParamProto& proto);
+  /**
+   * Create an instance of (sub) Param class based on the type from the
+   * configuration.
+   *
+   * @param[in] conf configuration
+   * @param a pointer to an instance
+   */
+  static Param* Create(const ParamProto& conf);
+
+  /**
+   * Try to slice the Param objects (from a neural net) into a given number of
+   * servers (groups) evenly. This is to achieve load-balance among servers.
+   *
+   * It does not change the Param objects, but just computes the length of each
+   * slice.
+   *
+   * @param num number of servers (groups) for maintaining the Param objects.
+   * @param params all Param objects from a neural net.
+   * @return the length of each slice.
+   */
+  static const vector<int> ComputeSlices(int num, const vector<Param*>& params);
+  /**
+   * It computes the length of each slice and slices the Param objects by adding
+   * the slicing information into every Param object.
+   *
+   * @copydetails ComputeSlices()
+   */
+  static void SliceParams(int num, const vector<Param*>& params);
 
   Param() {}
   virtual ~Param() {}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/include/worker.h
----------------------------------------------------------------------
diff --git a/include/worker.h b/include/worker.h
new file mode 100644
index 0000000..58f02c4
--- /dev/null
+++ b/include/worker.h
@@ -0,0 +1,311 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_WORKER_H_
+#define SINGA_WORKER_H_
+
+#include <string>
+#include <vector>
+#include "comm/socket.h"
+#include "neuralnet/neuralnet.h"
+#include "proto/job.pb.h"
+
+namespace singa {
+
+//!< sleep 5 milliseconds if the Param is not updated to the expected version
+const int kCollectSleepTime = 5;
+/**
+ * The Worker class which runs the training algorithm.
+ * The first worker group will initialize parameters of the Net,
+ * and put them into the distributed memory/table.
+ * The virtual function TrainOneBatch and TestOneBatch implement the
+ * training and test algorithm for one mini-batch data.
+ *
+ * Child workers override the two functions to implement their training
+ * algorithms, e.g., the BPWorker/CDWorker/BPTTWorker implements the BP/CD/BPTT
+ * algorithm respectively.
+ */
+class Worker {
+ public:
+  /**
+   * Create an instance of the subclass of Worker.
+   *
+   * @param[in] conf configuration of the TrainOneBatch algorithm. Different
+   * Worker subclasses implement different algorithms. Hence the creation is
+   * based on the TrainOneBatch algorithm type. Currently SINGA
+   * provides two algorithms:
+   * -# Back-propagation for the feed-forward models, e.g., CNN and MLP, and the
+   *  recurrent neural networks.
+   * -# Contrastive divergence for the energy models, e.g., RBM.
+   *
+   * @return a pointer to the instance of the Worker subclass.
+   */
+  static Worker* Create(const AlgProto& conf);
+  virtual ~Worker();
+  /**
+   * @param[in] grp_id global worker group ID
+   * @param[in] id worker ID within the group
+   * @param[in] conf job configuration
+   * @param[in] train_net pointer to the training neural net, which could be
+   * shared with other workers from the same group. Different workers run over
+   * differnt subset of layers.
+   * @param[in] val_net pointer to the validation neural net. Currently only the
+   * first worker from the first group would have validation neural net. All
+   * other workers receive nullptr for this argument.
+   * @param[in] test_net pointer to the test neural net. Currently only the
+   * first worker from the first group would have test neural net. All other
+   * workers receive nullptr for this argument.
+   */
+  virtual void Setup(int grp_id, int id, const JobProto& conf,
+      NeuralNet* train_net, NeuralNet* val_net, NeuralNet* test_net);
+
+  /**
+   * Main function of Worker.
+   *
+   * Train the neuralnet step by step, test/validation is done periodically.
+   */
+  void Run();
+
+  /**
+   * Init values of Param instances assocaited with local layers (i.e., layers
+   * dispatched to this worker).
+   *
+   * If one Param is owned by the worker, then it should be initialized and put
+   * to servers. Otherwise Get() should be called to get the Param. The Get()
+   * may not send get requests if the Param owner is in the same procs, for
+   * which case the memory space of the Param objects are shared. But if this
+   * worker and the Param owner worker run on different devices (e.g., GPUs),
+   * then the get request would be sent.
+   *
+   * If the training starts from scrath, every Param object is initialzed using
+   * ParamGenerator. After that, the worker may
+   * train for a couple of steps to warmup the params before put
+   * them to servers (warmup of JobProto controls this).
+   *
+   * If one Param object's name matches that of one Param object from the
+   * checkpoint files, its Param values would be loaded from checkpoint files.
+   *
+   * @param[in] job_conf job configuration which provides settings for
+   * checkpoint file paths, warmup steps and Param versions.
+   * @param[out] net pointer to a neural net whose Param values will be
+   * initialized.
+   */
+  void InitNetParams(const JobProto& job_conf, NeuralNet* net);
+
+  /**
+   * Checkpoint all Param objects owned by the worker onto disk.
+   * The serialization is done using BlobProtos which includes the name, version
+   * and values of each Param object.
+   * Different workers would generate different checkpoint files. The file path
+   * is <workspace>/checkpoint-<jobname>-step<step>-worker<worker_id>.bin
+   * @param[in] step training step
+   * @param[in] folder directory to put the checkpoint file
+   * @param net the training net whose Param objects will be dumped.
+   */
+  void Checkpoint(int step, const std::string& folder, NeuralNet* net);
+
+  /**
+    * Train one mini-batch.
+    * Test/Validation is done before training.
+    *
+    * @param[in] step training step.
+    * @param[in] net neural net to be trained.
+    */
+  virtual void TrainOneBatch(int step, NeuralNet* net) = 0;
+
+  /**
+   * Test/validate one mini-batch data.
+   *
+   * @param[in] step test step.
+   * @param[in] phase test could be done for validation or test phase.
+   * @param[in] net neural net for test
+   */
+  virtual void TestOneBatch(int step, Phase phase, NeuralNet* net) = 0;
+
+  /**
+   * Display infomation from layers.
+   *
+   * @param flag could be a combination of multiple phases, e.g, kTest|kForward,
+   * it is passed to the Layer::ToString() function for each layer to decide
+   * what to display .
+   * @param prefix display prefix, e.g., 'Train step 100', 'Test step 90'.
+   * @param net display layers from this neural net.
+   */
+  void Display(int flag, const std::string& prefix, NeuralNet* net);
+
+  /**
+   * Put Param values to server.
+   *
+   * @param param
+   * @param step used as current param version for the put request
+   */
+  int Put(int step, Param* param);
+
+  /**
+   * Get Param with specific version from server
+   * If the current version >= the requested version, then return.
+   * Otherwise send a get request to stub who would forwards it to servers.
+   * @param param
+   * @param step requested param version
+   */
+  int Get(int step, Param* param);
+
+  /**
+   * Update Param.
+   *
+   * @param param
+   * @param step training step used for updating (e.g., deciding learning rate).
+   */
+  int Update(int step, Param* param);
+
+  /**
+   * Wait for the response of the update/get requests.
+   *
+   * @param param
+   * @param step not used now.
+   */
+  int Collect(int step, Param* param);
+
+  /**
+   * Call Collect() for every param of net
+   */
+  int CollectAll(int step, NeuralNet* net);
+
+  /**
+   * Receive blobs from other workers due to model partitions.
+   */
+  void ReceiveBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
+
+  /**
+   * Send blobs to other workers due to model partitions.
+   */
+  void SendBlobs(bool data, bool grad, BridgeLayer* layer, NeuralNet* net);
+
+
+  /**
+   * @param[in] step
+   * @return true if it is time to display training info, e.g., loss; otherwise
+   * false.
+   */
+  inline bool DisplayNow(int step) const {
+    return job_conf_.disp_freq() > 0
+           && step >= job_conf_.disp_after()
+           && ((step - job_conf_.disp_after()) % job_conf_.disp_freq() == 0);
+  }
+  /**
+   * @param[in] step
+   * @return true if it is time to finish the training; otherwise false.
+   */
+  inline bool StopNow(int step) const {
+    return step >= job_conf_.train_steps();
+  }
+  /**
+   * @param[in] step
+   * @return true if it is time to do checkpoint Param objects; otherwise false.
+   */
+  inline bool CheckpointNow(int step) const {
+    return job_conf_.checkpoint_freq() > 0
+           && step >= job_conf_.checkpoint_after()
+           && ((step - job_conf_.checkpoint_after())
+              % job_conf_.checkpoint_freq() == 0);
+  }
+  /**
+   * @param[in] step
+   * @return true if it is time to do test over the test dataset.
+   */
+  inline bool TestNow(int step) const {
+    return job_conf_.test_freq() > 0
+      && job_conf_.test_steps() > 0
+      && step >= job_conf_.test_after()
+      && ((step - job_conf_.test_after()) % job_conf_.test_freq() == 0);
+  }
+  /**
+   * @param[in] step
+   * @return true if it is time to do test over the validation dataset.
+   */
+  inline bool ValidateNow(int step) const {
+    return job_conf_.validate_freq() > 0
+      && job_conf_.validate_steps() > 0
+      && step >= job_conf_.validate_after()
+      && ((step - job_conf_.validate_after()) % job_conf_.validate_freq() == 0);
+  }
+  /**
+   * @return a vector with pointers to all neural nets.
+   */
+  const std::vector<NeuralNet*> GetNets() const {
+    return std::vector<NeuralNet*> {train_net_, val_net_, test_net_};
+  }
+  /**
+   * @return training net.
+   */
+  inline NeuralNet* train_net() const {
+    return train_net_;
+  }
+  /**
+   * @return group ID
+   */
+  inline int grp_id() const { return grp_id_; }
+  /**
+   * @reutrn worker ID within the worker group.
+   */
+  inline int id() const { return id_; }
+
+ protected:
+  int grp_id_ = -1, id_ = -1;
+  int step_ = 0;
+  JobProto job_conf_;
+  NeuralNet* train_net_ = nullptr;
+  NeuralNet* test_net_ = nullptr;
+  NeuralNet* val_net_ = nullptr;
+  Dealer* layer_dealer_ = nullptr;
+  Dealer* dealer_ = nullptr;
+};
+
+class BPWorker: public Worker {
+ public:
+  void TrainOneBatch(int step, NeuralNet* net) override;
+  void TestOneBatch(int step, Phase phase, NeuralNet* net) override;
+  void Forward(int step, Phase phase, NeuralNet* net);
+  void Backward(int step, NeuralNet* net);
+};
+
+class CDWorker: public Worker {
+ public:
+  void TrainOneBatch(int step, NeuralNet* net) override;
+  void TestOneBatch(int step, Phase phase, NeuralNet* net) override;
+};
+
+inline int BlobTrgt(int grp, int layer) {
+  return (grp << 16) | layer;
+}
+
+inline int BlobGrp(int blob_trgt) {
+  return blob_trgt >> 16;
+}
+
+inline int BlobLayer(int blob_trgt) {
+  static int mask = (1 << 16) -1;
+  return blob_trgt & mask;
+}
+
+}  // namespace singa
+
+#endif  // SINGA_WORKER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/comm/msg.cc
----------------------------------------------------------------------
diff --git a/src/comm/msg.cc b/src/comm/msg.cc
new file mode 100644
index 0000000..2521c28
--- /dev/null
+++ b/src/comm/msg.cc
@@ -0,0 +1,215 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "comm/msg.h"
+
+#include <glog/logging.h>
+
+namespace singa {
+
+#ifdef USE_ZMQ
+Msg::~Msg() {
+  if (msg_ != nullptr)
+    zmsg_destroy(&msg_);
+  frame_ = nullptr;
+}
+
+Msg::Msg() {
+  msg_ = zmsg_new();
+}
+
+Msg::Msg(const Msg& msg) {
+  src_ = msg.src_;
+  dst_ = msg.dst_;
+  type_ = msg.type_;
+  trgt_val_ = msg.trgt_val_;
+  trgt_version_ = msg.trgt_version_;
+  msg_ = zmsg_dup(msg.msg_);
+}
+
+Msg::Msg(int src, int dst) {
+  src_ = src;
+  dst_ = dst;
+  msg_ = zmsg_new();
+}
+
+void Msg::SwapAddr() {
+  std::swap(src_, dst_);
+}
+
+int Msg::size() const {
+  return zmsg_content_size(msg_);
+}
+
+void Msg::AddFrame(const void* addr, int nBytes) {
+  zmsg_addmem(msg_, addr, nBytes);
+}
+
+int Msg::FrameSize() {
+  return zframe_size(frame_);
+}
+
+void* Msg::FrameData() {
+  return zframe_data(frame_);
+}
+
+char* Msg::FrameStr() {
+  return zframe_strdup(frame_);
+}
+bool Msg::NextFrame() {
+  frame_ = zmsg_next(msg_);
+  return frame_ != nullptr;
+}
+
+void Msg::FirstFrame() {
+  frame_ = zmsg_first(msg_);
+}
+
+void Msg::LastFrame() {
+  frame_ = zmsg_last(msg_);
+}
+
+void Msg::ParseFromZmsg(zmsg_t* msg) {
+  char* tmp = zmsg_popstr(msg);
+  sscanf(tmp, "%d %d %d %d %d",
+         &src_, &dst_, &type_, &trgt_val_, &trgt_version_);
+  frame_ = zmsg_first(msg);
+  msg_ = msg;
+}
+
+zmsg_t* Msg::DumpToZmsg() {
+  zmsg_pushstrf(msg_, "%d %d %d %d %d",
+      src_, dst_, type_, trgt_val_, trgt_version_);
+  zmsg_t *tmp = msg_;
+  msg_ = nullptr;
+  return tmp;
+}
+
+// frame marker indicating this frame is serialize like printf
+#define FMARKER "*singa*"
+
+#define kMaxFrameLen 2048
+
+int Msg::AddFormatFrame(const char *format, ...) {
+  va_list argptr;
+  va_start(argptr, format);
+  int size = strlen(FMARKER);
+  char dst[kMaxFrameLen];
+  memcpy(dst, FMARKER, size);
+  dst[size++] = 0;
+  while (*format) {
+    if (*format == 'i') {
+      int x = va_arg(argptr, int);
+      dst[size++] = 'i';
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == 'f') {
+      float x = static_cast<float> (va_arg(argptr, double));
+      dst[size++] = 'f';
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == '1') {
+      uint8_t x = va_arg(argptr, int);
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == '2') {
+      uint16_t x = va_arg(argptr, int);
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == '4') {
+      uint32_t x = va_arg(argptr, uint32_t);
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else if (*format == 's') {
+      char* x = va_arg(argptr, char *);
+      dst[size++] = 's';
+      memcpy(dst + size, x, strlen(x));
+      size += strlen(x);
+      dst[size++] = 0;
+    } else if (*format == 'p') {
+      void* x = va_arg(argptr, void *);
+      dst[size++] = 'p';
+      memcpy(dst + size, &x, sizeof(x));
+      size += sizeof(x);
+    } else {
+      LOG(ERROR) << "Unknown format " << *format;
+    }
+    format++;
+    CHECK_LE(size, kMaxFrameLen);
+  }
+  va_end(argptr);
+  zmsg_addmem(msg_, dst, size);
+  return size;
+}
+
+int Msg::ParseFormatFrame(const char *format, ...) {
+  va_list argptr;
+  va_start(argptr, format);
+  char* src = zframe_strdup(frame_);
+  CHECK_STREQ(FMARKER, src);
+  int size = strlen(FMARKER) + 1;
+  while (*format) {
+    if (*format == 'i') {
+      int *x = va_arg(argptr, int *);
+      CHECK_EQ(src[size++], 'i');
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == 'f') {
+      float *x = va_arg(argptr, float *);
+      CHECK_EQ(src[size++], 'f');
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == '1') {
+      uint8_t *x = va_arg(argptr, uint8_t *);
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == '2') {
+      uint16_t *x = va_arg(argptr, uint16_t *);
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == '4') {
+      uint32_t *x = va_arg(argptr, uint32_t *);
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else if (*format == 's') {
+      char* x = va_arg(argptr, char *);
+      CHECK_EQ(src[size++], 's');
+      int len = strlen(src + size);
+      memcpy(x, src + size, len);
+      x[len] = 0;
+      size += len + 1;
+    } else if (*format == 'p') {
+      void** x = va_arg(argptr, void **);
+      CHECK_EQ(src[size++], 'p');
+      memcpy(x, src + size, sizeof(*x));
+      size += sizeof(*x);
+    } else {
+      LOG(ERROR) << "Unknown format type " << *format;
+    }
+    format++;
+  }
+  va_end(argptr);
+  delete src;
+  return size;
+}
+#endif
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/comm/socket.cc
----------------------------------------------------------------------
diff --git a/src/comm/socket.cc b/src/comm/socket.cc
new file mode 100644
index 0000000..b9c7810
--- /dev/null
+++ b/src/comm/socket.cc
@@ -0,0 +1,180 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+#include "comm/socket.h"
+
+#include <glog/logging.h>
+
+namespace singa {
+
+#ifdef USE_ZMQ
+Poller::Poller() {
+  poller_ = zpoller_new(nullptr);
+}
+
+Poller::Poller(SocketInterface* socket) {
+  poller_ = zpoller_new(nullptr);
+  Add(socket);
+}
+
+void Poller::Add(SocketInterface* socket) {
+  zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID());
+  zpoller_add(poller_, zsock);
+  zsock2Socket_[zsock] = socket;
+}
+
+SocketInterface* Poller::Wait(int timeout) {
+  zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout));
+  if (sock != nullptr)
+    return zsock2Socket_[sock];
+  else
+  return nullptr;
+}
+
+bool Poller::Terminated() {
+  return zpoller_terminated(poller_);
+}
+
+
+Dealer::Dealer() : Dealer(-1) {}
+
+Dealer::Dealer(int id) : id_(id) {
+  dealer_ = zsock_new(ZMQ_DEALER);
+  CHECK_NOTNULL(dealer_);
+}
+
+Dealer::~Dealer() {
+  zsock_destroy(&dealer_);
+}
+
+int Dealer::Connect(const std::string& endpoint) {
+  CHECK_GT(endpoint.length(), 0);
+  if (endpoint.length()) {
+    CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0);
+    return 1;
+  }
+  return 0;
+}
+
+int Dealer::Send(Msg** msg) {
+  zmsg_t* zmsg = (*msg)->DumpToZmsg();
+  zmsg_send(&zmsg, dealer_);
+  delete *msg;
+  *msg = nullptr;
+  return 1;
+}
+
+Msg* Dealer::Receive() {
+  zmsg_t* zmsg = zmsg_recv(dealer_);
+  if (zmsg == nullptr)
+    return nullptr;
+  Msg* msg = new Msg();
+  msg->ParseFromZmsg(zmsg);
+  return msg;
+}
+
+void* Dealer::InternalID() const {
+  return dealer_;
+}
+
+Router::Router() : Router(100) {}
+
+Router::Router(int bufsize) {
+  nBufmsg_ = 0;
+  bufsize_ = bufsize;
+  router_ = zsock_new(ZMQ_ROUTER);
+  CHECK_NOTNULL(router_);
+  poller_ = zpoller_new(router_);
+  CHECK_NOTNULL(poller_);
+}
+
+Router::~Router() {
+  zsock_destroy(&router_);
+  for (auto it : id2addr_)
+    zframe_destroy(&it.second);
+  for (auto it : bufmsg_) {
+    for (auto *msg : it.second)
+      zmsg_destroy(&msg);
+  }
+}
+int Router::Bind(const std::string& endpoint) {
+  int port = -1;
+  if (endpoint.length()) {
+    port = zsock_bind(router_, "%s", endpoint.c_str());
+  }
+  CHECK_NE(port, -1) << endpoint;
+  LOG(INFO) << "bind successfully to " << endpoint + ":" + std::to_string(port);
+  return port;
+}
+
+int Router::Send(Msg **msg) {
+  zmsg_t* zmsg = (*msg)->DumpToZmsg();
+  int dstid = (*msg)->dst();
+  if (id2addr_.find(dstid) != id2addr_.end()) {
+    // the connection has already been set up
+    zframe_t* addr = zframe_dup(id2addr_[dstid]);
+    zmsg_prepend(zmsg, &addr);
+    zmsg_send(&zmsg, router_);
+  } else {
+    // the connection is not ready, buffer the message
+    if (bufmsg_.size() == 0)
+      nBufmsg_ = 0;
+    bufmsg_[dstid].push_back(zmsg);
+    ++nBufmsg_;
+    CHECK_LE(nBufmsg_, bufsize_);
+  }
+  delete *msg;
+  *msg = nullptr;
+  return 1;
+}
+
+Msg* Router::Receive() {
+  zmsg_t* zmsg = zmsg_recv(router_);
+  if (zmsg == nullptr) {
+    LOG(ERROR) << "Connection broken!";
+    exit(0);
+  }
+  zframe_t* dealer = zmsg_pop(zmsg);
+  Msg* msg = new Msg();
+  msg->ParseFromZmsg(zmsg);
+  if (id2addr_.find(msg->src()) == id2addr_.end()) {
+    // new connection, store the sender's identfier and send buffered messages
+    // for it
+    id2addr_[msg->src()] = dealer;
+    if (bufmsg_.find(msg->src()) != bufmsg_.end()) {
+      for (auto& it : bufmsg_.at(msg->src())) {
+        zframe_t* addr = zframe_dup(dealer);
+        zmsg_prepend(it, &addr);
+        zmsg_send(&it, router_);
+      }
+      bufmsg_.erase(msg->src());
+    }
+  } else {
+    zframe_destroy(&dealer);
+  }
+  return msg;
+}
+
+void* Router::InternalID() const {
+  return router_;
+}
+#endif
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/communication/msg.cc
----------------------------------------------------------------------
diff --git a/src/communication/msg.cc b/src/communication/msg.cc
deleted file mode 100644
index 6042057..0000000
--- a/src/communication/msg.cc
+++ /dev/null
@@ -1,215 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "communication/msg.h"
-
-#include <glog/logging.h>
-
-namespace singa {
-
-#ifdef USE_ZMQ
-Msg::~Msg() {
-  if (msg_ != nullptr)
-    zmsg_destroy(&msg_);
-  frame_ = nullptr;
-}
-
-Msg::Msg() {
-  msg_ = zmsg_new();
-}
-
-Msg::Msg(const Msg& msg) {
-  src_ = msg.src_;
-  dst_ = msg.dst_;
-  type_ = msg.type_;
-  trgt_val_ = msg.trgt_val_;
-  trgt_version_ = msg.trgt_version_;
-  msg_ = zmsg_dup(msg.msg_);
-}
-
-Msg::Msg(int src, int dst) {
-  src_ = src;
-  dst_ = dst;
-  msg_ = zmsg_new();
-}
-
-void Msg::SwapAddr() {
-  std::swap(src_, dst_);
-}
-
-int Msg::size() const {
-  return zmsg_content_size(msg_);
-}
-
-void Msg::AddFrame(const void* addr, int nBytes) {
-  zmsg_addmem(msg_, addr, nBytes);
-}
-
-int Msg::FrameSize() {
-  return zframe_size(frame_);
-}
-
-void* Msg::FrameData() {
-  return zframe_data(frame_);
-}
-
-char* Msg::FrameStr() {
-  return zframe_strdup(frame_);
-}
-bool Msg::NextFrame() {
-  frame_ = zmsg_next(msg_);
-  return frame_ != nullptr;
-}
-
-void Msg::FirstFrame() {
-  frame_ = zmsg_first(msg_);
-}
-
-void Msg::LastFrame() {
-  frame_ = zmsg_last(msg_);
-}
-
-void Msg::ParseFromZmsg(zmsg_t* msg) {
-  char* tmp = zmsg_popstr(msg);
-  sscanf(tmp, "%d %d %d %d %d",
-         &src_, &dst_, &type_, &trgt_val_, &trgt_version_);
-  frame_ = zmsg_first(msg);
-  msg_ = msg;
-}
-
-zmsg_t* Msg::DumpToZmsg() {
-  zmsg_pushstrf(msg_, "%d %d %d %d %d",
-      src_, dst_, type_, trgt_val_, trgt_version_);
-  zmsg_t *tmp = msg_;
-  msg_ = nullptr;
-  return tmp;
-}
-
-// frame marker indicating this frame is serialize like printf
-#define FMARKER "*singa*"
-
-#define kMaxFrameLen 2048
-
-int Msg::AddFormatFrame(const char *format, ...) {
-  va_list argptr;
-  va_start(argptr, format);
-  int size = strlen(FMARKER);
-  char dst[kMaxFrameLen];
-  memcpy(dst, FMARKER, size);
-  dst[size++] = 0;
-  while (*format) {
-    if (*format == 'i') {
-      int x = va_arg(argptr, int);
-      dst[size++] = 'i';
-      memcpy(dst + size, &x, sizeof(x));
-      size += sizeof(x);
-    } else if (*format == 'f') {
-      float x = static_cast<float> (va_arg(argptr, double));
-      dst[size++] = 'f';
-      memcpy(dst + size, &x, sizeof(x));
-      size += sizeof(x);
-    } else if (*format == '1') {
-      uint8_t x = va_arg(argptr, int);
-      memcpy(dst + size, &x, sizeof(x));
-      size += sizeof(x);
-    } else if (*format == '2') {
-      uint16_t x = va_arg(argptr, int);
-      memcpy(dst + size, &x, sizeof(x));
-      size += sizeof(x);
-    } else if (*format == '4') {
-      uint32_t x = va_arg(argptr, uint32_t);
-      memcpy(dst + size, &x, sizeof(x));
-      size += sizeof(x);
-    } else if (*format == 's') {
-      char* x = va_arg(argptr, char *);
-      dst[size++] = 's';
-      memcpy(dst + size, x, strlen(x));
-      size += strlen(x);
-      dst[size++] = 0;
-    } else if (*format == 'p') {
-      void* x = va_arg(argptr, void *);
-      dst[size++] = 'p';
-      memcpy(dst + size, &x, sizeof(x));
-      size += sizeof(x);
-    } else {
-      LOG(ERROR) << "Unknown format " << *format;
-    }
-    format++;
-    CHECK_LE(size, kMaxFrameLen);
-  }
-  va_end(argptr);
-  zmsg_addmem(msg_, dst, size);
-  return size;
-}
-
-int Msg::ParseFormatFrame(const char *format, ...) {
-  va_list argptr;
-  va_start(argptr, format);
-  char* src = zframe_strdup(frame_);
-  CHECK_STREQ(FMARKER, src);
-  int size = strlen(FMARKER) + 1;
-  while (*format) {
-    if (*format == 'i') {
-      int *x = va_arg(argptr, int *);
-      CHECK_EQ(src[size++], 'i');
-      memcpy(x, src + size, sizeof(*x));
-      size += sizeof(*x);
-    } else if (*format == 'f') {
-      float *x = va_arg(argptr, float *);
-      CHECK_EQ(src[size++], 'f');
-      memcpy(x, src + size, sizeof(*x));
-      size += sizeof(*x);
-    } else if (*format == '1') {
-      uint8_t *x = va_arg(argptr, uint8_t *);
-      memcpy(x, src + size, sizeof(*x));
-      size += sizeof(*x);
-    } else if (*format == '2') {
-      uint16_t *x = va_arg(argptr, uint16_t *);
-      memcpy(x, src + size, sizeof(*x));
-      size += sizeof(*x);
-    } else if (*format == '4') {
-      uint32_t *x = va_arg(argptr, uint32_t *);
-      memcpy(x, src + size, sizeof(*x));
-      size += sizeof(*x);
-    } else if (*format == 's') {
-      char* x = va_arg(argptr, char *);
-      CHECK_EQ(src[size++], 's');
-      int len = strlen(src + size);
-      memcpy(x, src + size, len);
-      x[len] = 0;
-      size += len + 1;
-    } else if (*format == 'p') {
-      void** x = va_arg(argptr, void **);
-      CHECK_EQ(src[size++], 'p');
-      memcpy(x, src + size, sizeof(*x));
-      size += sizeof(*x);
-    } else {
-      LOG(ERROR) << "Unknown format type " << *format;
-    }
-    format++;
-  }
-  va_end(argptr);
-  delete src;
-  return size;
-}
-#endif
-
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/communication/socket.cc
----------------------------------------------------------------------
diff --git a/src/communication/socket.cc b/src/communication/socket.cc
deleted file mode 100644
index 60e1cc1..0000000
--- a/src/communication/socket.cc
+++ /dev/null
@@ -1,180 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-#include "communication/socket.h"
-
-#include <glog/logging.h>
-
-namespace singa {
-
-#ifdef USE_ZMQ
-Poller::Poller() {
-  poller_ = zpoller_new(nullptr);
-}
-
-Poller::Poller(SocketInterface* socket) {
-  poller_ = zpoller_new(nullptr);
-  Add(socket);
-}
-
-void Poller::Add(SocketInterface* socket) {
-  zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID());
-  zpoller_add(poller_, zsock);
-  zsock2Socket_[zsock] = socket;
-}
-
-SocketInterface* Poller::Wait(int timeout) {
-  zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout));
-  if (sock != nullptr)
-    return zsock2Socket_[sock];
-  else
-  return nullptr;
-}
-
-bool Poller::Terminated() {
-  return zpoller_terminated(poller_);
-}
-
-
-Dealer::Dealer() : Dealer(-1) {}
-
-Dealer::Dealer(int id) : id_(id) {
-  dealer_ = zsock_new(ZMQ_DEALER);
-  CHECK_NOTNULL(dealer_);
-}
-
-Dealer::~Dealer() {
-  zsock_destroy(&dealer_);
-}
-
-int Dealer::Connect(const std::string& endpoint) {
-  CHECK_GT(endpoint.length(), 0);
-  if (endpoint.length()) {
-    CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0);
-    return 1;
-  }
-  return 0;
-}
-
-int Dealer::Send(Msg** msg) {
-  zmsg_t* zmsg = (*msg)->DumpToZmsg();
-  zmsg_send(&zmsg, dealer_);
-  delete *msg;
-  *msg = nullptr;
-  return 1;
-}
-
-Msg* Dealer::Receive() {
-  zmsg_t* zmsg = zmsg_recv(dealer_);
-  if (zmsg == nullptr)
-    return nullptr;
-  Msg* msg = new Msg();
-  msg->ParseFromZmsg(zmsg);
-  return msg;
-}
-
-void* Dealer::InternalID() const {
-  return dealer_;
-}
-
-Router::Router() : Router(100) {}
-
-Router::Router(int bufsize) {
-  nBufmsg_ = 0;
-  bufsize_ = bufsize;
-  router_ = zsock_new(ZMQ_ROUTER);
-  CHECK_NOTNULL(router_);
-  poller_ = zpoller_new(router_);
-  CHECK_NOTNULL(poller_);
-}
-
-Router::~Router() {
-  zsock_destroy(&router_);
-  for (auto it : id2addr_)
-    zframe_destroy(&it.second);
-  for (auto it : bufmsg_) {
-    for (auto *msg : it.second)
-      zmsg_destroy(&msg);
-  }
-}
-int Router::Bind(const std::string& endpoint) {
-  int port = -1;
-  if (endpoint.length()) {
-    port = zsock_bind(router_, "%s", endpoint.c_str());
-  }
-  CHECK_NE(port, -1) << endpoint;
-  LOG(INFO) << "bind successfully to " << endpoint + ":" + std::to_string(port);
-  return port;
-}
-
-int Router::Send(Msg **msg) {
-  zmsg_t* zmsg = (*msg)->DumpToZmsg();
-  int dstid = (*msg)->dst();
-  if (id2addr_.find(dstid) != id2addr_.end()) {
-    // the connection has already been set up
-    zframe_t* addr = zframe_dup(id2addr_[dstid]);
-    zmsg_prepend(zmsg, &addr);
-    zmsg_send(&zmsg, router_);
-  } else {
-    // the connection is not ready, buffer the message
-    if (bufmsg_.size() == 0)
-      nBufmsg_ = 0;
-    bufmsg_[dstid].push_back(zmsg);
-    ++nBufmsg_;
-    CHECK_LE(nBufmsg_, bufsize_);
-  }
-  delete *msg;
-  *msg = nullptr;
-  return 1;
-}
-
-Msg* Router::Receive() {
-  zmsg_t* zmsg = zmsg_recv(router_);
-  if (zmsg == nullptr) {
-    LOG(ERROR) << "Connection broken!";
-    exit(0);
-  }
-  zframe_t* dealer = zmsg_pop(zmsg);
-  Msg* msg = new Msg();
-  msg->ParseFromZmsg(zmsg);
-  if (id2addr_.find(msg->src()) == id2addr_.end()) {
-    // new connection, store the sender's identfier and send buffered messages
-    // for it
-    id2addr_[msg->src()] = dealer;
-    if (bufmsg_.find(msg->src()) != bufmsg_.end()) {
-      for (auto& it : bufmsg_.at(msg->src())) {
-        zframe_t* addr = zframe_dup(dealer);
-        zmsg_prepend(it, &addr);
-        zmsg_send(&it, router_);
-      }
-      bufmsg_.erase(msg->src());
-    }
-  } else {
-    zframe_destroy(&dealer);
-  }
-  return msg;
-}
-
-void* Router::InternalID() const {
-  return router_;
-}
-#endif
-
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 6fa70ee..d3f0f3e 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -19,16 +19,17 @@
 *
 *************************************************************/
 
-#include "driver.h"
-
 #include <glog/logging.h>
+#include <set>
 #include <string>
 #include "neuralnet/layer.h"
-#include "trainer/trainer.h"
 #include "utils/common.h"
 #include "utils/tinydir.h"
+#include "utils/cluster.h"
+#include "./stub.h"
+#include "./driver.h"
 
-extern "C" void openblas_set_num_threads(int);
+extern "C" void openblas_set_num_threads(int num);
 
 namespace singa {
 
@@ -109,22 +110,192 @@ void Driver::Init(int argc, char **argv) {
 }
 
 
-void Driver::Submit(bool resume, const JobProto& jobConf) {
+void Driver::Train(bool resume, const JobProto& job_conf) {
+  Cluster::Setup(job_id_, singa_conf_, job_conf.cluster());
   if (singa_conf_.has_log_dir())
-    SetupLog(singa_conf_.log_dir(), std::to_string(job_id_)
-             + "-" + jobConf.name());
+    SetupLog(singa_conf_.log_dir(),
+        std::to_string(job_id_) + "-" + job_conf.name());
   tinydir_dir workspace;
-  if (tinydir_open(&workspace, jobConf.cluster().workspace().c_str()) == -1)
-    LOG(FATAL) << "workspace does not exist: " << jobConf.cluster().workspace();
-  if (jobConf.num_openblas_threads() != 1)
-    LOG(WARNING) << "openblas with "
-                 << jobConf.num_openblas_threads() << " threads";
-  openblas_set_num_threads(jobConf.num_openblas_threads());
+  if (tinydir_open(&workspace, job_conf.cluster().workspace().c_str()) == -1)
+    LOG(FATAL) << "workspace not exist: " << job_conf.cluster().workspace();
+  if (job_conf.num_openblas_threads() != 1)
+    LOG(WARNING) << "openblas luanches "
+                 << job_conf.num_openblas_threads() << " threads";
+  openblas_set_num_threads(job_conf.num_openblas_threads());
+
   JobProto job;
-  job.CopyFrom(jobConf);
+  job.CopyFrom(job_conf);
+  if (resume)
+    SetupForResume(&job);
   job.set_id(job_id_);
-  Trainer trainer;
-  trainer.Start(resume, singa_conf_, &job);
+  Train(job);
 }
 
+void Driver::Train(const JobProto& job_conf) {
+  auto cluster = Cluster::Get();
+  int nserver_grps = cluster->nserver_groups();
+  int grp_size = cluster->nworkers_per_group();
+  Stub stub;
+  // no need to create Stub if there is only a single worker without servers,
+  // i.e., the training will be conducted by the single worker.
+  if (grp_size > 1 || nserver_grps > 0) {
+    stub.Setup();
+    // TODO(wangwei)  register endpoint to zookeeper if > 1 procs;
+    cluster->Register(getpid(), stub.endpoint());  // getpid() is from unistd.h
+  }
+
+  NeuralNet* net = NeuralNet::Create(job_conf.neuralnet(), kTrain, grp_size);
+  const vector<Worker*> workers = CreateWorkers(job_conf, net);
+  const vector<Server*> servers = CreateServers(job_conf, net);
+
+#ifdef USE_MPI
+  int nthreads = workers.size() + servers.size() + 1;
+  for (int i = 0; i < nthreads; i++)
+    MPIQueues.push_back(make_shared<SafeQueue>());
+#endif
+
+  vector<std::thread> threads;
+  for (auto server : servers)
+    threads.push_back(std::thread(&Server::Run, server));
+  for (auto worker : workers)
+    threads.push_back(std::thread(&Worker::Run, worker));
+  if (grp_size > 1 || nserver_grps > 0) {
+    int nservers_per_grp = cluster->nservers_per_group();
+    int lcm = LeastCommonMultiple(nservers_per_grp, nserver_grps);
+    auto slices = Param::ComputeSlices(lcm, net->params());
+    auto slice2server = PartitionSlices(nservers_per_grp, slices);
+    stub.Run(slice2server, workers, servers);
+  }
+
+  for (auto& thread : threads)
+    thread.join();
+  for (auto server : servers)
+    delete server;
+  delete net;
+  std::set<NeuralNet*> deleted{net, nullptr};
+  for (auto worker : workers) {
+    for (auto ptr : worker->GetNets())
+    if (deleted.find(ptr) == deleted.end()) {
+      delete ptr;
+      deleted.insert(ptr);
+    }
+    delete worker;
+  }
+}
+
+void Driver::SetupForResume(JobProto* job_conf) {
+  tinydir_dir dir;
+  std::string folder = Cluster::Get()->checkpoint_folder();
+  tinydir_open(&dir, folder.c_str());
+  int latest_step = 0;
+  // there would be multi checkpoint files (from diff workers) for one step
+  vector<std::string> ck_files;
+  // iterate all files to get the files for the last checkpoint
+  while (dir.has_next) {
+    tinydir_file file;
+    tinydir_readfile(&dir, &file);
+    tinydir_next(&dir);
+    char* ch = strstr(file.name, "step");
+    if (ch == nullptr) {
+      if (file.name[0] != '.')
+        LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
+      continue;
+    }
+    LOG(INFO) << "Add checkpoint file for resume: " << ch;
+    int step = atoi(ch+4);
+    if (step == latest_step) {
+      ck_files.push_back(file.name);
+    } else if (step > latest_step) {
+      latest_step = step;
+      ck_files.clear();
+      ck_files.push_back(std::string(file.name));
+    }
+  }
+  if (latest_step > 0) {
+    job_conf->set_step(latest_step);
+    if (!job_conf->has_reset_param_version())
+      job_conf->set_reset_param_version(false);
+    job_conf->clear_checkpoint_path();
+    for (auto ck_file : ck_files)
+      job_conf->add_checkpoint_path(folder + "/" + ck_file);
+  }
+  tinydir_close(&dir);
+}
+
+const vector<Worker*> Driver::CreateWorkers(const JobProto& job_conf,
+    NeuralNet* net) {
+  auto cluster = Cluster::Get();
+  vector<Worker*> workers;
+  if (!cluster->has_worker()) return workers;
+  int wgrp_size = cluster->nworkers_per_group();
+  int nservers_per_grp = cluster->nservers_per_group();
+  int nserver_grps = cluster->nserver_groups();
+  int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
+  const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(),
+      cluster->nworkers_per_group(), cluster->nworkers_per_procs());
+  int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3];
+  for (int gid = gstart; gid < gend; gid++) {
+    NeuralNet* train_net = nullptr, *test_net = nullptr, *val_net = nullptr;
+    if (gid == gstart) {
+      train_net = net;
+      Param::SliceParams(lcm, train_net->params());
+      // test and validation are performed by the 1st group.
+      if (gid == 0 && job_conf.test_steps() > 0) {
+        test_net = NeuralNet::Create(job_conf.neuralnet(), kTest, 1);
+        test_net->ShareParamsFrom(train_net);
+      }
+      if (gid == 0 && job_conf.validate_steps() > 0) {
+        val_net = NeuralNet::Create(job_conf.neuralnet(), kVal, 1);
+        val_net->ShareParamsFrom(train_net);
+      }
+    } else {
+      train_net = NeuralNet::Create(job_conf.neuralnet(), kTrain, wgrp_size);
+      if (cluster->share_memory()) {
+        train_net->ShareParamsFrom(net);
+      } else {
+        Param::SliceParams(lcm, train_net->params());
+      }
+    }
+    for (int wid = wstart; wid < wend; wid++) {
+      auto *worker = Worker::Create(job_conf.train_one_batch());
+      // TODO(wangwei) extend to test among workers in a grp
+      if (wid == 0)
+        worker->Setup(gid, wid, job_conf, train_net, val_net, test_net);
+      else
+        worker->Setup(gid, wid, job_conf, train_net, nullptr, nullptr);
+      workers.push_back(worker);
+    }
+  }
+  return workers;
+}
+
+const vector<Server*> Driver::CreateServers(const JobProto& job_conf,
+    NeuralNet* net) {
+  auto cluster = Cluster::Get();
+  vector<Server*> servers;
+  if (!cluster->has_server()) return servers;
+  int nservers_per_grp = cluster->nservers_per_group();
+  int nserver_grps = cluster->nserver_groups();
+  int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
+  auto slices = Param::ComputeSlices(lcm, net->params());
+  // partition among server groups, each group maintains one sub-set for sync
+  auto slice2group = PartitionSlices(nserver_grps, slices);
+  // partition within one server group, each server updates for one sub-set
+  auto slice2server = PartitionSlices(nservers_per_grp, slices);
+
+  int server_procs = cluster->procs_id();
+  // if true, server procs (logical) id starts after worker procs
+  if (cluster->server_worker_separate())
+    server_procs -= cluster->nworker_procs();
+  const vector<int> rng = cluster->ExecutorRng(server_procs,
+      cluster->nservers_per_group(), cluster->nservers_per_procs());
+  int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
+  for (int gid = gstart; gid < gend; gid++) {
+    for (int sid = start; sid < end; sid++) {
+      auto server = new Server(gid, sid, job_conf, slice2group, slice2server);
+      servers.push_back(server);
+    }
+  }
+  return servers;
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/main.cc
----------------------------------------------------------------------
diff --git a/src/main.cc b/src/main.cc
index 5d2ab2f..99c91b8 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,9 +19,9 @@
 *
 *************************************************************/
 
-#include "singa.h"
+#include "./singa.h"
 /**
- * \file main.cc provides an example main func.
+ * \file main.cc provides an example main function.
  *
  * Like the main func of Hadoop, it prepares the job configuration and submit it
  * to the Driver which starts the training.
@@ -31,19 +31,17 @@
  * func must call Driver::Init at the beginning, and pass the job configuration
  * and resume option to the Driver for job submission.
  *
- * Optionally, users can register their own implemented classes, e.g., layer,
- * updater, through the registration func provided by the Driver.
+ * Optionally, users can register their own implemented subclasses of Layer,
+ * Updater, etc. through the registration function provided by the Driver.
  *
  * Users must pass at least one argument to the singa-run.sh, i.e., the job
  * configuration file which includes the cluster topology setting. Other fields
  * e.g, neuralnet, updater can be configured in main.cc.
  *
  * TODO
- * Add helper functions for users to generate their configurations easily.
- * e.g., AddLayer(layer_type, source_layers, meta_data),
- * or, MLP(layer1_size, layer2_size, tanh, loss);
+ * Add helper functions for users to generate configurations for popular models
+ * easily, e.g., MLP(layer1_size, layer2_size, tanh, loss);
  */
-
 int main(int argc, char **argv) {
   // must create driver at the beginning and call its Init method.
   singa::Driver driver;
@@ -58,7 +56,7 @@ int main(int argc, char **argv) {
   // get the job conf, and custmize it if need
   singa::JobProto jobConf = driver.job_conf();
 
-  // submit the job
-  driver.Submit(resume, jobConf);
+  // submit the job for training
+  driver.Train(resume, jobConf);
   return 0;
 }


Mime
View raw message