singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [08/10] incubator-singa git commit: SINGA-120 - Implemented GRU and BPTT
Date Tue, 05 Jan 2016 18:10:37 GMT
SINGA-120 - Implemented GRU and BPTT

Change new memory computation formula following char-rnn (i.e., element-wise multiplication
before matrix multiplication)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/6a4c9960
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/6a4c9960
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/6a4c9960

Branch: refs/heads/master
Commit: 6a4c9960e0795aeac6df282d7190b6f93b305c58
Parents: 959ef70
Author: Wei Wang <wangwei@comp.nus.edu.sg>
Authored: Tue Jan 5 18:14:46 2016 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Wed Jan 6 01:55:08 2016 +0800

----------------------------------------------------------------------
 Makefile.am                                |  2 +
 include/singa/neuralnet/connection_layer.h | 28 +++++++++++++
 include/singa/neuralnet/layer.h            | 14 +++++++
 include/singa/neuralnet/neuron_layer.h     |  2 +-
 include/singa/neuralnet/output_layer.h     | 15 +++++++
 include/singa/utils/updater.h              |  8 ++--
 src/driver.cc                              |  6 ++-
 src/neuralnet/neuralnet.cc                 |  9 ++---
 src/neuralnet/neuron_layer/gru.cc          | 52 +++++++++++--------------
 src/proto/job.proto                        | 14 +++++++
 src/utils/updater.cc                       | 28 ++++++-------
 src/worker.cc                              | 16 ++++++--
 12 files changed, 133 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/Makefile.am
----------------------------------------------------------------------
diff --git a/Makefile.am b/Makefile.am
index d2b2aa8..7ae4537 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -75,6 +75,7 @@ SINGA_SRCS := src/driver.cc \
               src/neuralnet/connection_layer/concate.cc \
               src/neuralnet/connection_layer/slice.cc \
               src/neuralnet/connection_layer/split.cc \
+              src/neuralnet/connection_layer/rnn_dummy.cc \
               src/neuralnet/input_layer/char_rnn.cc \
               src/neuralnet/input_layer/onehot.cc \
               src/neuralnet/input_layer/csv.cc \
@@ -88,6 +89,7 @@ SINGA_SRCS := src/driver.cc \
               src/neuralnet/output_layer/argsort.cc \
               src/neuralnet/output_layer/csv.cc \
               src/neuralnet/output_layer/record.cc \
+              src/neuralnet/output_layer/char_rnn.cc \
               src/neuralnet/loss_layer/euclidean.cc \
               src/neuralnet/loss_layer/softmax.cc \
               src/neuralnet/neuron_layer/activation.cc \

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/connection_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/connection_layer.h b/include/singa/neuralnet/connection_layer.h
index a18f458..481d991 100644
--- a/include/singa/neuralnet/connection_layer.h
+++ b/include/singa/neuralnet/connection_layer.h
@@ -153,6 +153,34 @@ class SplitLayer : public ConnectionLayer {
   Layer2Index layer_idx_;
 };
 
+/**
+ * Dummy layer for RNN models, which provides input for other layers.
+ *
+ * Particularly, it is used in the test phase of RNN models to connect other
+ * layers and avoid cycles in the neural net config.
+ */
+class RNNDummyLayer : public ConnectionLayer {
+ public:
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+  void ComputeGradient(int flag, const vector<Layer*>& srclayers) {
+    LOG(FATAL) << "Not implemented";
+  }
+
+  const string srclayer(int step) const {
+    if (step > 0)
+      return dynamic_src_;
+    else
+      return "";
+  }
+
+ private:
+  string dynamic_src_;
+  float low_, high_;
+  bool integer_;
+  Layer* srclayer_;
+};
+
 
 }  // namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/layer.h b/include/singa/neuralnet/layer.h
index f4738fa..c1612a2 100644
--- a/include/singa/neuralnet/layer.h
+++ b/include/singa/neuralnet/layer.h
@@ -36,6 +36,20 @@ using std::string;
 
 // TODO(wangwei) make AuxType a template argument for Layer.
 using AuxType = int;
