tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch v0.6 updated: [ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5761)
Date Wed, 10 Jun 2020 00:32:00 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/v0.6 by this push:
     new 85af4a9  [ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5761)
85af4a9 is described below

commit 85af4a908ee7b66e81cc443b1673e890b20decb8
Author: xqdan <danxiaoqiang@126.com>
AuthorDate: Wed Jun 10 08:31:50 2020 +0800

    [ARITH][BACKPORT-0.6] fix a min/max simplify bug (#5761)
    
    Co-authored-by: d00221512 <d00221512@huawei.com>
---
 src/arithmetic/rewrite_simplify.cc                   | 10 ++++++++--
 tests/python/unittest/test_arith_rewrite_simplify.py | 12 ++++++++++++
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc
index b26f833..4d4242b 100644
--- a/src/arithmetic/rewrite_simplify.cc
+++ b/src/arithmetic/rewrite_simplify.cc
@@ -1157,8 +1157,11 @@ Mutate_(const Min* op, const Expr& self) {
     if (min(x * c1, c2).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
+      if (c1val == 0) {
+        return c2val < 0 ? c2.Eval() : c1.Eval();
+      }
       if (c2val % c1val == 0) {
-        if (c2val / c1val >= 0) {
+        if (c1val > 0) {
           return (min(x, c2val / c1val) * c1val).Eval();
         } else {
           return (max(x, c2val / c1val) * c1val).Eval();
@@ -1331,8 +1334,11 @@ Mutate_(const Max* op, const Expr& self) {
     if (max(x * c1, c2).Match(ret)) {
       int64_t c1val = c1.Eval()->value;
       int64_t c2val = c2.Eval()->value;
+      if (c1val == 0) {
+        return c2val > 0 ? c2.Eval() : c1.Eval();
+      }
       if (c2val % c1val == 0) {
-        if (c2val / c1val >= 0) {
+        if (c1val > 0) {
           return (max(x, c2val / c1val) * c1val).Eval();
         } else {
           return (min(x, c2val / c1val) * c1val).Eval();
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py
index 99c2942..f380493 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -529,6 +529,12 @@ def test_min_index_simplify():
 
     ck.verify(tvm.min(x * 3, 9), tvm.min(x, 3) * 3)
     ck.verify(tvm.min(3 - x, 2), 3 - tvm.max(x,  1))
+    ck.verify(tvm.min(x * 2, 0), tvm.min(x, 0) * 2)
+    ck.verify(tvm.min(0 - x * 2, 0), tvm.max(x, 0) * -2)
+    ck.verify(tvm.min(x * (-2), -4), tvm.max(x, 2) * -2)
+    ck.verify(tvm.min(x * (-2), 4), tvm.max(x, -2) * -2)
+    ck.verify(tvm.min(x * (0), 4), 0)
+    ck.verify(tvm.min(x * (0), -4), -4)
 
     # DivMod rules
     # truc div
@@ -609,6 +615,12 @@ def test_max_index_simplify():
 
     ck.verify(tvm.max(x * 3, 9), tvm.max(x, 3) * 3)
     ck.verify(tvm.max(3 - x, 1), 3 - tvm.min(x,  2))
+    ck.verify(tvm.max(x * 2, 0), tvm.max(x, 0) * 2)
+    ck.verify(tvm.max(0 - x * 2, 0), tvm.min(x, 0) * -2)
+    ck.verify(tvm.max(x * (-2), -4), tvm.min(x, 2) * -2)
+    ck.verify(tvm.max(x * (-2), 4), tvm.min(x, -2) * -2)
+    ck.verify(tvm.max(x * (0), 4), 4)
+    ck.verify(tvm.max(x * (0), -4), 0)
 
     # DivMod rules
     # truc div


Mime
View raw message