mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] marcoabreu commented on a change in pull request #10731: Fix a bug in getting MKLDNN memory
Date Mon, 30 Apr 2018 06:59:42 GMT
marcoabreu commented on a change in pull request #10731: Fix a bug in getting MKLDNN memory
URL: https://github.com/apache/incubator-mxnet/pull/10731#discussion_r184919959
 
 

 ##########
 File path: src/ndarray/ndarray.cc
 ##########
 @@ -495,24 +495,41 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   const mkldnn::memory *mem = GetMKLDNNData();
   // If the memory descriptor matches, it's easy.
   MKLDNNStream *stream = MKLDNNStream::Get();
-  if (mem->get_primitive_desc() == desc) {
-    return GetMKLDNNExact(mem, desc);
+  if (mem->get_primitive_desc() == new_pd) {
+    return GetMKLDNNExact(mem, new_pd);
   }
 
-  mkldnn::memory::primitive_desc _desc = desc;
+  mkldnn::memory::primitive_desc _pd = new_pd;
+  mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc();
+  mkldnn::memory::desc desc2 = _pd.desc();
   // Now we need to determine if we should reorder the memory.
   // If both use the default formats, we think we don't need to reorder.
-  mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc();
-  mkldnn::memory::desc desc2 = _desc.desc();
   if (desc1.data.format == GetDefaultFormat(desc1) &&
       desc2.data.format == GetDefaultFormat(desc2)) {
-    mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle()));
+    mkldnn_mem_ptr ret(new mkldnn::memory(new_pd, mem->get_data_handle()));
     stream->RegisterMem(ret);
     return ret.get();
-  } else {
-    mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(desc);
+  } else if (same_shape(desc1, desc2)) {
+    // If they have the same shape, we can reorder data directly.
+    mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(new_pd);
     stream->RegisterPrim(mkldnn::reorder(*mem, *ret));
     return ret;
+  } else {
+    // If they have different shapes, we need to reshape the array first.
+    // Since this method will only be used inside an operator, we can call
+    // MKLDNNDataReshape to reshape an array.
+    TShape required_shape(desc2.data.ndims);
+    for (int i = 0; i < desc2.data.ndims; i++)
+      required_shape[i] = desc2.data.dims[i];
+    NDArray reshaped = MKLDNNDataReshape(required_shape);
+    const mkldnn::memory *ret = reshaped.GetMKLDNNData();
 
 Review comment:
   Where is ret being destroyed? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message