singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [01/10] incubator-singa git commit: implement shared memory hogwild. update param at worker side
Date Mon, 15 Jun 2015 06:42:29 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master 856fc1fbe -> 4df2bb5a8


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/806826eb/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 3f343af..dbf8a48 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -180,7 +180,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto&
cproto,
     threads.push_back(std::thread(&Server::Run,server.get()));
   for(auto worker: workers)
     threads.push_back(std::thread(&Worker::Run,worker.get()));
-  Run(servers.size(), workers.size(), shards);
+  Run(workers.size(), servers.size(), shards);
   for(auto& thread: threads)
     thread.join();
 }
@@ -191,8 +191,6 @@ void Trainer::Run(int nworkers, int nservers,
   procs_id_=cluster->procs_id();
   map<int, shared_ptr<Dealer>> interprocs_dealers;
   Metric perf;
-  int perf_step=-1;
-  string perf_prefix;
   bool stop=false;
   while(!stop){
     Msg* msg=router_->Receive();
@@ -209,9 +207,9 @@ void Trainer::Run(int nworkers, int nservers,
           msg =HandleConnect(&msg);
         }else if(type==kStop){
           if(msg->src_flag()==kServer)
-            nworkers--;
-          else if (msg->src_flag()==kWorkerParam)
             nservers--;
+          else if (msg->src_flag()==kWorkerParam)
+            nworkers--;
           delete msg;
           msg=nullptr;
           if(nworkers==0&&nservers==0){
@@ -219,26 +217,19 @@ void Trainer::Run(int nworkers, int nservers,
             break;
           }
         }else if(type==kMetric){
-          int step=msg->target_first();
-          string prefix((char*)msg->frame_data(), msg->frame_size());
-          if(step!=perf_step||perf_prefix!=prefix){
-            if(perf_step>=0){
-              perf.Avg();
-              LOG(ERROR)<<perf_prefix<<" step-"
-                <<perf_step<<", "<<perf.ToString();
-              perf.Reset();
-            }
-            perf_step=step;
-            perf_prefix=prefix;
+          if(msg->src_first()==0){
+            int step=msg->target_first();
+            string prefix((char*)msg->frame_data(), msg->frame_size());
+            msg->next_frame();
+            Metric cur;
+            cur.ParseString(string((char*)msg->frame_data(), msg->frame_size()));
+            perf.AddMetrics(cur);
+            LOG(ERROR)<<prefix<<" step-" <<step<<", "<<perf.ToString();
+            perf.Reset();
           }
-          msg->next_frame();
-          Metric cur;
-          cur.ParseString(string((char*)msg->frame_data(), msg->frame_size()));
-          perf.AddMetrics(cur);
-          perf.Inc();
           delete msg;
           msg=nullptr;
-        }else {
+        }else if(cluster->nserver_groups()>1){
           int group_id=msg->src_first();
           int paramid=msg->target_first();
           auto entry=shards.at(group_id)->at(paramid);
@@ -261,6 +252,9 @@ void Trainer::Run(int nworkers, int nservers,
             default:
               break;
           }
+        }else{
+          delete msg;
+          msg=nullptr;
         }
       }else{
         int dst_procs_id;
@@ -282,9 +276,11 @@ void Trainer::Run(int nworkers, int nservers,
       }
     }
   }
+  /*
   perf.Avg();
   if(perf_step>=0)
     LOG(ERROR)<<perf_prefix<<" step-"<<perf_step<<", "<<perf.ToString();
+    */
 }
 Msg* Trainer::HandleConnect(Msg** msg){
   string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/806826eb/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index abfcdf0..3d400ee 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -22,6 +22,11 @@ void Worker::Setup(const ModelProto& model,
   auto cluster=Cluster::Get();
   int sgid=group_id_/cluster->nworker_groups_per_server_group();
   CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid));
+  if(model.hogwild()){
+    updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
+        ->Create("Updater"));
+    updater_->Init(model.updater());
+  }
 }
 
 void Worker::Run(){
@@ -124,12 +129,17 @@ int Worker::Get(shared_ptr<Param> param, int step){
   return 1;
 }
 int Worker::Update(shared_ptr<Param> param, int step){
-  Msg* msg=new Msg();
-  msg->set_src(group_id_, worker_id_, kWorkerParam);
-  msg->set_dst(-1, -1, kStub);
-  msg->set_type(kUpdate);
-  msg->set_target(param->owner(), step);
-  param_dealer_->Send(&msg);
+  if(updater_){
+    updater_->Update(step, param);
+    param->set_version(param->version()+1);
+  }else{
+    Msg* msg=new Msg();
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    msg->set_dst(-1, -1, kStub);
+    msg->set_type(kUpdate);
+    msg->set_target(param->owner(), step);
+    param_dealer_->Send(&msg);
+  }
   return 1;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/806826eb/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index ac5566c..e616a1c 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -175,6 +175,116 @@ void Param::Init(int v){
   set_version(v);
 }
 
+/********************HogwildParam***************************/
+Msg* HogwildParam::GenPutMsg(void* arg){
+  char buf[128];
+  sprintf(buf, "%d %f %f %p", size(),
+      learning_rate_multiplier(), weight_decay_multiplier(), mutable_cpu_data());
+  Msg* msg=new Msg();
+  msg->set_type(kPut);
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
+  msg->add_frame(buf, strlen(buf));
+	return msg;
+}
+
+Msg* HogwildParam::GenGetMsg(void* arg){
+  Msg* msg=new Msg();
+  msg->set_type(kGet);
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
+  return msg;
+}
+
+Msg* HogwildParam::GenUpdateMsg(void* arg){
+  Msg* msg=new Msg();
+  msg->set_type(kUpdate);
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
+  void* p=mutable_cpu_grad();
+  msg->add_frame(p, sizeof(void*));
+  return msg;
+}
+
+Msg* HogwildParam::GenSyncMsg(void* arg){
+  return nullptr;
+}
+
+Msg* HogwildParam::HandlePutMsg(Msg** msg){
+  int size;
+  float lr, wc;
+  sscanf(static_cast<char*>((*msg)->frame_data()), "%d %f %f",
+      &size, &lr, &wc);
+  proto_.set_learning_rate_multiplier(lr);
+  proto_.set_weight_decay_multiplier(wc);
+  CHECK((*msg)->next_frame());
+  vector<int> shape{size};
+  // set pointer
+  data_=std::make_shared<Blob<float>>(shape);
+  data_->set_version((*msg)->target_second());
+  grad_.Reshape(shape);
+  history_.Reshape(shape);
+  delete (*msg);
+  *msg=nullptr;
+  return nullptr;
+}
+
+Msg* HogwildParam::HandleGetMsg(Msg** msg){
+  if((*msg)->target_second()<=version()){
+    (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
+    (*msg)->SwapAddr();
+    (*msg)->set_type(kRGet);
+  }
+  return *msg;
+}
+
+int HogwildParam::ParseUpdateMsg(Msg** msg){
+  delete (*msg);
+  *msg=nullptr;
+  return 1;
+}
+
+Msg* HogwildParam::GenUpdateResponseMsg(void* arg){
+  Msg* msg=new Msg();
+  msg->set_type(kRUpdate);
+  int v=version();
+  if(arg!=nullptr)
+    v=*(int*)arg;
+  msg->set_target(owner(), v);
+  return msg;
+}
+
+Msg* HogwildParam::HandleSyncMsg(Msg** msg){
+  delete *msg;
+  *msg=nullptr;
+  return nullptr;
+}
+
+int HogwildParam::ParseSyncResponseMsg(Msg** msg){
+  delete *msg;
+  *msg=nullptr;
+  return 1;
+}
+int HogwildParam::ParsePutResponseMsg(Msg **msg){
+  return ParseSyncResponseMsg(msg);
+}
+int HogwildParam::ParseGetResponseMsg(Msg **msg){
+  // must be set after all other settings are done!
+  set_version((*msg)->target_second());
+  delete *msg;
+  *msg=nullptr;
+  return 1;
+}
+int HogwildParam::ParseUpdateResponseMsg(Msg **msg){
+  return ParseGetResponseMsg(msg);
+}
+
 /**************************RandomSyncParam********************************
 const vector<int> RandomSyncParam::RandomSample(int seed, int m, int n){
   vector<int> samples(m);


Mime
View raw message