singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [4/6] incubator-singa git commit: SINGA-8 Implement distributed Hogwild
Date Thu, 25 Jun 2015 13:45:02 GMT
SINGA-8 Implement distributed Hogwild

The original Param objects are sliced to make the size of parameters mastered by server groups
(roughly) equal.
Following Caffe's implementation, we let each server group master a subset of param slices.
Each server group updates all model parameters for the corresponding worker groups and synchronize
with other server groups on their mastered slices.
Tested on single node with multiple processes, each of which has one server group with one
server and one worker group with one worker.
The training loss decreases not as fast as shared-memory hogwild. TODO optimize and test on
multiple nodes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/884b9d70
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/884b9d70
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/884b9d70

Branch: refs/heads/master
Commit: 884b9d70a631bee4961fb3907e47a747c5dd2b89
Parents: ad13d03
Author: wang wei <wangwei@comp.nus.edu.sg>
Authored: Thu Jun 25 11:39:45 2015 +0800
Committer: wang wei <wangwei@comp.nus.edu.sg>
Committed: Thu Jun 25 11:50:28 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/model.conf    |  1 +
 include/communication/socket.h |  2 +-
 include/trainer/trainer.h      |  3 +-
 include/utils/cluster.h        | 10 ++---
 include/utils/param.h          |  5 +--
 src/communication/msg.cc       |  4 +-
 src/communication/socket.cc    |  2 +-
 src/proto/cluster.proto        |  1 -
 src/proto/common.proto         |  1 +
 src/trainer/server.cc          | 79 ++++++++++++++++++-------------------
 src/trainer/trainer.cc         | 37 +++++++++--------
 src/utils/cluster.cc           | 15 +++----
 src/utils/param.cc             | 14 +++----
 13 files changed, 83 insertions(+), 91 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/examples/cifar10/model.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf
