tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] spectrometerHBH opened a new pull request #5483: [TIR][Printer] text format printer considering future parsing use
Date Thu, 30 Apr 2020 03:31:53 GMT

spectrometerHBH opened a new pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483


   ## Several Points
   1. Rename vars/size_vars/buffers to use name hints to identify them
   2. A node is either in META or not. If ref A is put in META, ref B is contained in A, then
when we need to print ref B we print its META.
   3. Now only the `node` in `AttrStmt` may be put in META except that `node` is `StringImm/Var/Buffer/IterVar`.
   4. Print the Type of vars when we first encounter them on the tree.
   5. This PR hasn't combine relay's printer with tir's printer. I put several choices here
for combined API for discussion.
   - add an `astext` member function for IRModule: we only do mix printing when printing an
IRModule. Keep `relay.astext` and `tir.astext` separate for other uses.
   - merge relay's printer & tir's printer into one printer and use only one API `ir.astext`
   
   cc @tqchen @Hzfengsy 
   
   ## Several Examples
   ### Simple
   ```c++
   primfn(A0_1: handle, A1_1: handle, C_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "main"}
     buffers = {A1: Buffer(A1_2: handle, float32, [m: int32, n: int32], [stride: int32, stride_1:
int32], type="auto"),
              C: Buffer(C_2: handle, float32, [m, n], [stride_2: int32, stride_3: int32],
type="auto"),
              A0: Buffer(A0_2: handle, float32, [m, n], [stride_4: int32, stride_5: int32],
type="auto")}
     buffer_map = {C_1: C, A0_1: A0, A1_1: A1} {
     attr [B.v0: handle] "storage_scope" = "global";
     allocate(B.v0, float32, [n])  {
       attr [B.v1: handle] "storage_scope" = "global";
       allocate(B.v1, float32, [n])  {
         for (i: int32, 0, m) {
           for (j: int32, 0, n) {
             B.v0[j] = (load(float32, A0_2[((i*stride_4) + (j*stride_5))]) + float32(2))
             B.v1[j] = (load(float32, A0_2[((i*stride_4) + (j*stride_5))])*float32(3))
           }
           for (j_1: int32, 0, n) {
             C_2[((i*stride_2) + (j_1*stride_3))] = (load(float32, A1_2[((i*stride) + (j_1*stride_1))])
+ load(float32, B.v0[j_1]))
           }
         }
       }
     }
   }
   ```
   ### Conv on GPU
   ```c++
   primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "cuda"}
     buffers = {W: Buffer(W_2: handle, float32, [3, 3, 256, 512], []),
              B: Buffer(B_2: handle, float32, [14, 14, 512, 256], []),
              A: Buffer(A_2: handle, float32, [14, 14, 256, 256], [])}
     buffer_map = {B_1: B, A_1: A, W_1: W} {
     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent"
= 196;
     attr [B.local: handle] "storage_scope" = "local";
     allocate(B.local, float32, [64])  {
       attr [Apad.shared: handle] "storage_scope" = "shared";
       allocate(Apad.shared, float32, [512])  {
         attr [W.shared: handle] "storage_scope" = "shared";
         allocate(W.shared, float32, [512])  {
           attr [Apad.shared.local: handle] "storage_scope" = "local";
           allocate(Apad.shared.local, float32, [8])  {
             attr [W.shared.local: handle] "storage_scope" = "local";
             allocate(W.shared.local, float32, [8])  {
               attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")]
"thread_extent" = 8;
               attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")]
"thread_extent" = 4;
               attr [IterVar(threadIdx.y: int32, [0:8], "ThreadIndex", "threadIdx.y")] "thread_extent"
= 8;
               attr [IterVar(threadIdx.x: int32, [0:8], "ThreadIndex", "threadIdx.x")] "thread_extent"
= 8 {
                 for (ff.c.init: int32, 0, 4) {
                   for (nn.c.init: int32, 0, 4) {
                     B.local[((ff.c.init*4) + nn.c.init)] = float32(0)
                     B.local[(((ff.c.init*4) + nn.c.init) + 32)] = float32(0)
                     B.local[(((ff.c.init*4) + nn.c.init) + 16)] = float32(0)
                     B.local[(((ff.c.init*4) + nn.c.init) + 48)] = float32(0)
                   }
                 }
                 for (rc.outer: int32, 0, 32) {
                   for (ry: int32, 0, 3) {
                     for (rx: int32, 0, 3) {
                       for (ax3.inner.outer: int32, 0, 2) {
                         Apad.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer*4)),
1, 4)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + ry)) and ((floordiv(blockIdx.z,
14) + ry) < 15)) and (1 <= (rx + floormod(blockIdx.z, 14)))) and ((rx + floormod(blockIdx.z,
14)) < 15)), load(float32x4, A_2[ramp((((((((((ry*917504) + (blockIdx.z*65536)) + (rx*65536))
+ (rc.outer*2048)) + (threadIdx.y*256)) + (blockIdx.x*64)) + (threadIdx.x*8)) + (ax3.inner.outer*4))
- 983040), 1, 4)]), broadcast(float32(0), 4)], float32x4, "pure_intrin", 0)
                       }
                       for (ax3.inner.outer_1: int32, 0, 2) {
                         W.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)),
1, 4)] = load(float32x4, W_2[ramp((((((((ry*393216) + (rx*131072)) + (rc.outer*4096)) + (threadIdx.y*512))
+ (blockIdx.y*64)) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)])
                       }
                       for (rc.inner: int32, 0, 8) {
                         for (ax3: int32, 0, 4) {
                           Apad.shared.local[ax3] = load(float32, Apad.shared[(((rc.inner*64)
+ (threadIdx.x*4)) + ax3)])
                           Apad.shared.local[(ax3 + 4)] = load(float32, Apad.shared[((((rc.inner*64)
+ (threadIdx.x*4)) + ax3) + 32)])
                         }
                         for (ax3_1: int32, 0, 4) {
                           W.shared.local[ax3_1] = load(float32, W.shared[(((rc.inner*64)
+ (threadIdx.y*4)) + ax3_1)])
                           W.shared.local[(ax3_1 + 4)] = load(float32, W.shared[((((rc.inner*64)
+ (threadIdx.y*4)) + ax3_1) + 32)])
                         }
                         for (ff.c: int32, 0, 4) {
                           for (nn.c: int32, 0, 4) {
                             B.local[((ff.c*4) + nn.c)] = (load(float32, B.local[((ff.c*4)
+ nn.c)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[ff.c])))
                             B.local[(((ff.c*4) + nn.c) + 32)] = (load(float32, B.local[(((ff.c*4)
+ nn.c) + 32)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[(ff.c
+ 4)])))
                             B.local[(((ff.c*4) + nn.c) + 16)] = (load(float32, B.local[(((ff.c*4)
+ nn.c) + 16)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[ff.c])))
                             B.local[(((ff.c*4) + nn.c) + 48)] = (load(float32, B.local[(((ff.c*4)
+ nn.c) + 48)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[(ff.c
+ 4)])))
                           }
                         }
                       }
                     }
                   }
                 }
                 for (ff.inner.inner.inner: int32, 0, 4) {
                   for (nn.inner.inner.inner: int32, 0, 4) {
                     B_2[(((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024))
+ (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)]
= load(float32, B.local[((ff.inner.inner.inner*4) + nn.inner.inner.inner)])
                     B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024))
+ (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)
+ 8192)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 32)])
                     B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024))
+ (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)
+ 32)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 16)])
                     B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024))
+ (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)
+ 8224)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 48)])
                   }
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   ### TensorCore for Conv
   ```c++
   primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "main"}
     buffers = {W: Buffer(W_2: handle, float16, [3, 3, 16, 32, 16, 16], []),
              A: Buffer(A_2: handle, float16, [16, 14, 14, 16, 16, 16], []),
              Conv: Buffer(Conv_2: handle, float32, [16, 14, 14, 32, 16, 16], [])}
     buffer_map = {A_1: A, Conv_1: Conv, W_1: W} {
     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent"
= 196;
     attr [Conv.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
     allocate(Conv.wmma.accumulator, float32, [2048])  {
       attr [Apad.shared: handle] "storage_scope" = "shared";
       allocate(Apad.shared, float16, [12288])  {
         attr [W.shared: handle] "storage_scope" = "shared";
         allocate(W.shared, float16, [12288])  {
           attr [Apad.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
           allocate(Apad.shared.wmma.matrix_a, float16, [512])  {
             attr [W.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
             allocate(W.shared.wmma.matrix_b, float16, [1024])  {
               attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")]
"thread_extent" = 2;
               attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")]
"thread_extent" = 4;
               attr [IterVar(threadIdx.y: int32, [(nullptr)], "ThreadIndex", "threadIdx.y")]
"thread_extent" = 4;
               attr [IterVar(threadIdx.z: int32, [(nullptr)], "ThreadIndex", "threadIdx.z")]
"thread_extent" = 2 {
                 for (n.c.init: int32, 0, 2) {
                   for (o.c.init: int32, 0, 4) {
                     eval(call("tvm_fill_fragment", [Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4)
+ o.c.init), float32(0)], handle, "intrin", 0))
                   }
                 }
                 for (ic.outer: int32, 0, 8) {
                   for (kh: int32, 0, 3) {
                     for (ax2: int32, 0, 3) {
                       for (ax3: int32, 0, 2) {
                         for (ax4.ax5.fused.outer: int32, 0, 8) {
                           attr [IterVar(threadIdx.x: int32, [(nullptr)], "ThreadIndex", "threadIdx.x")]
"thread_extent" = 32;
                           Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512))
+ (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = call("tvm_if_then_else", [((((1
<= (floordiv(blockIdx.z, 14) + kh)) and ((floordiv(blockIdx.z, 14) + kh) < 15)) and
(1 <= (ax2 + floormod(blockIdx.z, 14)))) and ((ax2 + floormod(blockIdx.z, 14)) < 15)),
load(float16, A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816))
+ (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32))
+ threadIdx.x) - 61440)]), float16(0)], float16, "pure_intrin", 0)
                         }
                       }
                     }
                     for (ax1: int32, 0, 3) {
                       for (ax2_1: int32, 0, 2) {
                         attr [IterVar(threadIdx.x, [(nullptr)], "ThreadIndex", "threadIdx.x")]
"thread_extent" = 32;
                         W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512))
+ (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = load(float16x8, W_2[ramp(((((((((kh*393216)
+ (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512))
+ (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)])
                       }
                     }
                     for (ic.inner: int32, 0, 2) {
                       for (kw: int32, 0, 3) {
                         for (ax0: int32, 0, 2) {
                           eval(call("tvm_load_matrix_sync", [Apad.shared.wmma.matrix_a, 16,
16, 16, ax0, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0),
Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1], handle,
"intrin", 0), 16, "row_major"], handle, "intrin", 0))
                         }
                         for (ax3_1: int32, 0, 4) {
                           eval(call("tvm_load_matrix_sync", [W.shared.wmma.matrix_b, 16,
16, 16, ax3_1, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin",
0), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1],
handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                         }
                         for (n.c: int32, 0, 2) {
                           for (o.c: int32, 0, 4) {
                             eval(call("tvm_mma_sync", [Conv.wmma.accumulator, ((n.c*4) +
o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator,
((n.c*4) + o.c)], handle, "intrin", 0))
                           }
                         }
                       }
                     }
                   }
                 }
                 for (n.inner: int32, 0, 2) {
                   for (o.inner: int32, 0, 4) {
                     eval(call("tvm_store_matrix_sync", [Conv.wmma.accumulator, 16, 16, 16,
((n.inner*4) + o.inner), call("tvm_access_ptr", [call("type_annotation", [], float32, "pure_intrin",
0), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192))
+ (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2], handle, "intrin", 0),
16, "row_major"], handle, "intrin", 0))
                   }
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   ### GEMM on CPU
   ```c++
   primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "main"}
     buffers = {C: Buffer(C_2: handle, float32, [1024, 1024], []),
              B: Buffer(B_2: handle, float32, [1024, 1024], []),
              A: Buffer(A_2: handle, float32, [1024, 1024], [])}
     buffer_map = {C_1: C, A_1: A, B_1: B} {
     attr [packedB: handle] "storage_scope" = "global";
     allocate(packedB, float32x32, [32768])  {
       for (x: int32, 0, 32) "parallel" {
         for (y: int32, 0, 1024) {
           packedB[ramp(((x*32768) + (y*32)), 1, 32)] = load(float32x32, B_2[ramp(((y*1024)
+ (x*32)), 1, 32)])
         }
       }
       for (x.outer: int32, 0, 32) "parallel" {
         attr [C.global: handle] "storage_scope" = "global";
         allocate(C.global, float32, [1024])  {
           for (y.outer: int32, 0, 32) {
             for (x.c.init: int32, 0, 32) {
               C.global[ramp((x.c.init*32), 1, 32)] = broadcast(float32(0), 32)
             }
             for (k.outer: int32, 0, 256) {
               for (x.c: int32, 0, 32) {
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32),
1, 32)]) + (broadcast(load(float32, A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))]),
32)*load(float32x32, packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)])))
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32),
1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) +
1)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)])))
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32),
1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) +
2)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)])))
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32),
1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) +
3)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)])))
               }
             }
             for (x.inner: int32, 0, 32) {
               for (y.inner: int32, 0, 32) {
                 C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = load(float32,
C.global[((x.inner*32) + y.inner)])
               }
             }
           }
         }
       }
     }
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message