+
+inline const string AddUnrollingPrefix(int unroll_idx, const string& name) {
+  return std::to_string(unroll_idx) + "#" + name;
+}
+inline const string AddPartitionSuffix(int partition_idx, const string& name) {
+  return name + "@" + std::to_string(partition_idx);
+}
+
+
+inline const string AddPrefixSuffix(int unroll_idx, int partition_idx,
+    const string& name) {
+  return std::to_string(unroll_idx) + "#" + name + "@" +
+    std::to_string(partition_idx);
+}
 /**
  * Base layer class.
  *

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/neuron_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/neuron_layer.h b/include/singa/neuralnet/neuron_layer.h
index e587e38..e1a63a2 100644
--- a/include/singa/neuralnet/neuron_layer.h
+++ b/include/singa/neuralnet/neuron_layer.h
@@ -203,7 +203,7 @@ class GRULayer : public NeuronLayer {
   int batchsize_; // batch size
   int vdim_, hdim_; // dimensions
 
-  Blob<float> *update_gate, *reset_gate, *new_memory;
+  Blob<float> *update_gate, *reset_gate, *new_memory, *reset_context;
   //!< gru layer connect to two dst layers, hence need to grad blobs.
   Blob<float> aux_grad_;
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/neuralnet/output_layer.h
----------------------------------------------------------------------
diff --git a/include/singa/neuralnet/output_layer.h b/include/singa/neuralnet/output_layer.h
index c7e5d6a..9071f33 100644
--- a/include/singa/neuralnet/output_layer.h
+++ b/include/singa/neuralnet/output_layer.h
@@ -80,5 +80,20 @@ class RecordOutputLayer : public OutputLayer {
   int inst_ = 0;  //!< instance No.
   io::Store* store_ = nullptr;
 };
+
+/**
+ * Output layer for char rnn model, which convert sample id back to char and
+ * dump to stdout.
+ */
+class CharRNNOutputLayer : public OutputLayer {
+ public:
+  void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
+
+  void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
+
+ private:
+  string vocab_;
+};
+
 }  // namespace singa
 #endif  // SINGA_NEURALNET_OUTPUT_LAYER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/include/singa/utils/updater.h
----------------------------------------------------------------------
diff --git a/include/singa/utils/updater.h b/include/singa/utils/updater.h
index 575ab86..7fec78c 100644
--- a/include/singa/utils/updater.h
+++ b/include/singa/utils/updater.h
@@ -118,18 +118,16 @@ class NesterovUpdater : public Updater {
   void Update(int step, Param* param, float grad_scale) override;
 };
 
