mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] ZiyueHuang commented on a change in pull request #8732: rsp push and rsp pull for comm device, used in kvstore('device')
Date Thu, 01 Jan 1970 00:00:00 GMT
ZiyueHuang commented on a change in pull request #8732: rsp push and rsp pull for comm device,
used in kvstore('device')
URL: https://github.com/apache/incubator-mxnet/pull/8732#discussion_r153445175
 
 

 ##########
 File path: src/kvstore/comm.h
 ##########
 @@ -526,26 +522,66 @@ class CommDevice : public Comm {
     InitBuffersAndComm(src);
     auto& buf = merge_buf_[key];
     std::vector<NDArray> reduce(src.size());
-    CopyFromTo(src[0], &(buf.merged), priority);
-    reduce[0] = buf.merged;
 
-    if (buf.copy_buf.empty()) {
-      // TODO(mli) this results in large device memory usage for huge ndarray,
-      // such as the largest fullc in VGG. consider to do segment reduce with
-      // NDArray.Slice or gpu direct memory access. for the latter, we need to
-      // remove some ctx check, and also it reduces 20% perf
-      buf.copy_buf.resize(src.size()-1);
+    if (buf.merged.storage_type() == kDefaultStorage) {
+      CopyFromTo(src[0], &(buf.merged), priority);
+      reduce[0] = buf.merged;
+
+      if (buf.copy_buf.empty()) {
+        // TODO(mli) this results in large device memory usage for huge ndarray,
+        // such as the largest fullc in VGG. consider to do segment reduce with
+        // NDArray.Slice or gpu direct memory access. for the latter, we need to
+        // remove some ctx check, and also it reduces 20% perf
+        buf.copy_buf.resize(src.size()-1);
+        for (size_t i = 0; i < src.size()-1; ++i) {
+          buf.copy_buf[i] = NDArray(
+            buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+        }
+      }
       for (size_t i = 0; i < src.size()-1; ++i) {
-        buf.copy_buf[i] = NDArray(
-          buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+        CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
+        reduce[i+1] = buf.copy_buf[i];
       }
-    }
-    for (size_t i = 0; i < src.size()-1; ++i) {
-      CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
-      reduce[i+1] = buf.copy_buf[i];
-    }
 
-    ElementwiseSum(reduce, &buf.merged);
+      ElementwiseSum(reduce, &buf.merged);
+    } else {
+      std::vector<Engine::VarHandle> const_vars(src.size());
+      if (buf.copy_buf.empty()) {
+        buf.copy_buf.resize(src.size());
+        for (size_t j = 0; j < src.size(); ++j) {
+          buf.copy_buf[j] = NDArray(
+            buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
+            true, buf.merged.dtype());
+        }
+      }
+      for (size_t i = 0; i < src.size(); ++i) {
+        CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
+        reduce[i] = buf.copy_buf[i];
+        const_vars[i] = reduce[i].var();
+      }
+      auto result = buf.merged;
+      Engine::Get()->PushAsync(
 
 Review comment:
   Why this should move into `ndarray.cc`? I think it is fine here, i.e. push the operation
into engine in comm.h.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message