tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op
Date Thu, 26 Mar 2020 20:05:31 GMT
shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398857587
 
 

 ##########
 File path: src/relay/qnn/op/add.cc
 ##########
 @@ -97,39 +66,29 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array<Expr>&
new_args,
   //          Q_c = Q_a' + Q_b' - zp_c
   // The add op is done in int32 precision.
 
-  // Requantize LHS if necessary.
-  auto requantized_lhs = lhs;
-  if (!IsEqualScalar(lhs_scale, output_scale) ||
-      !IsEqualScalar(lhs_zero_point, output_zero_point)) {
-    requantized_lhs = Requantize(lhs, input_shape, lhs_scale, lhs_zero_point, output_scale,
-                                 output_zero_point, DataType::Int(32));
-  } else {
-    requantized_lhs = Cast(requantized_lhs, DataType::Int(32));
-  }
 
-  // Requantize RHS if necessary.
-  auto requantized_rhs = rhs;
-  if (!IsEqualScalar(rhs_scale, output_scale) ||
-      !IsEqualScalar(rhs_zero_point, output_zero_point)) {
-    requantized_rhs = Requantize(rhs, input_shape, rhs_scale, rhs_zero_point, output_scale,
-                                 output_zero_point, DataType::Int(32));
-  } else {
-    requantized_rhs = Cast(requantized_rhs, DataType::Int(32));
-  }
 
+  // Requantize LHS if necessary. Computes Q_a'
+  auto requantized_lhs = requantizeIfNeeded(args.lhs, args.lhs_scale,
+                                            args.lhs_zero_point,
+                                            args.output_scale, args.output_zero_point,
+                                            inputShapeAndDtype.input_shape);
+  // Requantize RHS if necessary. Computes Q_b'
+  auto requantized_rhs = requantizeIfNeeded(args.rhs, args.rhs_scale,
+                                            args.rhs_zero_point,
+                                            args.output_scale, args.output_zero_point,
+                                            inputShapeAndDtype.input_shape);
+  // Computes Q_a' + Q_b'
   auto output = Add(requantized_lhs, requantized_rhs);
 
-  // Subtract zero point.
+  // Subtract zero point. Computes (Q_a' + Q_b') - zp_c
   auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
-  if (!IsEqualScalar(output_zero_point, zero_scalar)) {
-    output = Subtract(output, output_zero_point);
+  if (!IsEqualScalar(args.output_zero_point, zero_scalar)) {
+    output = Subtract(output, args.output_zero_point);
   }
 
   // Go back to lower precision.
-  auto q_min = GetQmin(input_dtype);
-  auto q_max = GetQmax(input_dtype);
-  output = Clip(output, q_min, q_max);
-  return Cast(output, input_dtype);
+  return lowerPrecision(output, inputShapeAndDtype.input_dtype);
 
 Review comment:
   Well `ShrinkBackToOutDtype` implies that the the output will be converted to output dtype
but we always convet to input dtype (what goes in comes out). `ConvertDtype(fromExpression,
toTargetDtype)` is the convention I am going with. What do you think?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message