singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [4/5] incubator-singa git commit: SINGA-21 Code review 5
Date Wed, 23 Sep 2015 15:07:30 GMT
SINGA-21 Code review 5

review trainer.cc/h, driver.cc/.h, singa.h, main.cc
 - rewrite headers in driver.h
 - move template impl from driver.h to driver.cc
 - format code


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/366e6a82
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/366e6a82
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/366e6a82

Branch: refs/heads/master
Commit: 366e6a82684aff9c0b31e904e3c45dcca2163490
Parents: f50d293
Author: wang sheng <wangsheng1001@gmail.com>
Authored: Wed Sep 23 15:20:20 2015 +0800
Committer: wang sheng <wangsheng1001@gmail.com>
Committed: Wed Sep 23 15:28:43 2015 +0800

----------------------------------------------------------------------
 include/driver.h           |  45 +-------
 include/trainer/trainer.h  |  80 ++++++-------
 src/driver.cc              |  52 ++++++++-
 src/main.cc                |  10 +-
 src/neuralnet/neuralnet.cc |   4 +-
 src/trainer/trainer.cc     | 250 +++++++++++++++++++---------------------
 6 files changed, 211 insertions(+), 230 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/include/driver.h
----------------------------------------------------------------------
diff --git a/include/driver.h b/include/driver.h
index 7d15c98..563be77 100644
--- a/include/driver.h
+++ b/include/driver.h
@@ -22,7 +22,8 @@
 #ifndef SINGA_DRIVER_H_
 #define SINGA_DRIVER_H_
 
