tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] roastduck opened a new issue #5245: lower_warp_memory not working
Date Mon, 06 Apr 2020 09:01:32 GMT
roastduck opened a new issue #5245: lower_warp_memory not working
URL: https://github.com/apache/incubator-tvm/issues/5245
 
 
   `src/tir/transforms/lower_warp_memory.cc` transforms memory accesses in "warp" scope to
"local" scope, and generates warp shuffle primitives. I found it indeed transformed the scope
to "local", but it failed to generate the shuffle primitives for me.
   
   To reproduce:
   
   ```python
   import tvm
   import topi
   import numpy as np
   
   from topi.util import get_const_tuple
   
   dtype = "float32"
   target = "cuda"
   n = m = 32
   
   a = tvm.te.placeholder((m,), name="a", dtype=dtype)
   b = tvm.te.compute((n,), lambda i: a[(i + 1) % m])
   with tvm.target.cuda():
       s = tvm.te.create_schedule(b.op)
       th_x = tvm.te.thread_axis("threadIdx.x")
       blk_x = tvm.te.thread_axis("blockIdx.x")
       i, = b.op.axis
       blk, th = s[b].split(i, nparts=1)
       s[b].bind(blk, blk_x)
       s[b].bind(th, th_x)
       a_cache = s.cache_read(a, "warp", b)
       a_axis, = a_cache.op.axis
       s[a_cache].bind(a_axis, th_x)
       s[a_cache].compute_at(s[b], blk)
   
   print(tvm.lower(s, [a, b], target, simple_mode=True))
   compute = tvm.build(s, [a, b], target, name="run")
   print(compute.imported_modules[0].get_source())
   ```
   
   The output was:
   
   ```
   produce compute {
     // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
     // attr [a.warp] storage_scope = "warp"
     allocate a.warp[float32 * 32]
     produce a.warp {
       // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
       a.warp[threadIdx.x] = a[threadIdx.x]
     }
     // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
     compute[threadIdx.x] = a.warp[floormod((threadIdx.x + 1), 32)]
   }
   
   extern "C" __global__ void run_kernel0(void* __restrict__ a, void* __restrict__ compute)
{
      float a_warp[1];
     a_warp[(0)] = (( float*)a)[(((int)threadIdx.x))];
     (( float*)compute)[(((int)threadIdx.x))] = a_warp[(((((int)threadIdx.x) + 1) & 31))];
   }
   ```
   
   You will find that the shuffle primitive is missing. Instead, it directly stores the  unshuffled
data.
   
   There is a unit test for this pass in `tests/python/unittest/test_tir_transform_lower_warp_memory.py`,
in which it only asserts the "local" scope, but it does not check the overall correctness.
So the test is too weak to discover the problem.
   
   I also traced the lowering pass using GDB. I found the `CallNode::make` to create that
primitive in Line 248 is actually invoked. But the primitive is lost somewhere.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message