tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op
Date Thu, 26 Mar 2020 18:46:32 GMT
shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398810414
 
 

 ##########
 File path: src/relay/qnn/op/op_common.h
 ##########
 @@ -30,14 +30,155 @@
 #include <tvm/relay/qnn/attrs.h>
 #include <vector>
 #include "../../op/type_relations.h"
+#include "../../transforms/infer_layout_util.h"
 
 namespace tvm {
 namespace relay {
 namespace qnn {
 
-static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
+/*
+ * Number of inputs for the Qnn binary operators.
+ * Refer the QNN_REGISTER_BINARY_OP macro to see
+ * what the operators are.
+ */
+static constexpr int numQnnBinaryOpInputs = 8;
+
+/*
+ * Number of expected arg types.
+ */
+static constexpr int numQnnBinaryOpArgTypes = 9;
+
+/*
+ * \brief Simple struct to organize the inputs to the Qnn
+ * binary operators. The main reason to have a struct
+ * is to be able to perform the common checks needed at a
+ * central location.
+ */
+struct QnnBinaryOpArguments {
+  Expr lhs;
+  Expr rhs;
+  Expr lhs_scale;
+  Expr lhs_zero_point;
+  Expr rhs_scale;
+  Expr rhs_zero_point;
+  Expr output_scale;
+  Expr output_zero_point;
+
+  explicit QnnBinaryOpArguments(const Array<Expr>& new_args) {
+    CHECK_EQ(new_args.size(), numQnnBinaryOpInputs);
+    int idx = 0;
+    lhs = new_args[idx++];
+    rhs = new_args[idx++];
+    lhs_scale = new_args[idx++];
+    lhs_zero_point = new_args[idx++];
+    rhs_scale = new_args[idx++];
+    rhs_zero_point = new_args[idx++];
+    output_scale = new_args[idx++];
+    output_zero_point = new_args[idx++];
+    CHECK_EQ(idx, numQnnBinaryOpInputs);
+  }
+};
+
+/*
+ * \brief Simple structure to hold the input tensor's dtype
+ * and shape. This structure allows a common point to do
+ * all the validation checks.
+ */
+struct QnnBinaryOpDtypeAndShape {
+  DataType input_dtype;
+  Array <PrimExpr> input_shape;
+
+  explicit QnnBinaryOpDtypeAndShape(const Array<tvm::relay::Type>& arg_types) {
+    CHECK_EQ(arg_types.size(), numQnnBinaryOpArgTypes);
+    auto tensor_type = arg_types[0].as<TensorTypeNode>();
+    CHECK(tensor_type != nullptr);
+    input_dtype = tensor_type->dtype;
+    input_shape = tensor_type->shape;
+  }
+};
+
+/*
+ * \brief Converts the expression from expression's dtype
+ * to target dtype. This is mainly used for converting
+ * computations done in Int32 to lower precision Int8 or
+ * UInt8.
+ * \param expr The expression to whose dtype needs conversion.
+ * \param target_dtype The dtype of the target expression
+ * \return New expression with target dtype and possibly lower
+ * precision.
+ */
+inline Expr lowerPrecision(const Expr& expr,
+                           const DataType& target_dtype) {
+  auto q_min = GetQmin(target_dtype);
+  auto q_max = GetQmax(target_dtype);
+  auto output = Clip(expr, q_min, q_max);
+  return Cast(output, target_dtype);
+}
+
+/*
+ * Full precision Int32 data type for explicitly casting
+ * Int8/UInt8 to Int32 and create Int32 constants.
+ */
+const auto fullPrecisionInt32 = DataType::Int(32);
+
+/*
+ * \brief Requantizes the given expression if expression's
+ * scale and zero point both do not match target scale and
+ * zero point. This is mainly needed for requantizing the
+ * input tensors with output tensor's scale and zero point
+ * to ease the computation of final quantized tensor.
+ * \param expr The expression on which the check needs to be performed.
+ * \param expr_scale The scale of the expression.
+ * \param expr_zero_point The zero point of the expression.
+ * \param target_scale The scale of the output tensor.
+ * \param target_zero_point The zero point of the output tensor.
+ * \param expr_shape The shape of the input expression.
+ * \return New expression that is requantized to target scale and zero
+ * point if the expression scale and zero points are different otherwise
+ * it simply casts the given expression to Int32 as no requantization is
+ * needed in this case.
+ */
+inline Expr requantizeIfNeeded(const Expr& expr,
+                               const Expr& expr_scale,
+                               const Expr& expr_zero_point,
+                               const Expr& target_scale,
+                               const Expr& target_zero_point,
+                               const Array <PrimExpr>& expr_shape) {
+  auto result = expr;
+  if (!IsEqualScalar(expr_scale, target_scale) ||
+     !IsEqualScalar(expr_zero_point, target_zero_point)) {
+    result = Requantize(expr, expr_shape, expr_scale, expr_zero_point,
+                        target_scale, target_zero_point, fullPrecisionInt32);
+  } else {
+    result = Cast(result, fullPrecisionInt32);
 
 Review comment:
   'RequantizeOrUpcast`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message