-#include "singa.h"
+#include "proto/job.pb.h"
+#include "proto/singa.pb.h"
 
 namespace singa {
 
@@ -119,48 +120,6 @@ class Driver {
   SingaProto singa_conf_;
 };
 
-template<typename Subclass, typename Type>
-int Driver::RegisterLayer(const Type& type) {
-  auto factory = Singleton<Factory<singa::Layer>>::Instance();
-  factory->Register(type, CreateInstance(Subclass, Layer));
-  return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterParam(const Type& type) {
-  auto factory = Singleton<Factory<singa::Param>>::Instance();
-  factory->Register(type, CreateInstance(Subclass, Param));
-  return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterParamGenerator(const Type& type) {
-  auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance();
-  factory->Register(type, CreateInstance(Subclass, ParamGenerator));
-  return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterUpdater(const Type& type) {
-  auto factory = Singleton<Factory<singa::Updater>>::Instance();
-  factory->Register(type, CreateInstance(Subclass, Updater));
-  return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterLRGenerator(const Type& type) {
-  auto factory = Singleton<Factory<singa::LRGenerator>>::Instance();
-  factory->Register(type, CreateInstance(Subclass, LRGenerator));
-  return 1;
-}
-
-template<typename Subclass, typename Type>
-int Driver::RegisterWorker(const Type& type) {
-  auto factory = Singleton<Factory<singa::Worker>>::Instance();
-  factory->Register(type, CreateInstance(Subclass, Worker));
-  return 1;
-}
-
 }  // namespace singa
 
 #endif  // SINGA_DRIVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index d3d332f..1c0e039 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -19,26 +19,24 @@
 *
 *************************************************************/
 
-#ifndef INCLUDE_TRAINER_TRAINER_H_
-#define INCLUDE_TRAINER_TRAINER_H_
+#ifndef SINGA_TRAINER_TRAINER_H_
+#define SINGA_TRAINER_TRAINER_H_
 
 #include <queue>
-#include <vector>
 #include <unordered_map>
+#include <vector>
+#include "communication/socket.h"
+#include "neuralnet/neuralnet.h"
 #include "proto/job.pb.h"
 #include "proto/singa.pb.h"
+#include "trainer/server.h"
+#include "trainer/worker.h"
+#include "utils/factory.h"
 #include "utils/param.h"
 #include "utils/singleton.h"
-#include "utils/factory.h"
-#include "neuralnet/neuralnet.h"
-#include "trainer/worker.h"
-#include "trainer/server.h"
-#include "communication/socket.h"
 
 namespace singa {
 
-using std::vector;
-  
 /**
  * Every running process has a training object which launches one or more
  * worker (and server) threads.
@@ -77,7 +75,7 @@ class Trainer{
    * @param jobConf
    * @return server instances
    */
-  vector<Server*> CreateServers(const JobProto& jobConf);
+  std::vector<Server*> CreateServers(const JobProto& jobConf);
   /**
    * Create workers instances.
    * @param nthread total num of threads in current procs which is used to
@@ -86,8 +84,7 @@ class Trainer{
    * @param jobConf
    * @return worker instances
    */
-  vector<Worker*> CreateWorkers(const JobProto& jobConf);
-
+  std::vector<Worker*> CreateWorkers(const JobProto& jobConf);
   /**
    * Setup workers and servers.
    *
@@ -98,12 +95,11 @@ class Trainer{
    * @param workers
    * @param servers
    */
-  void SetupWorkerServer(
-    const JobProto& jobConf,
-    const vector<Worker*>& workers,
-    const vector<Server*>& servers);
-
-  void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
+  void SetupWorkerServer(const JobProto& jobConf,
+                         const std::vector<Worker*>& workers,
+                         const std::vector<Server*>& servers);
+  void Run(const std::vector<Worker*>& workers,
+           const std::vector<Server*>& servers);
   /**
    * Display metrics to log (standard output)
    */
@@ -118,24 +114,20 @@ class Trainer{
    * Handle messages to local servers and local stub
    */
   void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg);
-
-	/**
-	 * Generate a request message to Get the parameter object.
-	 */
-	const vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
-	void HandleGetResponse(ParamEntry* entry, Msg** msg);
-
-	/**
-	 * Generate a request message to Update the parameter object.
-	 */
-	const vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
+  /**
+   * Generate a request message to Get the parameter object.
+   */
+  const std::vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
+  void HandleGetResponse(ParamEntry* entry, Msg** msg);
+  /**
+   * Generate a request message to Update the parameter object.
+   */
+  const std::vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
   void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
-
   /**
-	 * Generate a request message to Put the parameter object.
-	 */
-	const vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
-
+   * Generate a request message to Put the parameter object.
+   */
+  const std::vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
   /**
    * Called by HandlePut, HandleUpdate and HandleGet functions
    * @param type message type
@@ -145,7 +137,7 @@ class Trainer{
    * @param ret generated messages
    */
   void GenMsgs(int type, int version, ParamEntry* entry,
-    Msg* msg, vector<Msg*> *ret);
+    Msg* msg, std::vector<Msg*> *ret);
   /**
    * Get a hash id for a Param object from a group.
    *
@@ -157,13 +149,15 @@ class Trainer{
   }
 
  protected:
-  int procs_id_;
-  Router *router_;
+  int procs_id_ = -1;
+  Router *router_ = nullptr;
   std::unordered_map<int, ParamEntry*> worker_shard_;
   //!< map from slice to the server that updates it
-  vector<int> slice2server_;
-  //stub will destroy all neuralnets in the end
-  vector<NeuralNet*> nets_;
+  std::vector<int> slice2server_;
+  // a buffer of created nets, will destroy them all in destructor
+  std::vector<NeuralNet*> nets_;
 };
-} /* singa */
-#endif // INCLUDE_TRAINER_TRAINER_H_
+
+}  // namespace singa
+
+#endif  // SINGA_TRAINER_TRAINER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 28d21c2..42a1330 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -24,24 +24,27 @@
 #include <cblas.h>
 #include <glog/logging.h>
 #include <string>
+#include "neuralnet/neuralnet.h"
+#include "neuralnet/layer.h"
+#include "trainer/trainer.h"
+#include "utils/common.h"
+#include "utils/factory.h"
+#include "utils/singleton.h"
 #include "utils/tinydir.h"
 
 namespace singa {
 
 void Driver::Init(int argc, char **argv) {
   google::InitGoogleLogging(argv[0]);
-
   //  unique job ID generated from singa-run.sh, passed in as "-singa_job <id>"
   int arg_pos = ArgPos(argc, argv, "-singa_job");
   job_id_ = (arg_pos != -1) ? atoi(argv[arg_pos+1]) : -1;
-
   //  global signa conf passed by singa-run.sh as "-singa_conf <path>"
   arg_pos = ArgPos(argc, argv, "-singa_conf");
   if (arg_pos != -1)
     ReadProtoFromTextFile(argv[arg_pos+1], &singa_conf_);
   else
     ReadProtoFromTextFile("conf/singa.conf", &singa_conf_);
-
   //  job conf passed by users as "-conf <path>"
   arg_pos = ArgPos(argc, argv, "-conf");
   CHECK_NE(arg_pos, -1);
@@ -107,7 +110,47 @@ void Driver::Init(int argc, char **argv) {
   RegisterParamGenerator<UniformSqrtFanInOutGen>(kUniformSqrtFanInOut);
 }
 
+template<typename Subclass, typename Type>
+int Driver::RegisterLayer(const Type& type) {
+  auto factory = Singleton<Factory<singa::Layer>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, Layer));
+  return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterParam(const Type& type) {
+  auto factory = Singleton<Factory<singa::Param>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, Param));
+  return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterParamGenerator(const Type& type) {
+  auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, ParamGenerator));
+  return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterUpdater(const Type& type) {
+  auto factory = Singleton<Factory<singa::Updater>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, Updater));
+  return 1;
+}
 
