singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wan...@apache.org
Subject incubator-singa git commit: SINGA-60 Make learning rate and param init modular
Date Wed, 19 Aug 2015 13:13:29 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master 97141e2e0 -> 6afa895b8


SINGA-60 Make learning rate and param init modular

Created a base class for getting learning rate, which is changed during
training.
Created a base class for initializing parameter values.

SINGA comes with a couple of built-in implementations for the two base
classes.
Users can also implement their own learning rate changing methods and
parameter initializing methods by extending the correponding base
classes.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6afa895b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6afa895b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6afa895b

Branch: refs/heads/master
Commit: 6afa895b8ea060a532ea01f1f4484c9db11a2496
Parents: 97141e2
Author: Wei Wang <wangwei@comp.nus.edu.sg>
Authored: Wed Aug 19 17:36:39 2015 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Wed Aug 19 19:19:38 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/job.conf |  71 +++++++++++++----------
 examples/mnist/conv.conf  |  71 +++++++++++++----------
 examples/mnist/job.conf   | 108 +++++++++++++++++++++--------------
 include/driver.h          |  33 +++++++++++
 include/trainer/worker.h  |   2 -
 include/utils/param.h     |  46 ++++++++++++++-
 include/utils/updater.h   |  89 +++++++++++++++++++----------
 src/driver.cc             |  23 +++++++-
 src/proto/job.proto       |  67 ++++++++++++----------
 src/trainer/server.cc     |   1 -
 src/trainer/worker.cc     |  25 +++------
 src/utils/param.cc        |  93 +++++++++++++++++++------------
 src/utils/updater.cc      | 124 ++++++++++++++++++-----------------------
 13 files changed, 467 insertions(+), 286 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/examples/cifar10/job.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/job.conf b/examples/cifar10/job.conf
