tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [CUDA] Fix codegen for warp shuffle intrinsics (#5606)
Date Mon, 18 May 2020 02:55:15 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new cb7bd98  [CUDA] Fix codegen for warp shuffle intrinsics (#5606)
cb7bd98 is described below

commit cb7bd986cadea53d6f41143a6ce747224e25aefb
Author: Shizhi Tang <rd0x01@gmail.com>
AuthorDate: Mon May 18 10:55:05 2020 +0800

    [CUDA] Fix codegen for warp shuffle intrinsics (#5606)
    
    * fix shfl intrin
    
    * improve test_lower_warp_memory_cuda_half_a_warp
---
 src/target/source/intrin_rule_cuda.cc              |  2 +-
 .../test_tir_transform_lower_warp_memory.py        | 26 ++++++++++++----------
 2 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc
index 4e4abd9..7ebcfa6 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -116,7 +116,7 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue*
rv) {
   const CallNode* call = e.as<CallNode>();
   CHECK(call != nullptr);
   CHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, warp_size
-  Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
+  Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2],
call->args[3]}};
   const char* name = T()(call->dtype, call->name);
   *rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
 }
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index bd55377..c3cf289 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -136,30 +136,32 @@ def test_lower_warp_memory_cuda_half_a_warp():
             print("Skip because gpu does not have fp16 support")
             return
 
-        m = 16
-        A = te.placeholder((m,), name='A', dtype=dtype)
-        B = te.compute((m,), lambda i: A[(i + 1) % m], name='B')
+        n, m = 16, 16
+        A = te.placeholder((n, m,), name='A', dtype=dtype)
+        B = te.compute((n, m,), lambda j, i: A[j, (i + 1) % m], name='B')
 
         cuda_target = tvm.target.create("cuda")
         assert cuda_target.thread_warp_size == 2 * m
         with cuda_target:
             s = te.create_schedule(B.op)
             tx = te.thread_axis("threadIdx.x")
+            ty = te.thread_axis("threadIdx.y")
             bx = te.thread_axis("blockIdx.x")
 
             AA = s.cache_read(A, "warp", [B])
-            xo, xi = s[B].split(B.op.axis[0], nparts=1)
-            s[B].bind(xi, tx)
-            s[B].bind(xo, bx)
-            s[AA].compute_at(s[B], xo)
-            xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
-            s[AA].bind(xo, bx)
-            s[AA].bind(xi, tx)
+            y, x = B.op.axis
+            z, y = s[B].split(y, nparts=2)
+            s[B].bind(x, tx)
+            s[B].bind(y, ty)
+            s[B].bind(z, bx)
+            s[AA].compute_at(s[B], y)
+            _, x = AA.op.axis
+            s[AA].bind(x, tx)
 
             ctx = tvm.gpu(0)
             func = tvm.build(s, [A, B], "cuda")
-            A_np = np.array(list(range(m)), dtype=dtype)
-            B_np = np.array(list(range(1, m)) + [0], dtype=dtype)
+            A_np = np.array([list(range(i, m + i)) for i in range(n)], dtype=dtype)
+            B_np = np.array([list(range(1 + i, m + i)) + [i] for i in range(n)], dtype=dtype)
             A_nd = tvm.nd.array(A_np, ctx)
             B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
             func(A_nd, B_nd)


Mime
View raw message