tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From liuyi...@apache.org
Subject [incubator-tvm] branch v0.6 updated: [BACKPORT-0.6][AUTOTVM] Fix a bug in generating the search space (#5876)
Date Sun, 21 Jun 2020 22:32:55 GMT
This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/v0.6 by this push:
     new 68fb0ae  [BACKPORT-0.6][AUTOTVM] Fix a bug in generating the search space (#5876)
68fb0ae is described below

commit 68fb0aed04600996f976ce74be0346a33fe23a35
Author: Yizhi Liu <liuyizhi@apache.org>
AuthorDate: Sun Jun 21 15:32:45 2020 -0700

    [BACKPORT-0.6][AUTOTVM] Fix a bug in generating the search space (#5876)
    
    Co-authored-by: wpan11nv <60017475+wpan11nv@users.noreply.github.com>
---
 python/tvm/autotvm/task/space.py            |  4 +++-
 tests/python/unittest/test_autotvm_space.py | 15 +++++++++++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py
index f1422bf..d83a248 100644
--- a/python/tvm/autotvm/task/space.py
+++ b/python/tvm/autotvm/task/space.py
@@ -226,7 +226,9 @@ class SplitSpace(TransformSpace):
     def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
         """Generate space by DFS"""
         if now == self.num_output - 1:
-            prod = np.prod(tmp_stack, dtype=np.int64)
+            prod = functools.reduce(lambda x, y: x * y, tmp_stack)
+            if prod > self.product:
+                return
             if self.product % prod == 0 or (not enforce_no_tail and prod < self.product):
                 self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
         else:
diff --git a/tests/python/unittest/test_autotvm_space.py b/tests/python/unittest/test_autotvm_space.py
index 85d5724..95f3201 100644
--- a/tests/python/unittest/test_autotvm_space.py
+++ b/tests/python/unittest/test_autotvm_space.py
@@ -62,6 +62,21 @@ def test_split():
     cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3)
     assert len(cfg.space_map['tile_c']) == 84
 
+    # Count the number of non-negative integer solutions of a + b + c + d = n
+    def count4(n):
+        cnt = 0
+        for a in range(0, n + 1):
+            for b in range(0, n - a + 1):
+                cnt += n - a - b + 1
+        return cnt
+
+    # test overflow
+    n = 25
+    cfg = ConfigSpace()
+    cfg.define_split('x', cfg.axis(2**n), policy='factors', num_outputs=4)
+    # count4(25) is 3276.
+    assert len(cfg.space_map['x']) == count4(n)
+
     # test fallback
     cfg = FallbackConfigEntity()
     cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)


Mime
View raw message