From commits-return-9584-archive-asf-public=cust-asf.ponee.io@tvm.apache.org Thu Mar 26 18:46:33 2020 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [207.244.88.153]) by mx-eu-01.ponee.io (Postfix) with SMTP id 12F8318065D for ; Thu, 26 Mar 2020 19:46:32 +0100 (CET) Received: (qmail 12529 invoked by uid 500); 26 Mar 2020 18:46:32 -0000 Mailing-List: contact commits-help@tvm.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@tvm.apache.org Delivered-To: mailing list commits@tvm.apache.org Received: (qmail 12514 invoked by uid 99); 26 Mar 2020 18:46:32 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 26 Mar 2020 18:46:32 +0000 From: GitBox To: commits@tvm.apache.org Subject: [GitHub] [incubator-tvm] shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op Message-ID: <158524839238.3974.5750732949562811245.gitbox@gitbox.apache.org> References: In-Reply-To: Date: Thu, 26 Mar 2020 18:46:32 -0000 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit shoubhik commented on a change in pull request #5153: Adding support for QNN subtract op URL: https://github.com/apache/incubator-tvm/pull/5153#discussion_r398810414 ########## File path: src/relay/qnn/op/op_common.h ########## @@ -30,14 +30,155 @@ #include #include #include "../../op/type_relations.h" +#include "../../transforms/infer_layout_util.h" namespace tvm { namespace relay { namespace qnn { -static inline bool QnnBroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, +/* + * Number of inputs for the Qnn binary operators. + * Refer the QNN_REGISTER_BINARY_OP macro to see + * what the operators are. + */ +static constexpr int numQnnBinaryOpInputs = 8; + +/* + * Number of expected arg types. + */ +static constexpr int numQnnBinaryOpArgTypes = 9; + +/* + * \brief Simple struct to organize the inputs to the Qnn + * binary operators. The main reason to have a struct + * is to be able to perform the common checks needed at a + * central location. + */ +struct QnnBinaryOpArguments { + Expr lhs; + Expr rhs; + Expr lhs_scale; + Expr lhs_zero_point; + Expr rhs_scale; + Expr rhs_zero_point; + Expr output_scale; + Expr output_zero_point; + + explicit QnnBinaryOpArguments(const Array& new_args) { + CHECK_EQ(new_args.size(), numQnnBinaryOpInputs); + int idx = 0; + lhs = new_args[idx++]; + rhs = new_args[idx++]; + lhs_scale = new_args[idx++]; + lhs_zero_point = new_args[idx++]; + rhs_scale = new_args[idx++]; + rhs_zero_point = new_args[idx++]; + output_scale = new_args[idx++]; + output_zero_point = new_args[idx++]; + CHECK_EQ(idx, numQnnBinaryOpInputs); + } +}; + +/* + * \brief Simple structure to hold the input tensor's dtype + * and shape. This structure allows a common point to do + * all the validation checks. + */ +struct QnnBinaryOpDtypeAndShape { + DataType input_dtype; + Array input_shape; + + explicit QnnBinaryOpDtypeAndShape(const Array& arg_types) { + CHECK_EQ(arg_types.size(), numQnnBinaryOpArgTypes); + auto tensor_type = arg_types[0].as(); + CHECK(tensor_type != nullptr); + input_dtype = tensor_type->dtype; + input_shape = tensor_type->shape; + } +}; + +/* + * \brief Converts the expression from expression's dtype + * to target dtype. This is mainly used for converting + * computations done in Int32 to lower precision Int8 or + * UInt8. + * \param expr The expression to whose dtype needs conversion. + * \param target_dtype The dtype of the target expression + * \return New expression with target dtype and possibly lower + * precision. + */ +inline Expr lowerPrecision(const Expr& expr, + const DataType& target_dtype) { + auto q_min = GetQmin(target_dtype); + auto q_max = GetQmax(target_dtype); + auto output = Clip(expr, q_min, q_max); + return Cast(output, target_dtype); +} + +/* + * Full precision Int32 data type for explicitly casting + * Int8/UInt8 to Int32 and create Int32 constants. + */ +const auto fullPrecisionInt32 = DataType::Int(32); + +/* + * \brief Requantizes the given expression if expression's + * scale and zero point both do not match target scale and + * zero point. This is mainly needed for requantizing the + * input tensors with output tensor's scale and zero point + * to ease the computation of final quantized tensor. + * \param expr The expression on which the check needs to be performed. + * \param expr_scale The scale of the expression. + * \param expr_zero_point The zero point of the expression. + * \param target_scale The scale of the output tensor. + * \param target_zero_point The zero point of the output tensor. + * \param expr_shape The shape of the input expression. + * \return New expression that is requantized to target scale and zero + * point if the expression scale and zero points are different otherwise + * it simply casts the given expression to Int32 as no requantization is + * needed in this case. + */ +inline Expr requantizeIfNeeded(const Expr& expr, + const Expr& expr_scale, + const Expr& expr_zero_point, + const Expr& target_scale, + const Expr& target_zero_point, + const Array & expr_shape) { + auto result = expr; + if (!IsEqualScalar(expr_scale, target_scale) || + !IsEqualScalar(expr_zero_point, target_zero_point)) { + result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, + target_scale, target_zero_point, fullPrecisionInt32); + } else { + result = Cast(result, fullPrecisionInt32); Review comment: 'RequantizeOrUpcast`? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services