tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From liuyi...@apache.org
Subject [incubator-tvm] branch v0.6 updated: [BACKPORT-0.6][Quantization] Fix annotation for multiply op (#4458) (#5850)
Date Fri, 19 Jun 2020 06:32:39 GMT
This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/v0.6 by this push:
     new ab76831  [BACKPORT-0.6][Quantization] Fix annotation for multiply op (#4458) (#5850)
ab76831 is described below

commit ab76831126e21d224fbd7d04c5f26f3ad3628c2e
Author: masahi <masahi129@gmail.com>
AuthorDate: Fri Jun 19 15:31:25 2020 +0900

    [BACKPORT-0.6][Quantization] Fix annotation for multiply op (#4458) (#5850)
    
    * fix mul rewrite
    
    * register Realize Rewrite for global avg pool and add test
    
    * remove unnecessary check
    
    * improve the test case
---
 python/tvm/relay/quantize/_annotate.py        |  6 ++--
 src/relay/pass/quantize/realize.cc            |  7 ++--
 tests/python/relay/test_pass_auto_quantize.py | 49 +++++++++++++++++++++++++++
 3 files changed, 56 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py
index 9d679d2..ab98f3c 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -214,8 +214,10 @@ def multiply_rewrite(ref_call, new_args, ctx):
         # quantize lhs to INPUT field
         if lhs_kind == QAnnotateKind.ACTIVATION:
             lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
-        # quantize rhs to WEIGHT field
-        rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
+        if _analysis.check_constant(rhs_expr):
+            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
+        else:
+            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
         expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
         return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc
index 4cf84f4..773551a 100644
--- a/src/relay/pass/quantize/realize.cc
+++ b/src/relay/pass/quantize/realize.cc
@@ -278,13 +278,9 @@ Expr MulRealize(const Call& ref_call,
     DataType dtype = cfg->dtype_activation;
     if (lhs->dtype != dtype) {
       ldata = Cast(ldata, dtype);
-    } else {
-      CHECK_EQ(lhs->dtype, dtype);
     }
     if (rhs->dtype != dtype) {
       rdata = Cast(rdata, dtype);
-    } else {
-      CHECK_EQ(rhs->dtype, dtype);
     }
 
     Expr ret = ForwardOp(ref_call, {ldata, rdata});
@@ -499,6 +495,9 @@ Expr AvgPoolRealize(const Call& ref_call,
 RELAY_REGISTER_OP("nn.avg_pool2d")
 .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
 
+RELAY_REGISTER_OP("nn.global_avg_pool2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+
 Expr CastHintRealize(const Call& ref_call,
                      const Array<Expr>& new_args,
                      const NodeRef& ctx) {
diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py
new file mode 100644
index 0000000..e4aa36b
--- /dev/null
+++ b/tests/python/relay/test_pass_auto_quantize.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import testing
+
+
+def quantize_and_build(out):
+    f = relay.Function(relay.analysis.free_vars(out), out)
+    mod, params = testing.create_workload(f)
+
+    with relay.quantize.qconfig(skip_conv_layers=[]):
+        qmod = relay.quantize.quantize(mod, params)
+
+    relay.build(qmod, "llvm", params=params)
+
+
+def test_mul_rewrite():
+    """a test case where rhs of mul is not constant"""
+    data = relay.var("data", shape=(1, 16, 64, 64))
+    multiplier = relay.sigmoid(relay.var("data", shape=(1, 16, 1, 1)))
+    conv = relay.nn.conv2d(data, relay.var("weight"),
+                           kernel_size=(3, 3),
+                           padding=(1, 1),
+                           channels=16)
+    act = relay.nn.relu(data=conv)
+
+    quantize_and_build(act * multiplier)
+
+    pool = relay.nn.global_avg_pool2d(data=act)
+
+    quantize_and_build(act * pool)
+
+if __name__ == "__main__":
+    test_mul_rewrite()


Mime
View raw message