index b294f03..f44ca50 100644
--- a/examples/cifar10/job.conf
+++ b/examples/cifar10/job.conf
@@ -5,16 +5,18 @@ test_freq:300
 disp_freq:30
 alg: kBP
 updater{
-  weight_decay:0.004
-  lr_change: kFixedStep
   type: kSGD
-  fixedstep_conf:{
-    step:0
-    step:60000
-    step:65000
-    step_lr:0.001
-    step_lr:0.0001
-    step_lr:0.00001
+  weight_decay:0.004
+  learning_rate {
+    type: kFixedStep
+    fixedstep_conf:{
+      step:0
+      step:60000
+      step:65000
+      step_lr:0.001
+      step_lr:0.0001
+      step_lr:0.00001
+    }
   }
 }
 neuralnet {
@@ -63,15 +65,18 @@ neuralnet {
     }
     param {
       name: "w1"
-      init_method:kGaussian
-      std:0.0001
-      lr_scale:1.0
+      init {
+        type:kGaussian
+        std:0.0001
+      }
     }
     param {
       name: "b1"
-      init_method: kConstant
       lr_scale:2.0
-      value:0
+      init {
+        type: kConstant
+        value:0
+      }
     }
   }
 
@@ -112,15 +117,18 @@ neuralnet {
     }
     param {
       name: "w2"
-      init_method:kGaussian
-      std:0.01
-      lr_scale:1.0
+      init {
+        type:kGaussian
+        std:0.01
+      }
     }
     param {
       name: "b2"
-      init_method: kConstant
       lr_scale:2.0
-      value:0
+      init {
+        type: kConstant
+        value:0
+      }
     }
   }
   layer {
@@ -160,13 +168,17 @@ neuralnet {
     }
     param {
       name: "w3"
-      init_method:kGaussian
-      std:0.01
+      init {
+        type:kGaussian
+        std:0.01
+      }
     }
     param {
       name: "b3"
-      init_method: kConstant
-      value:0
+      init {
+        type: kConstant
+        value:0
+      }
     }
   }
   layer {
@@ -193,17 +205,20 @@ neuralnet {
     }
     param {
       name: "w4"
-      init_method:kGaussian
-      std:0.01
-      lr_scale:1.0
       wd_scale:250
+      init {
+        type:kGaussian
+        std:0.01
+      }
     }
     param {
       name: "b4"
-      init_method: kConstant
       lr_scale:2.0
       wd_scale:0
-      value:0
+      init {
+        type: kConstant
+        value:0
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/examples/mnist/conv.conf
----------------------------------------------------------------------
diff --git a/examples/mnist/conv.conf b/examples/mnist/conv.conf
index 3509a36..1d4d740 100644
--- a/examples/mnist/conv.conf
+++ b/examples/mnist/conv.conf
@@ -4,15 +4,17 @@ test_steps:100
 test_freq:500
 disp_freq:50
 alg: kBP
-updater{
-  base_lr:0.01
+updater {
   momentum:0.9
   weight_decay:0.0005
-  lr_change: kInverse
   type: kSGD
-  inverse_conf {
-    gamma:0.0001
-    pow:0.75
+  learning_rate {
+    type : kInverse
+    base_lr:0.01
+    inverse_conf {
+      gamma:0.0001
+      pow:0.75
+    }
   }
 }
 neuralnet {
@@ -61,16 +63,19 @@ neuralnet {
       stride: 1
     }
     param{
-        name: "w1"
-        init_method:kUniformSqrtFanIn
-        lr_scale:1.0
+      name: "w1"
+      init {
+        type : kUniformSqrtFanIn
       }
+    }
     param{
-        name: "b1"
-        init_method: kConstant
-        lr_scale:2.0
+      name: "b1"
+      init {
+        type : kConstant
         value:0
       }
+      lr_scale:2.0
+    }
   }
   layer {
     name: "pool1"
@@ -92,16 +97,19 @@ neuralnet {
       stride: 1
     }
     param{
-        name: "w2"
-        init_method:kUniformSqrtFanIn
-        lr_scale:1.0
+      name: "w2"
+      init {
+        type :kUniformSqrtFanIn
       }
+    }
     param{
-        name: "b2"
-        init_method: kConstant
-        lr_scale:2.0
+      name: "b2"
+      init {
+        type : kConstant
         value:0
       }
+      lr_scale:2.0
+    }
   }
   layer {
     name: "pool2"
@@ -121,17 +129,19 @@ neuralnet {
       num_output: 500
     }
     param{
-        name: "w3"
-        init_method:kUniformSqrtFanIn
-        lr_scale:1.0
+      name: "w3"
+      init {
+        type :kUniformSqrtFanIn
       }
+    }
     param{
-        name: "b3"
-        init_method: kConstant
-        lr_scale:2.0
+      name: "b3"
+      init {
+        type : kConstant
         value:0
+      }
+      lr_scale:2.0
     }
-
   }
 
   layer {
@@ -149,14 +159,17 @@ neuralnet {
     }
     param {
       name: "w4"
-      init_method:kUniformSqrtFanIn
-      lr_scale:1
+      init {
+        type :kUniformSqrtFanIn
+      }
     }
     param {
       name: "b4"
-      init_method: kConstant
+      init {
+        type : kConstant
+        value:0
+      }
       lr_scale:2
-      value:0
     }
   }
   layer{

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/examples/mnist/job.conf
----------------------------------------------------------------------
diff --git a/examples/mnist/job.conf b/examples/mnist/job.conf
index 34fbca2..360e1ec 100644
--- a/examples/mnist/job.conf
+++ b/examples/mnist/job.conf
@@ -5,12 +5,14 @@ test_freq:60
 disp_freq:10
 alg: kBP
 updater{
-  base_lr: 0.001
-  lr_change: kStep
   type: kSGD
-  step_conf{
-    change_freq: 60
-    gamma: 0.997
+  learning_rate{
+    type : kStep
+    base_lr: 0.001
+    step_conf{
+      change_freq: 60
+      gamma: 0.997
+    }
   }
 }
 
@@ -61,15 +63,19 @@ neuralnet {
     }
     param{
       name: "w1"
-      init_method: kUniform
-      low:-0.05
-      high:0.05
+      init {
+        type: kUniform
+        low:-0.05
+        high:0.05
+      }
     }
     param{
       name: "b1"
-      init_method: kUniform
-      low: -0.05
-      high:0.05
+      init {
+        type : kUniform
+        low: -0.05
+        high:0.05
+      }
     }
   }
 
@@ -87,15 +93,19 @@ neuralnet {
     }
     param{
       name: "w2"
-      init_method: kUniform
-      low:-0.05
-      high:0.05
+      init {
+        type: kUniform
+        low:-0.05
+        high:0.05
+      }
     }
     param{
       name: "b2"
-      init_method: kUniform
-      low: -0.05
-      high:0.05
+      init {
+        type: kUniform
+        low: -0.05
+        high:0.05
+      }
     }
   }
 
@@ -113,15 +123,19 @@ neuralnet {
     }
     param{
       name: "w3"
-      init_method: kUniform
-      low:-0.05
-      high:0.05
+      init{
+        type: kUniform
+        low:-0.05
+        high:0.05
+      }
     }
     param{
       name: "b3"
-      init_method: kUniform
-      low: -0.05
-      high:0.05
+      init {
+        type : kUniform
+        low: -0.05
+        high:0.05
+      }
     }
 
   }
@@ -140,15 +154,19 @@ neuralnet {
     }
     param{
       name: "w4"
-      init_method: kUniform
-      low:-0.05
-      high:0.05
+      init {
+        type : kUniform
+        low:-0.05
+        high:0.05
+      }
     }
     param{
       name: "b4"
-      init_method: kUniform
-      low: -0.05
-      high:0.05
+      init {
+        type : kUniform
+        low: -0.05
+        high:0.05
+      }
     }
 
   }
@@ -167,15 +185,19 @@ neuralnet {
     }
     param{
       name: "w5"
-      init_method: kUniform
-      low:-0.05
-      high:0.05
+      init {
+        type : kUniform
+        low:-0.05
+        high:0.05
+      }
     }
     param{
       name: "b5"
-      init_method: kUniform
-      low: -0.05
-      high:0.05
+      init {
+        type : kUniform
+        low: -0.05
+        high:0.05
+      }
     }
 
   }
@@ -194,15 +216,19 @@ neuralnet {
     }
     param{
       name: "w6"
-      init_method: kUniform
-      low:-0.05
-      high:0.05
+      init {
+        type : kUniform
+        low:-0.05
+        high:0.05
+      }
     }
     param{
       name: "b6"
-      init_method: kUniform
-      low: -0.05
-      high:0.05
+      init {
+        type : kUniform
+        low: -0.05
+        high:0.05
+      }
     }
   }
   layer{

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/driver.h
----------------------------------------------------------------------
diff --git a/include/driver.h b/include/driver.h
index fcaab12..5a9ddfc 100644
--- a/include/driver.h
+++ b/include/driver.h
@@ -34,6 +34,16 @@ class Driver {
   template<typename Subclass, typename Type>
   int RegisterUpdater(const Type& type);
   /**
+   * Register a learning rate generator subclasses.
+   *
+   * @param type ID of the subclass. If called to register built-in subclasses,
+   * it is from ChangeMethod; if called to register user-defined
+   * subclass, it is a string;
+   * @return 0 if success; otherwise -1.
+   */
+  template<typename Subclass, typename Type>
+  int RegisterLRGenerator(const Type& type);
+  /**
    * Register a Worker subclass.
    *
    * @param type ID of the subclass. If called to register built-in subclasses,
@@ -54,6 +64,17 @@ class Driver {
   template<typename Subclass, typename Type>
   int RegisterParam(const Type& type);
   /**
+   * Register ParamGenerator subclasses for initalizing Param objects.
+   *
+   * @param type ID of the subclass. If called to register built-in subclasses,
+   * it is from InitMethod; if called to register user-defined
+   * subclass, it is a string;
+   * @return 0 if success; otherwise -1.
+   */
+  template<typename Subclass, typename Type>
+  int RegisterParamGenerator(const Type& type);
+
+  /**
    * Submit the job configuration for starting the job.
    * @param resume resume from last checkpoint if true.
    * @param job job configuration
@@ -90,12 +111,24 @@ int Driver::RegisterParam(const Type& type) {
   return 1;
 }
 template<typename Subclass, typename Type>
+int Driver::RegisterParamGenerator(const Type& type) {
+  auto factory = Singleton<Factory<singa::ParamGenerator>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, ParamGenerator));
+  return 1;
+}
+template<typename Subclass, typename Type>
 int Driver::RegisterUpdater(const Type& type) {
   auto factory = Singleton<Factory<singa::Updater>>::Instance();
   factory->Register(type, CreateInstance(Subclass, Updater));
   return 1;
 }
 template<typename Subclass, typename Type>
+int Driver::RegisterLRGenerator(const Type& type) {
+  auto factory = Singleton<Factory<singa::LRGenerator>>::Instance();
+  factory->Register(type, CreateInstance(Subclass, LRGenerator));
+  return 1;
+}
+template<typename Subclass, typename Type>
 int Driver::RegisterWorker(const Type& type) {
   auto factory = Singleton<Factory<singa::Worker>>::Instance();
   factory->Register(type, CreateInstance(Subclass, Worker));

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
index c50b54f..cc5a745 100644
--- a/include/trainer/worker.h
+++ b/include/trainer/worker.h
@@ -2,7 +2,6 @@
 #define SINGA_TRAINER_WORKER_H_
 #include "neuralnet/neuralnet.h"
 #include "proto/job.pb.h"
-#include "utils/updater.h"
 #include "communication/socket.h"
 
 namespace singa {
@@ -177,7 +176,6 @@ class Worker {
   JobProto job_conf_;
   shared_ptr<NeuralNet> train_net_, test_net_, validation_net_;
   Dealer* layer_dealer_, *dealer_;
-  Updater* updater_;
 };
 
 class BPWorker: public Worker{

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 83f64ed..f7a0982 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -6,6 +6,51 @@
 #include "utils/blob.h"
 #include "communication/msg.h"
 
+namespace singa {
+
+/**
+ * Base parameter generator which intializes parameter values.
+ */
+
+class ParamGenerator {
+ public:
+  static ParamGenerator* Create(const ParamGenProto& proto);
+  virtual ~ParamGenerator() {}
+
+  virtual void Init(const ParamGenProto& proto) {
+    proto_ = proto;
+  }
+
+  virtual void Fill(Blob<float>* data);
+
+ protected:
+  ParamGenProto proto_;
+};
+
+class GaussianGen: public ParamGenerator {
+ public:
+  void  Fill(Blob<float>* data) override;
+};
+
+class UniformGen: public ParamGenerator {
+ public:
+  void  Fill(Blob<float>* data) override;
+};
+
+class GaussianSqrtFanInGen: public GaussianGen {
+ public:
+  void  Fill(Blob<float>* data) override;
+};
+
+class UniformSqrtFanInGen: public UniformGen {
+ public:
+  void Fill(Blob<float>* data) override;
+};
+
+class UniformSqrtFanInOutGen: public UniformGen {
+ public:
+  void Fill(Blob<float>* data) override;
+};
 /**
  * Base paramter class.
  *
@@ -24,7 +69,6 @@
  * load-balance among servers. Hence, we slice large Param objects into small
  * pieces. At the server side, one slice is a Param object.
  */
-namespace singa {
 class Param {
  public:
   static Param* Create(const ParamProto& proto);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/include/utils/updater.h
----------------------------------------------------------------------
diff --git a/include/utils/updater.h b/include/utils/updater.h
index 92ddf6c..46d2c53 100644
--- a/include/utils/updater.h
+++ b/include/utils/updater.h
@@ -6,60 +6,94 @@
 
 namespace singa {
 /**
+ * Base learning rate generator.
+ *
+ * Generate learning rate for a give training step/iteration.
+ * There are many different ways to change the learning rate through time/step.
+ * Users can inherint this class to implment their own change method.
+ */
+class LRGenerator {
+ public:
+  static LRGenerator* Create(const LRGenProto& proto);
+  virtual ~LRGenerator() {}
+
+  virtual void Init(const LRGenProto& proto) {
+    proto_ = proto;
+  }
+
+  /**
+   * @param step training step/iteration.
+   * @return base learning rate regardless of step
+   */
+  virtual float Get(int step) {
+    return proto_.base_lr();
+  }
+
+ protected:
+  LRGenProto proto_;
+};
+
+class FixedStepLRGen : public LRGenerator {
+ public:
+  float Get(int step) override;
+ private:
+  int last_idx_ = 0;
+};
+class StepLRGen : public LRGenerator {
+ public:
+  float Get(int step) override;
+};
+class LinearLRGen : public LRGenerator {
+ public:
+  float Get(int step) override;
+};
+class ExpLRGen : public LRGenerator {
+ public:
+  float Get(int step) override;
+};
+class InvLRGen : public LRGenerator {
+ public:
+  float Get(int step) override;
+};
+class InvTLRGen : public LRGenerator {
+ public:
+  float Get(int step) override;
+};
+/**
  * Updater for Param.
  */
 class Updater{
  public:
   static Updater* Create(const UpdaterProto& proto);
   virtual ~Updater() {}
-  virtual void Init(const UpdaterProto &proto) {
-    proto_ = proto;
-  }
+  virtual void Init(const UpdaterProto &proto);
   virtual void Update(int step, Param* param, float grad_scale = 1.0f) = 0;
 
-  float GetLearningRate(int step);
-
  protected:
   UpdaterProto proto_;
+  LRGenerator* lr_gen_;
+  float weight_decay_;
+  float momentum_;
 };
 
 class SGDUpdater : public Updater {
  public:
-  virtual void Init(const UpdaterProto& proto);
-  virtual void Update(int step, Param* param, float grad_scale = 1.0f);
-
- protected:
-  float base_lr_;
-  float momentum_;
-  float weight_decay_;
+  void Update(int step, Param* param, float grad_scale = 1.0f);
 };
 
 class AdaGradUpdater : public Updater{
  public:
-  virtual void Init(const UpdaterProto& proto);
-  virtual void Update(int step, Param* param, float grad_scale = 1.0f);
-
- protected:
-  float base_lr_;
-  float delta_;
-  float weight_decay_;
+  void Update(int step, Param* param, float grad_scale = 1.0f) override;
 };
 
 
 class NesterovUpdater : public Updater {
  public:
-  virtual void Init(const UpdaterProto& proto);
-  virtual void Update(int step, Param* param, float grad_scale = 1.0f);
-
- protected:
-  float base_lr_;
-  float momentum_;
-  float weight_decay_;
+  void Update(int step, Param* param, float grad_scale = 1.0f) override;
 };
 /*
 class RMSPropUpdater : public Updater{
  public:
-  virtual void Init(const UpdaterProto& proto);
   virtual void Update(int step, Param* param, float grad_scale=1.0f);
 
  protected:
@@ -71,7 +105,6 @@ class RMSPropUpdater : public Updater{
 
 class AdaDeltaUpdater : public Updater{
  public:
-  virtual void Init(const UpdaterProto& proto);
   virtual void Update(int step, Param* param, float grad_scale=1.0f);
 
  protected:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index b79b609..1bc712d 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -50,18 +50,35 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<LMDBDataLayer, int>(kLMDBData);
 #endif
 
-  // register updater
+  // register updaters
   RegisterUpdater<AdaGradUpdater>(kAdaGrad);
   RegisterUpdater<NesterovUpdater>(kNesterov);
   // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp);
   RegisterUpdater<SGDUpdater>(kSGD);
 
-  // register worker
+  // register learning rate change methods
+  RegisterLRGenerator<LRGenerator>(kFixed);
+  RegisterLRGenerator<FixedStepLRGen>(kFixedStep);
+  RegisterLRGenerator<StepLRGen>(kStep);
+  RegisterLRGenerator<LinearLRGen>(kLinear);
+  RegisterLRGenerator<ExpLRGen>(kExponential);
+  RegisterLRGenerator<InvLRGen>(kInverse);
+  RegisterLRGenerator<InvTLRGen>(kInverseT);
+
+  // register workers
   RegisterWorker<BPWorker>(kBP);
   RegisterWorker<CDWorker>(kCD);
 
-  // register param
+  // register params
   RegisterParam<Param>(0);
+
+  // register param init methods
+  RegisterParamGenerator<ParamGenerator>(kConstant);
+  RegisterParamGenerator<GaussianGen>(kGaussian);
+  RegisterParamGenerator<UniformGen>(kUniform);
+  RegisterParamGenerator<GaussianSqrtFanInGen>(kGaussianSqrtFanIn);
+  RegisterParamGenerator<UniformSqrtFanInGen>(kUniformSqrtFanIn);
+  RegisterParamGenerator<UniformSqrtFanInOutGen>(kUniformSqrtFanInOut);
 }
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index 80c8b65..b4abe68 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -101,21 +101,11 @@ message UpdaterProto {
   // configuration for RMSProp algorithm
   optional RMSPropProto rmsprop_conf = 3;
 
-  // built-in change method for learning rate
-  optional ChangeMethod lr_change = 10 [default = kUserChange];
-  // user-defined change method
-  optional string user_lr_change = 11;
-
-  optional FixedStepProto fixedstep_conf = 40;
-  optional StepProto step_conf = 41;
-  optional LinearProto linear_conf = 42;
-  optional ExponentialProto exponential_conf = 43;
-  optional InverseProto inverse_conf = 44;
-  optional InverseTProto inverset_conf = 45;
+  // learning rate generator
+  optional LRGenProto learning_rate = 11;
   optional float momentum = 31 [default = 0];
   optional float weight_decay = 32 [default = 0];
-  // base learning rate
-  optional float base_lr = 34 [default = 0];
+
   // used to avoid divide by 0, i.e. x/(y+delta)
   optional float delta = 35 [default = 0.00000001];
 
@@ -220,24 +210,13 @@ message LayerProto {
 message ParamProto {
   // used for identifying the same params from diff models and display deug info
   optional string name =  1 [default = ""];
-  optional InitMethod init_method = 2 [default = kGaussian];
   // for built-in Param
   optional ParamType type = 3 [default = kParam];
   // for user-defined Param
   optional string user_type = 4;
-  // constant init
-  optional float value = 5 [default = 1];
-  // for uniform sampling
-  optional UniformProto uniform_conf = 6;
-  optional float low = 7 [default = -1];
-  optional float high = 8 [default = 1];
-
-  // for gaussian sampling
-  optional GaussianProto gaussian_conf = 9;
-  optional float mean = 10 [default = 0];
-  optional float std = 11 [default = 1];
 
-  // multiplied on the global learning rate.
+  optional ParamGenProto init =5;
+    // multiplied on the global learning rate.
   optional float lr_scale = 15 [default = 1];
   // multiplied on the global weight decay.
   optional float wd_scale = 16 [default = 1];
@@ -260,6 +239,38 @@ message ParamProto {
 // ---------------------------
 // protos for different layers
 // ---------------------------
+// learning rate generator proto
+message LRGenProto {
+  // user-defined change method
+  optional ChangeMethod type = 1 [default = kUserChange];
+  optional string user_type = 2;
+
+  optional float base_lr = 3 [default = 0.01];
+
+  optional FixedStepProto fixedstep_conf = 40;
+  optional StepProto step_conf = 41;
+  optional LinearProto linear_conf = 42;
+  optional ExponentialProto exponential_conf = 43;
+  optional InverseProto inverse_conf = 44;
+  optional InverseTProto inverset_conf = 45;
+
+  extensions 101 to 200;
+}
+
+message ParamGenProto {
+  optional InitMethod type = 1 [default = kUserInit];
+  optional string user_type =2;
+  // constant init
+  optional float value = 3 [default = 1];
+  // for gaussian sampling
+  optional float mean = 4 [default = 0];
+  optional float std = 5 [default = 1];
+  // for uniform sampling
+  optional float low = 8 [default = -1];
+  optional float high = 9 [default = 1];
+
+  extensions 101 to 200;
+}
 
 message RGBImageProto {
   // scale factor for each pixel
@@ -476,11 +487,9 @@ enum InitMethod {
   kGaussian = 1;
   // uniform sampling between low and high
   kUniform = 2;
-  // copy the content and history which are from previous training
-  kPretrained = 3;
   // from Toronto Convnet, let a=1/sqrt(fan_in), w*=a after generating from
   // Gaussian distribution
-  kGaussainSqrtFanIn = 4;
+  kGaussianSqrtFanIn = 4;
   // from Toronto Convnet, rectified linear activation, let
   // a=sqrt(3)/sqrt(fan_in), range is [-a, +a]; no need to set value=sqrt(3),
   // the program will multiply it.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index a8483de..1fda336 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -21,7 +21,6 @@ void Server::Setup(const UpdaterProto& proto,
     std::unordered_map<int, ParamEntry*>* shard,
     const vector<int>& slice2group) {
   updater_ = Updater::Create(proto);
-  updater_->Init(proto);
   shard_ = shard;
   slice2group_ = slice2group;
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 25fea7c..e047367 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -24,7 +24,6 @@ void Worker::Init(int thread_id, int grp_id, int id) {
   grp_id_ = grp_id;
   id_ = id;
   layer_dealer_ = dealer_ = nullptr;
-  updater_ = nullptr;
 }
 
 void Worker::Setup(
@@ -141,10 +140,8 @@ void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity)
{
 void Worker::Run() {
   LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_
<< ") start";
   auto cluster = Cluster::Get();
-  if (updater_==nullptr) {
-    int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
-    CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));
-  }
+  int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
+  CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));
   dealer_ = new Dealer(2*thread_id_);
   ConnectStub(grp_id_, id_, dealer_, kWorkerParam);
   for (auto layer : train_net_->layers()) {
@@ -190,10 +187,7 @@ void Worker::Run() {
   Checkpoint(step_, train_net_);
 
   // clean up
-  if(updater_ == nullptr) {
-    int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
-    cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp);
-  }
+  cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp);
   // notify the stub on worker stop
   Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1,-1, kStub));
   msg->set_type(kStop);
@@ -224,15 +218,10 @@ int Worker::Get(Param* param, int step) {
 
 int Worker::Update(Param* param, int step) {
   param->set_local_version(param->version());
-  if (updater_) {
-    updater_->Update(step, param);
-    param->set_version(param->version() + 1);
-  } else {
-    Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
-    msg->set_trgt(ParamTrgt(param->owner(), 0), step);
-    msg->set_type(kUpdate);
-    dealer_->Send(&msg);
-  }
+  Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
+  msg->set_type(kUpdate);
+  dealer_->Send(&msg);
   return 1;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index e658631..67f14ab 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -12,6 +12,60 @@ using namespace mshadow;
 using std::vector;
 using std::string;
 
+ParamGenerator* ParamGenerator::Create(const ParamGenProto& proto) {
+  auto factory = Singleton<Factory<ParamGenerator>>::Instance();
+  ParamGenerator * gen = nullptr;
+  if (proto.has_user_type())
+    gen = factory->Create(proto.user_type());
+  else
+    gen = factory->Create(proto.type());
+  gen->Init(proto);
+  return gen;
+}
+
+void ParamGenerator::Fill (Blob<float>* blob) {
+  Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
+  data = proto_.value();
+}
+void GaussianGen::Fill (Blob<float>* blob) {
+  Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
+  auto random = TSingleton<Random<cpu>>::Instance();
+  random->SampleGaussian(data, proto_.mean(), proto_.std());
+  if(proto_.value() != 1)
+    data *= proto_.value();
+}
+void GaussianSqrtFanInGen::Fill (Blob<float>* blob) {
+  // only valid for param matrix with num of cols as fan in
+  CHECK_EQ(blob->shape().size(), 2);
+  Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
+  GaussianGen::Fill(blob);
+  data /= sqrt(blob->shape().at(1));
+}
+
+void UniformGen::Fill (Blob<float>* blob) {
+  Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
+  auto random = TSingleton<Random<cpu>>::Instance();
+  random->SampleUniform(data, proto_.low(), proto_.high());
+  if(proto_.value() != 1)
+    data *= proto_.value();
+}
+
+void UniformSqrtFanInGen::Fill (Blob<float>* blob) {
+  // only valid for param matrix with num of cols as fan in
+  CHECK_EQ(blob->shape().size(), 2);
+  Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
+  UniformGen::Fill(blob);
+  data /= sqrt(blob->shape().at(1) / 3.0f);
+}
+
+void UniformSqrtFanInOutGen::Fill (Blob<float>* blob) {
+  // only valid for param matrix with num of cols as fan in
+  CHECK_EQ(blob->shape().size(), 2);
+  Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
+  UniformGen::Fill(blob);
+  data /= sqrt(blob->shape()[0] + blob->shape()[1]);
+}
+/*****************Param***********************************/
 Param* Param::Create(const ParamProto& proto) {
   Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
   Param* p = nullptr;
@@ -51,43 +105,8 @@ void Param::AddSlice(int slice_id, int size) {
 }
 
 void Param::InitValues(int version) {
-  Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
-  auto random = TSingleton<Random<cpu>>::Instance();
-  switch (proto_.init_method()) {
-  case InitMethod::kConstant:
-    data = proto_.value();
-    break;
-  case InitMethod::kUniform:
-    random->SampleUniform(data, proto_.low(), proto_.high());
-    if(proto_.value() != 1)
-      data *= proto_.value();
-    break;
-  case InitMethod::kUniformSqrtFanIn:
-    // only valid for param matrix with num of cols as fan in
-    CHECK_EQ(data_->shape().size(), 2);
-    random->SampleUniform(data, proto_.low(), proto_.high());
-    data *= proto_.value() / sqrt(data_->shape().at(1) / 3.0f);
-    break;
-  case InitMethod::kUniformSqrtFanInOut:
-    random->SampleUniform(data, proto_.low(), proto_.high());
-    if (proto_.value())
-      data *= proto_.value() / sqrt(data_->shape()[0] + data_->shape()[1]);
-    break;
-  case InitMethod::kGaussian:
-    random->SampleGaussian(data, proto_.mean(), proto_.std());
-    if(proto_.value() != 1)
-      data *= proto_.value();
-    break;
-  case InitMethod::kGaussainSqrtFanIn:
-    // only valid for param matrix with num of cols as fan in
-    CHECK_EQ(data_->shape().size(), 2);
-    random->SampleGaussian(data, proto_.mean(), proto_.std());
-    data *= proto_.value() / sqrt(data_->shape().at(1));
-    break;
-  default:
-    LOG(ERROR) << "Illegal parameter init method ";
-    break;
-  }
+  ParamGenerator* gen = ParamGenerator::Create(proto_.init());
+  gen->Fill(data_.get());
   set_version(version);
 }
 void Param::FromProto(const BlobProto& blob) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6afa895b/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
index 7d80844..24487d3 100644
--- a/src/utils/updater.cc
+++ b/src/utils/updater.cc
@@ -10,6 +10,53 @@ namespace  singa {
 using namespace mshadow;
 using namespace mshadow::expr;
 
+/**********************Learning rate generator******************************/
+LRGenerator* LRGenerator::Create(const LRGenProto& proto) {
+  auto factory = Singleton<Factory<LRGenerator>>::Instance();
+  LRGenerator* gen = nullptr;
+  if (proto.has_user_type())
+    gen = factory->Create(proto.user_type());
+  else
+    gen = factory->Create(proto.type());
+  gen->Init(proto);
+  return gen;
+}
+
+float FixedStepLRGen::Get(int step) {
+  if (last_idx_ < proto_.fixedstep_conf().step_size() -1
+      && step >= proto_.fixedstep_conf().step(last_idx_ + 1)) {
+      last_idx_ ++;
+    }
+  return proto_.fixedstep_conf().step_lr(last_idx_);
+}
+
+float StepLRGen::Get(int step) {
+  // do not cast int to float
+  int freq = proto_.step_conf().change_freq();
+  return  proto_.base_lr() * pow(proto_.step_conf().gamma(), step / freq);
+}
+
+float LinearLRGen::Get(int step) {
+  int freq = proto_.linear_conf().change_freq();
+  float r = step * 1.0  / freq;
+  return (1.0 - r) * proto_.base_lr() + r * proto_.linear_conf().final_lr();
+}
+
+float ExpLRGen::Get(int step) {
+  int freq = proto_.exponential_conf().change_freq();
+  return proto_.base_lr() / pow(2, step * 1. / freq);
+}
+
+float InvLRGen::Get(int step) {
+  return proto_.base_lr() * pow(1.f + proto_.inverse_conf().gamma() * step,
+           - proto_.inverse_conf().pow());
+}
+
+float InvTLRGen::Get(int step) {
+  return proto_.base_lr() / (1 + step * 1. / proto_.inverset_conf().final_lr());
+}
+
+/***********************Updater********************************/
 
 Updater* Updater::Create(const UpdaterProto& proto) {
   auto factory = Singleton<Factory<Updater>>::Instance();
@@ -18,69 +65,23 @@ Updater* Updater::Create(const UpdaterProto& proto) {
     updater = factory->Create(proto.user_type());
   else
     updater = factory->Create(proto.type());
+  updater->Init(proto);
   return updater;
 }
-float Updater::GetLearningRate(int step) {
-  float ret = 0., r = 0., base = proto_.base_lr();
-  int freq = 0;
-  switch (proto_.lr_change()) {
-    case ChangeMethod::kFixed:
-      ret = base;
-      break;
-    case ChangeMethod::kLinear:
-      // a is init, b is the final
-      freq = proto_.linear_conf().change_freq();
-      r = step * 1.0  / freq;
-      ret = (1.0 - r) * base + r * proto_.linear_conf().final_lr();
-      break;
-    case ChangeMethod::kExponential:
-      // a is init, b is the final, from convnet
-      freq = proto_.exponential_conf().change_freq();
-      ret = base / pow(2, step * 1. / freq);
-      break;
-    case ChangeMethod::kInverseT:
-      // a is init, b is the final, from convnet
-      CHECK_EQ(base, 2 * proto_.inverset_conf().final_lr())
-        << "final value should be the half";
-      ret = base / (1. + step * 1. / proto_.inverset_conf().final_lr());
-      break;
-    case ChangeMethod::kInverse:
-      // a is init, b is gamma, c is pow
-      ret = base * pow(1.f + proto_.inverse_conf().gamma() * step,
-           - proto_.inverse_conf().pow());
-      break;
-    case ChangeMethod::kStep:
-      // a is the base learning rate, b is gamma, from caffe
-      // notice it is step/change_steps, not step*1.0/change_steps
-      freq = proto_.step_conf().change_freq();
-      ret = base * pow(proto_.step_conf().gamma(), step / freq);
-      break;
-    case ChangeMethod::kFixedStep:
-      for (int i = 0; i < proto_.fixedstep_conf().step_size(); i++) {
-        if (step > proto_.fixedstep_conf().step(i))
-          ret = proto_.fixedstep_conf().step_lr(i);
-      }
-      break;
-    default:
-      LOG(ERROR) << "Wrong hyper-parameter update method";
-  }
-  return ret;
-}
 
 /***********************SGD with momentum******************************/
-void SGDUpdater::Init(const UpdaterProto& proto) {
-  Updater::Init(proto);
-  base_lr_ = proto.base_lr();
+void Updater::Init(const UpdaterProto& proto) {
   momentum_ = proto.momentum();
   weight_decay_ = proto.weight_decay();
+  lr_gen_ = LRGenerator::Create(proto.learning_rate());
 }
 
 void SGDUpdater::Update(int step, Param* param, float grad_scale) {
   Shape<1> s = Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
-  float lr = GetLearningRate(step)*param->lr_scale();
-  float wd = weight_decay_*param->wd_scale();
+  float lr = lr_gen_->Get(step) * param->lr_scale();
+  float wd = weight_decay_ * param->wd_scale();
   if (grad_scale != 1.f)
     grad *= grad_scale;
   if (wd > 0) {  // L2 regularization, should be done after timing grad_scale
@@ -97,20 +98,13 @@ void SGDUpdater::Update(int step, Param* param, float grad_scale) {
 }
 
 /***********************Nesterov******************************/
-void NesterovUpdater::Init(const UpdaterProto& proto) {
-  Updater::Init(proto);
-  base_lr_ = proto.base_lr();
-  CHECK_GT(base_lr_, 0);
-  weight_decay_ = proto.weight_decay();
-}
-
 void NesterovUpdater::Update(int step, Param* param, float grad_scale) {
   Shape<1> s = Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
   Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
   TensorContainer<cpu, 1> tmp(s);
-  float lr = GetLearningRate(step)*param->lr_scale();
+  float lr = lr_gen_->Get(step)*param->lr_scale();
   float wd = weight_decay_*param->wd_scale();
   if (grad_scale != 1.f)
     grad *= grad_scale;
@@ -123,20 +117,12 @@ void NesterovUpdater::Update(int step, Param* param, float grad_scale)
{
   data -= tmp;
 }
 /***********************AdaGrad******************************/
-void AdaGradUpdater::Init(const UpdaterProto& proto) {
-  Updater::Init(proto);
-  base_lr_ = proto.base_lr();
-  CHECK_GT(base_lr_, 0);
-  delta_ = proto.delta();
-  weight_decay_ = proto.weight_decay();
-}
-
 void AdaGradUpdater::Update(int step, Param* param, float grad_scale) {
   Shape<1> s = Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
   Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
-  float lr = GetLearningRate(step)*param->lr_scale();
+  float lr = lr_gen_->Get(step)*param->lr_scale();
   float wd = weight_decay_*param->wd_scale();
   if (grad_scale != 1.f)
     grad *= grad_scale;
@@ -144,7 +130,7 @@ void AdaGradUpdater::Update(int step, Param* param, float grad_scale)
{
     grad += data * wd;
   }
   history += F<op::square>(grad);
-  data -= lr * grad / (F<op::sqrtop>(history, delta_));
+  data -= lr * grad / (F<op::sqrtop>(history, proto_.delta()));
 }
 
 /***********************RMSProp******************************



Mime
View raw message