mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zhash...@apache.org
Subject [incubator-mxnet] branch v1.6.x updated: Fix race condition in FusedOp (#18498) (#18505)
Date Mon, 08 Jun 2020 21:32:06 GMT
This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new 467be8c  Fix race condition in FusedOp (#18498) (#18505)
467be8c is described below

commit 467be8c5b897b80cae3662fb3edd3add8b273d52
Author: Haibin Lin <linhaibin.eric@gmail.com>
AuthorDate: Mon Jun 8 14:30:58 2020 -0700

    Fix race condition in FusedOp (#18498) (#18505)
    
    Co-authored-by: Przemyslaw Tredak <ptredak@nvidia.com>
---
 src/operator/fusion/fused_op.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/operator/fusion/fused_op.cc b/src/operator/fusion/fused_op.cc
index 5e2d782..e78fb90 100644
--- a/src/operator/fusion/fused_op.cc
+++ b/src/operator/fusion/fused_op.cc
@@ -61,6 +61,7 @@ FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig&
config) :
 bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
                          std::vector<mxnet::TShape> *in_attrs,
                          std::vector<mxnet::TShape> *out_attrs) {
+  std::lock_guard<std::mutex> lock(my_mutex_);
   subgraph_.attrs.erase("shape");
   subgraph_.attrs.erase("shape_inputs");
   std::vector<mxnet::TShape> input_shapes(*in_attrs);
@@ -95,7 +96,6 @@ bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
     inferred = inferred && !op::shape_is_none(attr);
   }
   if (inferred) {
-    std::lock_guard<std::mutex> lock(my_mutex_);
     intermediate_shapes_.push_back({*in_attrs, *out_attrs, shapes});
   }
   return inferred;
@@ -104,6 +104,7 @@ bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
 bool FusedOp::InferType(const nnvm::NodeAttrs &attrs,
                         std::vector<int> *in_attrs,
                         std::vector<int> *out_attrs) {
+  std::lock_guard<std::mutex> lock(my_mutex_);
   subgraph_.attrs.erase("dtype");
   subgraph_.attrs.erase("dtype_inputs");
   std::vector<int> input_types(*in_attrs);
@@ -138,7 +139,6 @@ bool FusedOp::InferType(const nnvm::NodeAttrs &attrs,
     inferred = inferred && !op::type_is_none(attr);
   }
   if (inferred) {
-    std::lock_guard<std::mutex> lock(my_mutex_);
     intermediate_dtypes_.push_back({*in_attrs, *out_attrs, types});
   }
   return inferred;


Mime
View raw message