tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op
Date Fri, 27 Mar 2020 17:57:03 GMT
shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op
URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r399444523
 
 

 ##########
 File path: src/relay/qnn/op/op_common.h
 ##########
 @@ -30,14 +30,152 @@
 #include <tvm/relay/qnn/attrs.h>
 #include <vector>
 #include "../../op/type_relations.h"
+#include "../../transforms/infer_layout_util.h"
+#include "../util.h"
 
 namespace tvm {
 namespace relay {
 namespace qnn {
 
-static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, const
Attrs& attrs,
+/*
+ * Number of inputs for the Qnn binary operators.
+ * Refer the QNN_REGISTER_BINARY_OP macro to see
+ * what the operators are.
+ */
+static constexpr int kNumQnnBinaryOpInputs = 8;
+
+/*
+ * Number of expected arg types.
+ */
+static constexpr int kNumQnnBinaryOpArgTypes = 9;
+
+/*
+ * \brief Simple struct to organize the inputs to the Qnn
+ * binary operators. The main reason to have a struct
+ * is to be able to perform the common checks needed at a
+ * central location.
+ */
+struct QnnBinaryOpArguments {
+  Expr lhs;
+  Expr rhs;
+  Expr lhs_scale;
+  Expr lhs_zero_point;
+  Expr rhs_scale;
+  Expr rhs_zero_point;
+  Expr output_scale;
+  Expr output_zero_point;
+
+  explicit QnnBinaryOpArguments(const Array<Expr>& new_args) {
+    CHECK_EQ(new_args.size(), kNumQnnBinaryOpInputs);
+    int idx = 0;
+    lhs = new_args[idx++];
+    rhs = new_args[idx++];
+    lhs_scale = new_args[idx++];
+    lhs_zero_point = new_args[idx++];
+    rhs_scale = new_args[idx++];
+    rhs_zero_point = new_args[idx++];
+    output_scale = new_args[idx++];
+    output_zero_point = new_args[idx++];
+    CHECK_EQ(idx, kNumQnnBinaryOpInputs);
+  }
+};
+
+/*
+ * \brief Simple structure to hold the input tensor's dtype
+ * and shape. This structure allows a common point to do
+ * all the validation checks for Qnn binary operators.
+ */
+struct QnnBinaryOpTensorType {
+  DataType dtype;
+  Array <PrimExpr> shape;
+
+  explicit QnnBinaryOpTensorType(const Array<tvm::relay::Type>& arg_types,
+                                 const int32_t arg_idx) {
+    CHECK_EQ(arg_types.size(), kNumQnnBinaryOpArgTypes);
+    auto tensor_type = arg_types[arg_idx].as<TensorTypeNode>();
+    CHECK(tensor_type != nullptr);
+    dtype = tensor_type->dtype;
+    shape = tensor_type->shape;
+  }
+};
+
+/*
+ * \brief Converts the expression from expression's dtype
+ * to target dtype. This is mainly used for converting
+ * computations done in Int32 to lower precision Int8 or
+ * UInt8.
+ * \param expr The expression to whose dtype needs conversion.
+ * \param target_dtype The dtype of the target expression
+ * \return New expression with target dtype and possibly lower
+ * precision.
+ */
+inline Expr ConvertDtype(const Expr& expr,
+                         const DataType& target_dtype) {
+  auto q_min = GetQmin(target_dtype);
+  auto q_max = GetQmax(target_dtype);
+  auto output = Clip(expr, q_min, q_max);
+  return Cast(output, target_dtype);
+}
+
+/*
+ * \brief Requantizes the given expression if expression's
+ * scale and zero point both do not match target scale and
+ * zero point. This is mainly needed for requantizing the
+ * input tensors with output tensor's scale and zero point
+ * to ease the computation of final quantized tensor.
+ * \param expr The expression on which the check needs to be performed.
+ * \param expr_scale The scale of the expression.
+ * \param expr_zero_point The zero point of the expression.
+ * \param target_scale The scale of the output tensor.
+ * \param target_zero_point The zero point of the output tensor.
+ * \param expr_shape The shape of the input expression.
+ * \return New expression that is requantized to target scale and zero
+ * point if the expression scale and zero points are different otherwise
+ * it simply casts the given expression to Int32 as no requantization is
+ * needed in this case.
+ */
+inline Expr RequantizeOrUpcast(const Expr& expr,
+                               const Expr& expr_scale,
+                               const Expr& expr_zero_point,
+                               const Expr& target_scale,
+                               const Expr& target_zero_point,
+                               const Array <PrimExpr>& expr_shape,
+                               const DataType& target_dtype=DataType::Int(32)) {
+  auto result = expr;
+  if (!IsEqualScalar(expr_scale, target_scale) ||
+     !IsEqualScalar(expr_zero_point, target_zero_point)) {
 
 Review comment:
   done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message