singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [1/2] incubator-singa git commit: SINGA-346 Update cudnn from V5 to V7
Date Sun, 08 Jul 2018 08:06:12 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master 56292f1fb -> e16cea129


SINGA-346 Update cudnn from V5 to V7

support cudnn5 (conv and rnn has API changes from v5 to v7)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e2092030
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e2092030
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e2092030

Branch: refs/heads/master
Commit: e20920309bfeb6ed7e0adf3d529c2fba1d44ad2f
Parents: 56292f1
Author: Wang Wei <wangwei.cs@gmail.com>
Authored: Thu Jul 5 22:57:33 2018 +0800
Committer: wang wei <wangwei@comp.nus.edu.sg>
Committed: Sun Jul 8 16:00:38 2018 +0800

----------------------------------------------------------------------
 src/model/layer/cudnn_convolution.cc | 101 +++++++++---------
 src/model/layer/cudnn_rnn.cc         | 165 +++++++++++++++---------------
 2 files changed, 137 insertions(+), 129 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_convolution.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_convolution.cc b/src/model/layer/cudnn_convolution.cc
index 8846746..1b12f93 100644
--- a/src/model/layer/cudnn_convolution.cc
+++ b/src/model/layer/cudnn_convolution.cc
@@ -44,7 +44,7 @@ void CudnnConvolution::Setup(const Shape& in_sample, const LayerConf
&conf) {
   CHECK(prefer_ == "fastest" || prefer_ == "limited_workspace" ||
         prefer_ == "no_workspace" || prefer_ == "autotune")
       << "CudnnConvolution only supports four algorithm preferences: fastest, "
-         "limited_workspace, no_workspace and autotune";
+      "limited_workspace, no_workspace and autotune";
 }
 
 void CudnnConvolution::ToDevice(std::shared_ptr<Device> device) {
@@ -70,16 +70,19 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
                                          GetCudnnDataType(dtype), batchsize,
                                          channels_, height_, width_));
   CUDNN_CHECK(cudnnSetTensor4dDescriptor(
-      y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
-      num_filters_, conv_height_, conv_width_));
+                y_desc_, CUDNN_TENSOR_NCHW, GetCudnnDataType(dtype), batchsize,
+                num_filters_, conv_height_, conv_width_));
   if (bias_term_)
     CUDNN_CHECK(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW,
                                            GetCudnnDataType(dtype), 1,
                                            num_filters_, 1, 1));
   CUDNN_CHECK(cudnnSetConvolution2dDescriptor(conv_desc_, pad_h_, pad_w_,
-                                              stride_h_, stride_w_, 1, 1,
-                                              CUDNN_CROSS_CORRELATION, 
-                                              GetCudnnDataType(dtype)));
+              stride_h_, stride_w_, 1, 1,  // dilation x and y
+              CUDNN_CROSS_CORRELATION
+#if CUDNN_MAJOR == 5
+              , GetCudnnDataType(dtype)
+#endif  // CUDNN_MAJOR
+                                             ));
   CUDNN_CHECK(cudnnSetFilter4dDescriptor(filter_desc_, GetCudnnDataType(dtype),
                                          CUDNN_TENSOR_NCHW, num_filters_,
                                          channels_, kernel_h_, kernel_w_));
@@ -102,15 +105,15 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
       bwd_data_pref = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;
     }
     CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
-        ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
-        workspace_byte_limit_, &fp_alg_));
+                  ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fwd_pref,
+                  workspace_byte_limit_, &fp_alg_));
     CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
-        ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-        bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
+                  ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+                  bwd_filt_pref, workspace_byte_limit_, &bp_filter_alg_));
     // deprecated in cudnn v7
     CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
-        ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-        bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
+                  ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+                  bwd_data_pref, workspace_byte_limit_, &bp_data_alg_));
   } else if (prefer_ == "autotune") {
     const int topk = 1;
     int num_fp_alg, num_bp_filt_alg, num_bp_data_alg;
@@ -118,16 +121,16 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
     cudnnConvolutionBwdFilterAlgoPerf_t bp_filt_perf[topk];
     cudnnConvolutionBwdDataAlgoPerf_t bp_data_perf[topk];
     CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithm(
-        ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
-        &num_fp_alg, fp_alg_perf));
+                  ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, topk,
+                  &num_fp_alg, fp_alg_perf));
     fp_alg_ = fp_alg_perf[0].algo;
     CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
