singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [01/12] incubator-singa git commit: SINGA-55 Refactor main.cc and singa.h
Date Sat, 15 Aug 2015 08:11:14 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master f3cc20a90 -> 2498ff135


SINGA-55 Refactor main.cc and singa.h

A Driver class is implemented for initializing SINGA including parsing job ID,
and registering built-in subclasses of Layer, Updater, Worker and Param.
May use it to init MPI if using MPI as the message passing lib.

The main.cc file is updated to provide an example main func.

Update the GaussianSqrtFanin init method for Param (ref SINGA-58).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3ec12b92
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3ec12b92
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3ec12b92

Branch: refs/heads/master
Commit: 3ec12b92636ec893be845b08580ec54e302d6689
Parents: fcd377a
Author: Wei Wang <wangwei@comp.nus.edu.sg>
Authored: Sat Aug 15 11:48:38 2015 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Sat Aug 15 14:59:11 2015 +0800

----------------------------------------------------------------------
 include/neuralnet/neuralnet.h |   4 --
 include/singa.h               |  88 +++++++++++++++++++++++++-------
 include/trainer/trainer.h     |  10 ----
 include/trainer/worker.h      |  12 ++---
 src/driver.cc                 | 101 +++++++++++++++++++++++++++++++++++++
 src/main.cc                   |  48 +++++++++---------
 src/neuralnet/layer.cc        |  18 +++----
 src/neuralnet/neuralnet.cc    |  37 --------------
 src/proto/job.proto           |   4 +-
 src/trainer/server.cc         |   5 +-
 src/trainer/trainer.cc        |  20 ++------
 src/trainer/worker.cc         |  17 ++++---
 src/utils/param.cc            |   7 +--
 13 files changed, 233 insertions(+), 138 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/include/neuralnet/neuralnet.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/neuralnet.h b/include/neuralnet/neuralnet.h
