singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [06/13] incubator-singa git commit: SINGA-70 Refactor API of Layer, Worker, Server and Driver
Date Sun, 27 Sep 2015 14:34:29 GMT
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
deleted file mode 100644
index ecfc94a..0000000
--- a/src/trainer/trainer.cc
+++ /dev/null
@@ -1,469 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "trainer/trainer.h"
-
-#include <glog/logging.h>
-#include <unistd.h>
-#include <map>
-#include <thread>
-#include "mshadow/tensor.h"
-#include "proto/common.pb.h"
-#include "utils/cluster.h"
-#include "utils/common.h"
-#include "utils/tinydir.h"
-
-namespace singa {
-
-using std::vector;
-using std::string;
-
-/***********************Trainer****************************/
-Trainer::~Trainer() {
-  delete router_;
-  for (NeuralNet* p : nets_)
-    delete p;
-}
-
-void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) {
-  // register job to zookeeper at the beginning
-  auto cluster = Cluster::Setup(job->id(), singaConf, job->cluster());
-  if (resume) Resume(job);
-  router_ = new Router();
-  router_->Bind(kInprocRouterEndpoint);
-  const string hostip = cluster->hostip();
-  int port = router_->Bind("tcp://" + hostip + ":*");
-  // register endpoint to zookeeper
-  cluster->Register(getpid(), hostip + ":" + std::to_string(port));
-  const vector<Worker*> workers = CreateWorkers(*job);
-  const vector<Server*> servers = CreateServers(*job);
-  SetupWorkerServer(*job, workers, servers);
-#ifdef USE_MPI
-  int nthreads = workers.size() + servers.size();
-  for (int i = 0; i < nthreads; i++)
-    MPIQueues.push_back(make_shared<SafeQueue>());
-#endif
-  vector<std::thread> threads;
-  for (auto server : servers)
-    threads.push_back(std::thread(&Server::Run, server));
-  for (auto worker : workers)
-    threads.push_back(std::thread(&Worker::Run, worker));
-  Run(workers, servers);
-  for (auto& thread : threads)
-    thread.join();
-  for (auto server : servers)
-    delete server;
-  for (auto worker : workers)
-    delete worker;
-}
-
-void Trainer::Resume(JobProto* jobConf) {
-  tinydir_dir dir;
-  string folder = Cluster::Get()->checkpoint_folder();
-  tinydir_open(&dir, folder.c_str());
-  int latest_step = 0;
-  // there would be multi checkpoint files (from diff workers) for one step
-  vector<string> ck_files;
-  // iterate all files to get the files for the last checkpoint
-  while (dir.has_next) {
-    tinydir_file file;
-    tinydir_readfile(&dir, &file);
-    tinydir_next(&dir);
-    char* ch = strstr(file.name, "step");
-    if (ch == nullptr) {
-      if (file.name[0] != '.')
-        LOG(INFO) << "Irregular file in checkpoint folder: " << file.name;
-      continue;
-    }
-    LOG(INFO) << "Add checkpoint file for resume: " << ch;
-    int step = atoi(ch+4);
-    if (step == latest_step) {
-      ck_files.push_back(file.name);
-    } else if (step > latest_step) {
-      latest_step = step;
-      ck_files.clear();
-      ck_files.push_back(string(file.name));
-    }
-  }
-  if (latest_step > 0) {
-    jobConf->set_step(latest_step);
-    if (!jobConf->has_reset_param_version())
-      jobConf->set_reset_param_version(false);
-    jobConf->clear_checkpoint_path();
-    for (auto ck_file : ck_files)
-      jobConf->add_checkpoint_path(folder + "/" + ck_file);
-  }
-  tinydir_close(&dir);
-}
-
-const vector<int> SliceParams(const vector<Param*>& params) {
-  // for load-balance among servers in a group and among server groups
-  int nserver_grps = Cluster::Get()->nserver_groups();
-  int nservers_per_grp = Cluster::Get()->nservers_per_group();
-  int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
-  // collect sizes of unique Params
-  std::vector<int> paramsize;
-  for (auto param : params)
-    if (param->id() == param->owner())
-      paramsize.push_back(param->size());
-  // slice into lcm pieces to achieve good load-balance for both intra-group
-  // partition (among servers in a group) and inter-group partition (each group
-  // is assgined a sub-set of slices)
-  auto param_slice = Slice(lcm, paramsize);
-  // construct map from Param ID to its slices <slice id, len>
-  std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
-  vector<int> slices;
-  auto it = param_slice.begin();
-  int slice_id = 0;
-  for (auto param : params) {
-    if (param->id() == param->owner()) {
-      for (int len : *it) {
-        slices.push_back(len);
-        paramid2slices[param->id()].push_back(std::make_pair(slice_id++, len));
-      }
-      it++;
-    }
-  }
-  // add slice info for every Param
-  for (auto param : params)
-    for (auto entry : paramid2slices[param->owner()]) {
-      param->AddSlice(entry.first, entry.second);
-      LOG(INFO) << "param id " << param->id() << " owner=" << param->owner()
-        << ": " << entry.first << ", " << entry.second;
-    }
-  return slices;
-}
-
-void Trainer::SetupWorkerServer(const JobProto& job_conf,
-                                const vector<Worker*>& workers,
-                                const vector<Server*>& servers) {
-  auto cluster = Cluster::Get();
-  int grp_size = cluster->nworkers_per_group();
-  const auto& net_conf = job_conf.neuralnet();
-  auto net = NeuralNet::Create(net_conf, kTrain, grp_size);
-  nets_.push_back(net);
-  // MUST do SliceParam before share param/net with others
-  auto slices = SliceParams(net->params());
-  std::unordered_map<int, NeuralNet*> grp_net;
-  int first_grp = workers.size() ? workers.at(0)->grp_id() : -1;
-  for (auto worker : workers) {
-    int grp_id = worker->grp_id();
-    int worker_id = worker->id();
-    NeuralNet* test_net = nullptr;
-    NeuralNet* valid_net = nullptr;
-    if (grp_net.find(grp_id) == grp_net.end()) {
-      if (grp_id == first_grp) {
-        // test are performed only by the first group now.
-        // TODO(wangwei) update.
-        if (first_grp == 0 && job_conf.test_steps() && worker_id == 0) {
-          // hard code for exp
-          // TODO(wangwei) move test unit out as an independent module
-          test_net = NeuralNet::Create(net_conf, kTest, 1);
-          test_net->ShareParamsFrom(net);
-          nets_.push_back(test_net);
-        }
-        // validation are performed only by the first group.
-        // TODO(wangwei) update.
-        if (first_grp == 0 && job_conf.valid_steps() && worker_id == 0) {
-          valid_net = NeuralNet::Create(net_conf, kValidation, 1);
-          valid_net->ShareParamsFrom(net);
-          nets_.push_back(valid_net);
-        }
-        grp_net[grp_id] = net;
-      } else {
-        grp_net[grp_id] = NeuralNet::Create(net_conf, kTrain, grp_size);
-        nets_.push_back(grp_net[grp_id]);
-        if (cluster->share_memory())
-          grp_net[grp_id]->ShareParamsFrom(net);
-      }
-      for (auto layer : grp_net[grp_id]->layers()) {
-        bool local = layer->partition_id() >= workers.front()->id()
-          && layer->partition_id() <= workers.back()->id();
-        for (auto param : layer->GetParams()) {
-          int hash = Hash(grp_id, param->owner());
-          if (worker_shard_.find(hash) == worker_shard_.end())
-            worker_shard_[hash] = new ParamEntry();
-          worker_shard_[hash]->AddParam(local, param);
-        }
-      }
-    }
-    LOG(INFO) << "grp " << worker->grp_id() << ", worker "
-              << worker->id() << " net " << grp_net[grp_id];
-    worker->Setup(job_conf, grp_net[grp_id], valid_net, test_net);
-  }
-  //  partition among server groups, each group maintains one sub-set for sync
-  auto slice2group = PartitionSlices(cluster->nserver_groups(), slices);
-  //  partition within one server group, each server updates for one sub-set
-  slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices);
-  for (auto server : servers)
-    server->Setup(job_conf.updater(), slice2group, slice2server_);
-}
-
-vector<Server*> Trainer::CreateServers(const JobProto& job) {
-  auto cluster = Cluster::Get();
-  vector<Server*> servers;
-  if (!cluster->has_server())
-    return servers;
-  int server_procs = cluster->procs_id();
-  // if true, server procs (logical) id starts after worker procs
-  if (cluster->server_worker_separate())
-    server_procs -= cluster->nworker_procs();
-  const vector<int> rng = cluster->ExecutorRng(server_procs,
-                                               cluster->nservers_per_group(),
-                                               cluster->nservers_per_procs());
-  int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
-  for (int gid = gstart; gid < gend; gid++) {
-    for (int sid = start; sid < end; sid++) {
-      auto server = new Server(gid, sid);
-      servers.push_back(server);
-    }
-  }
-  return servers;
-}
-
-vector<Worker*> Trainer::CreateWorkers(const JobProto& job) {
-  auto cluster = Cluster::Get();
-  vector<Worker*> workers;
-  if (!cluster->has_worker())
-    return workers;
-  const vector<int> rng = cluster->ExecutorRng(cluster->procs_id(),
-                                               cluster->nworkers_per_group(),
-                                               cluster->nworkers_per_procs());
-  int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3];
-  for (int gid = gstart; gid < gend; gid++) {
-    for (int wid = wstart; wid < wend; wid++) {
-      auto *worker = Worker::Create(job);
-      worker->Init(gid, wid);
-      workers.push_back(worker);
-    }
-  }
-  return workers;
-}
-
-void Trainer::Run(const vector<Worker*>& workers,
-                  const vector<Server*>& servers) {
-  int nworkers = workers.size(), nservers = servers.size();
-  auto cluster = Cluster::Get();
-  procs_id_ = cluster->procs_id();
-  LOG(INFO) << "Stub in process " << procs_id_ << " starts";
-  std::map<int, Dealer*> inter_dealers;  // for sending msg to other procs
-  std::queue<Msg*> msg_queue;
-  while (true) {
-    Msg* msg = nullptr;
-    if (msg_queue.empty()) {
-      msg = router_->Receive();
-    } else {
-      msg = msg_queue.front();
-      msg_queue.pop();
-    }
-    int type = msg->type(), dst = msg->dst(), flag = AddrType(dst);
-    if (flag == kStub && (AddrProc(dst) == procs_id_ || AddrGrp(dst) == -1)) {
-      //  the following statements are ordered!
-      if (type == kConnect) {
-        DeleteMsg(&msg);
-      } else if (type == kMetric) {
-        DisplayMetric(&msg);
-      } else if (type == kStop) {
-        int src_flag = AddrType(msg->src());
-        if (src_flag == kServer) nservers--;
-        else if (src_flag == kWorkerParam) nworkers--;
-        DeleteMsg(&msg);
-        if (nworkers == 0 && nservers == 0) break;
-      } else {
-        HandleLocalMsg(&msg_queue, &msg);
-      }
-    } else {
-      int dst_procs = AddrProc(dst);
-      if (flag != kStub)
-        dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag);
-      if (dst_procs != procs_id_) {
-        if (inter_dealers.find(dst_procs) == inter_dealers.end())
-          inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs);
-        inter_dealers[dst_procs]->Send(&msg);
-      } else {
-        router_->Send(&msg);
-      }
-    }
-  }
-  LOG(ERROR) << "Stub in process " << procs_id_ << " stops";
-  for (auto& entry : inter_dealers)
-    delete entry.second;
-}
-
-void Trainer::DisplayMetric(Msg** msg) {
-  Msg* msgg = *msg;
-  // only display metrics from the first group
-  if (AddrGrp(msgg->src()) == 0) {
-    int step = msgg->trgt_version();
-    char prefix[128];
-    msgg->ParseFormatFrame("s", prefix);
-    CHECK(msgg->NextFrame());
-    const string perf(static_cast<char*>(msgg->FrameData()), msgg->FrameSize());
-    Metric cur(perf);
-    LOG(ERROR) << prefix << " step-" << step <<", " << cur.ToLogString();
-  }
-  DeleteMsg(msg);
-}
-
-Dealer* Trainer::CreateInterProcsDealer(int dst_procs) {
-  // forward to other procs
-  auto cluster = Cluster::Get();
-  auto dealer = new Dealer();
-  while (cluster->endpoint(dst_procs) == "") {
-    // kCollectSleepTime));
-    std::this_thread::sleep_for(std::chrono::milliseconds(3000));
-    LOG(ERROR) << "waiting for procs " << dst_procs << " to register";
-  }
-  dealer->Connect("tcp://"+cluster->endpoint(dst_procs));
-  return dealer;
-}
-
-void Trainer::HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg) {
-  Msg* msgg = *msg;
-  int paramid = ParamID(msgg->trgt_val());
-  int type = msgg->type();
-  int grp;
-  ParamEntry *entry = nullptr;
-  // TODO(wangwei) process other requests, e.g. RESTful
-  switch (type) {
-    case kUpdate:
-      grp = AddrGrp(msgg->src());
-      entry = worker_shard_.at(Hash(grp, paramid));
-      for (auto update_msg : HandleUpdate(entry, msg))
-        msg_queue->push(update_msg);
-      break;
-    case kRUpdate:
-      grp = AddrGrp(msgg->dst());
-      entry = worker_shard_.at(Hash(grp, paramid));
-      HandleUpdateResponse(entry, msg);
-      break;
-    case kGet:
-      grp = AddrGrp(msgg->src());
-      entry = worker_shard_.at(Hash(grp, paramid));
-      for (auto get_msg : HandleGet(entry, msg))
-        msg_queue->push(get_msg);
-      break;
-    case kRGet:
-      grp = AddrGrp(msgg->dst());
-      entry = worker_shard_.at(Hash(grp, paramid));
-      HandleGetResponse(entry, msg);
-      break;
-    case kPut:
-      grp = AddrGrp(msgg->src());
-      entry = worker_shard_.at(Hash(grp, paramid));
-      for (auto put_msg : HandlePut(entry, msg))
-        msg_queue->push(put_msg);
-      break;
-    default:
-      LOG(ERROR) << "Unknow message type:" << type;
-      break;
-  }
-}
-
-void Trainer::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg,
-                      vector<Msg*> *ret) {
-  int src_grp = AddrGrp(msg->src());
-  int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group();
-  auto param = entry->shares.at(0);
-  for (int idx = 0 ; idx < param->num_slices(); idx++) {
-    int slice_id = param->slice_start() + idx;
-    int server = slice2server_[slice_id];
-    int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
-    Msg* new_msg = nullptr;
-    if (type == kPut) {
-      CHECK_GT(entry->num_total, 0);
-      new_msg = param->GenPutMsg(dst_procs != procs_id_, idx);
-      new_msg->AddFormatFrame("i", entry->num_total);
-    } else if (type == kGet) {
-      new_msg = param->GenGetMsg(dst_procs != procs_id_, idx);
-    } else if (type == kUpdate) {
-      new_msg = param->GenUpdateMsg(dst_procs != procs_id_, idx);
-      new_msg->AddFormatFrame("i", entry->num_local);
-    } else {
-      LOG(FATAL) << "Wrong type";
-    }
-    new_msg->set_trgt(ParamTrgt(param->owner(), slice_id), version);
-    new_msg->set_src(Addr(src_grp, procs_id_, kStub));
-    new_msg->set_dst(Addr(dst_grp, server, kServer));
-    ret->push_back(new_msg);
-  }
-}
-
-const vector<Msg*> Trainer::HandleGet(ParamEntry* entry, Msg** msg) {
-  vector<Msg*> ret;
-  int version = (*msg)->trgt_version();
-  if (version > entry->next_version) {
-    entry->next_version = version;
-    GenMsgs(kGet, version, entry, *msg, &ret);
-  }
-  DeleteMsg(msg);
-  return ret;
-}
-
-const vector<Msg*> Trainer::HandleUpdate(ParamEntry *entry, Msg** msg) {
-  vector<Msg*> ret;
-  entry->num_update++;
-  if (entry->num_update >= entry->num_local) {
-    // average local gradient
-    if (entry->num_local > 1) {
-      auto it = entry->shares.begin();
-      auto shape = mshadow::Shape1((*it)->size());
-      mshadow::Tensor<mshadow::cpu, 1> sum((*it)->mutable_cpu_grad(), shape);
-      for (++it; it != entry->shares.end(); it++) {
-        mshadow::Tensor<mshadow::cpu, 1> grad((*it)->mutable_cpu_grad(), shape);
-        sum += grad;
-      }
-    }
-    int step = (*msg)->trgt_version();
-    GenMsgs(kUpdate, step, entry, *msg, &ret);
-    entry->num_update = 0;
-  }
-  DeleteMsg(msg);
-  return ret;
-}
-
-const vector<Msg*> Trainer::HandlePut(ParamEntry* entry, Msg** msg) {
-  vector<Msg*> ret;
-  int version = (*msg)->trgt_version();
-  GenMsgs(kPut, version, entry, *msg, &ret);
-  DeleteMsg(msg);
-  return ret;
-}
-
-void Trainer::HandleGetResponse(ParamEntry* entry, Msg** msg) {
-  int version = (*msg)->trgt_version();
-  int sliceid = SliceID((*msg)->trgt_val());
-  auto param = entry->shares.at(0);
-  if (param->ParseGetResponseMsg(*msg, sliceid-param->slice_start()))
-    param->set_version(version);
-  DeleteMsg(msg);
-}
-
-void Trainer::HandleUpdateResponse(ParamEntry* entry, Msg** msg) {
-  int version = (*msg)->trgt_version();
-  int sliceid = SliceID((*msg)->trgt_val());
-  auto param = entry->shares.at(0);
-  if (param->ParseUpdateResponseMsg(*msg, sliceid-param->slice_start()))
-    param->set_version(version);
-  DeleteMsg(msg);
-}
-
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
deleted file mode 100644
index 70859de..0000000
--- a/src/trainer/worker.cc
+++ /dev/null
@@ -1,411 +0,0 @@
-/************************************************************
-*
-* Licensed to the Apache Software Foundation (ASF) under one
-* or more contributor license agreements.  See the NOTICE file
-* distributed with this work for additional information
-* regarding copyright ownership.  The ASF licenses this file
-* to you under the Apache License, Version 2.0 (the
-* "License"); you may not use this file except in compliance
-* with the License.  You may obtain a copy of the License at
-* 
-*   http://www.apache.org/licenses/LICENSE-2.0
-* 
-* Unless required by applicable law or agreed to in writing,
-* software distributed under the License is distributed on an
-* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-* KIND, either express or implied.  See the License for the
-* specific language governing permissions and limitations
-* under the License.
-*
-*************************************************************/
-
-#include "trainer/worker.h"
-
-#include <glog/logging.h>
-#include <chrono>
-#include <thread>
-#include <typeinfo>
-#include "utils/cluster.h"
-#include "utils/factory.h"
-#include "utils/singleton.h"
-
-namespace singa {
-
-using std::string;
-
-Worker* Worker::Create(const JobProto& proto) {
-  auto factory = Singleton<Factory<singa::Worker>>::Instance();
-  Worker* worker = nullptr;
-  const auto& conf = proto.train_one_batch();
-  if (conf.has_user_alg())
-    worker = factory->Create(conf.user_alg());
-  else
-    worker = factory->Create(conf.alg());
-  return worker;
-}
-
-void Worker::Init(int grp_id, int id) {
-  grp_id_ = grp_id;
-  id_ = id;
-  layer_dealer_ = dealer_ = nullptr;
-}
-
-Worker::~Worker() {
-  if (layer_dealer_)
-    delete layer_dealer_;
-  if (dealer_)
-    delete dealer_;
-}
-
-void Worker::Setup(const JobProto& job, NeuralNet* train_net,
-                   NeuralNet* valid_net, NeuralNet* test_net) {
-  job_conf_.CopyFrom(job);
-  train_net_ = train_net;
-  validation_net_ = valid_net;
-  test_net_ = test_net;
-}
-
-void Worker::InitLocalParams() {
-  // for each server grp, its first subscriber worker grp does the param init
-  if (grp_id_ % Cluster::Get()->nworker_groups_per_server_group() == 0) {
-    // extract params that should be initialized by this worker
-    // must gen a name for each param if the user doesn't config it
-    std::unordered_map<string, Param*> name2param;
-    for (auto layer : train_net_->layers()) {
-      if (layer->partition_id() == id_) {
-        for (auto param : layer->GetParams()) {
-          // only owners fill the memory of parameter values.
-          if (param->owner() == param->id()) {
-            CHECK(name2param.find(param->name()) == name2param.end());
-            name2param[param->name()] = param;
-          }
-        }
-      }
-    }
-    // load from checkpoints. get param blob based on param name.
-    // the param from previous checkpoint files will be overwritten by
-    // the param with the same name in later checkpoint files.
-    for (const auto checkpoint : job_conf_.checkpoint_path()) {
-      LOG(ERROR) << "Load from checkpoint file " << checkpoint;
-      BlobProtos bps;
-      ReadProtoFromBinaryFile(checkpoint.c_str(), &bps);
-      for (int i = 0; i < bps.name_size(); i++) {
-        if (name2param.find(bps.name(i)) != name2param.end()) {
-          name2param.at(bps.name(i))->FromProto(bps.blob(i));
-          //  if load from pre-training params, reset version to start step
-          if (job_conf_.reset_param_version())
-            name2param.at(bps.name(i))->set_version(job_conf_.step());
-          else  // if resume training, use the same version as last checkpoint
-            name2param.at(bps.name(i))->set_version(bps.version(i));
-        }
-      }
-    }
-    // init other params who do not have checkpoint version
-    for (auto entry : name2param)
-      if (entry.second->version() < 0) {
-        entry.second->InitValues(job_conf_.step());
-        if (!job_conf_.reset_param_version())
-          LOG(ERROR) << "better reset version of params from checkpoints "
-            << "to the same as other newly initialized params!";
-      }
-
-    Metric perf;
-    // warmup training before put params to servers
-    for (; step_ < job_conf_.warmup_steps(); step_++)
-      TrainOneBatch(step_, &perf);
-    for (auto layer : train_net_->layers()) {
-      if (layer->partition_id() == id_)
-        for (auto param : layer->GetParams())
-          if (param->owner() == param->id())
-            Put(param, param->version());
-    }
-  }
-  // wait owners in the same procs init params, then no get requests sent
-  std::this_thread::sleep_for(std::chrono::milliseconds(1000));
-  for (auto layer : train_net_->layers()) {
-    if (layer->partition_id() == id_)
-      for (auto param : layer->GetParams())
-        Get(param, job_conf_.warmup_steps());
-  }
-}
-
-void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity) {
-  dealer->Connect(kInprocRouterEndpoint);
-  Msg* ping = new Msg(Addr(grp, id, entity), Addr(-1, -1, kStub));
-  ping->set_type(kConnect);
-  dealer->Send(&ping);
-}
-
-void Worker::Run() {
-  LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") start";
-  auto cluster = Cluster::Get();
-  int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
-  CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));
-  // TODO(wangsh): provide a unique sock id from cluster
-  dealer_ = new Dealer(0);
-  ConnectStub(grp_id_, id_, dealer_, kWorkerParam);
-  for (auto layer : train_net_->layers()) {
-    if (layer->partition_id() == id_) {
-      if (typeid(layer) == typeid(BridgeDstLayer)
-          || typeid(layer) == typeid(BridgeSrcLayer)) {
-        // TODO(wangsh): provide a unique socket id from cluster
-        layer_dealer_ = new Dealer(1);
-        ConnectStub(grp_id_, id_, layer_dealer_, kWorkerLayer);
-        break;
-      }
-    }
-  }
-
-  step_ = job_conf_.step();
-  InitLocalParams();
-  Metric perf;
-  while (!StopNow(step_)) {
-    if (ValidateNow(step_) && validation_net_ != nullptr) {
-      // LOG(ERROR)<<"Validation at step "<<step;
-      CollectAll(validation_net_, step_);
-      Test(job_conf_.valid_steps(), kValidation, validation_net_);
-    }
-    if (TestNow(step_) && test_net_ != nullptr) {
-      // LOG(ERROR)<<"Test at step "<<step;
-      CollectAll(test_net_, step_);
-      Test(job_conf_.test_steps(), kTest, test_net_);
-    }
-    if (CheckpointNow(step_)) {
-      CollectAll(train_net_, step_);
-      Checkpoint(step_, train_net_);
-      job_conf_.set_step(step_);
-    }
-    TrainOneBatch(step_, &perf);
-    // LOG(ERROR) << "Train " << step_;
-    if (DisplayNow(step_)) {
-      Report("Train", perf);
-      perf.Reset();
-    }
-    step_++;
-  }
-
-  // save the model
-  Checkpoint(step_, train_net_);
-  // clean up
-  cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp);
-  // notify the stub on worker stop
-  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
-  msg->set_type(kStop);
-  dealer_->Send(&msg);  // use param dealer to send the stop msg
-  LOG(ERROR) << "Worker (group = " <<grp_id_ << ", id = " << id_ << ") stops";
-}
-
-void Worker::Checkpoint(int step, NeuralNet* net) {
-  if (grp_id_ == 0) {
-    BlobProtos bps;
-    for (auto layer : net->layers()) {
-      if (layer->partition_id() == id_) {
-        for (auto param : layer->GetParams()) {
-          // only owners fill the memory of parameter values.
-          if (param->owner() == param->id()) {
-            auto *blob = bps.add_blob();
-            param->ToProto(blob);
-            bps.add_version(param->version());
-            bps.add_name(param->name());
-          }
-        }
-      }
-    }
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%s/step%d-worker%d.bin",
-             Cluster::Get()->checkpoint_folder().c_str(), step, id_);
-    LOG(INFO) << "checkpoint to " << buf;
-    WriteProtoToBinaryFile(bps, buf);
-  }
-}
-
-int Worker::Put(Param* param, int step) {
-  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
-  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
-  msg->set_type(kPut);
-  dealer_->Send(&msg);
-  return 1;
-}
-
-int Worker::Get(Param* param, int step) {
-  if (param->version() >= step)
-    return 1;
-  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
-  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
-  msg->set_type(kGet);
-  dealer_->Send(&msg);
-  return 1;
-}
-
-int Worker::Update(Param* param, int step) {
-  param->set_local_version(param->version());
-  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
-  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
-  msg->set_type(kUpdate);
-  dealer_->Send(&msg);
-  return 1;
-}
-
-int Worker::CollectAll(NeuralNet* net, int step) {
-  auto& layers = net->layers();
-  for (auto& layer : layers) {
-    if (layer->partition_id() == id_) {
-      for (Param* p : layer->GetParams()) {
-        Collect(p, step);
-      }
-    }
-  }
-  return 1;
-}
-
-int Worker::Collect(Param* param, int step) {
-  while (param->version() <= param->local_version())
-    std::this_thread::sleep_for(std::chrono::milliseconds(kCollectSleepTime));
-  return 1;
-}
-
-void Worker::Report(const string& prefix, const Metric & perf) {
-  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
-  msg->set_trgt(0, step_);
-  msg->set_type(kMetric);
-  const string disp = perf.ToString();
-  msg->AddFormatFrame("s", prefix.c_str());
-  msg->AddFrame(disp.c_str(), disp.length());
-  dealer_->Send(&msg);
-}
-
-void Worker::ReceiveBlobs(bool data, bool grad, BridgeLayer* layer,
-                          NeuralNet* net) {
-  while (!layer->ready()) {
-    auto msg = layer_dealer_->Receive();
-    CHECK_EQ(AddrGrp(msg->src()), grp_id_);
-    string name(static_cast<char*>(msg->FrameData()), msg->FrameSize());
-    auto receive_layer = net->name2layer(name);
-    auto data = receive_layer->mutable_data(nullptr);
-    msg->NextFrame();
-    memcpy(data->mutable_cpu_data(), msg->FrameData(), msg->FrameSize());
-    dynamic_cast<BridgeLayer*>(receive_layer)->set_ready(true);
-    delete msg;
-  }
-}
-
-void Worker::SendBlobs(bool data, bool grad, BridgeLayer* layer,
-                       NeuralNet* net) {
-  auto dst = layer->dstlayers().at(0);
-  Msg *msg = new Msg();
-  msg->set_src(Addr(grp_id_, id_, kWorkerLayer));
-  msg->set_dst(Addr(grp_id_, dst->partition_id(), kWorkerLayer));
-  msg->AddFrame(dst->name().c_str(), dst->name().length());
-  auto const & blob = layer->data(nullptr);
-  msg->AddFrame(blob.cpu_data(), blob.count() * sizeof(float));
-  layer_dealer_->Send(&msg);
-}
-
-void Worker::Test(int nsteps, Phase phase, NeuralNet* net) {
-  Metric perf;
-  for (int step = 0; step < nsteps; step++)
-    TestOneBatch(step, phase, net, &perf);
-  if (phase == kValidation)
-    Report("Validation", perf);
-  else if (phase == kTest)
-    Report("Test", perf);
-}
-
-/****************************BPWorker**********************************/
-void BPWorker::TrainOneBatch(int step, Metric* perf) {
-  Forward(step, kTrain, train_net_, perf);
-  Backward(step, train_net_);
-}
-
-void BPWorker::TestOneBatch(int step, Phase phase, NeuralNet* net,
-                            Metric* perf) {
-  Forward(step, phase, net, perf);
-}
-
-void BPWorker::Forward(int step, Phase phase, NeuralNet* net, Metric* perf) {
-  for (auto& layer : net->layers()) {
-    if (layer->partition_id() == id_) {
-      // TODO(wangwei): enable this for model partition
-      // recv data from other workers
-      // if (typeid(*layer) == typeid(BridgeDstLayer))
-      //   ReceiveBlobs(true, false, dynamic_cast<BridgeLayer*>(layer), net);
-      if (phase == kTrain) {
-        // wait until param is updated
-        for (Param* p : layer->GetParams()) {
-          Collect(p, step);
-        }
-      }
-      layer->ComputeFeature(phase | kForward, perf);
-      // TODO(wangwei): enable this for model partition
-      // send data to other workers
-      // if (typeid(*layer) == typeid(BridgeSrcLayer))
-      //   SendBlobs(true, false, dynamic_cast<BridgeLayer*>(layer), net);
-      if (DisplayDebugInfo(step))
-        LOG(INFO) << layer->DebugString(step, phase | kForward);
-    }
-  }
-}
-
-void BPWorker::Backward(int step, NeuralNet* net) {
-  auto& layers = net->layers();
-  for (auto it = layers.rbegin(); it != layers.rend(); it++) {
-    Layer* layer = *it;
-    if (layer->partition_id() == id_) {
-      // TODO(wangwei): enable this for model partition
-      // send data to other workers
-      // if (typeid(layer) == typeid(BridgeSrcLayer))
-      //   ReceiveBlobs(false, true, layer, net);
-      layer->ComputeGradient(kTrain | kBackward, nullptr);
-      if (DisplayDebugInfo(step))
-        LOG(INFO) << layer->DebugString(step, kTrain | kBackward);
-      for (Param* p : layer->GetParams())
-        Update(p, step);
-      // TODO(wangwei): enable this for model partition
-      // recv data from other workers
-      // if (typeid(layer) == typeid(BridgeDstLayer))
-      //   SendBlobs(false, true, dynamic_cast<BridgeDstLayer*>(layer), net);
-    }
-  }
-}
-
-/****************************CDWorker**********************************/
-void CDWorker::TrainOneBatch(int step, Metric* perf) {
-  const auto& layers = train_net_->layers();
-  for (auto* layer : layers) {
-    for (Param* p : layer->GetParams())  // wait until param is updated
-      Collect(p, step);
-    layer->ComputeFeature(kPositive, perf);
-  }
-  for (auto* layer : layers)
-    if (typeid(*layer) == typeid(RBMVisLayer)
-          || typeid(*layer) == typeid(RBMHidLayer))
-      layer->ComputeFeature(kNegative | kTest, perf);
-  for (int i = 1; i < job_conf_.train_one_batch().cd_conf().cd_k(); i++) {
-    for (auto* layer : layers) {
-      if (typeid(*layer) == typeid(RBMVisLayer)
-          || typeid(*layer) == typeid(RBMHidLayer))
-      layer->ComputeFeature(kNegative, perf);
-    }
-  }
-  for (auto* layer : layers) {
-    if (typeid(*layer) == typeid(RBMVisLayer)
-        || typeid(*layer) == typeid(RBMHidLayer)) {
-      layer->ComputeGradient(kTrain, nullptr);
-      for (Param* p : layer->GetParams()) {
-        Update(p, step);
-      }
-    }
-  }
-}
-
-void CDWorker::TestOneBatch(int step, Phase phase, NeuralNet* net,
-                            Metric* perf) {
-  auto& layers = net->layers();
-  for (auto *layer : layers)
-    layer->ComputeFeature(kPositive, perf);
-  for (auto *layer : layers)
-    if (typeid(*layer) == typeid(RBMVisLayer))
-      layer->ComputeFeature(kNegative | kTest, perf);
-}
-
-}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index 3b09417..c3cdc62 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -38,7 +38,7 @@ Cluster* Cluster::Setup(int job, const SingaProto& singaConf,
 Cluster* Cluster::Get() {
   if (!Singleton<Cluster>::Instance()->nprocs_) {
     LOG(ERROR) << "The first call to Get should "
-               << "provide the sys/model conf path";
+               << "provide the job conf path";
   }
   return Singleton<Cluster>::Instance();
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/utils/common.cc
----------------------------------------------------------------------
diff --git a/src/utils/common.cc b/src/utils/common.cc
index 65b2ec2..13f2552 100644
--- a/src/utils/common.cc
+++ b/src/utils/common.cc
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -244,7 +244,7 @@ string GetHostIP() {
   close(fd);
   string ip(inet_ntoa(((struct sockaddr_in *)&ifr.ifr_addr)->sin_addr));
   /* display result */
-  LOG(INFO) << "Host IP=(" << ip;
+  LOG(INFO) << "Host IP= " << ip;
   return ip;
 }
 
@@ -290,7 +290,7 @@ string Metric::ToLogString() const {
   string ret;
   size_t k = 0;
   for (auto e : entry_) {
-    ret += e.first + " : ";
+    ret += e.first + " = ";
     ret += std::to_string(e.second.second / e.second.first);
     if (++k < entry_.size())
       ret += ", ";

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 07c238c..1ee4dcd 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -7,9 +7,9 @@
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
-* 
+*
 *   http://www.apache.org/licenses/LICENSE-2.0
-* 
+*
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -24,9 +24,11 @@
 #include <glog/logging.h>
 #include <cmath>
 #include <random>
+#include <unordered_map>
 #include "mshadow/tensor.h"
 #include "utils/factory.h"
 #include "utils/singleton.h"
+#include "utils/common.h"
 
 namespace singa {
 
@@ -93,6 +95,7 @@ void UniformSqrtFanInOutGen::Fill(Blob<float>* blob) {
   data /= sqrt(blob->shape()[0] + blob->shape()[1]);
 }
 
+/****************** Param functions *********************************/
 Param* Param::Create(const ParamProto& proto) {
   Factory<Param>* factory = Singleton<Factory<Param>>::Instance();
   Param* p = nullptr;
@@ -104,6 +107,49 @@ Param* Param::Create(const ParamProto& proto) {
   return p;
 }
 
+const vector<int> Param::ComputeSlices(int num, const vector<Param*>& params) {
+  // collect sizes of unique Params
+  std::vector<int> paramsize;
+  for (auto param : params)
+    if (param->id() == param->owner())
+      paramsize.push_back(param->size());
+  // slice into lcm pieces to achieve good load-balance for both intra-group
+  // partition (among servers in a group) and inter-group partition (each group
+  // is assgined a sub-set of slices)
+  auto param_slice = Slice(num, paramsize);
+  vector<int> slices;
+  for (auto const vec: param_slice)
+    for (int len : vec)
+      slices.push_back(len);
+  return slices;
+}
+
+void Param::SliceParams(int num, const vector<Param*>& params) {
+  auto slices = ComputeSlices(num, params);
+  // construct map from Param ID to its slices <slice id, len>
+  std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
+  int slice_id = 0;
+  auto it = slices.begin();
+  for (auto param : params) {
+    if (param->id() == param->owner()) {
+      int len = 0;
+      while (len < param->size() && it != slices.end()) {
+        paramid2slices[param->id()].push_back(std::make_pair(slice_id++, *it));
+        len += *it;
+        it++;
+      }
+      CHECK_EQ(param->size(), len) << "length misamtch for ID=" << param->id();
+    }
+  }
+  for (auto param : params) {
+    for (auto entry : paramid2slices[param->owner()]) {
+      param->AddSlice(entry.first, entry.second);
+      LOG(INFO) << "param id " << param->id() << " owner=" << param->owner()
+        << ", slice id = " << entry.first << ", size = " << entry.second;
+    }
+  }
+}
+
 void Param::Setup(const vector<int>& shape) {
   data_ = std::make_shared<Blob<float>>(shape);
   grad_.Reshape(shape);
@@ -329,14 +375,14 @@ Msg* Param::HandleSyncMsg(Msg** msg, bool reserve) {
 }
 
 int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) {
-  CHECK_EQ(pending_get_[slice_idx], true);
+  CHECK(pending_get_[slice_idx]) << slice_idx;
   pending_get_[slice_idx] = false;
   ParseResponseMsg(msg, slice_idx);
   return (--num_pending_requests_) % num_slices_ == 0;
 }
 
 int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) {
-  CHECK_EQ(pending_update_[slice_idx], true);
+  CHECK(pending_update_[slice_idx]) << id() << " " << slice_idx;
   pending_update_[slice_idx] = false;
   ParseResponseMsg(msg, slice_idx);
   return (--num_pending_requests_) % num_slices_ == 0;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/321ef96a/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
new file mode 100644
index 0000000..153e1a1
--- /dev/null
+++ b/src/worker.cc
@@ -0,0 +1,410 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "./worker.h"
+
+#include <glog/logging.h>
+#include <chrono>
+#include <thread>
+#include <typeinfo>
+#include "utils/cluster.h"
+#include "utils/factory.h"
+#include "utils/singleton.h"
+
+namespace singa {
+
+using std::string;
+
+Worker* Worker::Create(const AlgProto& conf) {
+  auto factory = Singleton<Factory<singa::Worker>>::Instance();
+  Worker* worker = nullptr;
+  if (conf.has_user_alg())
+    worker = factory->Create(conf.user_alg());
+  else
+    worker = factory->Create(conf.alg());
+  return worker;
+}
+
+void Worker::Setup(int grp_id, int id, const JobProto& conf,
+    NeuralNet* train_net, NeuralNet* val_net, NeuralNet* test_net) {
+  grp_id_ = grp_id;
+  id_ = id;
+  job_conf_ = conf;
+  train_net_ = train_net;
+  val_net_ = val_net;
+  test_net_ = test_net;
+  layer_dealer_ = dealer_ = nullptr;
+}
+
+Worker::~Worker() {
+  if (layer_dealer_)
+    delete layer_dealer_;
+  if (dealer_)
+    delete dealer_;
+}
+
+void Worker::InitNetParams(const JobProto& job_conf, NeuralNet* net) {
+  // for each server grp, its first subscriber worker grp does the param init
+  if (grp_id_ % Cluster::Get()->nworker_groups_per_server_group() == 0) {
+    // extract params that should be initialized by this worker
+    // must gen a name for each param if the user doesn't config it
+    std::unordered_map<string, Param*> name2param;
+    for (auto layer : net->layers()) {
+      if (layer->partition_id() == id_) {
+        for (auto param : layer->GetParams()) {
+          // only owners fill the memory of parameter values.
+          if (param->owner() == param->id()) {
+            CHECK(name2param.find(param->name()) == name2param.end());
+            name2param[param->name()] = param;
+          }
+        }
+      }
+    }
+    // load from checkpoints. get param blob based on param name.
+    // the param from previous checkpoint files will be overwritten by
+    // the param with the same name in later checkpoint files.
+    for (const auto path : job_conf.checkpoint_path()) {
+      LOG(ERROR) << "Load from checkpoint file " << path;
+      BlobProtos bps;
+      ReadProtoFromBinaryFile(path.c_str(), &bps);
+      for (int i = 0; i < bps.name_size(); i++) {
+        if (name2param.find(bps.name(i)) != name2param.end()) {
+          name2param.at(bps.name(i))->FromProto(bps.blob(i));
+          //  if load from pre-training params, reset version to start step
+          if (job_conf.reset_param_version())
+            name2param.at(bps.name(i))->set_version(job_conf.step());
+          else  // if resume training, use the same version as last checkpoint
+            name2param.at(bps.name(i))->set_version(bps.version(i));
+        }
+      }
+    }
+    // init other params who do not have checkpoint version
+    for (auto entry : name2param)
+      if (entry.second->version() < 0) {
+        entry.second->InitValues(job_conf.step());
+        if (!job_conf.reset_param_version())
+          LOG(ERROR) << "better reset version of params from checkpoints "
+            << "to the same as other newly initialized params!";
+      }
+
+    // warmup training before put params to servers
+    for (; step_ < job_conf.warmup_steps(); step_++)
+      TrainOneBatch(step_, net);
+    for (auto layer : net->layers()) {
+      if (layer->partition_id() == id_)
+        for (auto param : layer->GetParams())
+          if (param->owner() == param->id())
+            Put(param->version(), param);
+    }
+  }
+  // wait owners in the same procs init params, then no get requests sent
+  std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+  for (auto layer : net->layers()) {
+    if (layer->partition_id() == id_)
+      for (auto param : layer->GetParams())
+        Get(job_conf.warmup_steps(), param);
+  }
+}
+
+void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity) {
+  dealer->Connect(kInprocRouterEndpoint);
+  Msg* ping = new Msg(Addr(grp, id, entity), Addr(-1, -1, kStub));
+  ping->set_type(kConnect);
+  dealer->Send(&ping);
+}
+
+void Worker::Run() {
+  LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") start";
+  auto cluster = Cluster::Get();
+  int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
+  CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));
+  // TODO(wangsh): provide a unique sock id from cluster
+  dealer_ = new Dealer(0);
+  ConnectStub(grp_id_, id_, dealer_, kWorkerParam);
+  for (auto layer : train_net_->layers()) {
+    if (layer->partition_id() == id_) {
+      if (typeid(layer) == typeid(BridgeDstLayer)
+          || typeid(layer) == typeid(BridgeSrcLayer)) {
+        // TODO(wangsh): provide a unique socket id from cluster
+        layer_dealer_ = new Dealer(1);
+        ConnectStub(grp_id_, id_, layer_dealer_, kWorkerLayer);
+        break;
+      }
+    }
+  }
+
+  step_ = job_conf_.step();
+  InitNetParams(job_conf_, train_net_);
+  while (!StopNow(step_)) {
+    if (ValidateNow(step_) && val_net_ != nullptr) {
+      CollectAll(step_, val_net_);
+      for (int step = 0; step < job_conf_.validate_steps(); step++)
+        TestOneBatch(step, kVal, val_net_);
+      Display(kVal, "Validation @ step " + std::to_string(step_), val_net_);
+    }
+    if (TestNow(step_) && test_net_ != nullptr) {
+      CollectAll(step_, test_net_);
+      for (int step = 0; step < job_conf_.test_steps(); step++)
+        TestOneBatch(step, kTest, test_net_);
+      Display(kTest, "Test @ step " + std::to_string(step_), test_net_);
+    }
+    if (CheckpointNow(step_) && grp_id_ == 0) {
+      CollectAll(step_, train_net_);
+      Checkpoint(step_, Cluster::Get()->checkpoint_folder(), train_net_);
+      job_conf_.set_step(step_);
+    }
+    TrainOneBatch(step_, train_net_);
+    if (DisplayNow(step_) && grp_id_ == 0 && id_ == 0)
+      Display(kTrain, "Train @ step " + std::to_string(step_), train_net_);
+    step_++;
+  }
+
+  // save the model
+  if (grp_id_ == 0)
+    Checkpoint(step_, Cluster::Get()->checkpoint_folder(), train_net_);
+  // clean up
+  cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp);
+  // notify the stub on worker stop
+  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_type(kStop);
+  dealer_->Send(&msg);  // use param dealer to send the stop msg
+  LOG(ERROR) << "Worker (group = " <<grp_id_ << ", id = " << id_ << ") stops";
+}
+
+void Worker::Checkpoint(int step, const std::string& folder, NeuralNet* net) {
+  BlobProtos bps;
+  for (auto layer : net->layers()) {
+    if (layer->partition_id() == id_) {
+      for (auto param : layer->GetParams()) {
+        // only owners fill the memory of parameter values.
+        if (param->owner() == param->id()) {
+          auto *blob = bps.add_blob();
+          param->ToProto(blob);
+          bps.add_version(param->version());
+          bps.add_name(param->name());
+        }
+      }
+    }
+  }
+  char buf[256];
+  snprintf(buf, sizeof(buf), "%s/step%d-worker%d", folder.c_str(), step, id_);
+  LOG(INFO) << "checkpoint to " << buf;
+  WriteProtoToBinaryFile(bps, buf);
+}
+
+int Worker::Put(int step, Param* param) {
+  if (dealer_ == nullptr) {
+    LOG(ERROR) << "Null dealer in worker (" << grp_id_ << ", " << id_ << ")";
+    return 1;
+  }
+  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
+  msg->set_type(kPut);
+  dealer_->Send(&msg);
+  return 1;
+}
+
+int Worker::Get(int step, Param* param) {
+  if (param->version() >= step)
+    return 1;
+  if (dealer_ == nullptr) {
+    LOG(ERROR) << "Null dealer in worker (" << grp_id_ << ", " << id_ << ")";
+    return 1;
+  }
+  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
+  msg->set_type(kGet);
+  dealer_->Send(&msg);
+  return 1;
+}
+
+int Worker::Update(int step, Param* param) {
+  param->set_local_version(param->version());
+  if (dealer_ == nullptr) {
+    LOG(ERROR) << "Null dealer in worker (" << grp_id_ << ", " << id_ << ")";
+    return 1;
+  }
+  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
+  msg->set_type(kUpdate);
+  dealer_->Send(&msg);
+  return 1;
+}
+
+int Worker::CollectAll(int step, NeuralNet* net) {
+  auto& layers = net->layers();
+  for (auto& layer : layers) {
+    if (layer->partition_id() == id_) {
+      for (Param* p : layer->GetParams()) {
+        Collect(step, p);
+      }
+    }
+  }
+  return 1;
+}
+
+int Worker::Collect(int step, Param* param) {
+  while (param->version() <= param->local_version())
+    std::this_thread::sleep_for(std::chrono::milliseconds(kCollectSleepTime));
+  return 1;
+}
+
+void Worker::Display(int flag, const std::string& prefix, NeuralNet* net) {
+  for (auto layer : net->layers()) {
+    if (layer->partition_id() == id_) {
+      const string& disp = layer->ToString(false, flag);
+      if (disp.length())
+        LOG(ERROR) << prefix << ": " << disp;
+      if (job_conf_.debug()) {
+        const string& info = layer->ToString(true, flag);
+        if (info.length()) {
+          LOG(INFO) <<  prefix << info;
+        }
+      }
+    }
+  }
+}
+
+void Worker::ReceiveBlobs(bool data, bool grad, BridgeLayer* layer,
+                          NeuralNet* net) {
+  if (layer_dealer_ == nullptr) {
+    LOG(ERROR) << "Null dealer in worker (" << grp_id_ << ", " << id_ << ")";
+  }
+  while (!layer->ready()) {
+    auto msg = layer_dealer_->Receive();
+    CHECK_EQ(AddrGrp(msg->src()), grp_id_);
+    string name(static_cast<char*>(msg->FrameData()), msg->FrameSize());
+    auto receive_layer = net->name2layer(name);
+    auto data = receive_layer->mutable_data(nullptr);
+    msg->NextFrame();
+    memcpy(data->mutable_cpu_data(), msg->FrameData(), msg->FrameSize());
+    dynamic_cast<BridgeLayer*>(receive_layer)->set_ready(true);
+    delete msg;
+  }
+}
+
+void Worker::SendBlobs(bool data, bool grad, BridgeLayer* layer,
+                       NeuralNet* net) {
+  if (layer_dealer_ == nullptr) {
+    LOG(ERROR) << "Null dealer in worker (" << grp_id_ << ", " << id_ << ")";
+  }
+  auto dst = net->srclayers(layer).at(0);
+  Msg *msg = new Msg();
+  msg->set_src(Addr(grp_id_, id_, kWorkerLayer));
+  msg->set_dst(Addr(grp_id_, dst->partition_id(), kWorkerLayer));
+  msg->AddFrame(dst->name().c_str(), dst->name().length());
+  auto const & blob = layer->data(nullptr);
+  msg->AddFrame(blob.cpu_data(), blob.count() * sizeof(float));
+  layer_dealer_->Send(&msg);
+}
+
+/****************************BPWorker**********************************/
+void BPWorker::TrainOneBatch(int step, NeuralNet* net) {
+  Forward(step, kTrain, net);
+  Backward(step, net);
+}
+
+void BPWorker::TestOneBatch(int step, Phase phase, NeuralNet* net) {
+  Forward(step, phase, net);
+}
+
+void BPWorker::Forward(int step, Phase phase, NeuralNet* net) {
+  for (auto& layer : net->layers()) {
+    if (layer->partition_id() == id_) {
+      // TODO(wangwei): enable this for model partition
+      // recv data from other workers
+      // if (typeid(*layer) == typeid(BridgeDstLayer))
+      //   ReceiveBlobs(true, false, dynamic_cast<BridgeLayer*>(layer), net);
+      if (phase == kTrain) {
+        // wait until param is updated
+        for (Param* p : layer->GetParams()) {
+          Collect(step, p);
+        }
+      }
+      layer->ComputeFeature(phase | kForward, net->srclayers(layer));
+      // TODO(wangwei): enable this for model partition
+      // send data to other workers
+      // if (typeid(*layer) == typeid(BridgeSrcLayer))
+      //   SendBlobs(true, false, dynamic_cast<BridgeLayer*>(layer), net);
+    }
+  }
+}
+
+void BPWorker::Backward(int step, NeuralNet* net) {
+  auto& layers = net->layers();
+  for (auto it = layers.rbegin(); it != layers.rend(); it++) {
+    Layer* layer = *it;
+    if (layer->partition_id() == id_) {
+      // TODO(wangwei): enable this for model partition
+      // send data to other workers
+      // if (typeid(layer) == typeid(BridgeSrcLayer))
+      //   ReceiveBlobs(false, true, layer, net);
+      layer->ComputeGradient(kTrain | kBackward, net->srclayers(layer));
+      for (Param* p : layer->GetParams())
+        Update(step, p);
+      // TODO(wangwei): enable this for model partition
+      // recv data from other workers
+      // if (typeid(layer) == typeid(BridgeDstLayer))
+      //   SendBlobs(false, true, dynamic_cast<BridgeDstLayer*>(layer), net);
+    }
+  }
+}
+
+/****************************CDWorker**********************************/
+void CDWorker::TrainOneBatch(int step, NeuralNet* net) {
+  const auto& layers = net->layers();
+  for (auto* layer : layers) {
+    for (Param* p : layer->GetParams())  // wait until param is updated
+      Collect(step, p);
+    layer->ComputeFeature(kPositive, net->srclayers(layer));
+  }
+  for (auto* layer : layers)
+    if (typeid(*layer) == typeid(RBMVisLayer)
+          || typeid(*layer) == typeid(RBMHidLayer))
+      layer->ComputeFeature(kNegative | kTest, net->srclayers(layer));
+  for (int i = 1; i < job_conf_.train_one_batch().cd_conf().cd_k(); i++) {
+    for (auto* layer : layers) {
+      if (typeid(*layer) == typeid(RBMVisLayer)
+          || typeid(*layer) == typeid(RBMHidLayer))
+      layer->ComputeFeature(kNegative, net->srclayers(layer));
+    }
+  }
+  for (auto* layer : layers) {
+    if (typeid(*layer) == typeid(RBMVisLayer)
+        || typeid(*layer) == typeid(RBMHidLayer)) {
+      layer->ComputeGradient(kTrain, net->srclayers(layer));
+      for (Param* p : layer->GetParams()) {
+        Update(step, p);
+      }
+    }
+  }
+}
+
+void CDWorker::TestOneBatch(int step, Phase phase, NeuralNet* net) {
+  auto& layers = net->layers();
+  for (auto *layer : layers)
+    layer->ComputeFeature(kPositive, net->srclayers(layer));
+  for (auto *layer : layers)
+    if (typeid(*layer) == typeid(RBMVisLayer))
+      layer->ComputeFeature(kNegative | kTest, net->srclayers(layer));
+}
+
+}  // namespace singa


Mime
View raw message