mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-mxnet] zixuanweeei commented on a change in pull request #15741: MKL-DNN LBR-GRU Inference Integration (FP32 LBR-GRU)
Date Mon, 05 Aug 2019 04:55:33 GMT
zixuanweeei commented on a change in pull request #15741: MKL-DNN LBR-GRU Inference Integration
(FP32 LBR-GRU)
URL: https://github.com/apache/incubator-mxnet/pull/15741#discussion_r310438190
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_rnn_impl.h
 ##########
 @@ -437,91 +450,118 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs,
   auto dst_iter_md = mkldnn::memory::desc(
       {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
 
-  for (int l = 0; l < L; l++) {
+  for (int l = 0; l < num_layer; l++) {
     if (mode == rnn_enum::kLstm) {
       std::vector<void*> srcs_data;
       srcs_data.push_back(hx_ptr);
       srcs_data.push_back(cx_ptr);
-      auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index];
+      mkldnn::memory& tmp_src_iter_memory = mkldnn_mems->concat_iter_memory[l + layer_index];
       ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
-          {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype,
-          2, srcs_data, tmp_src_iter_memory);
+          {{1, 1, 1, batch_size, hidden_size}, {1, 1, 1, batch_size, hidden_size}},
+          {1, 1, nstates, batch_size, hidden_size}, mkldnn_dtype, 2, srcs_data,
+          tmp_src_iter_memory, &(mkldnn_mems->uni_states_memory));
     } else {
-      (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr);
+      mkldnn_mems->concat_iter_memory[l + layer_index].set_data_handle(hx_ptr);
     }
     hx_ptr += cell_size;
     if (mode == rnn_enum::kLstm) {
       cx_ptr += cell_size;
     }
   }
 
-  auto user_src_iter_memory = null_memory_;
-  if (L == 1) {
-    user_src_iter_memory = (*concat_iter_memory)[layer_index];
+  mkldnn::memory* user_src_iter_memory;
+  if (num_layer == 1) {
+    user_src_iter_memory = &(mkldnn_mems->concat_iter_memory[layer_index]);
   } else {
-    user_src_iter_memory = (*concat_iter_memory)[L + layer_index];
+    user_src_iter_memory = &(mkldnn_mems->concat_iter_memory[num_layer + layer_index]);
     std::vector<void*> src_l_data;
     std::vector<mkldnn::memory::dims> src_l_dim;
-    for (int l = 0; l < L; l++) {
+    for (int l = 0; l < num_layer; l++) {
       src_l_data.push_back(reinterpret_cast<DType *>
-          ((*concat_iter_memory)[l + layer_index].get_data_handle()));
-      src_l_dim.push_back({1, 1, nstates, N, H});
+          (mkldnn_mems->concat_iter_memory[l + layer_index].get_data_handle()));
+      src_l_dim.push_back({1, 1, nstates, batch_size, hidden_size});
     }
     ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim,
-        {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory);
+        {num_layer, 1, nstates, batch_size, hidden_size}, mkldnn_dtype, 0, src_l_data,
+        *user_src_iter_memory, &(mkldnn_mems->concat_states_memory));
   }
-  (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle());
+  mkldnn_mems->hcx_memory[layer_index].set_data_handle(user_src_iter_memory->get_data_handle());
 
-  auto src_wx_f = (*concat_weight_memory)[2 * layer_index];
-  auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1];
+  mkldnn::memory& src_wx_f = mkldnn_mems->concat_weight_memory[2 * layer_index];
+  mkldnn::memory& src_wh_f = mkldnn_mems->concat_weight_memory[2 * layer_index + 1];
 
   std::vector<void*> srcs_data_x;
   std::vector<void*> srcs_data_h;
   std::vector<mkldnn::memory::dims> src_l_dim_x;
   std::vector<mkldnn::memory::dims> src_l_dim_h;
-  if (!initialized) {
-    if (L == 1) {
+
+  bool has_adjusted = false;
+  if (!initialized || is_train) {
+    if (num_layer == 1) {
       DType* wx = w_ptr;
-      DType* wh = w_ptr + I * H * ngates;
+      DType* wh = wx + input_size * hidden_size * ngates;
       if (mode == rnn_enum::kGru) {
-        AdjustGruWeightGateOrder(wx, I, H);
-        AdjustGruWeightGateOrder(wh, H, H);
-        AdjustGruBiasGateOrder(b_ptr, H);
-        AdjustGruBiasGateOrder(b_ptr + H * ngates, H);
+        AdjustGruWeightGateOrder(wx, input_size, hidden_size);
+        AdjustGruWeightGateOrder(wh, hidden_size, hidden_size);
+        has_adjusted = true;
       }
       src_wx_f.set_data_handle(wx);
       src_wh_f.set_data_handle(wh);
     } else {
-      for (int l = 0; l < L; l++) {
-        DType* wx = w_ptr;
-        DType* wh = w_ptr + I * H * ngates;
-        DType* bx = b_ptr + l * ngates * H * 2;
-        DType* bh = b_ptr + l * ngates * H * 2 + H * ngates;
+      for (int l = 0; l < num_layer; l++) {
+        DType* wx = w_ptr + l * w_size;
+        DType* wh = wx + input_size * hidden_size * ngates;
         if (mode == rnn_enum::kGru) {
-          AdjustGruWeightGateOrder(wx, I, H);
-          AdjustGruWeightGateOrder(wh, H, H);
-          AdjustGruBiasGateOrder(bx, H);
-          AdjustGruBiasGateOrder(bh, H);
+          AdjustGruWeightGateOrder(wx, input_size, hidden_size);
+          AdjustGruWeightGateOrder(wh, hidden_size, hidden_size);
+          has_adjusted = true;
         }
         srcs_data_x.push_back(wx);
         srcs_data_h.push_back(wh);
         src_l_dim_x.push_back(weights_layer_r_tz);
         src_l_dim_h.push_back(weights_iter_r_tz);
-        w_ptr = w_ptr + w_size;
       }
       ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
-          src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f);
+          src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f,
+          &(mkldnn_mems->weight_layer_mems));
       ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
-          src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f);
+          src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f,
+          &(mkldnn_mems->weight_iter_mems));
     }
-    MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index]));
-    MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index]));
-
-    DType* user_bias_f = reinterpret_cast<DType *> ((*bias_memory)[layer_index].get_data_handle());
-    #pragma omp parallel for num_threads(omp_threads)
-    for (int j = 0; j < L * single_b_size; j++) {
-      int k = j / single_b_size;
-      user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size];
+    MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, mkldnn_mems->wx_memory[layer_index]));
+    MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, mkldnn_mems->wh_memory[layer_index]));
+
+    DType* user_bias_f = reinterpret_cast<DType *>(
+        mkldnn_mems->bias_memory[layer_index].get_data_handle());
+    if (mode == rnn_enum::kGru) {
+      const int mx_single_b_sz = ngates * hidden_size;
+      for (int l = 0; l < num_layer; l++) {
+        #pragma omp parallel for num_threads(omp_threads)
 
 Review comment:
   Thanks for noting that.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message