singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject [06/12] incubator-singa git commit: SINGA-55 Refactor main.cc and singa.h
Date Sat, 15 Aug 2015 08:11:19 GMT
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2498ff13/include/singa.h
----------------------------------------------------------------------
diff --cc include/singa.h
index 6fb9e97,82df64b..52d1f90
--- a/include/singa.h
+++ b/include/singa.h
@@@ -1,92 -1,40 +1,88 @@@
  #ifndef SINGA_SINGA_H_
  #define SINGA_SINGA_H_
++
++#include <cblas.h>
  #include <gflags/gflags.h>
  #include <glog/logging.h>
--#include <cblas.h>
 -
 -#include "utils/common.h"
 +#include <string>
- 
- #include "utils/common.h"
++#include "communication/socket.h"
++#include "neuralnet/neuralnet.h"
  #include "proto/job.pb.h"
  #include "proto/singa.pb.h"
--
++#include "trainer/trainer.h"
++#include "utils/common.h"
  #include "utils/param.h"
  #include "utils/singleton.h"
  #include "utils/factory.h"
  
--#include "neuralnet/neuralnet.h"
--#include "trainer/trainer.h"
--#include "communication/socket.h"
- 
 +namespace singa {
+ 
 -DEFINE_string(singa_conf, "conf/singa.conf", "Global config file");
 +class Driver {
 + public:
 +  /**
 +   * Init SINGA, including init glog, parse job id and job conf from cmd line,
 +   * and register built-in layer, worker, updater, param subclasses.
 +   *
 +   * May be used for MPI init if it is used for message passing.
 +   */
 +  void Init(int argc, char** argv);
 +  /**
 +   * Register a Layer subclass.
 +   *
 +   * T is the subclass.
 +   * @param type layer type ID. If called by users, it should be different to
 +   * the types of built-in layers.
 +   * @return 0 if success; otherwise -1.
 +   */
 +  template<typename T>
 +  int RegisterLayer(int type);
 +  /**
 +   * Register Updater subclasses.
 +   *
 +   * T is the subclass.
 +   * @param type updater type ID. If called by users, it should be different to
 +   * the types of built-in updaters.
 +   * @return 0 if success; otherwise -1.
 +   */
 +  template<typename T>
 +  int RegisterUpdater(int type);
 +  /**
 +   * Register Worker subclasses.
 +   *
 +   * T is the subclass.
 +   * @param type worker type ID. If called by users, it should be different to
 +   * the types of built-in workers
 +   * @return 0 if success; otherwise -1.
 +   */
 +  template<typename T>
 +  int RegisterWorker(int type);
 +  /**
 +   * Register Param subclasses.
 +   *
 +   * T is the subclass.
 +   * @param type param type. If called by users, it should be different to the
 +   * types of built-in params. SINGA currently provides only one built-in Param
 +   * implementation whose type ID is 0.
 +   * @return 0 if success; otherwise -1.
 +   */
 +  template<typename T>
 +  int RegisterParam(int type);
- 
 +  /**
 +   * Submit the job configuration for starting the job.
 +   * @param resume resume from last checkpoint if true.
 +   * @param job job configuration
 +   */
 +  void Submit(bool resume, const JobProto& job);
- 
 +  /**
 +   * @return job ID which is generated by zookeeper and passed in by the
 +   * launching script.
 +   */
-   int job_id() const {
-     return job_id_;
-   }
++  inline int job_id() const { return job_id_; }
 +
 + private:
 +  int job_id_;
 +};
+ 
 -namespace singa {
 -void SubmitJob(int job, bool resume, const JobProto& jobConf) {
 -  SingaProto singaConf;
 -  ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf);
 -  if (singaConf.has_log_dir())
 -    SetupLog(singaConf.log_dir(),
 -        std::to_string(job) + "-" + jobConf.name());
 -  if (jobConf.num_openblas_threads() != 1)
 -    LOG(WARNING) << "openblas is set with " << jobConf.num_openblas_threads()
 -      << " threads";
 -  openblas_set_num_threads(jobConf.num_openblas_threads());
 -  JobProto proto;
 -  proto.CopyFrom(jobConf);
 -  proto.set_id(job);
 -  Trainer trainer;
 -  trainer.Start(resume, singaConf, &proto);
 -}
  }  // namespace singa
--#endif  //  SINGA_SINGA_H_
  
++#endif  // SINGA_SINGA_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2498ff13/src/driver.cc
----------------------------------------------------------------------
diff --cc src/driver.cc
index 05c1195,0000000..5469583
mode 100644,000000..100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@@ -1,101 -1,0 +1,102 @@@
 +#include "singa.h"
