tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)
Date Thu, 12 Mar 2020 16:35:47 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 173b4fc  Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)
173b4fc is described below

commit 173b4fc4c46499056ebc5682c20fcff2582bc9db
Author: pankratz <35379668+dpankratz@users.noreply.github.com>
AuthorDate: Thu Mar 12 10:35:36 2020 -0600

    Fixed div by zero core dump. Fixed rounding intrinsics on int crash (#5026)
---
 src/arith/const_fold.h                   |  2 ++
 src/tir/ir/op.cc                         | 15 +++++++++++++++
 tests/python/unittest/test_lang_basic.py | 19 +++++++++++++++++--
 tests/python/unittest/test_tvm_intrin.py | 11 +++++++++++
 4 files changed, 45 insertions(+), 2 deletions(-)

diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index bae34bd..a440af9 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -181,6 +181,7 @@ inline PrimExpr TryConstFold<tir::ModNode>(PrimExpr a, PrimExpr
b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
+        CHECK_NE(pb->value, 0) << "Divide by zero";
         return IntImm(rtype, pa->value % pb->value);
       }
       if (pa) {
@@ -226,6 +227,7 @@ inline PrimExpr TryConstFold<tir::FloorModNode>(PrimExpr a, PrimExpr
b) {
   TVM_INDEX_CONST_PROPAGATION({
       const DataType& rtype = a.dtype();
       if (pa && pb) {
+        CHECK_NE(pb->value, 0) << "Divide by zero";
         return IntImm(rtype, floormod(pa->value, pb->value));
       }
       if (pa) {
diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc
index 452c3bb..2882fea 100644
--- a/src/tir/ir/op.cc
+++ b/src/tir/ir/op.cc
@@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
 }
 
 PrimExpr floor(PrimExpr x) {
+  if (x.dtype().is_int() || x.dtype().is_uint()) {
+    return x;
+  }
   using tir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
@@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) {
 }
 
 PrimExpr ceil(PrimExpr x) {
+  if (x.dtype().is_int() || x.dtype().is_uint()) {
+    return x;
+  }
   using tir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
@@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) {
 }
 
 PrimExpr round(PrimExpr x) {
+  if (x.dtype().is_int() || x.dtype().is_uint()) {
+    return x;
+  }
   using tir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
@@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) {
 }
 
 PrimExpr nearbyint(PrimExpr x) {
+  if (x.dtype().is_int() || x.dtype().is_uint()) {
+    return x;
+  }
   using tir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
@@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) {
 }
 
 PrimExpr trunc(PrimExpr x) {
+  if (x.dtype().is_int() || x.dtype().is_uint()) {
+    return x;
+  }
   using tir::FloatImmNode;
   const FloatImmNode* fx = x.as<FloatImmNode>();
   if (fx) {
diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py
index cd532a0..c279194 100644
--- a/tests/python/unittest/test_lang_basic.py
+++ b/tests/python/unittest/test_lang_basic.py
@@ -187,14 +187,14 @@ def test_bitwise():
     assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
     assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
 
+
 def test_float_bitwise():
     t = tvm.tir.const(1.5,dtype='float32')
     for test in [lambda lhs, rhs : lhs << rhs,
                     lambda lhs, rhs : lhs >> rhs,
                     lambda lhs, rhs : lhs | rhs,
                     lambda lhs, rhs : lhs ^ rhs,
-                    lambda lhs, rhs : lhs & rhs
-                ]:
+                    lambda lhs, rhs : lhs & rhs]:
         try:
             test(t,10.0)
             assert False
@@ -206,6 +206,20 @@ def test_float_bitwise():
     except RuntimeError:
         pass
 
+
+def test_divide_by_zero():
+    for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs),
+                    lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs),
+                    lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs),
+                    lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs),
+                    lambda lhs, rhs : tvm.tir.div(lhs,rhs)]:
+        try:
+            test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32'))
+            assert False
+        except tvm.TVMError:
+            pass
+
+
 def test_isnan():
     x = te.var('x', 'float32')
     assert str(tvm.tir.isnan(x)) == 'isnan(x)'
@@ -250,6 +264,7 @@ if __name__ == "__main__":
     test_all()
     test_bitwise()
     test_float_bitwise()
+    test_divide_by_zero()
     test_isnan()
     test_equality()
     test_equality_string_imm()
diff --git a/tests/python/unittest/test_tvm_intrin.py b/tests/python/unittest/test_tvm_intrin.py
index 0054273..52ae440 100644
--- a/tests/python/unittest/test_tvm_intrin.py
+++ b/tests/python/unittest/test_tvm_intrin.py
@@ -44,6 +44,16 @@ def test_nearbyint():
     tvm.testing.assert_allclose(
         a_rounded.asnumpy(), np.rint(a.asnumpy()))
 
+def test_round_intrinsics_on_int():
+    i = tvm.te.var("i", 'int32')
+    for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil,
+                            tvm.tir.floor, tvm.tir.nearbyint]:
+        assert op(tvm.tir.const(10,'int32')).value == 10
+        assert op(tvm.tir.const(True,'bool')).value == True
+        assert op(i).same_as(i)
+
+    assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False
+
 
 def test_unary_intrin():
     test_funcs = [
@@ -75,3 +85,4 @@ def test_unary_intrin():
 if __name__ == "__main__":
     test_nearbyint()
     test_unary_intrin()
+    test_round_intrinsics_on_int()


Mime
View raw message