ZiyueHuang commented on a change in pull request #8732: rsp push and rsp pull for comm device,
used in kvstore('device')
URL: https://github.com/apache/incubator-mxnet/pull/8732#discussion_r153445175
##########
File path: src/kvstore/comm.h
##########
@@ -526,26 +522,66 @@ class CommDevice : public Comm {
InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());
- CopyFromTo(src[0], &(buf.merged), priority);
- reduce[0] = buf.merged;
- if (buf.copy_buf.empty()) {
- // TODO(mli) this results in large device memory usage for huge ndarray,
- // such as the largest fullc in VGG. consider to do segment reduce with
- // NDArray.Slice or gpu direct memory access. for the latter, we need to
- // remove some ctx check, and also it reduces 20% perf
- buf.copy_buf.resize(src.size()-1);
+ if (buf.merged.storage_type() == kDefaultStorage) {
+ CopyFromTo(src[0], &(buf.merged), priority);
+ reduce[0] = buf.merged;
+
+ if (buf.copy_buf.empty()) {
+ // TODO(mli) this results in large device memory usage for huge ndarray,
+ // such as the largest fullc in VGG. consider to do segment reduce with
+ // NDArray.Slice or gpu direct memory access. for the latter, we need to
+ // remove some ctx check, and also it reduces 20% perf
+ buf.copy_buf.resize(src.size()-1);
+ for (size_t i = 0; i < src.size()-1; ++i) {
+ buf.copy_buf[i] = NDArray(
+ buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+ }
+ }
for (size_t i = 0; i < src.size()-1; ++i) {
- buf.copy_buf[i] = NDArray(
- buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+ CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
+ reduce[i+1] = buf.copy_buf[i];
}
- }
- for (size_t i = 0; i < src.size()-1; ++i) {
- CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
- reduce[i+1] = buf.copy_buf[i];
- }
- ElementwiseSum(reduce, &buf.merged);
+ ElementwiseSum(reduce, &buf.merged);
+ } else {
+ std::vector<Engine::VarHandle> const_vars(src.size());
+ if (buf.copy_buf.empty()) {
+ buf.copy_buf.resize(src.size());
+ for (size_t j = 0; j < src.size(); ++j) {
+ buf.copy_buf[j] = NDArray(
+ buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
+ true, buf.merged.dtype());
+ }
+ }
+ for (size_t i = 0; i < src.size(); ++i) {
+ CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
+ reduce[i] = buf.copy_buf[i];
+ const_vars[i] = reduce[i].var();
+ }
+ auto result = buf.merged;
+ Engine::Get()->PushAsync(
Review comment:
Why this should move into `ndarray.cc`? I think it is fine here, i.e. push the operation
into engine in comm.h.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services
|