-/*
 class RMSPropUpdater : public Updater {
  public:
-  virtual void Update(int step, Param* param, float grad_scale);
+  void Init(const UpdaterProto &proto) override;
+  void Update(int step, Param* param, float grad_scale) override;
 
  protected:
-  float base_lr_;
-  float delta_;
   float rho_;
-  float weight_decay_;
 };
 
+/*
 class AdaDeltaUpdater : public Updater {
  public:
   virtual void Update(int step, Param* param, float grad_scale);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/driver.cc
----------------------------------------------------------------------
diff --git a/src/driver.cc b/src/driver.cc
index 21968bb..1e4929f 100644
--- a/src/driver.cc
+++ b/src/driver.cc
@@ -74,6 +74,7 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<CharRNNInputLayer, int>(kCharRNN);
   RegisterLayer<RNNLabelLayer, int>(kRNNLabel);
   RegisterLayer<OneHotLayer, int>(kOneHot);
+  RegisterLayer<CharRNNOutputLayer, int>(kCharRNNOutput);
 
   // connection layers
   RegisterLayer<BridgeDstLayer, int>(kBridgeDst);
@@ -81,6 +82,7 @@ void Driver::Init(int argc, char **argv) {
   RegisterLayer<ConcateLayer, int>(kConcate);
   RegisterLayer<SliceLayer, int>(kSlice);
   RegisterLayer<SplitLayer, int>(kSplit);
+  RegisterLayer<RNNDummyLayer, int>(kRNNDummy);
 
   RegisterLayer<AccuracyLayer, int>(kAccuracy);
   RegisterLayer<ArgSortLayer, int>(kArgSort);
@@ -125,7 +127,7 @@ void Driver::Init(int argc, char **argv) {
   // register updaters
   RegisterUpdater<AdaGradUpdater>(kAdaGrad);
   RegisterUpdater<NesterovUpdater>(kNesterov);
-  // TODO(wangwei) RegisterUpdater<kRMSPropUpdater>(kRMSProp);
+  RegisterUpdater<RMSPropUpdater>(kRMSProp);
   RegisterUpdater<SGDUpdater>(kSGD);
 
   // register learning rate change methods
@@ -198,6 +200,8 @@ void Driver::Test(const JobProto& job_conf) {
   auto worker = Worker::Create(job_conf.train_one_batch());
   worker->Setup(0, 0, job_conf, nullptr, nullptr, nullptr);
   auto net = NeuralNet::Create(job_conf.neuralnet(), kTest, 1);
+  WriteStringToTextFile(Cluster::Get()->vis_folder() + "/test_net.json",
+      net->ToGraph(true).ToJson());
   vector<string> paths;
   for (const auto& p : job_conf.checkpoint_path())
     paths.push_back(p);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index f9579b1..49978a1 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -144,7 +144,7 @@ const NetProto NeuralNet::Unrolling(const NetProto& net_conf) {
     for (int i = 0; i < org_layer.unroll_len(); i ++) { // unroll
       LayerProto* unroll_layer = conf.add_layer();
       unroll_layer->CopyFrom(org_layer); // create a new layer conf
-      if (org_layer.unroll_len() > 1) {
+      // if (org_layer.unroll_len() > 1) {
         // update layer names
         std::stringstream sstm;
         sstm << i << '#' << unroll_layer->name();
@@ -160,7 +160,7 @@ const NetProto NeuralNet::Unrolling(const NetProto& net_conf) {
           sstm1 << i << '#' << param->name();
           param->set_name(sstm1.str());
         }
-      }
+      // }
       // clear unrolling related fields
       unroll_layer->clear_unroll_len();
       unroll_layer->clear_unroll_conn_type();
@@ -257,6 +257,7 @@ void NeuralNet::Load(const vector<string>& paths,
     ReadProtoFromBinaryFile(path.c_str(), &bps);
     for (int i = 0; i < bps.name_size(); i++) {
       if (params.find(bps.name(i)) != params.end()) {
+        // LOG(ERROR) << "Loading param = " << bps.name(i);
         params.at(bps.name(i))->FromProto(bps.blob(i));
         params.at(bps.name(i))->set_version(bps.version(i));
       }
@@ -458,12 +459,10 @@ Graph* NeuralNet::CreateGraph(const NetProto& netproto, int npartitions)
{
   map<string, const LayerProto*> name2proto;
   for (const LayerProto& layer : net_w_connection.layer()) {
     vector<Node*> nodes;
-    char suffix[4];
     for (int i = 0; i < npartitions; i++) {
       LayerProto *proto = new LayerProto(layer);
-      snprintf(suffix, sizeof(suffix), "%02d", i);
       // differentiate partitions
-      string nodename = layer.name() + "@" + string(suffix);
+      string nodename = layer.name() + "@" + std::to_string(i);
       proto->set_name(nodename);
       proto->set_type(layer.type());
       proto->set_partition_dim(layer.partition_dim());

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/neuralnet/neuron_layer/gru.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/gru.cc b/src/neuralnet/neuron_layer/gru.cc
index 9ba5a50..cf7425b 100644
--- a/src/neuralnet/neuron_layer/gru.cc
+++ b/src/neuralnet/neuron_layer/gru.cc
@@ -98,6 +98,8 @@ void GRULayer::Setup(const LayerProto& conf,
 
   update_gate = new Blob<float>(batchsize_, hdim_);
   reset_gate = new Blob<float>(batchsize_, hdim_);
+  // reset gate x context
+  reset_context = new Blob<float>(batchsize_, hdim_);
   new_memory = new Blob<float>(batchsize_, hdim_);
 }
 
@@ -130,24 +132,23 @@ void GRULayer::ComputeFeature(int flag,
 	Map<op::Sigmoid<float>,float>(*update_gate, update_gate);
 
 	// Compute the reset gate
-	GEMM(1.0f, 0.0f, src,*w_r_hx_t,reset_gate);
+	GEMM(1.0f, 0.0f, src, *w_r_hx_t, reset_gate);
 	if (bias_r_ != nullptr)
-		MVAddRow(1.0f,1.0f,bias_r_->data(),reset_gate);
+		MVAddRow(1.0f,1.0f, bias_r_->data(),reset_gate);
 	GEMM(1.0f, 1.0f, *context, *w_r_hh_t, reset_gate);
 	Map<op::Sigmoid<float>,float>(*reset_gate, reset_gate);
 
 	// Compute the new memory
-	GEMM(1.0f, 0.0f, src, *w_c_hx_t, new_memory);
+	Mult<float>(*reset_gate, *context, reset_context);
+	GEMM(1.0f, 0.0f, *reset_context, *w_c_hh_t, new_memory);
+	GEMM(1.0f, 1.0f, src, *w_c_hx_t, new_memory);
 	if (bias_c_ != nullptr)
-		MVAddRow(1.0f,1.0f,bias_c_->data(), new_memory);
-	Mult<float>(*reset_gate, *new_memory, new_memory);
-	GEMM(1.0f, 1.0f, *context, *w_c_hh_t, new_memory);
+		MVAddRow(1.0f, 1.0f, bias_c_->data(), new_memory);
 	Map<op::Tanh<float>,float>(*new_memory, new_memory);
 
-
-  Sub(*context, *new_memory, &data_);
+  Sub(*new_memory, *context, &data_);
   Mult(data_, *update_gate, &data_);
-  Add(data_, *new_memory, &data_);
+  AXPY(1.0f, *context, &data_);
 
 	// delete the pointers
 	if (srclayers.size() == 1)
@@ -192,24 +193,19 @@ void GRULayer::ComputeGradient(int flag,
 	Map<singa::op::TanhGrad<float>, float>(*new_memory, &dnewmdc);
 
 	Blob<float> dLdz (batchsize_, hdim_);
-	Sub<float>(*context, *new_memory, &dLdz);
+	Sub<float>(*new_memory, *context, &dLdz);
 	Mult<float>(dLdz, grad_, &dLdz);
 	Mult<float>(dLdz, dugatedz, &dLdz);
 
 	Blob<float> dLdc (batchsize_,hdim_);
-	Blob<float> z1 (batchsize_,hdim_);
-  z1.SetValue(1.0f);
-	AXPY<float>(-1.0f, *update_gate, &z1);
-	Mult(grad_,z1,&dLdc);
-	Mult(dLdc,dnewmdc,&dLdc);
+	Mult(grad_, *update_gate, &dLdc);
+	Mult(dLdc, dnewmdc, &dLdc);
 
 	Blob<float> reset_dLdc (batchsize_,hdim_);
-	Mult(dLdc, *reset_gate, &reset_dLdc);
+  GEMM(1.0f, 0.0f, dLdc, weight_c_hh_->data(), &reset_dLdc);
 
 	Blob<float> dLdr (batchsize_, hdim_);
-	Blob<float> cprev (batchsize_, hdim_);
-	GEMM(1.0f, 0.0f, *context, weight_c_hh_->data().T(), &cprev);
-	Mult(dLdc, cprev, &dLdr);
+	Mult(reset_dLdc, *context, &dLdr);
 	Mult(dLdr, drgatedr, &dLdr);
 
 	// Compute gradients for parameters of update gate
@@ -230,29 +226,25 @@ void GRULayer::ComputeGradient(int flag,
 
 	// Compute gradients for parameters of new memory
 	Blob<float> *dLdc_t = Transpose(dLdc);
-	GEMM(1.0f, beta, *dLdc_t, src,weight_c_hx_->mutable_grad());
+	GEMM(1.0f, beta, *dLdc_t, src, weight_c_hx_->mutable_grad());
+	GEMM(1.0f, beta, *dLdc_t, *reset_context, weight_c_hh_->mutable_grad());
 	if (bias_c_ != nullptr)
 		MVSumRow(1.0f, beta, dLdc, bias_c_->mutable_grad());
 	delete dLdc_t;
 
-	Blob<float> *reset_dLdc_t = Transpose(reset_dLdc);
-	GEMM(1.0f, beta, *reset_dLdc_t, *context, weight_c_hh_->mutable_grad());
-	delete reset_dLdc_t;
-
 	// Compute gradients for data input layer
 	if (srclayers[0]->mutable_grad(this) != nullptr) {
-		GEMM(1.0f,0.0f,dLdc, weight_c_hx_->data(), ilayer->mutable_grad(this));
-		GEMM(1.0f,1.0f,dLdz, weight_z_hx_->data(), ilayer->mutable_grad(this));
-		GEMM(1.0f,1.0f,dLdr, weight_r_hx_->data(), ilayer->mutable_grad(this));
+		GEMM(1.0f,0.0f, dLdc, weight_c_hx_->data(), ilayer->mutable_grad(this));
+		GEMM(1.0f,1.0f, dLdz, weight_z_hx_->data(), ilayer->mutable_grad(this));
+		GEMM(1.0f,1.0f, dLdr, weight_r_hx_->data(), ilayer->mutable_grad(this));
 	}
 
 	if (clayer != nullptr && clayer->mutable_grad(this) != nullptr) {
 		// Compute gradients for context layer
-		GEMM(1.0f, 0.0f, reset_dLdc, weight_c_hh_->data(),
-        clayer->mutable_grad(this));
+    Mult(reset_dLdc, *reset_gate, clayer->mutable_grad(this));
 		GEMM(1.0f, 1.0f, dLdr, weight_r_hh_->data(), clayer->mutable_grad(this));
 		GEMM(1.0f, 1.0f, dLdz, weight_z_hh_->data(), clayer->mutable_grad(this));
-		Add(clayer->grad(this), *update_gate, clayer->mutable_grad(this));
+		AXPY(-1.0f, *update_gate, clayer->mutable_grad(this));
     // LOG(ERROR) << "grad to prev gru " << Asum(clayer->grad(this));
 	}
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/proto/job.proto
----------------------------------------------------------------------
diff --git a/src/proto/job.proto b/src/proto/job.proto
index e520eba..28a3a68 100644
--- a/src/proto/job.proto
+++ b/src/proto/job.proto
@@ -253,6 +253,7 @@ message LayerProto {
   optional ConcateProto concate_conf = 502;
   optional SliceProto slice_conf = 503;
   optional SplitProto split_conf = 504;
+  optional RNNDummyProto rnn_dummy_conf = 505;
 
   extensions 1001 to 1100;
 }
@@ -456,6 +457,17 @@ message DummyProto {
   repeated int32 shape = 3;
 }
 
+message RNNDummyProto {
+  optional string dynamic_srclayer = 1;
+  // if shape set, random generate the data blob
+  repeated int32 shape = 2;
+  // if integer is true, generate integer data
+  optional bool integer = 3 [default = false];
+  // range of the random generation
+  optional float low = 4 [default = 0];
+  optional float high = 5 [default = 0];
+}
+
 // Message that stores parameters used by DropoutLayer
 message DropoutProto {
   // dropout ratio
@@ -667,6 +679,7 @@ enum LayerType {
   kArgSort = 401;
   kCSVOutput = 402;
   kRecordOutput = 403;
+  kCharRNNOutput = 404;
 
   /*
    * Connection layers
@@ -677,6 +690,7 @@ enum LayerType {
   kConcate = 502;
   kSlice = 503;
   kSplit = 504;
+  kRNNDummy = 505;
 
   /*
    * User defined layer

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
index a9f70c0..200670a 100644
--- a/src/utils/updater.cc
+++ b/src/utils/updater.cc
@@ -174,31 +174,27 @@ void AdaGradUpdater::Update(int step, Param* param, float grad_scale)
{
   data -= lr * grad / (F<sqrtop>(history, proto_.delta()));
 }
 
-/***********************RMSProp******************************
-void RMSPropUpdater::Init(const UpdaterProto& proto){
+/***********************RMSProp******************************/
+void RMSPropUpdater::Init(const UpdaterProto& proto) {
   Updater::Init(proto);
-  base_lr_ = proto.base_lr();
-  CHECK_GT(base_lr_, 0);
-  delta_ = proto.delta();
   rho_ = proto.rmsprop_conf().rho();
-  weight_decay_ = proto.weight_decay();
 }
 