++
 +namespace singa {
 +
 +/**
 + * the job and singa_conf arguments are passed by the singa script which is
 + * transparent to users
 + */
 +DEFINE_int32(job, -1, "Unique job ID generated from singa-run.sh");
 +DEFINE_string(singa_conf, "conf/singa.conf", "Global config file");
 +
 +void Driver::Init(int argc, char **argv) {
 +  google::InitGoogleLogging(argv[0]);
 +  gflags::ParseCommandLineFlags(&argc, &argv, true);
 +  job_id_ = FLAGS_job;
 +
 +  // register layers
 +  RegisterLayer<BridgeDstLayer>(kBridgeDst);
 +  RegisterLayer<BridgeSrcLayer>(kBridgeSrc);
 +  RegisterLayer<ConvolutionLayer>(kConvolution);
 +  RegisterLayer<ConcateLayer>(kConcate);
 +  RegisterLayer<DropoutLayer>(kDropout);
 +  RegisterLayer<InnerProductLayer>(kInnerProduct);
 +  RegisterLayer<LabelLayer>(kLabel);
 +  RegisterLayer<LRNLayer>(kLRN);
 +  RegisterLayer<MnistLayer>(kMnist);
 +  RegisterLayer<PrefetchLayer>(kPrefetch);
 +  RegisterLayer<PoolingLayer>(kPooling);
 +  RegisterLayer<RGBImageLayer>(kRGBImage);
 +  RegisterLayer<ReLULayer>(kReLU);
 +  RegisterLayer<ShardDataLayer>(kShardData);
 +  RegisterLayer<SliceLayer>(kSlice);
 +  RegisterLayer<SoftmaxLossLayer>(kSoftmaxLoss);
 +  RegisterLayer<SplitLayer>(kSplit);
 +  RegisterLayer<TanhLayer>(kTanh);
 +  RegisterLayer<RBMVisLayer>(kRBMVis);
 +  RegisterLayer<RBMHidLayer>(kRBMHid);
 +#ifdef USE_LMDB
-   RegisterLayer(factory, LMDBData);
++  RegisterLayer<LMDBDataLayer>(kLMDBData);
 +#endif
 +
 +  // register updater
 +  RegisterUpdater<AdaGradUpdater>(kAdaGrad);
 +  RegisterUpdater<NesterovUpdater>(kNesterov);
-   //  TODO(wangwei)  RegisterUpdater<kRMSPropUpdater>(kRMSProp);
++  // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp);
 +  RegisterUpdater<SGDUpdater>(kSGD);
 +
 +  // register worker
 +  RegisterWorker<BPWorker>(kBP);
 +  RegisterWorker<CDWorker>(kCD);
 +
 +  // register param
 +  RegisterParam<Param>(0);
 +}
 +
 +template<typename T>
 +int Driver::RegisterLayer(int type) {
 +  auto factory = Singleton<Factory<singa::Layer>>::Instance();
 +  factory->Register(type, CreateInstance(T, Layer));
 +  return 1;
 +}
 +
 +template<typename T>
 +int Driver::RegisterParam(int type) {
 +  auto factory = Singleton<Factory<singa::Param>>::Instance();
 +  factory->Register(type, CreateInstance(T, Param));
 +  return 1;
 +}
 +
 +template<typename T>
 +int Driver::RegisterUpdater(int type) {
 +  auto factory = Singleton<Factory<singa::Updater>>::Instance();
 +  factory->Register(type, CreateInstance(T, Updater));
 +  return 1;
 +}
 +
 +template<typename T>
 +int Driver::RegisterWorker(int type) {
 +  auto factory = Singleton<Factory<singa::Worker>>::Instance();
 +  factory->Register(type, CreateInstance(T, Worker));
 +  return 1;
 +}
 +
 +void Driver::Submit(bool resume, const JobProto& jobConf) {
 +  SingaProto singaConf;
 +  ReadProtoFromTextFile(FLAGS_singa_conf.c_str(), &singaConf);
 +  if (singaConf.has_log_dir())
 +    SetupLog(singaConf.log_dir(), std::to_string(FLAGS_job)
 +        + "-" + jobConf.name());
 +  if (jobConf.num_openblas_threads() != 1)
 +    LOG(WARNING) << "openblas with "
 +      << jobConf.num_openblas_threads() << " threads";
 +  openblas_set_num_threads(jobConf.num_openblas_threads());
 +
 +  JobProto job;
 +  job.CopyFrom(jobConf);
 +  job.set_id(job_id_);
 +  Trainer trainer;
 +  trainer.Start(resume, singaConf, &job);
 +}
 +
 +}  // namespace singa


Mime
View raw message