-        ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
-        &num_bp_filt_alg, bp_filt_perf));
+                  ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_, topk,
+                  &num_bp_filt_alg, bp_filt_perf));
     bp_filter_alg_ = bp_filt_perf[0].algo;
     CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithm(
-        ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
-        &num_bp_data_alg, bp_data_perf));
+                  ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_, topk,
+                  &num_bp_data_alg, bp_data_perf));
     bp_data_alg_ = bp_data_perf[0].algo;
   } else {
     LOG(FATAL) << "Preferred algorithm is not available!";
@@ -135,22 +138,22 @@ void CudnnConvolution::InitCudnn(const Tensor &input) {
 
   size_t fp_byte, bp_data_byte, bp_filter_byte;
   CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
-      ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
-      &fp_byte));
+                ctx->cudnn_handle, x_desc_, filter_desc_, conv_desc_, y_desc_, fp_alg_,
+                &fp_byte));
   CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
-      ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
-      bp_data_alg_, &bp_data_byte));
+                ctx->cudnn_handle, filter_desc_, y_desc_, conv_desc_, x_desc_,
+                bp_data_alg_, &bp_data_byte));
   CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
-      ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
-      bp_filter_alg_, &bp_filter_byte));
+                ctx->cudnn_handle, x_desc_, y_desc_, conv_desc_, filter_desc_,
+                bp_filter_alg_, &bp_filter_byte));
   workspace_count_ = std::max(std::max(fp_byte, bp_data_byte), bp_filter_byte) /
-                         sizeof(float) +
+                     sizeof(float) +
                      1;
   if (workspace_count_ * sizeof(float) > workspace_byte_limit_)
     LOG(WARNING) << "The required memory for workspace ("
-      << workspace_count_ * sizeof(float)
-      << ") is larger than the expected Bytes ("
-      << workspace_byte_limit_ << ")";
+                 << workspace_count_ * sizeof(float)
+                 << ") is larger than the expected Bytes ("
+                 << workspace_byte_limit_ << ")";
   workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
   has_init_cudnn_ = true;
 }
@@ -170,23 +173,23 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input)
{
     int n, c, h, w, s;
     cudnnDataType_t type;
     CUDNN_CHECK(cudnnGetTensor4dDescriptor(x_desc_, &type, &n, &c, &h, &w,
-          &s, &s, &s, &s));
+                                           &s, &s, &s, &s));
     if (batchsize != static_cast<size_t>(n))
       InitCudnn(input);
     CHECK(input.shape(1) == static_cast<size_t>(c)
-        && input.shape(2) == static_cast<size_t>(h)
-        && input.shape(3) == static_cast<size_t>(w))
-      << "input sample shape should not change"
-      << "previous shape " << c << ", " << h << ", " <<
w
-      << "current shape " << input.shape(1) << ", " << input.shape(2)
<< ", "
-      << input.shape(3);
+          && input.shape(2) == static_cast<size_t>(h)
+          && input.shape(3) == static_cast<size_t>(w))
+        << "input sample shape should not change"
+        << "previous shape " << c << ", " << h << ", " <<
w
+        << "current shape " << input.shape(1) << ", " << input.shape(2)
<< ", "
+        << input.shape(3);
   }
 
   Shape shape{batchsize, num_filters_, conv_height_, conv_width_};
   Tensor output(shape, dev, dtype);
