mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-mxnet] TaoLv commented on a change in pull request #17187: [MKLDNN] Support channel wise quantization for FullyConnected
Date Mon, 30 Dec 2019 09:13:15 GMT
TaoLv commented on a change in pull request #17187: [MKLDNN] Support channel wise quantization
for FullyConnected
URL: https://github.com/apache/incubator-mxnet/pull/17187#discussion_r361932973
 
 

 ##########
 File path: src/operator/subgraph/mkldnn/mkldnn_fc.cc
 ##########
 @@ -115,84 +151,153 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
   CHECK_EQ(out_data.size(), total_num_outputs);
 
   NDArray data = in_data[fullc::kData];
-  NDArray weight = cached_weight_ ? *cached_weight_ : in_data[fullc::kWeight];
+  NDArray weight = in_data[fullc::kWeight];
   NDArray output = out_data[fullc::kOut];
-
-  mkldnn::memory::desc out_md = GetMemDesc(output);
-  MKLDNNFCFlattenData(default_param, out_data[fullc::kOut], &data, &out_md);
+  MKLDNNFCFlattenData(default_param, &data);
 
   if (initialized_ && mkldnn_param.quantized) {
-    if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
-        cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
-        (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias)))
{
-          initialized_ = false;
-        }
+    if (channel_wise) {
+      if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+          weight_ver_ != weight.version() ||
+          (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
+        initialized_ = false;
+      }
+    } else {
+      if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+          cached_min_weight_ != min_weight || cached_max_weight_ != max_weight ||
+          (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != max_bias)))
{
+        initialized_ = false;
+      }
+    }
   }
 
   if (!initialized_) {
+    const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     cached_min_data_ = min_data;
     cached_max_data_ = max_data;
     cached_min_weight_ = min_weight;
     cached_max_weight_ = max_weight;
+    weight_ver_ = weight.version();
+    cached_weight_ = weight;
     if (has_bias) {
-      cached_bias_ = in_data[fullc::kBias];
       cached_min_bias_ = min_bias;
       cached_max_bias_ = max_bias;
+      bias_ver_ = in_data[fullc::kBias].version();
+      cached_bias_ = in_data[fullc::kBias];
     } else {
       cached_bias_ = NDArray();
     }
 
+    // create cached out_md
+    const mxnet::TShape ishape = data.shape();
+    const mxnet::TShape oshape = output.shape();
+    mkldnn::memory::dims out_dims(2);
+    if (oshape.ndim() == 2) {
+      out_dims[0] = static_cast<int>(oshape[0]);
+      out_dims[1] = static_cast<int>(oshape[1]);
+    } else {
+      if (!default_param.flatten) {
+        out_dims[0] = static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1));
+        out_dims[1] = static_cast<int>(oshape[oshape.ndim()-1]);
+      } else {
+        out_dims[0] = static_cast<int>(static_cast<int>(oshape[0]));
+        out_dims[1] = static_cast<int>(oshape.ProdShape(1, oshape.ndim()));
+      }
+    }
+    mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()),
+      static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(2)));
+    cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, CpuEngine::Get()->get_engine());
+
+    bool support_channelwise_scale = false;
     if (mkldnn_param.quantized) {
       CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
-      float data_scale  = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
-      float weight_scale = GetQuantizeScale(mshadow::kInt8, cached_min_weight_, cached_max_weight_);
-      if (has_bias) {
-        NDArray bias = in_data[fullc::kBias];
-        float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_min_bias_, cached_max_bias_);
-        float bias_int32_rescale = data_scale * weight_scale / bias_scale;
-        // TODO(zhennan): mkldnn has bug to handle INT_MAX in bias, so set the maximum value
of bias
-        // to INT_MAX / 2.
-        float bias_max_rescale =
-            MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, cached_max_bias_) /
bias_scale;
-        if (bias_int32_rescale > bias_max_rescale) {
-          // avoid overflow on bias
-          bias_int32_rescale = bias_max_rescale;
-          float weight_rescale = bias_int32_rescale * bias_scale / data_scale / weight_scale;
-          cached_weight_.reset(new NDArray(weight.storage_type(), weight.shape(), weight.ctx(),
-                                           true, mshadow::kInt8));
-          int8_t *weight_ptr = weight.data().dptr<int8_t>();
-          int8_t *quantized_weight_ptr = cached_weight_->data().dptr<int8_t>();
-          size_t weight_size = weight.shape().Size();
-#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-          for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
-            quantized_weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+      data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, cached_max_data_);
+
+      bool fuse_requantize = false;
+      // Channelwise scaling is only supported when fusion is enabled (requantize or dequantize).
+      if (mkldnn_param.min_calib_range.has_value() &&
+          mkldnn_param.max_calib_range.has_value()) {
+        cached_min_output_ = mkldnn_param.min_calib_range.value();
+        cached_max_output_ = mkldnn_param.max_calib_range.value();
+        support_channelwise_scale = true;
+        fuse_requantize = true;
+      }
+      if (mkldnn_param.enable_float_output) {
+        support_channelwise_scale = true;
+      }
+      // channel_wise  support_channelwise_scale  result
+      // True          True                       True
+      // True          False                      Error
+      // False         True/False                 False
+      if (channel_wise && !support_channelwise_scale) {
+        LOG(FATAL)
+          << "Currently, channel-wise quantization requires fuse requantize or dequantize.";
 
 Review comment:
   Is it something that users may encounter? If so, what kind of action is suggested?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message