index 1dbf44a..99f4a5c 100644
--- a/include/neuralnet/neuralnet.h
+++ b/include/neuralnet/neuralnet.h
@@ -27,10 +27,6 @@ using std::shared_ptr;
 class NeuralNet {
  public:
   /**
-   * Register Layers, i.e., map layer type to layer class
-   */
-  static void RegisterLayers();
-  /**
    * Create the neural network for training, test or validation.
    *
    * Parameters for test/validation net can share those from training after

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/include/singa.h
----------------------------------------------------------------------
diff --git a/include/singa.h b/include/singa.h
index 82df64b..6fb9e97 100644
--- a/include/singa.h
+++ b/include/singa.h
@@ -3,6 +3,7 @@
 #include <gflags/gflags.h>
 #include <glog/logging.h>
 #include <cblas.h>
+#include <string>
 
 #include "utils/common.h"
 #include "proto/job.pb.h"
@@ -16,25 +17,76 @@
 #include "trainer/trainer.h"
 #include "communication/socket.h"
 
-DEFINE_string(singa_conf, "conf/singa.conf", "Global config file");
-
 namespace singa {
-void SubmitJob(int job, bool resume, const JobProto& jobConf) {
-  SingaProto singaConf;
-  ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf);
-  if (singaConf.has_log_dir())
-    SetupLog(singaConf.log_dir(),
-        std::to_string(job) + "-" + jobConf.name());
-  if (jobConf.num_openblas_threads() != 1)
-    LOG(WARNING) << "openblas is set with " << jobConf.num_openblas_threads()
-      << " threads";
-  openblas_set_num_threads(jobConf.num_openblas_threads());
-  JobProto proto;
-  proto.CopyFrom(jobConf);
-  proto.set_id(job);
-  Trainer trainer;
-  trainer.Start(resume, singaConf, &proto);
-}
+class Driver {
+ public:
+  /**
+   * Init SINGA, including init glog, parse job id and job conf from cmd line,
+   * and register built-in layer, worker, updater, param subclasses.
+   *
+   * May be used for MPI init if it is used for message passing.
+   */
+  void Init(int argc, char** argv);
+  /**
+   * Register a Layer subclass.
+   *
+   * T is the subclass.
+   * @param type layer type ID. If called by users, it should be different to
+   * the types of built-in layers.
+   * @return 0 if success; otherwise -1.
+   */
+  template<typename T>
+  int RegisterLayer(int type);
+  /**
+   * Register Updater subclasses.
+   *
+   * T is the subclass.
+   * @param type updater type ID. If called by users, it should be different to
+   * the types of built-in updaters.
+   * @return 0 if success; otherwise -1.
+   */
+  template<typename T>
+  int RegisterUpdater(int type);
+  /**
+   * Register Worker subclasses.
+   *
+   * T is the subclass.
+   * @param type worker type ID. If called by users, it should be different to
+   * the types of built-in workers
+   * @return 0 if success; otherwise -1.
+   */
+  template<typename T>
+  int RegisterWorker(int type);
+  /**
+   * Register Param subclasses.
+   *
+   * T is the subclass.
+   * @param type param type. If called by users, it should be different to the
+   * types of built-in params. SINGA currently provides only one built-in Param
+   * implementation whose type ID is 0.
+   * @return 0 if success; otherwise -1.
+   */
+  template<typename T>
+  int RegisterParam(int type);
+
+  /**
+   * Submit the job configuration for starting the job.
+   * @param resume resume from last checkpoint if true.
+   * @param job job configuration
+   */
+  void Submit(bool resume, const JobProto& job);
+
+  /**
+   * @return job ID which is generated by zookeeper and passed in by the
+   * launching script.
+   */
+  int job_id() const {
+    return job_id_;
+  }
+
+ private:
+  int job_id_;
+};
 }  // namespace singa
 #endif  //  SINGA_SINGA_H_
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index 1d28de6..8be5269 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -80,16 +80,6 @@ class Trainer{
 
   void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
   /**
-   * Register default implementations for all base classes used in the system,
-   * e.g., the Updater, BaseMsg, etc.
-   *
-   * All built-in layer implementations are
-   * registered here.
-   * For other base classes, use its base class name (string) as the key and the
-   * implementation class as the value, e.g., <"Updater" SGDUpdater>.
-   */
-  void RegisterDefaultClasses();
-  /**
    * Generate msg to trigger synchronization with other server groups.
    *
    * @param server the local server index whom the message is sent to

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
index 0557ee2..2adbc66 100644
--- a/include/trainer/worker.h
+++ b/include/trainer/worker.h
@@ -26,7 +26,7 @@ class Worker {
    * @param grp_id global worker group ID
    * @param id worker ID within the group
    */
-  Worker(int thread_id, int grp_id, int id);
+  virtual void Init(int thread_id, int grp_id, int id);
   virtual ~Worker();
   /**
    * Setup members
@@ -181,8 +181,8 @@ class Worker {
 
 class BPWorker: public Worker{
  public:
-  BPWorker(int thread_id, int grp_id, int id);
   ~BPWorker(){}
+  void Init(int thread_id, int grp_id, int id) override;
   void TrainOneBatch(int step, Metric* perf) override;
   void TestOneBatch(int step, Phase phase, shared_ptr<NeuralNet> net,
       Metric* perf) override;
@@ -193,11 +193,11 @@ class BPWorker: public Worker{
 
 class CDWorker: public Worker{
  public:
-  CDWorker(int thread_id, int group_id, int worker_id);
   ~CDWorker() {}
-  virtual void TrainOneBatch(int step, Metric* perf);
-  virtual void TestOneBatch(int step, Phase phase,
-       shared_ptr<NeuralNet> net, Metric* perf);
+  void Init(int thread_id, int group_id, int worker_id) override;
+  void TrainOneBatch(int step, Metric* perf) override;
+  void TestOneBatch(int step, Phase phase,
+       shared_ptr<NeuralNet> net, Metric* perf) override;
   void PositivePhase(int step, shared_ptr<NeuralNet> net, Metric* perf);
   void NegativePhase(int step, shared_ptr<NeuralNet> net, Metric* perf);
   void GradientPhase(int step, shared_ptr<NeuralNet> net);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
new file mode 100644
index 0000000..05c1195
--- /dev/null
+++ b/src/driver.cc
@@ -0,0 +1,101 @@
+#include "singa.h"
+namespace singa {
+
+/**
+ * the job and singa_conf arguments are passed by the singa script which is
+ * transparent to users
+ */
+DEFINE_int32(job, -1, "Unique job ID generated from singa-run.sh");
+DEFINE_string(singa_conf, "conf/singa.conf", "Global config file");
+
+void Driver::Init(int argc, char **argv) {
+  google::InitGoogleLogging(argv[0]);
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+  job_id_ = FLAGS_job;
+
+  // register layers
+  RegisterLayer<BridgeDstLayer>(kBridgeDst);
+  RegisterLayer<BridgeSrcLayer>(kBridgeSrc);
+  RegisterLayer<ConvolutionLayer>(kConvolution);
+  RegisterLayer<ConcateLayer>(kConcate);
+  RegisterLayer<DropoutLayer>(kDropout);
+  RegisterLayer<InnerProductLayer>(kInnerProduct);
+  RegisterLayer<LabelLayer>(kLabel);
+  RegisterLayer<LRNLayer>(kLRN);
+  RegisterLayer<MnistLayer>(kMnist);
+  RegisterLayer<PrefetchLayer>(kPrefetch);
+  RegisterLayer<PoolingLayer>(kPooling);
+  RegisterLayer<RGBImageLayer>(kRGBImage);
+  RegisterLayer<ReLULayer>(kReLU);
+  RegisterLayer<ShardDataLayer>(kShardData);
+  RegisterLayer<SliceLayer>(kSlice);
+  RegisterLayer<SoftmaxLossLayer>(kSoftmaxLoss);
+  RegisterLayer<SplitLayer>(kSplit);
+  RegisterLayer<TanhLayer>(kTanh);
+  RegisterLayer<RBMVisLayer>(kRBMVis);
+  RegisterLayer<RBMHidLayer>(kRBMHid);
+#ifdef USE_LMDB
+  RegisterLayer(factory, LMDBData);
+#endif
+
+  // register updater
+  RegisterUpdater<AdaGradUpdater>(kAdaGrad);
+  RegisterUpdater<NesterovUpdater>(kNesterov);
+  //  TODO(wangwei)  RegisterUpdater<kRMSPropUpdater>(kRMSProp);
+  RegisterUpdater<SGDUpdater>(kSGD);
+
+  // register worker
+  RegisterWorker<BPWorker>(kBP);
+  RegisterWorker<CDWorker>(kCD);
+
+  // register param
+  RegisterParam<Param>(0);
+}
+
+template<typename T>
+int Driver::RegisterLayer(int type) {
+  auto factory = Singleton<Factory<singa::Layer>>::Instance();
+  factory->Register(type, CreateInstance(T, Layer));
+  return 1;
+}
+
+template<typename T>
+int Driver::RegisterParam(int type) {
+  auto factory = Singleton<Factory<singa::Param>>::Instance();
+  factory->Register(type, CreateInstance(T, Param));
+  return 1;
+}
+
+template<typename T>
+int Driver::RegisterUpdater(int type) {
+  auto factory = Singleton<Factory<singa::Updater>>::Instance();
+  factory->Register(type, CreateInstance(T, Updater));
+  return 1;
+}
+
+template<typename T>
+int Driver::RegisterWorker(int type) {
+  auto factory = Singleton<Factory<singa::Worker>>::Instance();
+  factory->Register(type, CreateInstance(T, Worker));
+  return 1;
+}
+
+void Driver::Submit(bool resume, const JobProto& jobConf) {
+  SingaProto singaConf;
+  ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf);
+  if (singaConf.has_log_dir())
+    SetupLog(singaConf.log_dir(), std::to_string(FLAGS_job)
+        + "-" + jobConf.name());
+  if (jobConf.num_openblas_threads() != 1)
+    LOG(WARNING) << "openblas with "
+      << jobConf.num_openblas_threads() << " threads";
+  openblas_set_num_threads(jobConf.num_openblas_threads());
+
+  JobProto job;
+  job.CopyFrom(jobConf);
+  job.set_id(job_id_);
+  Trainer trainer;
+  trainer.Start(resume, singaConf, &job);
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/main.cc
----------------------------------------------------------------------
diff --git a/src/main.cc b/src/main.cc
index 00b75ff..692fe2c 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -1,40 +1,40 @@
 #include "singa.h"
 /**
- * \file main.cc is the main entry of SINGA, like the driver program for Hadoop.
+ * \file main.cc provides an example main func.
+ *
+ * Like the main func of Hadoop, it prepares the job configuration and submit it
+ * to the Driver which starts the training.
+ *
+ * Users can define their own main func to prepare the job configuration in
+ * different ways other than reading it from a configuration file. But the main
+ * func must call Driver::Init at the beginning, and pass the job configuration
+ * and resume option to the Driver for job submission.
+ *
+ * Optionally, users can register their own implemented classes, e.g., layer,
+ * updater, through the registration func provided by the Driver.
  *
- * 1. Users register their own implemented classes, e.g., layer, updater, etc.
- * 2. Users prepare the google protobuf object for the job configuration.
- * 3. Users call trainer to start the training.
  *
  * TODO
- * 1. Add helper functions for users to configure their model easily,
- * e.g., AddLayer(layer_type, source_layers, meta_data).
+ * Add helper functions for users to generate their configurations easily.
+ * e.g., AddLayer(layer_type, source_layers, meta_data),
+ * or, MLP(layer1_size, layer2_size, tanh, loss);
  */
 
-DEFINE_int32(job, -1, "Unique job ID generated from singa-run.sh");
 DEFINE_bool(resume, false, "Resume from checkpoint passed at cmd line");
 DEFINE_string(conf, "./job.conf", "job conf passed at cmd line");
 
-/**
- * Register layers, and other customizable classes.
- *
- * If users want to use their own implemented classes, they should register
- * them here. Refer to the Worker::RegisterDefaultClasses()
- */
-void RegisterClasses() {
-
-}
-
-
 int main(int argc, char **argv) {
-  google::InitGoogleLogging(argv[0]);
-  gflags::ParseCommandLineFlags(&argc, &argv, true);
+  //  must create driver at the beginning and call its Init method.
+  singa::Driver driver;
+  driver.Init(argc, argv);
+
+  //  users can register new subclasses of layer, updater, etc.
 
+  //  prepare job conf;
   singa::JobProto jobConf;
-  std::string job_file = FLAGS_conf;
-  singa::ReadProtoFromTextFile(job_file.c_str(), &jobConf);
+  singa::ReadProtoFromTextFile(FLAGS_conf.c_str(), &jobConf);
 
-  RegisterClasses();
-  singa::SubmitJob(FLAGS_job, FLAGS_resume, jobConf);
+  //  submit the job
+  driver.Submit(FLAGS_resume, jobConf);
   return 0;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index 810d0b4..5c3d688 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -72,9 +72,9 @@ void ConvolutionLayer::Setup(const LayerProto& proto, int npartitions)
{
   col_grad_.Reshape(vector<int>{col_height_, col_width_});
 
   Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
-  weight_ = factory->Create("Param");
+  weight_ = factory->Create(proto.param(0).type());
   weight_->Setup(proto.param(0), vector<int>{num_filters_, col_height_});
-  bias_ = factory->Create("Param");
+  bias_ = factory->Create(proto.param(1).type());
   bias_->Setup(proto.param(1), vector<int>{num_filters_});
 }
 
@@ -189,8 +189,8 @@ void RBMVisLayer::Setup(const LayerProto& proto,
   data_.Reshape(vector<int>{batchsize_, vdim_});  // this is visible dimension
   vis_sample_.Reshape(vector<int>{neg_batchsize_, vdim_});
   Factory<Param>* factory = Singleton<Factory<Param>>::Instance();
-  weight_ = factory->Create("Param");
-  bias_ = factory->Create("Param");
+  weight_ = factory->Create(proto.param(0).type());
+  bias_ = factory->Create(proto.param(1).type());
   weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_});
   bias_->Setup(proto.param(1), vector<int>{vdim_});
 }
