singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject incubator-singa git commit: SINGA-57 Improve Distributed Hogwild
Date Tue, 15 Sep 2015 10:06:52 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master d5d817e14 -> ed9e37369


SINGA-57 Improve Distributed Hogwild

The ClusterProto::sync_freq field controls the frequency of sync between
server groups.
After updating of Param (slice), the server checks the num of updates
since last sync. It also checks the num of pending syncs (i.e., requests
haven't received reponses) to avoid sending too many msgs to stopped
servers (the msgs would be occupy the memory of the sending buffer)
The server respones to every sync requests with the latest Param values.

Note: current does not support (there is bug) multiple worker groups in
one process for the distributed hogwild framework. We recommend to
replace this cluster topology with in-memory hogwild, i.e., launching
one worker group with multiple workers and one server group.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/ed9e3736
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/ed9e3736
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/ed9e3736

Branch: refs/heads/master
Commit: ed9e37369c69dd76078e8285bc33d6b04ba60e9f
Parents: d5d817e
Author: Wei Wang <wangwei@comp.nus.edu.sg>
Authored: Tue Sep 15 15:51:28 2015 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Tue Sep 15 16:03:37 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/job.conf |   2 +-
 include/trainer/server.h  |  62 +++++++++++----
 include/trainer/trainer.h |  10 ---
 include/utils/cluster.h   |  16 ++--
 src/proto/common.proto    |   1 -
 src/proto/job.proto       |   7 +-
 src/trainer/server.cc     | 167 ++++++++++++++++++-----------------------
 src/trainer/trainer.cc    | 124 ++++++++----------------------
 src/utils/cluster.cc      |  21 ++++++
 9 files changed, 190 insertions(+), 220 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/examples/cifar10/job.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/job.conf b/examples/cifar10/job.conf
index b36c45a..343d969 100644
--- a/examples/cifar10/job.conf
+++ b/examples/cifar10/job.conf
@@ -2,7 +2,7 @@ name: "cifar10-convnet"
 train_steps: 1000
 test_steps: 100
 test_freq:300
-disp_freq:30
+disp_freq: 30
 train_one_batch {
   alg: kBP
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index 8cc37c5..869d10a 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -23,8 +23,8 @@ class Server{
   Server(int thread_id, int group_id, int server_id);
   virtual ~Server();
   void Setup(const UpdaterProto& proto,
-      std::unordered_map<int, ParamEntry*>* shard,
-      const std::vector<int>& slice2group);
+      const std::vector<int>& slice2group,
+      const std::vector<int>& slice2server);
   void Run();
   const int grp_id() const {
     return grp_id_;
@@ -38,13 +38,30 @@ class Server{
  	/**
 	 * Process GET request.
    *
-   * @return the orignal message or response message
+   * @return the orignal message or a response message which contains the values
+   * of the Param with the request version.
    */
 	virtual Msg* HandleGet(Msg** msg);
 
 	/**
 	 * Process Update request.
    *
+   * It waits until received the gradients from all workers from the same worker
+   * group. After updating, it responses to each sender with the new Param
+   * values. It may generate a sync message to the server group that maintains
+   * the global version of the updated Param (slice).
+   *
+   * Note: there is no counter for each worker group on the number of received
+   * update requests. Hence it is possible that the server would conduct the
+   * update when it receives x requests from group a and y requests from group
+   * b where x + y = group size. To avoid this problem, we can
+   * 1. maintain request list for each group for each Param at the server side
+   * 2. do not span a worker group among multiple nodes. then the updates from
+   * the same group would be locally aggregated on the worker node. And the
+   * server would conduct the update immediately after receiving the aggregated
+   * request.
+   * 3. launch only one worker group.
+   *
    * @return the orignal message or response message
    */
   const std::vector<Msg*> HandleUpdate(Msg **msg);
@@ -52,30 +69,47 @@ class Server{
 	/**
 	 * Process PUT request.
    *
-   * @return the original message or response message. If we don't want need to
+   * @return the original message or response message. If we don't want to
    * acknowledge the put request, then return nullptr.
 	 */
 	virtual Msg* HandlePut(Msg **msg);
 
 	/**
-   * TODO Process SYNC request.
-	 */
+   * Handle sync request from other server groups.
+   *
+   * It adds updates of Param (slice) from other server groups directly to
+   * local Param (slice). Currently, each Param (slice) has a master group,
+   * i.e., slice2group_[sliceid], which would receive such requests from all
+   * other server groups for the Param object.
+   *
+   * @param msg request msg containing the parameter updates
+   * @return response msg that contains the fresh parameter values.
+   */
 	virtual Msg* HandleSyncRequest(Msg** msg);
 
   /**
-   * Generate sync message which sends local mastered Param slice to other
-   * server groups
-   * @param param slice to be sync with others
-   * @return sync messages
+   * Handle sync response.
+   *
+   * The response msg includes the latest values of a Param object, for which
+   * this server sent the sync request to the master/maintainer group.
+   * The local Param values are replaced with the addition result of local
+   * udpates since the sync request was sent and the received Param values.
+   *
+   * @param response message
    */
-  const std::vector<Msg*> GenSyncMsgs(Param* param);
+  void HandleSyncResponse(Msg** msg);
 
  protected:
   int thread_id_,grp_id_, id_;
   Updater* updater_;
-  std::unordered_map<int, ParamEntry*> *shard_;
-  std::vector<int> slice2group_;
-  std::unordered_map<int, std::shared_ptr<Blob<float>>> last_data_;
+  //!< map from slice ID to slice and deleted in the destructor
+  std::unordered_map<int, ParamEntry*> shard_;
+  std::vector<int> slice2group_, slice2server_;
+  //!< num of updates from last sync with master server group for a param/slice
+  std::vector<int> nUpdates_;
+  //!< num of sync requests that have not been responded
+  std::vector<int> nPendingSync_;
+  std::vector<Blob<float>> last_sync_;
   std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
 };
 } /* Server */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index 8be5269..ed50705 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -80,14 +80,6 @@ class Trainer{
 
   void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
   /**
-   * Generate msg to trigger synchronization with other server groups.
-   *
-   * @param server the local server index whom the message is sent to
-   * @param servers all local servers
-   * @return sync msg
-   */
-  Msg* GenSyncReminderMsg(int server, const vector<Server*>& servers);
-  /**
    * Display metrics to log (standard output)
    */
   void DisplayMetric(Msg** msg);
@@ -143,8 +135,6 @@ class Trainer{
   int procs_id_;
   Router *router_;
   std::unordered_map<int, ParamEntry*> worker_shard_;
-  //!< map from slice ID to slice, used by servers and deleted in the destructor
-  std::unordered_map<int, ParamEntry*> server_shard_;
   //!< map from slice to the server that updates it
   vector<int> slice2server_;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/include/utils/cluster.h
----------------------------------------------------------------------
diff --git a/include/utils/cluster.h b/include/utils/cluster.h
index be0e0de..73474af 100644
--- a/include/utils/cluster.h
+++ b/include/utils/cluster.h
@@ -90,12 +90,8 @@ class Cluster {
   const int worker_timeout() const { return cluster_.worker_timeout(); }
   const int server_timeout() const { return cluster_.server_timeout(); }
   */
-  inline bool server_update() const { return cluster_.server_update(); }
   inline bool share_memory() const { return cluster_.share_memory(); }
-  /**
-   * bandwidth Bytes/s
-   */
-  inline int bandwidth() const { return cluster_.bandwidth(); }
+  inline int sync_freq() const { return cluster_.sync_freq(); }
   inline int poll_time() const { return cluster_.poll_time(); }
   ClusterRuntime* runtime() const { return cluster_rt_; }
 
@@ -106,6 +102,16 @@ class Cluster {
     return procs_ids_.at(Hash(group_id, id, flag));
   }
   inline std::string hostip() const { return hostip_; }
+
+  /**
+   * @param pid, processs ID
+   * @param group_size, num of executors in a group
+   * @param procs_size, num of executors in a procs
+   *
+   * @return a vector with 4 integers:
+   * [group start, group end), [start executor, end executor)
+   */
+  const std::vector<int> ExecutorRng(int pid, int group_size, int procs_size);
   /**
    * Register this process.
    *

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/src/proto/common.proto
----------------------------------------------------------------------
diff --git a/src/proto/common.proto b/src/proto/common.proto
index 671d6ad..e166299 100644
--- a/src/proto/common.proto
+++ b/src/proto/common.proto
@@ -13,7 +13,6 @@ enum MsgType {
   kRUpdate = 9;
   kConnect = 10;
   kMetric = 11;
-  kSyncReminder = 12;
 };
 
 enum EntityType {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 7861eae..80998e1 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -129,15 +129,14 @@ message ClusterProto {
   // servers and workers in different processes?
   optional bool server_worker_separate = 20 [default = false];
 
+  // sync frequency between server groups
+  optional int32 sync_freq = 21 [default = 1];
+
   // port number used by ZeroMQ
   optional int32 start_port = 60 [default = 6723];
-  // conduct updates at server side; otherwise do it at worker side
-  optional bool server_update = 61 [default = true];
   // share memory space between worker groups in one procs
   optional bool share_memory = 62 [default = true];
 
-  // bandwidth of ethernet, Bytes per second, default is 1 Gbps
-  optional int32 bandwidth = 80 [default = 134217728];
   // poll time in milliseconds
   optional int32 poll_time = 81 [default = 100];
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index b4c386f..601a837 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -18,15 +18,22 @@ Server::Server(int thread_id,int group_id, int server_id):
 }
 
 void Server::Setup(const UpdaterProto& proto,
-    std::unordered_map<int, ParamEntry*>* shard,
-    const vector<int>& slice2group) {
+    const vector<int>& slice2group,
+    const vector<int>& slice2server) {
   updater_ = Updater::Create(proto);
-  shard_ = shard;
   slice2group_ = slice2group;
+  slice2server_ = slice2server;
+  nUpdates_.resize(slice2group_.size(), 0);
+  nPendingSync_.resize(slice2group_.size(), 0);
+  last_sync_.resize(slice2group_.size());
 }
 
 Server::~Server() {
   delete updater_;
+  // free Params (i.e., slices) in server shard
+  for (auto entry : shard_)
+    for (auto param : entry.second->shares)
+      delete param;
 }
 
 void Stop(void * running) {
@@ -35,6 +42,7 @@ void Stop(void * running) {
 
 void Server::Run() {
   LOG(ERROR) << "Server (group = " << grp_id_ <<", id = " << id_
<< ") start";
+
   auto dealer = new Dealer(2*thread_id_);
   CHECK(dealer->Connect(kInprocRouterEndpoint));
   Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub));
@@ -44,13 +52,10 @@ void Server::Run() {
   auto cluster = Cluster::Get();
   bool running = true;
   CHECK(cluster->runtime()->WatchSGroup(grp_id_, id_, Stop, &running));
-
-  int nserver_grps = cluster->nserver_groups();
-  vector<Param*> master_params;
-  size_t syncEntry=0;
   Poller poll(dealer);
   // start recv loop and process requests
   while (running) {
+    // must use poller here; otherwise Receive() gets stuck after workers stop.
     auto *sock = poll.Wait(cluster->poll_time());
     if (poll.Terminated()) {
       LOG(ERROR) << "Connection broken!";
@@ -58,34 +63,18 @@ void Server::Run() {
     } else if (sock == nullptr) {
       continue;
     }
-    Msg* msg=dealer->Receive();
-    if (msg==nullptr) break;
-    Msg* response=nullptr;
-    int type=msg->type();
+    Msg* msg = dealer->Receive();
+    if (msg == nullptr) break; //  interrupted
+    Msg* response = nullptr;
+    int type = msg->type();
     int slice_id = SliceID(msg->trgt_val());
     if (type == kPut) {
       response = HandlePut(&msg);
-      if(slice2group_[slice_id] == grp_id_)
-        master_params.push_back(shard_->at(slice_id)->shares.at(0));
     } else {
-      if (shard_->find(slice_id) == shard_->end()) {
-        // delay the processing by re-queue the msg.
+      if (shard_.find(slice_id) == shard_.end()) {
+        // delay the processing by re-queue the msg. May sleep for a while?
         response = msg;
-      } else if (type == kSyncReminder) {
-        DeleteMsg(&msg);
-        if(syncEntry >= master_params.size())
-          continue;
-        auto param = master_params.at(syncEntry);
-        // control the frequency of synchronization
-        // currently sync is triggerred only when the slice is updated
-        // by local worker or other workers for at least nserver_groups times.
-        // TODO may optimize the trigger condition.
-        if (abs(param->local_version() - param->version()) >= nserver_grps) {
-          for (auto msg : GenSyncMsgs(param))
-            dealer->Send(&msg);
-          syncEntry = (syncEntry+1) % master_params.size();
-        }
-      } else {
+      }  else {
         switch (type) {
           case kGet:
             response = HandleGet(&msg);
@@ -97,6 +86,9 @@ void Server::Run() {
           case kSyncRequest:
             response = HandleSyncRequest(&msg);
             break;
+          case kSyncResponse:
+            HandleSyncResponse(&msg);
+            break;
           default:
             LOG(ERROR)<<"Unknown message type "<<type;
             break;
@@ -117,31 +109,10 @@ void Server::Run() {
   delete dealer;
 }
 
-const vector<Msg*> Server::GenSyncMsgs(Param* param) {
-  vector<Msg*> ret;
-  // TODO replace the argument (0,0) to sync a chunk instead of a slice
-  auto msg = param->GenSyncMsg(0, 0);
-  auto cluster = Cluster::Get();
-  for (int i = 0; i < cluster->nserver_groups(); i++) {
-    if (i != grp_id_) {
-      Msg* tmp = msg;
-      if (i < cluster->nserver_groups() - 1)
-        tmp = new Msg(*msg);
-      // assume only one server per group, TODO generalize it
-      tmp->set_dst(Addr(i, 0, kServer));
-      tmp->set_src(Addr(grp_id_, id_, kServer));
-      ret.push_back(tmp);
-      param->set_version(param->local_version());
-      //LOG(ERROR)<<"sync slice="<<param->id()<<" to procs "<<i;
-    }
-  }
-  return ret;
-}
-
 Msg* Server::HandlePut(Msg **msg) {
   int version = (*msg)->trgt_version();
   int slice_id = SliceID((*msg)->trgt_val());
-  if (shard_->find(slice_id) != shard_->end())
+  if (shard_.find(slice_id) != shard_.end())
     LOG(FATAL) << "Param (" << slice_id << ") is put more than once";
 
   // TODO(wangwei) replace hard coded param type 0
@@ -152,17 +123,15 @@ Msg* Server::HandlePut(Msg **msg) {
   if ((*msg)->NextFrame())
     (*msg)->ParseFormatFrame("i", &num_shares);
   DeleteMsg(msg);
-  (*shard_)[slice_id] = new ParamEntry(num_shares, param);
+  shard_[slice_id] = new ParamEntry(num_shares, param);
   // must set version after HandlePutMsg which allocates the memory
   param->set_version(version);
   param->set_local_version(version);
   param->set_id(slice_id);
-  //LOG(ERROR)<<"put norm "<<param->data().asum_data()<<", "<<pid;
   // allocate blob for param sync between groups.
-  if (Cluster::Get()->nserver_groups() > 1 && slice2group_[slice_id] != grp_id_)
{
-    last_data_[slice_id] = std::make_shared<Blob<float>>();
-    last_data_[slice_id]->ReshapeLike(param->data());
-    last_data_[slice_id]->CopyFrom(param->data());
+  if (slice2group_[slice_id] != grp_id_) {
+    last_sync_[slice_id].ReshapeLike(param->data());
+    last_sync_[slice_id].CopyFrom(param->data());
   }
   LOG(INFO)<<"server (group = " << grp_id_ << ", id = " << id_ <<")
put slice="
     << slice_id << " size=" << param->size();
@@ -171,7 +140,7 @@ Msg* Server::HandlePut(Msg **msg) {
 
 Msg* Server::HandleGet(Msg **msg) {
   int val = (*msg)->trgt_val();
-  auto param = shard_->at(SliceID(val))->shares.at(0);
+  auto param = shard_.at(SliceID(val))->shares.at(0);
   // re-queue the request if the param is not updated to the required version
   if(param->version()<(*msg)->trgt_version())
     return *msg;
@@ -186,15 +155,14 @@ Msg* Server::HandleGet(Msg **msg) {
 const vector<Msg*> Server::HandleUpdate(Msg **msg) {
   vector<Msg*> ret;
   int sliceid = SliceID((*msg)->trgt_val());
-  auto entry = shard_->at(sliceid);
+  auto entry = shard_.at(sliceid);
   buffer_requests_[sliceid].push_back(*msg);
   int num_update;
   (*msg)->LastFrame();
   (*msg)->ParseFormatFrame("i", &num_update);
   (*msg)->FirstFrame();
   entry->num_update += num_update;
-  // LOG(ERROR) << "update "<<sliceid<< " from "<<(*msg)->src_second()
-  //  << ", " << num_update << " total " << entry->num_total;
+  // LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src())
<< ", " << num_update << " total " << entry->num_total;
   // do update until recv gradients from all shares of this param/slice
   if (entry->num_update >= entry->num_total) {
     CHECK_EQ(entry->num_update, entry->num_total);
@@ -211,6 +179,26 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
       ret.push_back(response);
     }
     entry->num_update = 0;
+    nUpdates_[sliceid]++;
+    // sync with master group after at least sync_freq local updates
+    // the last check is to avoid sending msg to stopped servers
+    if (slice2group_[sliceid] != grp_id_
+        && nUpdates_[sliceid] >= Cluster::Get()->sync_freq()
+        && nPendingSync_[sliceid] <= Cluster::Get()->sync_freq()) {
+      auto shape = Shape1(param->size());
+      Tensor<cpu, 1> tmp(last_sync_[sliceid].mutable_cpu_data(), shape);
+      Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
+      tmp = cur - tmp;
+      int addr = Addr(slice2group_[sliceid], slice2server_[sliceid], kServer);
+      Msg* sync = new Msg(Addr(grp_id_, id_, kServer), addr);
+      sync->set_type(kSyncRequest);
+      sync->set_trgt((*msg)->trgt_val(), param->local_version());
+      sync->AddFrame(tmp.dptr, param->size() * sizeof(float));
+      Copy(tmp, cur);
+      ret.push_back(sync);
+      nUpdates_[sliceid] = 0;
+      nPendingSync_[sliceid]++;
+    }
   }
   *msg = nullptr;
   return ret;
@@ -219,38 +207,33 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
 Msg* Server::HandleSyncRequest(Msg **msg) {
   Msg* msgg = *msg;
   int slice = SliceID(msgg->trgt_val());
-  auto param = shard_->at(slice)->shares.at(0);
-  Msg* response=nullptr;
-  auto shape=Shape1(param->size());
+  auto param = shard_.at(slice)->shares.at(0);
+  auto shape = Shape1(param->size());
   CHECK_EQ(msgg->FrameSize(), param->size()*sizeof(float));
-  Tensor<cpu, 1> tmp(static_cast<float*>(msgg->FrameData()), shape);
+  Tensor<cpu, 1> inc(static_cast<float*>(msgg->FrameData()), shape);
   Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
-  //LOG(ERROR)<<"Recv sync for "<<param->id();
-  if (slice2group_[slice] == grp_id_) {
-    // recv sync msg on slice I am mastering
-    cur+=tmp;
-    param->set_local_version(param->local_version()+1);
-  } else {  // recv sync msg on slice mastered by others
-    TensorContainer<cpu, 1> diff(shape);
-    Tensor<cpu, 1> prev(last_data_[param->id()]->mutable_cpu_data(), shape);
-    diff=cur-prev;
-    msgg->NextFrame();
-    int bandwidth;
-    msgg->ParseFormatFrame("i", &bandwidth);
-    if (bandwidth > 0) {
-      // send back my updates to the server group mastering this param
-      response=new Msg(msgg->dst(), msgg->src());
-      response->set_type(kSyncRequest);
-      response->set_trgt(param->id(), param->version());
-      response->AddFrame(diff.dptr, param->size()*sizeof(float));
-      prev=diff+tmp;
-      Copy(cur, prev);
-    } else {  // no bandwidth, aggregate my updates for next sync
-      Copy(prev, tmp);
-      cur=tmp+diff;
-    }
-  }
+  // recv sync msg on the slice I am maintaining
+  cur += inc;
+  msgg->SwapAddr();
+  msgg->set_type(kSyncResponse);
+  // copy the fresh param value into the response msg
+  Copy(inc, cur);
+  return msgg;
+}
+
+// recv sync msg on slice mastered by others
+void Server::HandleSyncResponse(Msg **msg) {
+  Msg* msgg = *msg;
+  int slice = SliceID(msgg->trgt_val());
+  auto param = shard_.at(slice)->shares.at(0);
+  auto shape=Shape1(param->size());
+  Tensor<cpu, 1> prev(last_sync_[param->id()].mutable_cpu_data(), shape);
+  Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
+  Tensor<cpu, 1> master(static_cast<float*>(msgg->FrameData()), shape);
+  cur += master - prev;  // cur = master + (cur - prev);
+  Copy(prev, cur);
   DeleteMsg(msg);
-  return response;
+  nPendingSync_[slice]--;
 }
+
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index b6dc729..b02ef3e 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -21,10 +21,6 @@ using std::make_shared;
 
 /***********************Trainer****************************/
 Trainer::~Trainer() {
-  // free Params (i.e., slices) in server shard
-  for (auto entry : server_shard_)
-    for (auto param : entry.second->shares)
-      delete param;
   delete router_;
 }
 
@@ -120,10 +116,11 @@ void Trainer::SetupWorkerServer(
 
   //  partition among server groups, each group maintains one sub-set for sync
   auto slice2group = PartitionSlices(cluster->nserver_groups(), slices);
-  for (auto server : servers)
-    server->Setup(job_conf.updater(), &server_shard_, slice2group);
   //  partition within one server group, each server updates for one sub-set
   slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices);
+
+  for (auto server : servers)
+    server->Setup(job_conf.updater(), slice2group, slice2server_);
 }
 
 vector<Server*> Trainer::CreateServers(int nthreads, const JobProto& job) {
@@ -132,46 +129,33 @@ vector<Server*> Trainer::CreateServers(int nthreads, const JobProto&
job) {
   if (!cluster->has_server())
     return servers;
 
-  int pid = cluster->procs_id();
+  int server_procs = cluster->procs_id();
   // if true, server procs (logical) id starts after worker procs
   if (cluster->server_worker_separate())
-    pid -= cluster->nworker_procs();
-  int procs_size = cluster->nservers_per_procs();
-  int grp_size = cluster->nservers_per_group();
-  int gid = pid *  procs_size / grp_size;
-  int start = pid * procs_size % grp_size;
-  int end = start + procs_size;
-  for (int sid = start; sid < end; sid++) {
-    auto server = new Server(nthreads++, gid, sid);
-    servers.push_back(server);
+    server_procs -= cluster->nworker_procs();
+  const vector<int> rng = cluster->ExecutorRng(server_procs,
+      cluster->nservers_per_group(),
+      cluster->nservers_per_procs());
+  int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
+  for (int gid = gstart; gid < gend; gid++) {
+    for (int sid = start; sid < end; sid++) {
+      auto server = new Server(nthreads++, gid, sid);
+      servers.push_back(server);
+    }
   }
   return servers;
 }
 
+
 vector<Worker*> Trainer::CreateWorkers(int nthreads, const JobProto& job) {
   auto cluster=Cluster::Get();
   vector<Worker*> workers;
   if(!cluster->has_worker())
     return workers;
-  int pid = cluster->procs_id();
-  int grp_size = cluster->nworkers_per_group();
-  int procs_size = cluster->nworkers_per_procs();
-  int gstart, gend, wstart, wend;
-  if (grp_size >= procs_size) {
-    // all workers in this procs are from the same group
-    gstart = pid * procs_size / grp_size;
-    gend = gstart + 1;
-    wstart = pid * procs_size % grp_size;
-    wend = wstart + procs_size;
-  } else {
-    // there are multiple (complete) groups in this procs.
-    CHECK_EQ(procs_size % grp_size, 0);
-    int groups_per_procs = procs_size / grp_size;
-    gstart = pid * groups_per_procs;
-    gend = (pid+1) * groups_per_procs;
-    wstart = 0;
-    wend = grp_size;
-  }
+  const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(),
+      cluster->nworkers_per_group(),
+      cluster->nworkers_per_procs());
+  int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3];
   for (int gid = gstart; gid < gend; gid++) {
     for (int wid = wstart; wid < wend; wid++) {
       auto *worker = Worker::Create(job);
@@ -260,12 +244,6 @@ void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto*
job) {
     delete worker;
 }
 
-inline int bandwidth(int bytes, system_clock::time_point start) {
-  auto now=system_clock::now();
-  auto duration=duration_cast<std::chrono::milliseconds> (now - start);
-  return static_cast<int>(bytes*1000.f/duration.count());
-}
-
 void Trainer::Run(
     const vector<Worker*>& workers,
     const vector<Server*>& servers) {
@@ -274,42 +252,20 @@ void Trainer::Run(
   procs_id_ = cluster->procs_id();
   LOG(INFO) << "Stub in process " << procs_id_ << " starts";
 
-  // for sync among server groups
-  auto start = std::chrono::system_clock::now();
-  float trans_size = 0.f;  // total size of msg transferred since start time
-  int sync_server_id = 0;
-  int max_bandwidth = cluster->bandwidth();
-  int nserver_grps = cluster->nserver_groups();
-
   map<int, Dealer*> inter_dealers;  // for sending msg to other procs
 
   std::queue<Msg*> msg_queue;
-  Poller poll(router_);
-  bool stop=false;
-  while (!stop || !msg_queue.empty()) {
+  while (true) {
+    Msg* msg = nullptr;
     if (msg_queue.empty()) {
-      // if the poll time is large, then the poller may not expire
-      // if it is small, then many reminder messages will be sent which may
-      // slow done the process of other request. TODO tune it.
-      auto *sock = poll.Wait(cluster->poll_time());
-      if (poll.Terminated()) {
-        LOG(ERROR) << "Connection broken!";
-        exit(0);
-      } else if (sock == nullptr) {
-        if (nserver_grps > 1 && bandwidth(trans_size, start) < max_bandwidth)
{
-          Msg* msg = GenSyncReminderMsg(sync_server_id, servers);
-          router_->Send(&msg) ;
-          sync_server_id = (sync_server_id + 1) % nservers;
-        }
-        continue;
-      }
-      Msg* msg = router_->Receive();
-      msg_queue.push(msg);
+      msg = router_->Receive();
+    } else {
+      msg = msg_queue.front();
+      msg_queue.pop();
     }
-    Msg* msg = msg_queue.front();
-    msg_queue.pop();
     int type = msg->type(), dst = msg->dst(), flag = AddrType(dst);
     if (flag == kStub && (AddrProc(dst) == procs_id_ || AddrGrp(dst) == -1)) {
+      //  the following statements are ordered!
       if (type == kConnect) {
         DeleteMsg(&msg);
       } else if (type == kMetric) {
@@ -320,28 +276,18 @@ void Trainer::Run(
         else if (src_flag == kWorkerParam) nworkers--;
         DeleteMsg(&msg);
         if (nworkers == 0 && nservers == 0) break;
-      } else if (nserver_grps > 0) {
-        HandleLocalMsg(&msg_queue, &msg);
       } else {
-        DeleteMsg(&msg);
+        HandleLocalMsg(&msg_queue, &msg);
       }
     } else {
       int dst_procs = AddrProc(dst);
       if (flag != kStub)
         dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag);
       if (dst_procs != procs_id_) {
-        if (bandwidth(trans_size, start) <= cluster->bandwidth()) {
-          start = std::chrono::system_clock::now();
-          trans_size = 0;
-        }
-        trans_size += msg->size();
-
         if (inter_dealers.find(dst_procs) == inter_dealers.end())
           inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs);
         inter_dealers[dst_procs]->Send(&msg);
       } else {
-        if (type == kSyncRequest)
-          msg->AddFormatFrame("i", max_bandwidth - bandwidth(trans_size, start));
         router_->Send(&msg);
       }
     }
@@ -351,14 +297,6 @@ void Trainer::Run(
     delete entry.second;
 }
 
-Msg* Trainer::GenSyncReminderMsg(int server, const vector<Server*>& servers ) {
-  Msg* msg = new Msg();
-  msg->set_src(Addr(-1,-1, kStub));
-  msg->set_dst(Addr(servers[server]->grp_id(), servers[server]->id(), kServer));
-  msg->set_type(kSyncReminder);
-  return msg;
-}
-
 void Trainer::DisplayMetric(Msg** msg) {
   Msg* msgg = *msg;
   // only display metrics from the first group
@@ -436,16 +374,16 @@ void Trainer::GenMsgs(int type, int version, ParamEntry* entry,
   for (int idx = 0 ; idx < param->num_slices(); idx++) {
     int slice_id =param->slice_start() + idx;
     int server = slice2server_[slice_id];
-    int procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
+    int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
     Msg* new_msg = nullptr;
     if (type == kPut) {
       CHECK_GT(entry->num_total, 0);
-      new_msg = param->GenPutMsg(procs != procs_id_, idx);
+      new_msg = param->GenPutMsg(dst_procs != procs_id_, idx);
       new_msg->AddFormatFrame("i", entry->num_total);
     } else if (type == kGet) {
-      new_msg = param->GenGetMsg(procs != procs_id_, idx);
+      new_msg = param->GenGetMsg(dst_procs != procs_id_, idx);
     } else if (type == kUpdate) {
-      new_msg = param->GenUpdateMsg(procs != procs_id_, idx);
+      new_msg = param->GenUpdateMsg(dst_procs != procs_id_, idx);
       new_msg->AddFormatFrame("i", entry->num_local);
     } else {
       LOG(FATAL) << "Wrong type";

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/ed9e3736/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index 9664064..a1716a1 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -6,6 +6,7 @@
 #include <fstream>
 
 namespace singa {
+using std::vector;
 
 Cluster* Cluster::Setup(int job, const SingaProto& singaConf,
                         const ClusterProto& clusterConf) {
@@ -71,6 +72,26 @@ void Cluster::SetupFolders(const ClusterProto &cluster) {
   mkdir(checkpoint_folder().c_str(),  S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
 }
 
+const vector<int> Cluster::ExecutorRng(int pid, int grp_size, int procs_size) {
+  int gstart, gend, start, end;
+  if (grp_size >= procs_size) {
+    // all workers in this procs are from the same group
+    gstart = pid * procs_size / grp_size;
+    gend = gstart + 1;
+    start = pid * procs_size % grp_size;
+    end = start + procs_size;
+  } else {
+    // there are multiple (complete) groups in this procs.
+    CHECK_EQ(procs_size % grp_size, 0);
+    int groups_per_procs = procs_size / grp_size;
+    gstart = pid * groups_per_procs;
+    gend = (pid+1) * groups_per_procs;
+    start = 0;
+    end = grp_size;
+  }
+  return vector<int>{gstart, gend, start, end};
+}
+
 int Cluster::Hash(int gid, int id, int flag) {
   int ret = -1;
   if (flag == kServer) {


Mime
View raw message