-void RMSPropUpdater::Update(int step, Param* param, float grad_scale){
+void RMSPropUpdater::Update(int step, Param* param, float grad_scale) {
   Shape<1> s=Shape1(param->size());
   Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
   Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
   Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
-  history=history*rho_+(1-rho_)*F<op::square>(grad*grad_scale);
-  float lr=GetLearningRate(step)*param->lr_scale();
-  float wd=weight_decay_*param->wd_scale();
-  if(wd>0){ // L2 regularization
-    grad+=data*wd;
-  }
-  data-=lr*grad/(F<op::sqrtop>(history,delta_));
+  float lr = lr_gen_->Get(step) * param->lr_scale();
+  float wd = weight_decay_ * param->wd_scale();
+  if (grad_scale != 1.f)
+    grad *= grad_scale;
+  if (wd > 0)  //  L2 regularization, should be done after timing grad_scale
+    grad += data * wd;
+  history = history * rho_ + (1 - rho_) * F<square>(grad);
+  data -= lr * grad / (F<sqrtop>(history, proto_.delta()));
 }
-
-***********************AdaDelta******************************
+/***********************AdaDelta******************************
 void AdaDeltaUpdater::Init(const UpdaterProto& proto){
   Updater::Init(proto);
   delta_=proto.delta();

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/6a4c9960/src/worker.cc
----------------------------------------------------------------------
diff --git a/src/worker.cc b/src/worker.cc
index 4e1dc75..abe74e7 100644
--- a/src/worker.cc
+++ b/src/worker.cc
@@ -390,13 +390,24 @@ void BPTTWorker::Forward(int step, Phase phase, NeuralNet* net) {
         }
       }
       vector<Layer*> src = net->srclayers(layer);
+      if ((phase & kTest) && typeid(*layer) == typeid(RNNDummyLayer)) {
+        CHECK_LE(src.size(), 1);
+        auto dummy = dynamic_cast<RNNDummyLayer*>(layer);
+        Layer* srclayer = net->name2layer(dummy->srclayer(step));
+        if (step > 0)
+          CHECK(srclayer != nullptr);
+        if (srclayer != nullptr) {
+          src.clear();
+          src.push_back(srclayer);
+        }
+      }
       // if full state rnn and not the starting of a new passing of the dataset,
       // feed the hidden state of the last unit to the first unit.
       if (layer->unroll_index() == 0 && full_state_ && !begin_) {
         Layer* last = net->last_unroll_layer(layer);
-        if (last != layer) {
+        CHECK(last != nullptr);
+        if (last != layer || (phase & kTest))
           src.push_back(last);
-        }
       }
       // LOG(ERROR) << layer->name() << " forward";
       // int ret =
@@ -405,7 +416,6 @@ void BPTTWorker::Forward(int step, Phase phase, NeuralNet* net) {
       if ((phase & Phase::kTrain) && ret == Status::kEnd)
         begin_ = true;
       */
-
       if (job_conf_.debug() && DisplayNow(step) && grp_id_ == 0)
         label[layer->name()] = layer->ToString(true, phase | kForward);
     }


Mime
View raw message