@@ -282,10 +282,10 @@ void RBMHidLayer::Setup(const LayerProto& proto,
   data_.Reshape(vector<int>{batchsize_, hdim_});
   hid_sample_.Reshape(vector<int>{neg_batchsize_, hdim_});
   Factory<Param>* factory = Singleton<Factory<Param>>::Instance();
-  bias_ = factory->Create("Param");
-  weight_ = factory->Create("Param");
-  bias_->Setup(proto.param(1), vector<int>{hdim_});
+  weight_ = factory->Create(proto.param(0).type());
+  bias_ = factory->Create(proto.param(0).type());
   weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_});
+  bias_->Setup(proto.param(1), vector<int>{hdim_});
 }
 
 void RBMHidLayer::ComputeFeature(Phase phase, Metric* perf) {
@@ -339,8 +339,8 @@ void InnerProductLayer::Setup(const LayerProto& proto, int npartitions)
{
   data_.Reshape(vector<int>{batchsize_, hdim_});
   grad_.ReshapeLike(data_);
   Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
-  weight_ = factory->Create("Param");
-  bias_ = factory->Create("Param");
+  weight_ = factory->Create(proto.param(0).type());
+  bias_ = factory->Create(proto.param(0).type());
   weight_->Setup(proto.param(0), vector<int>{hdim_, vdim_});
   bias_->Setup(proto.param(1), vector<int>{hdim_});
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index e2565e3..08e8e0f 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -2,46 +2,9 @@
 #include <queue>
 
 #include "neuralnet/neuralnet.h"
-#include "neuralnet/optional_layer.h"
 #include "utils/singleton.h"
 
 namespace singa {
-// macros to shorten the code
-#define LayerT(x) LayerType::k##x
-
-#define RegisterLayer(factory, id) \
-  factory->Register(LayerType::k##id, \
-      CreateInstance(id##Layer, Layer))
-
-void NeuralNet::RegisterLayers() {
-  Factory<Layer>* factory = Singleton<Factory<Layer>>::Instance();
-  // FooLayer's type is kFoo, register using Foo
-  RegisterLayer(factory, BridgeDst);
-  RegisterLayer(factory, BridgeSrc);
-  RegisterLayer(factory, Convolution);
-  RegisterLayer(factory, Concate);
-  RegisterLayer(factory, Dropout);
-  RegisterLayer(factory, InnerProduct);
-  RegisterLayer(factory, Label);
-  RegisterLayer(factory, LRN);
-  RegisterLayer(factory, Mnist);
-  RegisterLayer(factory, Prefetch);
-  RegisterLayer(factory, Pooling);
-  RegisterLayer(factory, RGBImage);
-  RegisterLayer(factory, ReLU);
-  RegisterLayer(factory, ShardData);
-  RegisterLayer(factory, Slice);
-  RegisterLayer(factory, SoftmaxLoss);
-  RegisterLayer(factory, Split);
-  RegisterLayer(factory, Tanh);
-  RegisterLayer(factory, RBMVis);
-  RegisterLayer(factory, RBMHid);
-
-#ifdef USE_LMDB
-  RegisterLayer(factory, LMDBData);
-#endif
-}
-
 shared_ptr<NeuralNet> NeuralNet::Create(
     const NetProto& net_conf,
     Phase phase,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index fe8dc21..43ea42e 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -205,6 +205,8 @@ message ParamProto {
   // used for identifying the same params from diff models and display deug info
   optional string name =  1 [default = ""];
   optional InitMethod init_method = 2 [default = kGaussian];
+  // currently there is only one built-in Param impl with type 0.
+  optional int32 type = 3 [default = 0];
   // constant init
   optional float value = 5 [default = 1];
   // for uniform sampling
@@ -422,7 +424,7 @@ message ReLUProto {
 }
 
 message RMSPropProto {
-  // history=history*rho_+(1-rho_)*(grad*grad_scale);
+ // history=history*rho_+(1-rho_)*(grad*grad_scale);
   required float rho = 1;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 8c97d5d..a3749d1 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -18,7 +18,7 @@ Server::Server(int thread_id,int group_id, int server_id):
 void Server::Setup(const UpdaterProto& proto,
     std::unordered_map<int, ParamEntry*>* shard,
     const vector<int>& slice2group) {
-  updater_ = Singleton<Factory<Updater>>::Instance()->Create("Updater");
+  updater_ = Singleton<Factory<Updater>>::Instance()->Create(proto.type());
   updater_->Init(proto);
   shard_ = shard;
   slice2group_ = slice2group;
@@ -143,7 +143,8 @@ Msg* Server::HandlePut(Msg **msg) {
   if (shard_->find(slice_id) != shard_->end())
     LOG(FATAL) << "Param (" << slice_id << ") is put more than once";
 
-  auto  param = Singleton<Factory<Param>>::Instance()->Create("Param");
+  // TODO(wangwei) replace hard coded param type 0
+  auto  param = Singleton<Factory<Param>>::Instance()->Create(0);
   auto response = param->HandlePutMsg(msg, true);
   // parse num of shares of this param from a worker group
   int num_shares = 1;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 699fc30..ea060ec 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -28,15 +28,6 @@ Trainer::~Trainer() {
   delete router_;
 }
 
-void Trainer::RegisterDefaultClasses() {
-  // register all implemented layers
-  singa::NeuralNet::RegisterLayers();
-  auto param_factory = Singleton<Factory<singa::Param>>::Instance();
-  param_factory->Register("Param", CreateInstance(Param, Param));
-  auto updater_factory = Singleton<Factory<singa::Updater>>::Instance();
-  updater_factory->Register("Updater", CreateInstance(SGDUpdater, Updater));
-}
-
 const vector<int> SliceParams(const vector<Param*>& params) {
   // for load-balance among servers in a group and among server groups
   int nserver_grps = Cluster::Get()->nserver_groups();
@@ -181,15 +172,11 @@ vector<Worker*> Trainer::CreateWorkers(int nthreads, const JobProto&
job) {
     wstart = 0;
     wend = grp_size;
   }
+  auto factory = Singleton<Factory<singa::Worker>>::Instance();
   for (int gid = gstart; gid < gend; gid++) {
     for (int wid = wstart; wid < wend; wid++) {
-      Worker* worker=nullptr;
-      if (job.alg() == TrainOneBatchAlg::kBP)
-        worker = new BPWorker(nthreads++,gid, wid);
-      else if (job.alg() == TrainOneBatchAlg::kCD)
-        worker=new CDWorker(nthreads++,gid, wid);
-      else
-        LOG(FATAL) << "unknown alg for trainonebatch func " << job.alg();
+      Worker* worker = factory->Create(job.alg());
+      worker->Init(nthreads++,gid, wid);
       workers.push_back(worker);
     }
   }
@@ -240,7 +227,6 @@ void Trainer::Resume(JobProto* jobConf) {
 void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) {
   // register job to zookeeper at the beginning
   auto cluster = Cluster::Get(job->id(), singaConf, job->cluster());
-  RegisterDefaultClasses();
   if (resume)
     Resume(job);
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index b6f9d44..6052f3a 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -10,9 +10,12 @@
 namespace singa {
 using std::thread;
 
-Worker::Worker(int thread_id, int grp_id, int id):
-  thread_id_(thread_id), grp_id_(grp_id), id_(id),
-  layer_dealer_(nullptr), dealer_(nullptr), updater_(nullptr) {
+void Worker::Init(int thread_id, int grp_id, int id) {
+  thread_id_ = thread_id;
+  grp_id_ = grp_id;
+  id_ = id;
+  layer_dealer_ = dealer_ = nullptr;
+  updater_ = nullptr;
 }
 
 void Worker::Setup(
@@ -327,8 +330,8 @@ bool Worker::ValidateNow(const int step) const {
 
 
 /****************************BPWorker**********************************/
-BPWorker::BPWorker(int thread_id, int group_id, int worker_id):
-  Worker(thread_id, group_id, worker_id) {
+void BPWorker::Init(int thread_id, int group_id, int worker_id) {
+  Worker::Init(thread_id, group_id, worker_id);
 }
 
 void BPWorker::Forward(
@@ -382,8 +385,8 @@ void BPWorker::TestOneBatch(int step, Phase phase,
 }
 
 /****************************CDWorker**********************************/
-CDWorker::CDWorker(int thread_id, int group_id, int worker_id):
-  Worker(thread_id, group_id, worker_id) {
+void CDWorker::Init(int thread_id, int group_id, int worker_id) {
+  Worker::Init(thread_id, group_id, worker_id);
 }
 
 void CDWorker::PositivePhase(int step,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3ec12b92/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 7adea7c..06c16e9 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -52,9 +52,9 @@ void Param::InitValues(int version) {
       data *= proto_.value();
     break;
   case InitMethod::kUniformSqrtFanIn:
-    random->SampleUniform(data, proto_.low(), proto_.high());
     // only valid for param matrix with num of cols as fan in
     CHECK_EQ(data_->shape().size(), 2);
+    random->SampleUniform(data, proto_.low(), proto_.high());
     data *= proto_.value() / sqrt(data_->shape().at(1) / 3.0f);
     break;
   case InitMethod::kUniformSqrtFanInOut:
@@ -68,9 +68,10 @@ void Param::InitValues(int version) {
       data *= proto_.value();
     break;
   case InitMethod::kGaussainSqrtFanIn:
+    // only valid for param matrix with num of cols as fan in
+    CHECK_EQ(data_->shape().size(), 2);
     random->SampleGaussian(data, proto_.mean(), proto_.std());
-    if (proto_.value())
-      data *= proto_.value() / sqrt(data_->shape()[0]);
+    data *= proto_.value() / sqrt(data_->shape().at(1));
     break;
   default:
     LOG(ERROR) << "Illegal parameter init method ";


Mime
View raw message