-  output.device()->Exec([input, output, this](Context *ctx) {
+  output.device()->Exec([input, output, this](Context * ctx) {
     Block *inblock = input.block(), *outblock = output.block(),
-          *wblock = this->weight_.block();
+           *wblock = this->weight_.block();
     float alpha = 1.f, beta = 0.f;
     cudnnConvolutionForward(ctx->cudnn_handle, &alpha, this->x_desc_,
                             inblock->data(), this->filter_desc_, wblock->data(),
@@ -197,7 +200,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input)
{
   }, {input.block(), weight_.block()}, {output.block()}, workspace_.block());
 
   if (bias_term_) {
-    output.device()->Exec([output, this](Context *ctx) {
+    output.device()->Exec([output, this](Context * ctx) {
       float beta = 1.f, alpha = 1.0f;
       Block *outblock = output.block(), *bblock = this->bias_.block();
       cudnnAddTensor(ctx->cudnn_handle, &alpha, this->bias_desc_,
@@ -209,7 +212,7 @@ const Tensor CudnnConvolution::Forward(int flag, const Tensor &input)
{
 }
 
 const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
-    int flag, const Tensor &grad) {
+int flag, const Tensor &grad) {
   CHECK(has_init_cudnn_);
   CHECK_EQ(grad.device()->lang(), kCuda);
   CHECK_EQ(grad.nDim(), 4u);
@@ -225,7 +228,7 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
   // LOG(ERROR) << "backward bias";
   if (bias_term_) {
     db.ResetLike(bias_);
-    dx.device()->Exec([grad, db, this](Context *ctx) {
+    dx.device()->Exec([grad, db, this](Context * ctx) {
       Block *dyblock = grad.block(), *dbblock = db.block();
       float alpha = 1.f, beta = 0.f;
       cudnnConvolutionBackwardBias(ctx->cudnn_handle, &alpha, this->y_desc_,
@@ -234,22 +237,22 @@ const std::pair<Tensor, vector<Tensor>> CudnnConvolution::Backward(
     }, {grad.block()}, {db.block()});
   }
   // LOG(ERROR) << "backward w";
-  dx.device()->Exec([grad, dw, src_data, this](Context *ctx) {
+  dx.device()->Exec([grad, dw, src_data, this](Context * ctx) {
     Block *inblock = src_data.block(), *dyblock = grad.block(),
-          *dwblock = dw.block();
+           *dwblock = dw.block();
     float alpha = 1.f, beta = 0.f;
     cudnnConvolutionBackwardFilter(
-        ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(),
-        this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_,
-        this->workspace_.block()->mutable_data(),
-        this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
-        dwblock->mutable_data());
+      ctx->cudnn_handle, &alpha, this->x_desc_, inblock->data(),
+      this->y_desc_, dyblock->data(), this->conv_desc_, this->bp_filter_alg_,
+      this->workspace_.block()->mutable_data(),
+      this->workspace_count_ * sizeof(float), &beta, this->filter_desc_,
+      dwblock->mutable_data());
   }, {grad.block(), src_data.block()}, {dw.block(), workspace_.block()});
 
   // LOG(ERROR) << "backward src";
-  dx.device()->Exec([dx, grad, this](Context *ctx) {
+  dx.device()->Exec([dx, grad, this](Context * ctx) {
     Block *wblock = this->weight_.block(), *dyblock = grad.block(),
-          *dxblock = dx.block();
+           *dxblock = dx.block();
     float alpha = 1.f, beta = 0.f;
     cudnnConvolutionBackwardData(ctx->cudnn_handle, &alpha, this->filter_desc_,
                                  wblock->data(), this->y_desc_, dyblock->data(),

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e2092030/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
index fb5fee0..28a52c5 100644
--- a/src/model/layer/cudnn_rnn.cc
+++ b/src/model/layer/cudnn_rnn.cc
@@ -125,8 +125,8 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
   CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size));
   dropout_state_ = Tensor(Shape{state_size}, dev, kChar);
   CUDNN_CHECK(cudnnSetDropoutDescriptor(
-      dropout_desc_, ctx->cudnn_handle, 1 - dropout_,  // keep probability
-      dropout_state_.block()->mutable_data(), state_size, seed_));
+                dropout_desc_, ctx->cudnn_handle, 1 - dropout_,  // keep probability
+                dropout_state_.block()->mutable_data(), state_size, seed_));
 
   CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
   cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
@@ -144,10 +144,15 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) {
     rnn_mode = CUDNN_RNN_TANH;
   else if (rnn_mode_ == "gru")
     rnn_mode = CUDNN_GRU;
+#ifdef CUDNN_MAJOR == 5
+  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_,
+                                    dropout_desc_, input_mode, direction,
+                                    rnn_mode, dtype_));
+#else
   CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_,
                                     dropout_desc_, input_mode, direction,
                                     rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_));
-
+#endif
   size_t weight_size;
   CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0],
                                     &weight_size, dtype_));
@@ -199,7 +204,7 @@ void CudnnRNN::UpdateSpaces(size_t seq_length, shared_ptr<Device>
dev) {
   }
 
   CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_,
-                                             seq_length, x_descs_, &count));
+              seq_length, x_descs_, &count));
   if (reserve_space_.Size() != count) {
     reserve_space_ = Tensor(Shape{count}, dev, kChar);
     // reserve_space_.SetValue(0);
@@ -263,8 +268,8 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor>
&inputs) {
 
   if (rnn_desc_ != nullptr)
     CHECK_EQ(dtype_, GetCudnnDataType(dtype))
-      << "Cannot change cudnn data type during training from " << dtype_
-      << " to " << GetCudnnDataType(dtype);
+        << "Cannot change cudnn data type during training from " << dtype_
+        << " to " << GetCudnnDataType(dtype);
   else
     dtype_ = GetCudnnDataType(dtype);
 
@@ -303,57 +308,57 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor>
&inputs) {
   // LOG(INFO) << "hidden size " << hy.Size();
   // LOG(INFO) << "weight size " << weight_.Size() << " value " <<
weight_.L1();
   Block *inb = input.block(), *outb = output.block(),
-        *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
-        *hyb = hy.block(), *cyb = cy.block(),
-        *wspace = this->workspace_.block(),
-        *rspace = this->reserve_space_.block();
+         *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(),
+          *hyb = hy.block(), *cyb = cy.block(),
+           *wspace = this->workspace_.block(),
+            *rspace = this->reserve_space_.block();
   if (flag & kTrain) {
     CHECK_EQ(reserve_space_.device()->lang(), kCuda);
     CHECK_EQ(did, reserve_space_.device()->id());
     dev->Exec(
-        [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context *ctx) {
-        // clang-format off
-        cudnnRNNForwardTraining(
-            ctx->cudnn_handle,
-            this->rnn_desc_,
-            this->seq_length_,
-            this->x_descs_, inb->data(),
-            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-            this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
-            this->weight_desc_, wb->data(),
-            this->y_descs_, outb->mutable_data(),
-            this->hy_desc_, hyb->mutable_data(),
-            this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
-            wspace->mutable_data(),
-            this->workspace_.Size(), rspace->mutable_data(),
-            this->reserve_space_.Size());
-        // clang-format on
-        },
-        {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
+    [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, this](Context * ctx) {
+      // clang-format off
+      cudnnRNNForwardTraining(
+        ctx->cudnn_handle,
+        this->rnn_desc_,
+        this->seq_length_,
+        this->x_descs_, inb->data(),
+        this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+        this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+        this->weight_desc_, wb->data(),
+        this->y_descs_, outb->mutable_data(),
+        this->hy_desc_, hyb->mutable_data(),
+        this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+        wspace->mutable_data(),
+        this->workspace_.Size(), rspace->mutable_data(),
+        this->reserve_space_.Size());
+      // clang-format on
+    },
+    {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace});
     buf_.push(input);
     buf_.push(output);
     buf_.push(hx);
     buf_.push(cx);
   } else {
-    dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context *ctx) {
+    dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, this](Context * ctx) {
       // clang-format off
       cudnnRNNForwardInference(
-          ctx->cudnn_handle,
-          this->rnn_desc_,
-          this->seq_length_,
-          this->x_descs_, inb->data(),
-          this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-          this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
-          this->weight_desc_, wb->data(),
-          this->y_descs_, outb->mutable_data(),
-          this->hy_desc_, hyb->mutable_data(),
-          this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
-          wspace->mutable_data(), this->workspace_.Size());
+        ctx->cudnn_handle,
+        this->rnn_desc_,
+        this->seq_length_,
+        this->x_descs_, inb->data(),
+        this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+        this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+        this->weight_desc_, wb->data(),
+        this->y_descs_, outb->mutable_data(),
+        this->hy_desc_, hyb->mutable_data(),
+        this->cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(),
+        wspace->mutable_data(), this->workspace_.Size());
       // clang-format on
     }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace});
   }
   auto outputs =
-      SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
+    SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output);
   outputs.push_back(hy);
   if (has_cell_) outputs.push_back(cy);
   return outputs;
@@ -361,7 +366,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor>
&inputs) {
 
 // TODO(wangwei) check Tensor device to be on cuda?
 const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
-    int flag, const vector<Tensor> &grads) {
+int flag, const vector<Tensor> &grads) {
   // dhy (and dcy) is at last
   const Tensor cx = buf_.top();  // cannot use const Tensor& due to pop()
   buf_.pop();
@@ -395,45 +400,45 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
     dcx.ResetLike(dhx);
   dw.SetValue(0.0f);
   Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(),
-        *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
-        *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
-        *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
-        *wspace = workspace_.block(), *rspace = reserve_space_.block();
+         *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(),
+          *wb = weight_.block(), *dwb = dw.block(), *hxb = hx.block(),
+           *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(),
+            *wspace = workspace_.block(), *rspace = reserve_space_.block();
 
   y.device()->Exec(
-      [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
-       rspace, this](Context *ctx) {
-        // clang-format off
-        cudnnRNNBackwardData(
-            ctx->cudnn_handle,
-            this->rnn_desc_,
-            this->seq_length_,
-            this->y_descs_, yb->data(),
-            this->dy_descs_, dyb->data(),
-            this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
-            this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
-            this->weight_desc_, wb->data(),
-            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-            this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
-            this->dx_descs_, dxb->mutable_data(),
-            this->dhx_desc_, dhxb->mutable_data(),
-            this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
-            wspace->mutable_data(), this->workspace_.Size(),
-            rspace->mutable_data(), this->reserve_space_.Size());
-        cudnnRNNBackwardWeights(
-            ctx->cudnn_handle,
-            this->rnn_desc_,
-            this->seq_length_,
-            this->x_descs_, xb->data(),
-            this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
-            this->y_descs_, yb->data(),
-            wspace->data(), this->workspace_.Size(),
-            this->dweight_desc_, dwb->mutable_data(),
-            rspace->data(), this->reserve_space_.Size());
-        // clang-format on
-      },
-      {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
-      {dxb, dwb, dhxb, dcxb, wspace, rspace});
+    [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace,
+  rspace, this](Context * ctx) {
+    // clang-format off
+    cudnnRNNBackwardData(
+      ctx->cudnn_handle,
+      this->rnn_desc_,
+      this->seq_length_,
+      this->y_descs_, yb->data(),
+      this->dy_descs_, dyb->data(),
+      this->dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(),
+      this->dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(),
+      this->weight_desc_, wb->data(),
+      this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+      this->cx_desc_, cxb == nullptr ? nullptr : cxb->data(),
+      this->dx_descs_, dxb->mutable_data(),
+      this->dhx_desc_, dhxb->mutable_data(),
+      this->dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(),
+      wspace->mutable_data(), this->workspace_.Size(),
+      rspace->mutable_data(), this->reserve_space_.Size());
+    cudnnRNNBackwardWeights(
+      ctx->cudnn_handle,
+      this->rnn_desc_,
+      this->seq_length_,
+      this->x_descs_, xb->data(),
+      this->hx_desc_, hxb == nullptr ? nullptr : hxb->data(),
+      this->y_descs_, yb->data(),
+      wspace->data(), this->workspace_.Size(),
+      this->dweight_desc_, dwb->mutable_data(),
+      rspace->data(), this->reserve_space_.Size());
+    // clang-format on
+  },
+  {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace},
+  {dxb, dwb, dhxb, dcxb, wspace, rspace});
 
   vector <Tensor> param_grad{dw};
   auto data_grads = SplitOutput(num_dy, input_size_, grads, dx);


Mime
View raw message