index 42be6dd..2bf76b0 100644
--- a/examples/cifar10/model.conf
+++ b/examples/cifar10/model.conf
@@ -25,6 +25,7 @@ layer{
   sharddata_conf {
     path: "examples/cifar10/cifar10_train_shard"
     batchsize: 16
+    random_skip: 5000
   }
   exclude: kTest
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/include/communication/socket.h
----------------------------------------------------------------------
diff --git a/include/communication/socket.h b/include/communication/socket.h
index b98656e..5a9598c 100644
--- a/include/communication/socket.h
+++ b/include/communication/socket.h
@@ -59,7 +59,7 @@ class Poller {
   /**
    * @return true if the poller is terminated due to process interupt
    */
-  virtual bool Terminated()=0;
+  virtual bool Terminated();
 
  protected:
 #ifdef USE_ZMQ

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index fb716bc..31d3704 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -101,8 +101,7 @@ class Trainer{
       const ModelProto& mproto, vector<int> *slice_size);
 
   void Run(const vector<shared_ptr<Worker>>& workers,
-      const vector<shared_ptr<Server>>& servers,
-      const std::map<int, shared_ptr<ParamShard>>& shards);
+      const vector<shared_ptr<Server>>& servers);
   /**
    * Register default implementations for all base classes used in the system,
    * e.g., the Updater, BaseMsg, etc.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/include/utils/cluster.h
----------------------------------------------------------------------
diff --git a/include/utils/cluster.h b/include/utils/cluster.h
index 3830383..55b10a9 100644
--- a/include/utils/cluster.h
+++ b/include/utils/cluster.h
@@ -32,7 +32,10 @@ class Cluster {
   int nworkers_per_procs()const{return cluster_.nworkers_per_procs();}
   int nservers_per_procs()const{return cluster_.nservers_per_procs();}
   int nworker_groups_per_server_group() const {
-    return cluster_.nworker_groups()/cluster_.nserver_groups();
+    if(nserver_groups()==0||nservers_per_group()==0)
+      return 1;
+    else
+      return cluster_.nworker_groups()/cluster_.nserver_groups();
   }
 
   /**
@@ -49,10 +52,7 @@ class Cluster {
    * @return true if the calling procs has worker threads.
    */
   bool has_worker()const {
-    if(server_worker_separate()){
-      return procs_id_<nworker_procs();
-    }else
-      return procs_id_<nprocs_;
+    return procs_id_<nworker_procs();
   }
   /**
    * @return global procs id, which starts from 0.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index d449fba..61e862b 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -39,7 +39,7 @@ class Param {
    * This function is called at server side where the Param is actually a slice
    * of an original Param object.
    * */
-  virtual Msg* GenSyncMsg();
+  virtual Msg* GenSyncMsg(int offset, int size);
   /**
    * Generate the message to response the update request.
    *
@@ -70,8 +70,6 @@ class Param {
    * \copydetails HandleGetMsg(Msg**)
    */
   virtual Msg* HandleSyncMsg(Msg** msg);
-
-<<<<<<< HEAD
   /**
    * Server parses update request message.
    *
@@ -106,7 +104,6 @@ class Param {
    * @param shape
    */
   virtual void Setup(const ParamProto& proto, const std::vector<int>& shape);
-  virtual void Setup(const vector<int>& shape);
   /*
    * Fill the values according to initmethod, e.g., gaussian distribution
    *

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/communication/msg.cc
----------------------------------------------------------------------
diff --git a/src/communication/msg.cc b/src/communication/msg.cc
index 7ee8cad..38512d2 100644
--- a/src/communication/msg.cc
+++ b/src/communication/msg.cc
@@ -11,8 +11,8 @@ Msg::Msg(const Msg& msg){
   src_=msg.src_;
   dst_=msg.dst_;
   type_=msg.type_;
-  target_first_=msg.target_first_;
-  target_second_=msg.target_second_;
+  trgt_first_=msg.trgt_first_;
+  trgt_second_=msg.trgt_second_;
   msg_=zmsg_dup(msg.msg_);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/communication/socket.cc
----------------------------------------------------------------------
diff --git a/src/communication/socket.cc b/src/communication/socket.cc
index c6925d8..0cb0982 100644
--- a/src/communication/socket.cc
+++ b/src/communication/socket.cc
@@ -90,7 +90,7 @@ Router::~Router() {
       zmsg_destroy(&msg);
   }
 }
-int Router::Bind(std::string endpoint){
+int Router::Bind(const std::string& endpoint){
   int port=-1;
   if(endpoint.length()){
     port=zsock_bind(router_, "%s", endpoint.c_str());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/proto/cluster.proto
----------------------------------------------------------------------
diff --git a/src/proto/cluster.proto b/src/proto/cluster.proto
index 3317f2a..8fbdbbe 100644
--- a/src/proto/cluster.proto
+++ b/src/proto/cluster.proto
@@ -43,7 +43,6 @@ message ClusterProto {
   optional int32 bandwidth=50 [default=134217728];
   // poll time in milliseconds
   optional int32 poll_time=51 [default =100];
->>>>>>> SINGA-8 Implement distributed Hogwild
 }
 
 message ServerTopology {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/proto/common.proto
----------------------------------------------------------------------
diff --git a/src/proto/common.proto b/src/proto/common.proto
index 6bc0919..70b743c 100644
--- a/src/proto/common.proto
+++ b/src/proto/common.proto
@@ -13,6 +13,7 @@ enum MsgType {
   kRUpdate = 9;
   kConnect = 10;
   kMetric = 11;
+  kSyncReminder = 12;
 };
 
 enum EntityType {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 9e0dee3..9ea4509 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -7,6 +7,7 @@
 #include "utils/singleton.h"
 #include "utils/factory.h"
 #include "utils/cluster.h"
+#include "proto/common.pb.h"
 
 using namespace mshadow;
 namespace singa {
@@ -34,8 +35,9 @@ void Server::Run(){
   ping->add_frame("PING", 4);
   ping->set_type(kConnect);
   dealer_->Send(&ping);
-  int syncEntry=0;
-	//start recv loop and process requests
+  vector<shared_ptr<Param>> master_params;
+  size_t syncEntry=0;
+  //start recv loop and process requests
   while (true){
     Msg* msg=dealer_->Receive();
     if (msg==nullptr)
@@ -53,46 +55,39 @@ void Server::Run(){
       CHECK_STREQ("PONG", pong.c_str());
       DeleteMsg(&msg);
     }else if(type==kPut){
+      int pid = msg->trgt_second();
       response = HandlePut(&msg);
+      if(slice2group_[pid]==group_id_)
+        master_params.push_back(shard_->at(pid));
     }else{
       int pid=msg->trgt_second();
       if(shard_->find(pid)==shard_->end()){
         // delay the processing by re-queue the msg.
         response=msg;
         DLOG(ERROR)<<"Requeue msg";
-    }else if(type==kSyncReminder){
-      DeleteMsg(&msg);
-      unsigned nchecks=0, nparams=shard_->size();
-      while(nchecks<nparams
-          &&group_locator_->at(shard_->at(syncEntry))!=group_id_){
-        syncEntry=(syncEntry+1)%nparams;
-        nchecks++;
-      }
-      if(nchecks==nparams) continue;
-      auto param=shard_->at(syncEntry);
-      if(param->local_version()!=param->version()){
-        sync=param->GenSyncMsg(true);
-        for(int i=0;i<cluster->nserver_groups();i++){
-          if(i!=group_id_) {
-            Msg* tmp=sync;
-            if(i<cluster->nserver_groups()-1)
-              tmp= new Msg(*sync);
-            tmp->set_dst(i, server_locator_->at(param), kServer);
-            tmp->set_src(group_id_, server_id_, kServer);
-            dealer_->Send(&tmp);
-            param->set_version(param->local_version());
-            //DLOG(ERROR)<<"sync";
+      }else if(type == kSyncReminder){
+        DeleteMsg(&msg);
+        if(syncEntry>=master_params.size())
+          continue;
+        auto param=master_params.at(syncEntry);
+        if(param->local_version()!=param->version()){
+          sync=param->GenSyncMsg(0,0);
+          for(int i=0;i<cluster->nserver_groups();i++){
+            if(i!=group_id_) {
+              Msg* tmp=sync;
+              if(i<cluster->nserver_groups()-1)
+                tmp= new Msg(*sync);
+              // assume only one server per group, TODO generalize it
+              tmp->set_dst(i, 0, kServer);
+              tmp->set_src(group_id_, server_id_, kServer);
+              dealer_->Send(&tmp);
+              param->set_version(param->local_version());
+              //DLOG(ERROR)<<"sync";
+            }
           }
+          syncEntry=(syncEntry+1)%master_params.size();
         }
-      }
-    }else {
-      int pid=msg->target_first();
-      if(shard_->find(pid)==shard_->end()){
-        // delay the processing by re-queue the msg.
-        response=msg;
-        LOG(ERROR)<<"Requeue";
->>>>>>> SINGA-8 Implement distributed Hogwild
-      } else{
+      }else{
         auto param=shard_->at(pid);
         switch (type){
           case kGet:
@@ -118,7 +113,7 @@ void Server::Run(){
 
 Msg* Server::HandlePut(Msg **msg){
   int version=(*msg)->trgt_third();
-  int pid=(*msg)->target_first();
+  int pid=(*msg)->trgt_second();
   shared_ptr<Param> param=nullptr;
   if(shard_->find(pid)!=shard_->end()){
     LOG(ERROR)<<"Param ("<<pid<<") is put more than once";
@@ -126,19 +121,21 @@ Msg* Server::HandlePut(Msg **msg){
   }else{
     auto factory=Singleton<Factory<Param>>::Instance();
     param=shared_ptr<Param>(factory ->Create("Param"));
-    param->set_id(pid);
     (*shard_)[pid]=param;
   }
   auto response=param->HandlePutMsg(msg);
   // must set version after HandlePutMsg which allocates the memory
   param->set_version(version);
+  param->set_local_version(version);
+  param->set_id(pid);
   if(Cluster::Get()->nserver_groups()>1 &&
-      group_locator_->at(param)!=group_id_){
+      slice2group_[pid]!=group_id_){
     last_data_[pid]=std::make_shared<Blob<float>>();
     last_data_[pid]->ReshapeLike(param->data());
     last_data_[pid]->CopyFrom(param->data());
   }
-  LOG(INFO)<<"Server put param "<<pid<<" size="<<param->size()<<"
Bytes";
+  LOG(INFO)<<"server ("<<group_id_<<", "<<server_id_
+    <<") put slice="<<pid<<" size="<<param->size();
   return response;
 }
 
@@ -161,9 +158,9 @@ Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) {
   int step=(*msg)->trgt_third();
   bool copy=param->ParseUpdateMsg(msg);
   updater_->Update(step, param);
-  param->set_version(param->version()+1);
+  param->set_local_version(param->local_version()+1);
   auto response=param->GenUpdateResponseMsg(copy);
-  response->set_trgt(paramid, sliceid, param->version());
+  response->set_trgt(paramid, sliceid, param->local_version());
   response->SetAddr(tmp);
   delete tmp;
   return response;
@@ -175,7 +172,7 @@ Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){
   CHECK_EQ((*msg)->frame_size(), param->size()*sizeof(float));
   Tensor<cpu, 1> tmp(static_cast<float*>((*msg)->frame_data()), shape);
   Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
-  if(group_locator_->at(param)==group_id_){
+  if(slice2group_[param->id()]==group_id_){
     cur+=tmp;
     param->set_local_version(param->local_version()+1);
   }else{
@@ -188,7 +185,7 @@ Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){
     if(bandwidth>0){
       response=new Msg();
       response->set_type(kSyncRequest);
-      response->set_target(param->id(), param->version());
+      response->set_trgt(-1, param->id(), param->version());
       response->add_frame(diff.dptr, param->size()*sizeof(float));
       (*msg)->SwapAddr();
       response->SetAddr(*msg);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index cd80296..4e7f932 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -179,12 +179,12 @@ vector<shared_ptr<Worker>> Trainer::CreateWorkers(int nthreads,
   auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
       cluster->nworkers_per_group());
   int lcm=LeastCommonMultiple(cluster->nserver_groups(), cluster->nservers_per_group());
-    auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size
-    for(auto param: net->params()){
-      if(param->id()==param->owner())
-        for(auto entry: paramid2slices[param->id()])
-          slice_size->push_back(entry.second);
-    }
+  auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size
+  for(auto param: net->params()){
+    if(param->id()==param->owner())
+      for(auto entry: paramid2slices[param->id()])
+        slice_size->push_back(entry.second);
+  }
 
   for(int gid=gstart;gid<gend;gid++){
     shared_ptr<NeuralNet> train_net, test_net, validation_net;
@@ -249,15 +249,20 @@ vector<shared_ptr<Worker>> Trainer::CreateWorkers(int nthreads,
 
 void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
     int procs_id){
-  procs_id_=procs_id;
+  // procs_id is only used for resume training
+  CHECK_EQ(procs_id, -1);
   RegisterDefaultClasses(mproto);
 
   auto cluster=Cluster::Get(cproto, procs_id);
   router_=make_shared<Router>();
   router_->Bind(kInprocRouterEndpoint);
-  if(cluster->nprocs()>1)
-    router_->Bind(cluster->endpoint());
+  if(cluster->nprocs()>1){
+    int port=router_->Bind("tcp://127.0.0.1:*");
+    cluster->Register(cluster->hostname()+":"+std::to_string(port));
+  }else
+    cluster->set_procs_id(0);
 
+  procs_id_ = cluster->procs_id();
   int nthreads=1;
   // create workers
   vector<int> slices;
@@ -280,7 +285,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto&
cproto,
     threads.push_back(std::thread(&Server::Run,server.get()));
   for(auto worker: workers)
     threads.push_back(std::thread(&Worker::Run,worker.get()));
-  Run(workers, servers, shards);
+  Run(workers, servers);
   for(auto& thread: threads)
     thread.join();
   for(auto x: ctx)
@@ -292,9 +297,9 @@ inline int bandwidth(int bytes, system_clock::time_point start){
   auto duration=duration_cast<TimeT> (now - start);
   return static_cast<int>(bytes*1000.f/duration.count());
 }
+
 void Trainer::Run(const vector<shared_ptr<Worker>>& workers,
-    const vector<shared_ptr<Server>>& servers,
-    const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
+    const vector<shared_ptr<Server>>& servers){
   auto cluster=Cluster::Get();
   procs_id_=cluster->procs_id();
   LOG(INFO)<<"Stub in process "<<procs_id_<<" starts";
@@ -307,7 +312,7 @@ void Trainer::Run(const vector<shared_ptr<Worker>>& workers,
   poll.Add(router_.get());
   int sync_server=0, nworkers=workers.size(), nservers=servers.size();
   while(!stop){
-    Socket *sock=poll.Wait(cluster->poll_time());
+    auto *sock=poll.Wait(cluster->poll_time());
     if(poll.Terminated()){
       LOG(ERROR)<<"Connection broken!";
       exit(0);
@@ -321,7 +326,6 @@ void Trainer::Run(const vector<shared_ptr<Worker>>& workers,
         msg->set_type(kSyncReminder);
         sync_server=(sync_server+1)%servers.size();
         router_->Send(&msg);
-        //LOG(ERROR)<<"Reminder";
       }
       continue;
     }
@@ -345,14 +349,13 @@ void Trainer::Run(const vector<shared_ptr<Worker>>&
workers,
             nservers--;
           else if (msg->src_flag()==kWorkerParam)
             nworkers--;
-          delete msg;
-          msg=nullptr;
+          DeleteMsg(&msg);
           if(nworkers==0&&nservers==0){
             stop=true;
             break;
           }
         }else if(type==kMetric){
-          if(msg->src_first()==0){
+          if(msg->src_first()>=0){
             int step=msg->trgt_first();
             string prefix((char*)msg->frame_data(), msg->frame_size());
             msg->next_frame();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index 023abbe..bf423b4 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -14,21 +14,16 @@ Cluster::Cluster(const ClusterProto &cluster, int procs_id) {
   procs_id_=procs_id;
   cluster_ = cluster;
   SetupFolders(cluster);
-  int nprocs;
   if(server_worker_separate())
     nprocs_=nworker_procs()+nserver_procs();
   else
-    nprocs=std::max(nworker_procs(), nserver_procs());
-  CHECK_LT(procs_id, nprocs);
-  if (cluster_.has_nprocs())
-    CHECK_EQ(cluster.nprocs(), nprocs);
-  else
-    cluster_.set_nprocs(nprocs);
-  if(nprocs>1&&procs_id>-1){
+    nprocs_=std::max(nworker_procs(), nserver_procs());
+  CHECK_LT(procs_id, nprocs_);
+  if(nprocs_>1&&procs_id>-1){
     std::ifstream ifs(cluster.hostfile(), std::ifstream::in);
     std::string line;
-    while(std::getline(ifs, line)
-        &&endpoints_.size()<static_cast<size_t>(nprocs_)){
+    while(std::getline(ifs, line)&&
+        endpoints_.size()< static_cast<size_t>(nprocs_)){
       endpoints_.push_back(line);
     }
     CHECK_EQ(endpoints_.size(), nprocs_);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/884b9d70/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 4ad17ce..ac3a6bb 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -133,11 +133,11 @@ Msg* Param::GenUpdateMsg(bool copy, int idx){
   return msg;
 }
 
-Msg* Param::GenSyncMsg(bool copy, int v){
+Msg* Param::GenSyncMsg(int offset, int size){
   Msg* msg=new Msg();
   msg->set_type(kSyncRequest);
-  msg->set_target(id(), local_version());
-  msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
+  msg->set_trgt(-1, id(), local_version());
+  msg->add_frame(mutable_cpu_data(), data_->count()*sizeof(float));
   return msg;
 }
 
@@ -150,9 +150,10 @@ Msg* Param::HandlePutMsg(Msg** msg){
   proto_.set_learning_rate_multiplier(lr);
   proto_.set_weight_decay_multiplier(wc);
   vector<int> shape{size};
-  Setup(shape);
-  set_local_version((*msg)->target_second());
-  set_version((*msg)->target_second());
+  ParamProto proto;
+  Setup(proto, shape);
+  set_local_version((*msg)->trgt_third());
+  set_version((*msg)->trgt_third());
   if(ptr==nullptr){
     CHECK((*msg)->next_frame());
     CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
@@ -205,7 +206,6 @@ Msg* Param::HandleSyncMsg(Msg** msg){
   return nullptr;
 }
 
-<<<<<<< HEAD
 int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){
   DeleteMsg(msg);
   return 1;


Mime
View raw message