+template<typename Subclass, typename Type>
+int Driver::RegisterLRGenerator(const Type& type) {
+  auto factory = Singleton<Factory<singa::LRGenerator>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, LRGenerator));
+  return 1;
+}
+
+template<typename Subclass, typename Type>
+int Driver::RegisterWorker(const Type& type) {
+  auto factory = Singleton<Factory<singa::Worker>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, Worker));
+  return 1;
+}
 
 void Driver::Submit(bool resume, const JobProto& jobConf) {
   if (singa_conf_.has_log_dir())
@@ -118,9 +161,8 @@ void Driver::Submit(bool resume, const JobProto& jobConf) {
     LOG(FATAL) << "workspace does not exist: " << jobConf.cluster().workspace();
   if (jobConf.num_openblas_threads() != 1)
     LOG(WARNING) << "openblas with "
-      << jobConf.num_openblas_threads() << " threads";
+                 << jobConf.num_openblas_threads() << " threads";
   openblas_set_num_threads(jobConf.num_openblas_threads());
-
   JobProto job;
   job.CopyFrom(jobConf);
   job.set_id(job_id_);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/main.cc
----------------------------------------------------------------------
diff --git a/src/main.cc b/src/main.cc
index 5e94de4..5d2ab2f 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -45,20 +45,20 @@
  */
 
 int main(int argc, char **argv) {
-  //  must create driver at the beginning and call its Init method.
+  // must create driver at the beginning and call its Init method.
   singa::Driver driver;
   driver.Init(argc, argv);
 
-  //  if -resume in argument list, set resume to true; otherwise false
+  // if -resume in argument list, set resume to true; otherwise false
   int resume_pos = singa::ArgPos(argc, argv, "-resume");
   bool resume = (resume_pos != -1);
 
-  //  users can register new subclasses of layer, updater, etc.
+  // users can register new subclasses of layer, updater, etc.
 
-  //  get the job conf, and custmize it if need
+  // get the job conf, and custmize it if need
   singa::JobProto jobConf = driver.job_conf();
 
-  //  submit the job
+  // submit the job
   driver.Submit(resume, jobConf);
   return 0;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index 200824a..775a5a7 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -19,10 +19,10 @@
 *
 *************************************************************/
 
+#include "neuralnet/neuralnet.h"
+
 #include <algorithm>
 #include <queue>
-
-#include "neuralnet/neuralnet.h"
 #include "utils/singleton.h"
 
 namespace singa {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/366e6a82/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 8a0589e..ecfc94a 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -19,25 +19,21 @@
 *
 *************************************************************/
 
-#include <thread>
-#include <vector>
-#include <map>
-#include <chrono>
+#include "trainer/trainer.h"
+
 #include <glog/logging.h>
-#include "utils/tinydir.h"
 #include <unistd.h>
+#include <map>
+#include <thread>
+#include "mshadow/tensor.h"
+#include "proto/common.pb.h"
 #include "utils/cluster.h"
 #include "utils/common.h"
-#include "proto/common.pb.h"
-#include "trainer/trainer.h"
-#include "mshadow/tensor.h"
-
+#include "utils/tinydir.h"
 
 namespace singa {
+
 using std::vector;
-using std::map;
-using std::queue;
-using namespace std::chrono;
 using std::string;
 
 /***********************Trainer****************************/
@@ -47,12 +43,82 @@ Trainer::~Trainer() {
     delete p;
 }
 
+void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) {
+  // register job to zookeeper at the beginning
+  auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster());
+  if (resume) Resume(job);
+  router_ = new Router();
+  router_->Bind(kInprocRouterEndpoint);
+  const string hostip = cluster->hostip();
+  int port = router_->Bind("tcp://" + hostip + ":*");
+  // register endpoint to zookeeper
+  cluster->Register(getpid(), hostip + ":" + std::to_string(port));
+  const vector<Worker*> workers = CreateWorkers(*job);
+  const vector<Server*> servers = CreateServers(*job);
+  SetupWorkerServer(*job, workers, servers);
+#ifdef USE_MPI
+  int nthreads = workers.size() + servers.size();
+  for (int i = 0; i < nthreads; i++)
+    MPIQueues.push_back(make_shared<SafeQueue>());
+#endif
+  vector<std::thread> threads;
+  for (auto server : servers)
+    threads.push_back(std::thread(&Server::Run, server));
+  for (auto worker : workers)
+    threads.push_back(std::thread(&Worker::Run, worker));
+  Run(workers, servers);
+  for (auto& thread : threads)
+    thread.join();
+  for (auto server : servers)
+    delete server;
+  for (auto worker : workers)
+    delete worker;
+}
+
+void Trainer::Resume(JobProto* jobConf) {
+  tinydir_dir dir;
+  string folder = Cluster::Get()->checkpoint_folder();
+  tinydir_open(&dir, folder.c_str());
+  int latest_step = 0;
+  // there would be multi checkpoint files (from diff workers) for one step
+  vector<string> ck_files;
+  // iterate all files to get the files for the last checkpoint
+  while (dir.has_next) {
+    tinydir_file file;
+    tinydir_readfile(&dir, &file);
+    tinydir_next(&dir);
+    char* ch = strstr(file.name, "step");
+    if (ch == nullptr) {
+      if (file.name[0] != '.')
+        LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
+      continue;
+    }
+    LOG(INFO) << "Add checkpoint file for resume: " << ch;
+    int step = atoi(ch+4);
+    if (step == latest_step) {
+      ck_files.push_back(file.name);
+    } else if (step > latest_step) {
+      latest_step = step;
+      ck_files.clear();
+      ck_files.push_back(string(file.name));
+    }
+  }
+  if (latest_step > 0) {
+    jobConf->set_step(latest_step);
+    if (!jobConf->has_reset_param_version())
+      jobConf->set_reset_param_version(false);
+    jobConf->clear_checkpoint_path();
+    for (auto ck_file : ck_files)
+      jobConf->add_checkpoint_path(folder + "/" + ck_file);
+  }
+  tinydir_close(&dir);
+}
+
 const vector<int> SliceParams(const vector<Param*>& params) {
   // for load-balance among servers in a group and among server groups
   int nserver_grps = Cluster::Get()->nserver_groups();
   int nservers_per_grp = Cluster::Get()->nservers_per_group();
   int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
-
   // collect sizes of unique Params
   std::vector<int> paramsize;
   for (auto param : params)
@@ -86,10 +152,9 @@ const vector<int> SliceParams(const vector<Param*>& params)
{
   return slices;
 }
 
-void Trainer::SetupWorkerServer(
-    const JobProto& job_conf,
-    const vector<Worker*>& workers,
-    const vector<Server*>& servers) {
+void Trainer::SetupWorkerServer(const JobProto& job_conf,
+                                const vector<Worker*>& workers,
+                                const vector<Server*>& servers) {
   auto cluster = Cluster::Get();
   int grp_size = cluster->nworkers_per_group();
   const auto& net_conf = job_conf.neuralnet();
@@ -97,7 +162,6 @@ void Trainer::SetupWorkerServer(
   nets_.push_back(net);
   // MUST do SliceParam before share param/net with others
   auto slices = SliceParams(net->params());
-
   std::unordered_map<int, NeuralNet*> grp_net;
   int first_grp = workers.size() ? workers.at(0)->grp_id() : -1;
   for (auto worker : workers) {
@@ -107,13 +171,17 @@ void Trainer::SetupWorkerServer(
     NeuralNet* valid_net = nullptr;
     if (grp_net.find(grp_id) == grp_net.end()) {
       if (grp_id == first_grp) {
-        //  test are performed only by the first group now. TODO update.
+        // test are performed only by the first group now.
+        // TODO(wangwei) update.
         if (first_grp == 0 && job_conf.test_steps() && worker_id == 0) {
-          test_net = NeuralNet::Create(net_conf, kTest, 1); // hard code for exp
+          // hard code for exp
+          // TODO(wangwei) move test unit out as an independent module
+          test_net = NeuralNet::Create(net_conf, kTest, 1);
           test_net->ShareParamsFrom(net);
           nets_.push_back(test_net);
         }
-        //  validation are performed only by the first group. TODO update.
+        // validation are performed only by the first group.
+        // TODO(wangwei) update.
         if (first_grp == 0 && job_conf.valid_steps() && worker_id == 0) {
           valid_net = NeuralNet::Create(net_conf, kValidation, 1);
           valid_net->ShareParamsFrom(net);
@@ -123,7 +191,7 @@ void Trainer::SetupWorkerServer(
       } else {
         grp_net[grp_id] = NeuralNet::Create(net_conf, kTrain, grp_size);
         nets_.push_back(grp_net[grp_id]);
-        if(cluster->share_memory())
+        if (cluster->share_memory())
           grp_net[grp_id]->ShareParamsFrom(net);
       }
       for (auto layer : grp_net[grp_id]->layers()) {
@@ -141,12 +209,10 @@ void Trainer::SetupWorkerServer(
               << worker->id() << " net " << grp_net[grp_id];
     worker->Setup(job_conf, grp_net[grp_id], valid_net, test_net);
   }
-
   //  partition among server groups, each group maintains one sub-set for sync
   auto slice2group = PartitionSlices(cluster->nserver_groups(), slices);
   //  partition within one server group, each server updates for one sub-set
   slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices);
-
   for (auto server : servers)
     server->Setup(job_conf.updater(), slice2group, slice2server_);
 }
@@ -156,14 +222,13 @@ vector<Server*> Trainer::CreateServers(const JobProto& job)
{
   vector<Server*> servers;
   if (!cluster->has_server())
     return servers;
-
   int server_procs = cluster->procs_id();
   // if true, server procs (logical) id starts after worker procs
   if (cluster->server_worker_separate())
     server_procs -= cluster->nworker_procs();
   const vector<int> rng = cluster->ExecutorRng(server_procs,
-      cluster->nservers_per_group(),
-      cluster->nservers_per_procs());
+                                               cluster->nservers_per_group(),
+                                               cluster->nservers_per_procs());
   int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
   for (int gid = gstart; gid < gend; gid++) {
     for (int sid = start; sid < end; sid++) {
@@ -174,15 +239,14 @@ vector<Server*> Trainer::CreateServers(const JobProto& job)
{
   return servers;
 }
 
-
 vector<Worker*> Trainer::CreateWorkers(const JobProto& job) {
-  auto cluster=Cluster::Get();
+  auto cluster = Cluster::Get();
   vector<Worker*> workers;
-  if(!cluster->has_worker())
+  if (!cluster->has_worker())
     return workers;
   const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(),
-      cluster->nworkers_per_group(),
-      cluster->nworkers_per_procs());
+                                               cluster->nworkers_per_group(),
+                                               cluster->nworkers_per_procs());
   int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3];
   for (int gid = gstart; gid < gend; gid++) {
     for (int wid = wstart; wid < wend; wid++) {
@@ -194,93 +258,13 @@ vector<Worker*> Trainer::CreateWorkers(const JobProto& job)
{
   return workers;
 }
 
-void Trainer::Resume(JobProto* jobConf) {
-  tinydir_dir dir;
-  string folder = Cluster::Get()->checkpoint_folder();
-  tinydir_open(&dir, folder.c_str());
-  int latest_step = 0;
-  // there would be multi checkpoint files (from diff workers) for one step
-  vector<string> ck_files;
-  // iterate all files to get the files for the last checkpoint
-  while (dir.has_next) {
-    tinydir_file file;
-    tinydir_readfile(&dir, &file);
-    tinydir_next(&dir);
-    char* ch = strstr(file.name, "step");
-    if (ch == nullptr) {
-      if (file.name[0] != '.')
-        LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
-      continue;
-    }
-
-    LOG(INFO) << "Add checkpoint file for resume: " << ch;
-    int step = atoi(ch+4);
-    if (step == latest_step) {
-      ck_files.push_back(file.name);
-    } else if(step > latest_step) {
-      latest_step = step;
-      ck_files.clear();
-      ck_files.push_back(string(file.name));
-    }
-  }
-
-  if (latest_step > 0) {
-    jobConf->set_step(latest_step);
-    if (!jobConf->has_reset_param_version())
-      jobConf->set_reset_param_version(false);
-    jobConf->clear_checkpoint_path();
-    for (auto ck_file : ck_files)
-      jobConf->add_checkpoint_path(folder + "/" + ck_file);
-  }
-  tinydir_close(&dir);
-}
-
-void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) {
-  // register job to zookeeper at the beginning
-  auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster());
-  if (resume)
-    Resume(job);
-
-  router_ = new Router();
-  router_->Bind(kInprocRouterEndpoint);
-  const string hostip = cluster->hostip();
-  int port = router_->Bind("tcp://" + hostip + ":*");
-  // register endpoint to zookeeper
-  cluster->Register(getpid(), hostip + ":" + std::to_string(port));
-
-  const vector<Worker*> workers = CreateWorkers(*job);
-  const vector<Server*> servers = CreateServers(*job);
-  SetupWorkerServer(*job, workers, servers);
-
-#ifdef USE_MPI
-  int nthreads = workers.size() + servers.size();
-  for (int i = 0; i < nthreads; i++)
-    MPIQueues.push_back(make_shared<SafeQueue>());
-#endif
-  vector<std::thread> threads;
-  for(auto server : servers)
-    threads.push_back(std::thread(&Server::Run, server));
-  for(auto worker : workers)
-    threads.push_back(std::thread(&Worker::Run, worker));
-  Run(workers, servers);
-  for(auto& thread : threads)
-    thread.join();
-  for(auto server : servers)
-    delete server;
-  for(auto worker : workers)
-    delete worker;
-}
-
-void Trainer::Run(
-    const vector<Worker*>& workers,
-    const vector<Server*>& servers) {
+void Trainer::Run(const vector<Worker*>& workers,
+                  const vector<Server*>& servers) {
   int nworkers = workers.size(), nservers = servers.size();
   auto cluster = Cluster::Get();
   procs_id_ = cluster->procs_id();
   LOG(INFO) << "Stub in process " << procs_id_ << " starts";
-
-  map<int, Dealer*> inter_dealers;  // for sending msg to other procs
-
+  std::map<int, Dealer*> inter_dealers;  // for sending msg to other procs
   std::queue<Msg*> msg_queue;
   while (true) {
     Msg* msg = nullptr;
@@ -343,26 +327,27 @@ Dealer* Trainer::CreateInterProcsDealer(int dst_procs) {
   // forward to other procs
   auto cluster = Cluster::Get();
   auto dealer = new Dealer();
-  while(cluster->endpoint(dst_procs)=="") {
-    //kCollectSleepTime));
+  while (cluster->endpoint(dst_procs) == "") {
+    // kCollectSleepTime));
     std::this_thread::sleep_for(std::chrono::milliseconds(3000));
-    LOG(ERROR)<<"waiting for procs "<< dst_procs<<" to register";
+    LOG(ERROR) << "waiting for procs " << dst_procs << " to register";
   }
   dealer->Connect("tcp://"+cluster->endpoint(dst_procs));
   return dealer;
 }
 
-void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) {
+void Trainer::HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg) {
   Msg* msgg = *msg;
   int paramid = ParamID(msgg->trgt_val());
   int type = msgg->type();
   int grp;
   ParamEntry *entry = nullptr;
-  switch (type) {  // TODO process other requests, e.g. RESTful
+  // TODO(wangwei) process other requests, e.g. RESTful
+  switch (type) {
     case kUpdate:
       grp = AddrGrp(msgg->src());
       entry = worker_shard_.at(Hash(grp, paramid));
-      for(auto update_msg : HandleUpdate(entry, msg))
+      for (auto update_msg : HandleUpdate(entry, msg))
         msg_queue->push(update_msg);
       break;
     case kRUpdate:
@@ -373,7 +358,7 @@ void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg)
{
     case kGet:
       grp = AddrGrp(msgg->src());
       entry = worker_shard_.at(Hash(grp, paramid));
-      for(auto get_msg : HandleGet(entry, msg))
+      for (auto get_msg : HandleGet(entry, msg))
         msg_queue->push(get_msg);
       break;
     case kRGet:
@@ -384,22 +369,22 @@ void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg)
{
     case kPut:
       grp = AddrGrp(msgg->src());
       entry = worker_shard_.at(Hash(grp, paramid));
-      for(auto put_msg : HandlePut(entry, msg))
+      for (auto put_msg : HandlePut(entry, msg))
         msg_queue->push(put_msg);
       break;
     default:
-      LOG(ERROR)<<"Unknow message type:"<<type;
+      LOG(ERROR) << "Unknow message type:" << type;
       break;
   }
 }
 
-void Trainer::GenMsgs(int type, int version, ParamEntry* entry,
-    Msg* msg, vector<Msg*> *ret) {
+void Trainer::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg,
+                      vector<Msg*> *ret) {
   int src_grp = AddrGrp(msg->src());
   int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group();
-  auto param=entry->shares.at(0);
+  auto param = entry->shares.at(0);
   for (int idx = 0 ; idx < param->num_slices(); idx++) {
-    int slice_id =param->slice_start() + idx;
+    int slice_id = param->slice_start() + idx;
     int server = slice2server_[slice_id];
     int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
     Msg* new_msg = nullptr;
@@ -440,10 +425,10 @@ const vector<Msg*> Trainer::HandleUpdate(ParamEntry *entry, Msg**
msg) {
     // average local gradient
     if (entry->num_local > 1) {
       auto it = entry->shares.begin();
-      auto shape=mshadow::Shape1((*it)->size());
-      mshadow::Tensor<mshadow::cpu,1> sum((*it)->mutable_cpu_grad(), shape);
+      auto shape = mshadow::Shape1((*it)->size());
+      mshadow::Tensor<mshadow::cpu, 1> sum((*it)->mutable_cpu_grad(), shape);
       for (++it; it != entry->shares.end(); it++) {
-        mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
+        mshadow::Tensor<mshadow::cpu, 1> grad((*it)->mutable_cpu_grad(), shape);
         sum += grad;
       }
     }
@@ -480,4 +465,5 @@ void Trainer::HandleUpdateResponse(ParamEntry* entry, Msg** msg) {
     param->set_version(version);
   DeleteMsg(msg);
 }
-} /* singa */
+
+}  // namespace singa


Mime
View raw message