tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
Date Tue, 21 Apr 2020 21:23:26 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d327787  [PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
d327787 is described below

commit d3277874a24e775d2476b0eb0ad89f3a46964a14
Author: Tianqi Chen <tqchen@users.noreply.github.com>
AuthorDate: Tue Apr 21 14:23:18 2020 -0700

    [PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
---
 include/tvm/target/target.h                        |   5 +-
 python/tvm/autotvm/measure/measure_methods.py      |   8 +-
 python/tvm/driver/build_module.py                  |  29 +-
 python/tvm/tir/function.py                         |  16 +
 src/target/target.cc                               |   4 +-
 tests/python/relay/test_pass_fold_constant.py      |   8 +-
 tests/python/unittest/test_target_codegen_cuda.py  |  10 +-
 tests/python/unittest/test_target_codegen_llvm.py  |  11 +-
 .../unittest/test_tir_pass_verify_gpu_code.py      |   8 +-
 tutorials/dev/low_level_custom_pass.py             |  11 +-
 vta/python/vta/build_module.py                     |  56 +-
 vta/python/vta/ir_pass.py                          | 995 ---------------------
 vta/python/vta/transform.py                        | 962 ++++++++++++++++++++
 13 files changed, 1050 insertions(+), 1073 deletions(-)

diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 59aa955..829de73 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -27,6 +27,7 @@
 #include <tvm/support/with.h>
 #include <tvm/node/container.h>
 #include <tvm/ir/expr.h>
+#include <tvm/ir/transform.h>
 
 #include <string>
 #include <vector>
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
   /*! \brief Whether to partition const loop */
   bool partition_const_loop = false;
 
-  /*! \brief Whether to dump the IR of each pass (only when building from python) */
-  std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
+  /*! \brief List of passes to be injected into the low-level pipeline. */
+  std::vector<std::pair<int, transform::Pass>> add_lower_pass;
 
   /*! \brief Whether to dump the IR of each pass (only when building from python) */
   bool dump_pass_ir = false;
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 698ddbc..5ddc5df 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
     """Verify the validity of a gpu kernel.
     This pass will check memory usage and number of threads per block.
     """
-    def verify_pass(stmt):
-        valid = ir_pass.VerifyGPUCode(stmt, kwargs)
+    def verify_pass(f, *_):
+        valid = ir_pass.VerifyGPUCode(f.body, kwargs)
         if not valid:
             raise InstantiationError("Skipped because of invalid gpu kernel")
-        return stmt
-    return verify_pass
+        return f
+    return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py
index 35700ba..dcd6d44 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
     return tvm.IRModule({name: func})
 
 
-def _wrap_as_prim_func_pass(flist, name):
-    """Wrap flist as a function pass.
-
-    This is an temporary adapter before we fully
-    migrate to the new pass manager.
-    """
-    def _transform(func, *_):
-        stmt = func.body
-        for f in flist:
-            stmt = f(stmt)
-        # create a new function with updated body.
-        return tvm.tir.PrimFunc(func.params,
-                                stmt,
-                                func.ret_type,
-                                func.buffer_map,
-                                func.attrs)
-    return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
-
-
 def lower(sch,
           args,
           name="main",
@@ -190,15 +171,15 @@ def lower(sch,
     else:
         mod = sch
 
+    pass_list = lower_phase0
     # Phase 1
-    pass_list = [
-        _wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
+    pass_list += [
         tvm.tir.transform.InjectPrefetch(),
         tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
         tvm.tir.transform.NarrowDataType(32),
         tvm.tir.transform.Simplify(),
-        _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
     ]
+    pass_list += lower_phase1
 
     # Phase 2
     if not simple_mode:
@@ -214,8 +195,8 @@ def lower(sch,
             cfg.auto_unroll_max_depth,
             cfg.auto_unroll_max_extent,
             cfg.unroll_explicit),
-        _wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
     ]
+    pass_list += lower_phase2
 
     # Phase 3
     pass_list += [
@@ -225,7 +206,7 @@ def lower(sch,
 
     if not cfg.disable_select_rewriting:
         pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
-    pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
+    pass_list += lower_phase3
 
     # Instrument BoundCheckers
     if cfg.instrument_bound_checkers:
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 4ec1a71..47ad94f 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
 
         self.__init_handle_by_constructor__(
             _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
+
+    def with_body(self, new_body):
+        """Create a new PrimFunc with the same set signatures but a new body.
+
+        Parameters
+        ----------
+        new_body : Stmt
+            The new body.
+
+        Returns
+        -------
+        new_func : PrimFunc
+            The created new function.
+        """
+        return PrimFunc(
+            self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
diff --git a/src/target/target.cc b/src/target/target.cc
index 50856d6..a72ce1c 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
 TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
   BuildConfig cfg = args[0];
-  std::vector< std::pair<int, PackedFunc> > add_lower_pass;
+  std::vector<std::pair<int, transform::Pass>> add_lower_pass;
   CHECK_EQ(args.size() % 2, 1);
   for (int i = 1; i < args.size(); i += 2) {
     add_lower_pass.push_back(std::make_pair(
       args[i].operator int(),
-      args[i + 1].operator tvm::runtime::PackedFunc()));
+      args[i + 1].operator transform::Pass()));
   }
   cfg->add_lower_pass = add_lower_pass;
   });
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 4f44d2b..b212b26 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -51,11 +51,13 @@ def test_fold_const():
         z = relay.add(y, relay.const(c_data))
         return relay.Function([x], z)
 
-    def fail(x):
-        raise RuntimeError()
+    def FailPass():
+        def _transform(m, *args):
+            raise RuntimeError()
+        return tvm.transform.module_pass(_transform, opt_level=0)
 
     # the fold constant should work on any context.
-    with tvm.target.build_config(add_lower_pass=[(0, fail)]):
+    with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
         with tvm.target.create("cuda"):
             zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py
index 739fc6f..4c2ec2e 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
     sch[c].bind(xo, thrx)
     sch[c].vectorize(xi)
 
-    def my_vectorize(stmt):
+    def MyVectorize():
         def vectorizer(op):
             if op.for_type == tvm.tir.For.Vectorized:
                 four = tvm.tir.const(4, 'int32')
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
                 new_b = tvm.tir.Shuffle(bs, ids)
                 return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
             return None
-        return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
 
-    with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
+        def _transform(f, *_):
+            return f.with_body(
+                tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
+        return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
+
+    with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
         module = tvm.build(sch, [a, b, c], target='cuda')
         a_ = np.array(list(range(64)), dtype='int32')
         b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 44b05c9..26f9347 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
     c = te.compute((8, ), lambda x: a[x] + b[7-x])
     sch = te.create_schedule(c.op)
 
-    def my_vectorize(stmt):
-
+    def my_vectorize():
         def vectorizer(op):
             store = op.body
             idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
             value = new_a + new_b
             return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
 
-        return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
+        def _transform(f, *_):
+            return f.with_body(
+                tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
+
+        return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
 
-    with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
+    with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
         ir = tvm.lower(sch, [a, b, c], simple_mode=True)
         module = tvm.build(sch, [a, b, c])
         a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
diff --git a/tests/python/unittest/test_tir_pass_verify_gpu_code.py b/tests/python/unittest/test_tir_pass_verify_gpu_code.py
index 6e138a2..091a374 100644
--- a/tests/python/unittest/test_tir_pass_verify_gpu_code.py
+++ b/tests/python/unittest/test_tir_pass_verify_gpu_code.py
@@ -19,10 +19,10 @@ import tvm
 from tvm import te
 
 def get_verify_pass(valid, **kwargs):
-    def verify_pass(stmt):
-        valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
-        return stmt
-    return verify_pass
+    def _fverify(f, *_):
+        valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
+        return f
+    return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
 
 def test_shared_memory():
     def check_shared_memory(dtype):
diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py
index d35913b..49e86fd 100644
--- a/tutorials/dev/low_level_custom_pass.py
+++ b/tutorials/dev/low_level_custom_pass.py
@@ -117,19 +117,20 @@ def vectorize8(op):
         return body
     return None
 
-def vectorize(stmt):
+@tvm.tir.transform.prim_func_pass(opt_level=0)
+def vectorize(f, mod, ctx):
     global loops
 
-    tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
+    tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
 
     if not loops:
-        return stmt
+        return sf
 
     # The last list arugment indicates what kinds of nodes will be transformed.
     # Thus, in this case only `For` nodes will call `vectorize8`
-    stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
+    return f.with_body(
+        tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
 
-    return stmt
 
 #####################################################################
 # Glue to Lowering
diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py
index 4c33d36..40bee86 100644
--- a/vta/python/vta/build_module.py
+++ b/vta/python/vta/build_module.py
@@ -14,25 +14,22 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=unused-argument
+# pylint: disable=unused-argument, invalid-name
 """VTA specific buildin for runtime."""
 import tvm
-from . import ir_pass
+from . import transform
 from .environment import get_env
 
 
-def lift_coproc_scope(x):
-    """Lift coprocessings cope to the """
-    x = ir_pass.lift_alloc_to_scope_begin(x)
-    x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
-    return x
-
-def early_rewrite(stmt):
+def EarlyRewrite():
     """Try to do storage rewrite in early pass."""
-    try:
-        return tvm.tir.ir_pass.StorageRewrite(stmt)
-    except tvm.error.TVMError:
-        return stmt
+    def _transform(mod, ctx):
+        try:
+            return tvm.tir.transform.StorageRewrite()(mod)
+        except tvm.error.TVMError:
+            return mod
+    return tvm.transform.module_pass(
+        _transform, opt_level=0, name="tir.vta.EarlyRewrite")
 
 
 def build_config(debug_flag=0, **kwargs):
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
           vta_module = tvm.build(s, ...)
     """
     env = get_env()
-    def add_debug(stmt):
+
+    @tvm.tir.transform.prim_func_pass(opt_level=0)
+    def add_debug(f, *_):
         debug = tvm.tir.call_extern(
             "int32", "VTASetDebugMode",
             env.dev.command_handle,
             debug_flag)
 
-        return tvm.tir.stmt_seq(debug, stmt)
-    pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
-                 (1, ir_pass.inject_dma_intrin),
-                 (1, ir_pass.inject_skip_copy),
-                 (1, ir_pass.annotate_alu_coproc_scope),
-                 (1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
-                 (1, lift_coproc_scope),
-                 (1, ir_pass.inject_coproc_sync),
-                 (1, early_rewrite)]
+        return f.with_body(tvm.tir.stmt_seq(debug, f.body))
+
+
+    pass_list = [(0, transform.InjectConv2DTransposeSkip()),
+                 (1, transform.InjectDMAIntrin()),
+                 (1, transform.InjectSkipCopy()),
+                 (1, transform.AnnotateALUCoProcScope()),
+                 (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
+                 (1, transform.LiftAllocToScopeBegin()),
+                 (1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
+                 (1, transform.InjectCoProcSync()),
+                 (1, EarlyRewrite())]
     if debug_flag:
         pass_list.append((1, add_debug))
-    pass_list.append((2, ir_pass.inject_alu_intrin))
-    pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
-    pass_list.append((3, ir_pass.fold_uop_loop))
-    pass_list.append((3, ir_pass.cpu_access_rewrite))
+    pass_list.append((2, transform.InjectALUIntrin()))
+    pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
+    pass_list.append((3, transform.FoldUopLoop()))
+    pass_list.append((3, transform.CPUAccessRewrite()))
     return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
 
 
diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py
deleted file mode 100644
index 9836d13..0000000
--- a/vta/python/vta/ir_pass.py
+++ /dev/null
@@ -1,995 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Additional IR Pass for VTA"""
-# pylint: disable=len-as-condition, no-else-return
-import tvm
-from tvm import te
-from topi import util
-
-from .environment import get_env
-
-
-def _match_pragma(stmt, key):
-    """Internal helper to match stmt to pragma stmt.
-
-    Parameters
-    ----------
-    stmt : Stmt
-        The AttrStmt
-
-    key : str
-        The pragma key
-    """
-    return ((stmt.attr_key == "pragma_" + key) or
-            (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
-
-
-def fold_uop_loop(stmt_in):
-    """Detect and fold uop loop.
-
-    VTA support uop programming model
-    that recognizes loop structure.
-    This pass detect the loop structure
-    and extract that into uop loop AST.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Output statement.
-    """
-    env = get_env()
-
-    def _fold_outermost_loop(body):
-        stmt = body
-        if not isinstance(stmt, tvm.tir.For):
-            return None, body, None
-
-        loop_var = stmt.loop_var
-        gemm_offsets = [None, None, None]
-        fail = [False]
-
-        def _post_order(op):
-            assert isinstance(op, tvm.tir.Call)
-            base_args = 2
-            if op.name == "VTAUopPush":
-                args = []
-                args += op.args[:base_args]
-                for i in range(3):
-                    m = tvm.arith.detect_linear_equation(
-                        op.args[i + base_args], [loop_var])
-                    if not m:
-                        fail[0] = True
-                        return op
-                    if gemm_offsets[i] is not None:
-                        if not tvm.ir.structural_equal(m[0], gemm_offsets[i]):
-                            fail[0] = True
-                            return op
-                        args.append(m[1])
-                    else:
-                        gemm_offsets[i] = m[0]
-                        args.append(m[1])
-                args += op.args[base_args+3:]
-                return tvm.tir.call_extern("int32", "VTAUopPush", *args)
-            if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
-                raise RuntimeError("unexpected op %s" % op)
-            return op
-
-        ret = tvm.tir.ir_pass.IRTransform(
-            stmt.body, None, _post_order, ["Call"])
-
-        if not fail[0] and all(x is not None for x in gemm_offsets):
-            def _visit(op):
-                if op.same_as(loop_var):
-                    fail[0] = True
-            tvm.tir.ir_pass.PostOrderVisit(ret, _visit)
-            if not fail[0]:
-                begin = tvm.tir.call_extern(
-                    "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
-                end = tvm.tir.call_extern("int32", "VTAUopLoopEnd")
-                return [begin, ret, end]
-        raise ValueError("Failed to fold the GEMM instructions..")
-
-    def _do_fold(stmt):
-        if (stmt.attr_key == "coproc_uop_scope" and
-                isinstance(stmt.value, tvm.tir.StringImm) and
-                stmt.value.value == env.dev.vta_push_uop.value):
-            body = stmt.body
-            begins = []
-            ends = []
-            try:
-                begin, body, end = _fold_outermost_loop(body)
-                if begin is not None:
-                    begins.append(begin)
-                if end is not None:
-                    ends.append(end)
-                begin, body, end = _fold_outermost_loop(body)
-                if begin is not None:
-                    begins.append(begin)
-                if end is not None:
-                    ends.append(end)
-            except ValueError:
-                pass
-            if body == stmt.body:
-                return stmt
-            ends = list(reversed(ends))
-            body = tvm.tir.stmt_seq(*(begins + [body] + ends))
-            return tvm.tir.AttrStmt(
-                stmt.node, stmt.attr_key, stmt.value, body)
-        return None
-    out = tvm.tir.ir_pass.IRTransform(
-        stmt_in, _do_fold, None, ["AttrStmt"])
-    return out
-
-
-def cpu_access_rewrite(stmt_in):
-    """Detect CPU access to VTA buffer and get address correctly.
-
-    VTA's buffer is an opaque handle that do not
-    correspond to address in CPU.
-    This pass detect CPU access and rewrite to use pointer
-    returned VTABufferCPUPtr for CPU access.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    env = get_env()
-    rw_info = {}
-    def _post_order(op):
-        if isinstance(op, tvm.tir.Allocate):
-            buffer_var = op.buffer_var
-            if not buffer_var in rw_info:
-                return None
-            new_var = rw_info[buffer_var]
-            let_stmt = tvm.tir.LetStmt(
-                new_var, tvm.tir.call_extern(
-                    "handle", "VTABufferCPUPtr",
-                    env.dev.command_handle,
-                    buffer_var), op.body)
-            alloc = tvm.tir.Allocate(
-                buffer_var, op.dtype, op.extents,
-                op.condition, let_stmt)
-            del rw_info[buffer_var]
-            return alloc
-        if isinstance(op, tvm.tir.Load):
-            buffer_var = op.buffer_var
-            if not buffer_var in rw_info:
-                rw_info[buffer_var] = te.var(
-                    buffer_var.name + "_ptr", "handle")
-            new_var = rw_info[buffer_var]
-            return tvm.tir.Load(op.dtype, new_var, op.index)
-        if isinstance(op, tvm.tir.Store):
-            buffer_var = op.buffer_var
-            if not buffer_var in rw_info:
-                rw_info[buffer_var] = te.var(
-                    buffer_var.name + "_ptr", "handle")
-            new_var = rw_info[buffer_var]
-            return tvm.tir.Store(new_var, op.value, op.index)
-        raise RuntimeError("not reached")
-    stmt = tvm.tir.ir_pass.IRTransform(
-        stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
-    for buffer_var, new_var in rw_info.items():
-        stmt = tvm.tir.LetStmt(
-            new_var, tvm.tir.call_extern(
-                "handle", "VTABufferCPUPtr",
-                env.dev.command_handle,
-                buffer_var), stmt)
-    return stmt
-
-
-def lift_alloc_to_scope_begin(stmt_in):
-    """Lift allocate to beginning of the current scope.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    lift_stmt = [[]]
-    def _merge_block(slist, body):
-        for op in slist:
-            if op.body == body:
-                body = op
-            elif isinstance(op, tvm.tir.Allocate):
-                body = tvm.tir.Allocate(
-                    op.buffer_var, op.dtype,
-                    op.extents, op.condition, body)
-            elif isinstance(op, tvm.tir.AttrStmt):
-                body = tvm.tir.AttrStmt(
-                    op.node, op.attr_key, op.value, body)
-            elif isinstance(op, tvm.tir.For):
-                body = tvm.tir.For(
-                    op.loop_var, op.min, op.extent, op.for_type,
-                    op.device_api, body)
-            else:
-                raise RuntimeError("unexpected op")
-        del slist[:]
-        return body
-
-    def _pre_order(op):
-        if isinstance(op, tvm.tir.For):
-            lift_stmt.append([])
-        elif isinstance(op, tvm.tir.AttrStmt):
-            if op.attr_key == "virtual_thread":
-                lift_stmt.append([])
-
-    def _post_order(op):
-        if isinstance(op, tvm.tir.Allocate):
-            lift_stmt[-1].append(op)
-            return op.body
-        if isinstance(op, tvm.tir.AttrStmt):
-            if op.attr_key == "storage_scope":
-                lift_stmt[-1].append(op)
-                return op.body
-            if op.attr_key == "virtual_thread":
-                return _merge_block(lift_stmt.pop() + [op], op.body)
-            return op
-        if isinstance(op, tvm.tir.For):
-            return _merge_block(lift_stmt.pop() + [op], op.body)
-        raise RuntimeError("not reached")
-    stmt = tvm.tir.ir_pass.IRTransform(
-        stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
-    assert len(lift_stmt) == 1
-    return _merge_block(lift_stmt[0], stmt)
-
-
-def inject_skip_copy(stmt_in):
-    """Pass to inject skip copy stmt, used for debug purpose.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    def _do_fold(stmt):
-        if _match_pragma(stmt, "skip_dma_copy"):
-            return tvm.tir.Evaluate(0)
-        return None
-    return tvm.tir.ir_pass.IRTransform(
-        stmt_in, _do_fold, None, ["AttrStmt"])
-
-
-def inject_coproc_sync(stmt_in):
-    """Pass to inject skip copy stmt, used in debug.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    success = [False]
-    def _do_fold(stmt):
-        if _match_pragma(stmt, "coproc_sync"):
-            success[0] = True
-            sync = tvm.tir.Call(
-                "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
-            return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
-        if _match_pragma(stmt, "trim_loop"):
-            op = stmt.body
-            assert isinstance(op, tvm.tir.For)
-            return tvm.tir.For(
-                op.loop_var, op.min, 2, op.for_type,
-                op.device_api, op.body)
-        return None
-    stmt = tvm.tir.ir_pass.IRTransform(
-        stmt_in, None, _do_fold, ["AttrStmt"])
-    stmt = tvm.tir.ir_pass.CoProcSync(stmt)
-    return stmt
-
-
-def inject_dma_intrin(stmt_in):
-    """Pass to inject DMA copy intrinsics.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    env = get_env()
-    idxd = tvm.tir.indexdiv
-    idxm = tvm.tir.indexmod
-
-    def _check_compact(buf):
-        ndim = len(buf.shape)
-        size = tvm.tir.const(1, buf.shape[0].dtype)
-        for i in reversed(range(ndim)):
-            if not util.equal_const_int(size - buf.strides[i], 0):
-                raise RuntimeError(
-                    "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
-            size = size * buf.shape[i]
-
-    def _fold_buffer_dim(buf, scope, elem_block):
-        ndim = len(buf.shape)
-        x_size = 1
-        base = 0
-        for i in range(1, ndim + 1):
-            if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
-                raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
-            x_size = x_size * buf.shape[ndim - i]
-            if util.equal_const_int(x_size - elem_block, 0):
-                base = i + 1
-                break
-        if base == 0:
-            raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
-                scope, elem_block, buf.shape))
-        shape = [elem_block]
-        strides = [1]
-
-        if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block):
-            shape.append(1)
-            strides.append(elem_block)
-
-        analyzer = tvm.arith.Analyzer()
-        while base < ndim + 1:
-            x_size = 1
-            x_stride = buf.strides[ndim - base]
-            next_base = base
-            if not util.equal_const_int(idxm(x_stride, elem_block), 0):
-                raise RuntimeError(
-                    "scope %s need to have block=%d, shape=%s, strides=%s" % (
-                        scope, elem_block, buf.shape, buf.strides))
-            for i in range(base, ndim + 1):
-                k = ndim - i
-                if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
-                    break
-                x_size = x_size * buf.shape[k]
-                next_base = i + 1
-            shape.append(analyzer.simplify(x_size))
-            strides.append(x_stride)
-            assert next_base != base
-            base = next_base
-
-        strides = list(reversed(strides))
-        shape = list(reversed(shape))
-        return shape, strides
-
-    def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
-        elem_block = elem_bytes * 8 // elem_width
-        if buf.dtype != dtype:
-            raise RuntimeError("Expect buffer type to be %s instead of %s" %
-                               (dtype, buf.dtype))
-        shape, strides = buf.shape, buf.strides
-        if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
-            raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
-        if allow_fold:
-            shape, strides = _fold_buffer_dim(buf, scope, elem_block)
-        else:
-            shape = list(x for x in shape)
-            strides = list(x for x in strides)
-
-        def raise_error():
-            """Internal function to raise error """
-            raise RuntimeError(
-                ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
-                 " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))
-
-        ndim = len(shape)
-
-        # Check if the inner-tensor is already flat
-        flat = util.equal_const_int(shape[-1], elem_block)
-
-        if flat:
-            if not util.equal_const_int(strides[-1], 1):
-                raise_error()
-
-            if ndim == 1:
-                x_size = 1
-                x_stride = 1
-                y_size = 1
-                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-            if not util.equal_const_int(strides[-2] - elem_block, 0):
-                raise_error()
-
-            if ndim == 2:
-                x_size = shape[-2]
-                x_stride = shape[-2]
-                y_size = 1
-                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-            if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
-                raise_error()
-
-            if ndim == 3:
-                x_size = shape[-2]
-                x_stride = idxd(strides[-3], elem_block)
-                y_size = shape[-3]
-                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-
-        else:
-            if not util.equal_const_int(strides[-1], 1):
-                raise_error()
-            if not util.equal_const_int(strides[-2] - shape[-1], 0):
-                raise_error()
-            if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
-                raise_error()
-
-            if ndim == 2:
-                x_size = 1
-                x_stride = 1
-                y_size = 1
-                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-            if not util.equal_const_int(strides[-3], elem_block):
-                raise_error()
-
-            if ndim == 3:
-                x_size = shape[-3]
-                x_stride = shape[-3]
-                y_size = 1
-                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-            if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
-                raise_error()
-
-            if ndim == 4:
-                x_size = shape[-3]
-                x_stride = idxd(strides[-4], elem_block)
-                y_size = shape[-4]
-                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-
-        raise_error()
-
-
-    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
-        # FIXME: pad_value is ignored...
-        _ = pad_value
-        if dst.scope == "global":
-            # Store
-            if pad_before or pad_after:
-                raise RuntimeError("Do not support copy into DRAM with pad")
-            if src.scope == env.acc_scope:
-                elem_width = env.OUT_WIDTH
-                elem_bytes = env.OUT_ELEM_BYTES
-                mem_type = env.dev.MEM_ID_OUT
-                data_type = "int%d" % env.OUT_WIDTH
-                task_qid = env.dev.QID_STORE_OUT
-            else:
-                raise RuntimeError("Do not support copy %s->dram" % (src.scope))
-            _check_compact(src)
-            x_size, y_size, x_stride, offset = _get_2d_pattern(
-                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
-            irb = tvm.tir.ir_builder.create()
-            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
-                           env.dev.get_task_qid(task_qid))
-            irb.emit(tvm.tir.call_extern(
-                "int32", "VTAStoreBuffer2D",
-                env.dev.command_handle,
-                src.access_ptr("r", "int32"),
-                mem_type, dst.data, offset, x_size, y_size, x_stride))
-            return irb.get()
-        elif src.scope == "global":
-            if dst.scope == env.acc_scope:
-                elem_width = env.ACC_WIDTH
-                elem_bytes = env.ACC_ELEM_BYTES
-                mem_type = env.dev.MEM_ID_ACC
-                data_type = "int%d" % env.ACC_WIDTH
-                task_qid = env.dev.QID_LOAD_OUT
-            elif dst.scope == env.inp_scope:
-                elem_width = env.INP_WIDTH
-                elem_bytes = env.INP_ELEM_BYTES
-                mem_type = env.dev.MEM_ID_INP
-                data_type = "int%d" % env.INP_WIDTH
-                task_qid = env.dev.QID_LOAD_INP
-            elif dst.scope == env.wgt_scope:
-                elem_width = env.WGT_WIDTH
-                elem_bytes = env.WGT_ELEM_BYTES
-                mem_type = env.dev.MEM_ID_WGT
-                data_type = "int%d" % env.WGT_WIDTH
-                task_qid = env.dev.QID_LOAD_WGT
-            else:
-                raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
-            # collect pad statistics
-            if pad_before:
-                assert pad_after
-                ndim = len(pad_before)
-                if ndim <= 2 or ndim > 5:
-                    raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
-                if ndim == 5:
-                    # This case occurs when batch size N > 1
-                    y_pad_before = pad_before[1]
-                    x_pad_before = pad_before[2]
-                    y_pad_after = pad_after[1]
-                    x_pad_after = pad_after[2]
-                    for dim in range(3, ndim):
-                        if not util.equal_const_int(pad_before[dim], 0):
-                            raise ValueError("Do not support pad on the innermost block")
-                        if not util.equal_const_int(pad_after[dim], 0):
-                            raise ValueError("Do not support pad on the innermost block")
-                else:
-                    y_pad_before = pad_before[0]
-                    x_pad_before = pad_before[1]
-                    y_pad_after = pad_after[0]
-                    x_pad_after = pad_after[1]
-                    for dim in range(2, ndim):
-                        if not util.equal_const_int(pad_before[dim], 0):
-                            raise ValueError("Do not support pad on the innermost block")
-                        if not util.equal_const_int(pad_after[dim], 0):
-                            raise ValueError("Do not support pad on the innermost block")
-                allow_fold = False
-            else:
-                x_pad_before = 0
-                y_pad_before = 0
-                x_pad_after = 0
-                y_pad_after = 0
-                allow_fold = True
-
-            _check_compact(dst)
-            x_size, y_size, x_stride, offset = _get_2d_pattern(
-                src, elem_width, elem_bytes, data_type,
-                dst.scope, allow_fold=allow_fold)
-
-            irb = tvm.tir.ir_builder.create()
-            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
-                           env.dev.get_task_qid(task_qid))
-
-            irb.emit(tvm.tir.call_extern(
-                "int32", "VTALoadBuffer2D",
-                env.dev.command_handle,
-                src.data, offset, x_size, y_size, x_stride,
-                x_pad_before, y_pad_before,
-                x_pad_after, y_pad_after,
-                dst.access_ptr("r", "int32"), mem_type))
-            return irb.get()
-
-        else:
-            raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
-
-    return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
-
-
-def _get_gemm_intrin_buffer():
-    env = get_env()
-    wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
-    assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
-    wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
-    assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
-    inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
-    assert inp_lanes == env.BATCH * env.BLOCK_IN
-    inp_shape = (env.BATCH, env.BLOCK_IN)
-    assert inp_shape[0] * inp_shape[1] == inp_lanes
-    out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
-    assert out_lanes == env.BATCH * env.BLOCK_OUT
-    out_shape = (env.BATCH, env.BLOCK_OUT)
-    assert out_shape[0] * out_shape[1] == out_lanes
-    wgt = te.placeholder((wgt_shape[0], wgt_shape[1]),
-                         dtype="int%d" % env.WGT_WIDTH,
-                         name=env.wgt_scope)
-    inp = te.placeholder((inp_shape[0], inp_shape[1]),
-                         dtype="int%d" % env.INP_WIDTH,
-                         name=env.inp_scope)
-    k = te.reduce_axis((0, wgt_shape[1]), name="k")
-    out_dtype = "int%d" % env.ACC_WIDTH
-    out = te.compute((out_shape[0], out_shape[1]),
-                     lambda i, j: te.sum(inp[i, k].astype(out_dtype) *
-                                         wgt[j, k].astype(out_dtype),
-                                         axis=[k]),
-                     name="out")
-    wgt_layout = tvm.tir.decl_buffer(
-        wgt.shape, wgt.dtype, env.wgt_scope,
-        scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
-    inp_layout = tvm.tir.decl_buffer(
-        inp.shape, inp.dtype, env.inp_scope,
-        scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
-    out_layout = tvm.tir.decl_buffer(
-        out.shape, out.dtype, env.acc_scope,
-        scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
-
-    return wgt_layout, inp_layout, out_layout
-
-
-def inject_conv2d_transpose_skip(stmt_in):
-    """Pass to skip 0-weights in conv2d transpose with stride > 1.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    env = get_env()
-    dwgt, dinp, dout = _get_gemm_intrin_buffer()
-
-    calls = []
-    selects = []
-
-    def _find_basics(op):
-        if isinstance(op, tvm.tir.BufferLoad):
-            calls.append(op)
-        elif isinstance(op, tvm.tir.Select):
-            selects.append(op)
-
-    def _do_fold(op):
-        if _match_pragma(op, "conv2d_transpose_gemm"):
-            is_init = ".init" in str(op)
-            tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
-
-            if is_init:
-                # create inner most block
-                irb = tvm.tir.ir_builder.create()
-                dev = env.dev
-                irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
-                irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
-                irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
-                                             0, 1,
-                                             dout.access_ptr("rw", "int32"),
-                                             0, 0,
-                                             0, 0, 0))
-                inner = irb.get()
-                # TODO(@tmoreau89): This is only a temporary fix, please take a look.
-                body = op.body.body
-                while isinstance(body, tvm.tir.IfThenElse):
-                    body = body.then_case
-                args = body.indices
-                res_buffer = body.buffer
-                tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
-                inner = tvm.tir.AttrStmt(
-                    [dout, res_buffer], 'buffer_bind_scope',
-                    tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
-                return inner
-            else:
-                conv_call, data_call, kernel_call = calls[-3:]
-                pad_data_tensor = data_call.buffer
-                kernel_tensor = kernel_call.buffer
-                res_tensor = conv_call.buffer
-
-                if selects:
-                    condition = selects[0].condition
-                else:
-                    condition = tvm.tir.const(1, 'int')
-
-                # create inner most block
-                irb = tvm.tir.ir_builder.create()
-                with irb.if_scope(condition):
-                    dev = env.dev
-                    irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
-                    irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
-                    irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
-                                                 0, 0,
-                                                 dout.access_ptr("rw", "int32"),
-                                                 dinp.access_ptr("r", "int32"),
-                                                 dwgt.access_ptr("r", "int32"),
-                                                 0, 0, 0))
-                inner = irb.get()
-
-                args = conv_call.indices
-                tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
-                       1, 0, 1, 0, env.BLOCK_OUT)
-                inner = tvm.tir.AttrStmt(
-                    [dout, res_tensor], 'buffer_bind_scope',
-                    tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
-                args = kernel_call.indices
-                tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
-                       1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
-                inner = tvm.tir.AttrStmt(
-                    [dwgt, kernel_tensor], 'buffer_bind_scope',
-                    tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
-                args = data_call.indices
-                tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
-                       1, 0, 1, 0, env.BLOCK_IN)
-                inner = tvm.tir.AttrStmt(
-                    [dinp, pad_data_tensor], 'buffer_bind_scope',
-                    tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
-                return inner
-        return None
-    ret = tvm.tir.ir_pass.IRTransform(
-        stmt_in, _do_fold, None, ["AttrStmt"])
-    return ret
-
-
-def annotate_alu_coproc_scope(stmt_in):
-    """Pass to insert ALU instruction.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    env = get_env()
-    def _do_fold(stmt):
-        if _match_pragma(stmt, "alu"):
-            irb = tvm.tir.ir_builder.create()
-            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
-                           env.dev.get_task_qid(env.dev.QID_COMPUTE))
-            irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
-                           tvm.tir.StringImm("VTAPushALUOp"))
-            irb.emit(stmt)
-            return irb.get()
-        if _match_pragma(stmt, "skip_alu"):
-            return tvm.tir.Evaluate(0)
-        return stmt
-
-    stmt_out = tvm.tir.ir_pass.IRTransform(
-        stmt_in, None, _do_fold, ["AttrStmt"])
-
-    return stmt_out
-
-
-def inject_alu_intrin(stmt_in):
-    """Pass to inject ALU micro-ops.
-
-    Parameters
-    ----------
-    stmt_in : Stmt
-        Input statement
-
-    Returns
-    -------
-    stmt_out : Stmt
-        Transformed statement
-    """
-    env = get_env()
-    idxm = tvm.tir.indexmod
-    analyzer = tvm.arith.Analyzer()
-
-    def _do_fold(stmt):
-        def _equal(x, y):
-            return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
-
-        def _flatten_loop(src_coeff, dst_coeff, extents):
-            src_coeff = list(src_coeff)
-            dst_coeff = list(dst_coeff)
-            extents = list(extents)
-            rev_src_coeff = [src_coeff.pop()]
-            rev_dst_coeff = [dst_coeff.pop()]
-            rev_extents = []
-            assert src_coeff
-            vsrc = src_coeff.pop()
-            vdst = dst_coeff.pop()
-            vext = extents.pop()
-            while src_coeff:
-                next_src = src_coeff.pop()
-                next_dst = dst_coeff.pop()
-                next_ext = extents.pop()
-
-                if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
-                    vext = analyzer.simplify(vext * next_ext)
-                else:
-                    rev_src_coeff.append(vsrc)
-                    rev_dst_coeff.append(vdst)
-                    rev_extents.append(vext)
-                    vsrc = next_src
-                    vdst = next_dst
-                    vext = next_ext
-            rev_src_coeff.append(vsrc)
-            rev_dst_coeff.append(vdst)
-            rev_extents.append(vext)
-            rev_src_coeff.reverse()
-            rev_dst_coeff.reverse()
-            rev_extents.reverse()
-
-            return rev_src_coeff, rev_dst_coeff, rev_extents
-
-        if _match_pragma(stmt, "alu"):
-            # Get to the innermost loop body
-            loop_body = stmt.body
-            nest_size = 0
-            while isinstance(loop_body, tvm.tir.For):
-                loop_body = loop_body.body
-                nest_size += 1
-            # Get the src/dst arguments
-            dst_var = loop_body.buffer_var
-            dst_idx = loop_body.index
-            # Derive loop variables and extents
-            tmp_body = stmt.body
-            indices = []
-            extents = []
-            for _ in range(nest_size):
-                indices.append(tmp_body.loop_var)
-                extents.append(tmp_body.extent)
-                tmp_body = tmp_body.body
-            # Derive opcode
-            if isinstance(loop_body.value, tvm.tir.Add):
-                alu_opcode = env.dev.ALU_OPCODE_ADD
-                lhs = loop_body.value.a
-                rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.tir.Sub):
-                alu_opcode = env.dev.ALU_OPCODE_SUB
-                lhs = loop_body.value.a
-                rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.tir.Mul):
-                alu_opcode = env.dev.ALU_OPCODE_MUL
-                lhs = loop_body.value.a
-                rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.tir.Min):
-                alu_opcode = env.dev.ALU_OPCODE_MIN
-                lhs = loop_body.value.a
-                rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.tir.Max):
-                alu_opcode = env.dev.ALU_OPCODE_MAX
-                lhs = loop_body.value.a
-                rhs = loop_body.value.b
-            elif isinstance(loop_body.value, tvm.tir.Call):
-                if loop_body.value.name == 'shift_left':
-                    alu_opcode = env.dev.ALU_OPCODE_SHR
-                    lhs = loop_body.value.args[0]
-                    rhs = analyzer.simplify(-loop_body.value.args[1])
-                elif loop_body.value.name == 'shift_right':
-                    alu_opcode = env.dev.ALU_OPCODE_SHR
-                    lhs = loop_body.value.args[0]
-                    rhs = loop_body.value.args[1]
-                else:
-                    raise RuntimeError(
-                        "Function call not recognized %s" % (loop_body.value.name))
-            elif isinstance(loop_body.value, tvm.tir.Load):
-                alu_opcode = env.dev.ALU_OPCODE_SHR
-                lhs = loop_body.value
-                rhs = tvm.tir.const(0, "int32")
-            else:
-                raise RuntimeError(
-                    "Expression not recognized %s, %s, %s" % (
-                        type(loop_body.value), str(loop_body.value), str(stmt)))
-
-            # Derive array index coefficients
-            dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
-            # Check if lhs/rhs is immediate
-            use_imm = False
-            imm_val = None
-            if isinstance(rhs, tvm.tir.IntImm):
-                assert lhs.buffer_var.same_as(dst_var)
-                src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
-                use_imm = True
-                imm_val = rhs
-            if isinstance(lhs, tvm.tir.IntImm):
-                assert rhs.buffer_var.same_as(dst_var)
-                src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
-                use_imm = True
-                imm_val = lhs
-            if imm_val is None:
-                imm_val = 0
-                assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
-                src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
-                src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
-                # Determine which side has the same coefficients
-                lhs_equal = True
-                rhs_equal = True
-                for i, coef in enumerate(dst_coeff):
-                    if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
-                        lhs_equal = False
-                    if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
-                        rhs_equal = False
-                # Make sure at least one of the source is identical to the
-                # destination (in-place computation)
-                assert lhs_equal or rhs_equal
-                # Assign the source coefficients
-                if lhs_equal:
-                    src_coeff = src_rhs_coeff
-                else:
-                    src_coeff = src_lhs_coeff
-
-            # Ensure that we have the proper tensor dimensions in the
-            # innermost loop (pattern match)
-            src_coeff = list(src_coeff)
-            dst_coeff = list(dst_coeff)
-            extents = list(extents)
-            assert len(src_coeff) > 1
-            assert len(dst_coeff) > 1
-            assert len(extents) != 0
-            assert tvm.ir.structural_equal(
-                analyzer.simplify(
-                    idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
-            assert tvm.ir.structural_equal(
-                analyzer.simplify(
-                    idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
-            assert tvm.ir.structural_equal(src_coeff[-2], 1)
-            assert tvm.ir.structural_equal(dst_coeff[-2], 1)
-            if env.BATCH > 1:
-                assert len(src_coeff) > 2
-                assert len(dst_coeff) > 2
-                assert len(extents) > 1
-                assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
-                assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
-
-            # Apply tensorization of the loop coefficients
-            src_offset = src_coeff[-1]
-            dst_offset = dst_coeff[-1]
-            if env.BATCH == 1:
-                src_coeff = src_coeff[:-2]
-                dst_coeff = dst_coeff[:-2]
-                extents = extents[:-1]
-            else:
-                src_coeff = src_coeff[:-3]
-                dst_coeff = dst_coeff[:-3]
-                extents = extents[:-2]
-            src_coeff.append(src_offset)
-            dst_coeff.append(dst_offset)
-            src_coeff = [
-                analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
-            dst_coeff = [
-                analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
-
-            # Flatten the outer loops
-            if extents:
-                src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
-
-            # Insert ALU micro-ops
-            irb = tvm.tir.ir_builder.create()
-            for idx, extent in enumerate(extents):
-                irb.emit(tvm.tir.call_extern(
-                    "int32", "VTAUopLoopBegin",
-                    extent, dst_coeff[idx], src_coeff[idx], 0))
-            use_imm = int(use_imm)
-            irb.emit(tvm.tir.call_extern(
-                "int32", "VTAUopPush",
-                1, 0,
-                dst_coeff[len(dst_coeff)-1],
-                src_coeff[len(src_coeff)-1],
-                0,
-                alu_opcode, use_imm, imm_val))
-            for extent in extents:
-                irb.emit(tvm.tir.call_extern(
-                    "int32", "VTAUopLoopEnd"))
-            return irb.get()
-        return stmt
-
-    stmt_out = tvm.tir.ir_pass.IRTransform(
-        stmt_in, None, _do_fold, ["AttrStmt"])
-    return stmt_out
-
-
-def debug_print(stmt):
-    """A debug pass that print the stmt
-
-    Parameters
-    ----------
-    stmt : Stmt
-        The input statement
-
-    Returns
-    -------
-    stmt : Stmt
-        The
-    """
-    # pylint: disable=superfluous-parens
-    print(stmt)
-    return stmt
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
new file mode 100644
index 0000000..f930b3f
--- /dev/null
+++ b/vta/python/vta/transform.py
@@ -0,0 +1,962 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Additional Transformation Passes. for VTA"""
+# pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name
+import tvm
+from tvm import te
+from topi import util
+
+from .environment import get_env
+
+
+def _match_pragma(stmt, key):
+    """Internal helper to match stmt to pragma stmt.
+
+    Parameters
+    ----------
+    stmt : Stmt
+        The AttrStmt
+
+    key : str
+        The pragma key
+    """
+    return ((stmt.attr_key == "pragma_" + key) or
+            (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
+
+
+def FoldUopLoop():
+    """Detect and fold uop loop.
+
+    VTA support uop programming model
+    that recognizes loop structure.
+    This pass detect the loop structure
+    and extract that into uop loop AST.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _fold_outermost_loop(body):
+        stmt = body
+        if not isinstance(stmt, tvm.tir.For):
+            return None, body, None
+
+        loop_var = stmt.loop_var
+        gemm_offsets = [None, None, None]
+        fail = [False]
+
+        def _post_order(op):
+            assert isinstance(op, tvm.tir.Call)
+            base_args = 2
+            if op.name == "VTAUopPush":
+                args = []
+                args += op.args[:base_args]
+                for i in range(3):
+                    m = tvm.arith.detect_linear_equation(
+                        op.args[i + base_args], [loop_var])
+                    if not m:
+                        fail[0] = True
+                        return op
+                    if gemm_offsets[i] is not None:
+                        if not tvm.ir.structural_equal(m[0], gemm_offsets[i]):
+                            fail[0] = True
+                            return op
+                        args.append(m[1])
+                    else:
+                        gemm_offsets[i] = m[0]
+                        args.append(m[1])
+                args += op.args[base_args+3:]
+                return tvm.tir.call_extern("int32", "VTAUopPush", *args)
+            if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
+                raise RuntimeError("unexpected op %s" % op)
+            return op
+
+        ret = tvm.tir.ir_pass.IRTransform(
+            stmt.body, None, _post_order, ["Call"])
+
+        if not fail[0] and all(x is not None for x in gemm_offsets):
+            def _visit(op):
+                if op.same_as(loop_var):
+                    fail[0] = True
+            tvm.tir.ir_pass.PostOrderVisit(ret, _visit)
+            if not fail[0]:
+                begin = tvm.tir.call_extern(
+                    "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
+                end = tvm.tir.call_extern("int32", "VTAUopLoopEnd")
+                return [begin, ret, end]
+        raise ValueError("Failed to fold the GEMM instructions..")
+
+    def _do_fold(stmt):
+        env = get_env()
+        if (stmt.attr_key == "coproc_uop_scope" and
+                isinstance(stmt.value, tvm.tir.StringImm) and
+                stmt.value.value == env.dev.vta_push_uop.value):
+            body = stmt.body
+            begins = []
+            ends = []
+            try:
+                begin, body, end = _fold_outermost_loop(body)
+                if begin is not None:
+                    begins.append(begin)
+                if end is not None:
+                    ends.append(end)
+                begin, body, end = _fold_outermost_loop(body)
+                if begin is not None:
+                    begins.append(begin)
+                if end is not None:
+                    ends.append(end)
+            except ValueError:
+                pass
+            if body == stmt.body:
+                return stmt
+            ends = list(reversed(ends))
+            body = tvm.tir.stmt_seq(*(begins + [body] + ends))
+            return tvm.tir.AttrStmt(
+                stmt.node, stmt.attr_key, stmt.value, body)
+        return None
+
+    def _ftransform(f, mod, ctx):
+        return f.with_body(tvm.tir.ir_pass.IRTransform(
+            f.body, _do_fold, None, ["AttrStmt"]))
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
+
+
+def CPUAccessRewrite():
+    """Detect CPU access to VTA buffer and get address correctly.
+
+    VTA's buffer is an opaque handle that do not
+    correspond to address in CPU.
+    This pass detect CPU access and rewrite to use pointer
+    returned VTABufferCPUPtr for CPU access.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _ftransform(f, mod, ctx):
+        rw_info = {}
+        env = get_env()
+        def _post_order(op):
+            if isinstance(op, tvm.tir.Allocate):
+                buffer_var = op.buffer_var
+                if not buffer_var in rw_info:
+                    return None
+                new_var = rw_info[buffer_var]
+                let_stmt = tvm.tir.LetStmt(
+                    new_var, tvm.tir.call_extern(
+                        "handle", "VTABufferCPUPtr",
+                        env.dev.command_handle,
+                        buffer_var), op.body)
+                alloc = tvm.tir.Allocate(
+                    buffer_var, op.dtype, op.extents,
+                    op.condition, let_stmt)
+                del rw_info[buffer_var]
+                return alloc
+            if isinstance(op, tvm.tir.Load):
+                buffer_var = op.buffer_var
+                if not buffer_var in rw_info:
+                    rw_info[buffer_var] = te.var(
+                        buffer_var.name + "_ptr", "handle")
+                new_var = rw_info[buffer_var]
+                return tvm.tir.Load(op.dtype, new_var, op.index)
+            if isinstance(op, tvm.tir.Store):
+                buffer_var = op.buffer_var
+                if not buffer_var in rw_info:
+                    rw_info[buffer_var] = te.var(
+                        buffer_var.name + "_ptr", "handle")
+                new_var = rw_info[buffer_var]
+                return tvm.tir.Store(new_var, op.value, op.index)
+            raise RuntimeError("not reached")
+
+        stmt_in = f.body
+        stmt = tvm.tir.ir_pass.IRTransform(
+            stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
+
+        for buffer_var, new_var in rw_info.items():
+            stmt = tvm.tir.LetStmt(
+                new_var, tvm.tir.call_extern(
+                    "handle", "VTABufferCPUPtr",
+                    env.dev.command_handle,
+                    buffer_var), stmt)
+        return f.with_body(stmt)
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
+
+
+def LiftAllocToScopeBegin():
+    """Lift allocate to beginning of the current scope.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _ftransform(f, mod, ctx):
+        lift_stmt = [[]]
+        def _merge_block(slist, body):
+            for op in slist:
+                if op.body == body:
+                    body = op
+                elif isinstance(op, tvm.tir.Allocate):
+                    body = tvm.tir.Allocate(
+                        op.buffer_var, op.dtype,
+                        op.extents, op.condition, body)
+                elif isinstance(op, tvm.tir.AttrStmt):
+                    body = tvm.tir.AttrStmt(
+                        op.node, op.attr_key, op.value, body)
+                elif isinstance(op, tvm.tir.For):
+                    body = tvm.tir.For(
+                        op.loop_var, op.min, op.extent, op.for_type,
+                        op.device_api, body)
+                else:
+                    raise RuntimeError("unexpected op")
+            del slist[:]
+            return body
+
+        def _pre_order(op):
+            if isinstance(op, tvm.tir.For):
+                lift_stmt.append([])
+            elif isinstance(op, tvm.tir.AttrStmt):
+                if op.attr_key == "virtual_thread":
+                    lift_stmt.append([])
+
+        def _post_order(op):
+            if isinstance(op, tvm.tir.Allocate):
+                lift_stmt[-1].append(op)
+                return op.body
+            if isinstance(op, tvm.tir.AttrStmt):
+                if op.attr_key == "storage_scope":
+                    lift_stmt[-1].append(op)
+                    return op.body
+                if op.attr_key == "virtual_thread":
+                    return _merge_block(lift_stmt.pop() + [op], op.body)
+                return op
+            if isinstance(op, tvm.tir.For):
+                return _merge_block(lift_stmt.pop() + [op], op.body)
+            raise RuntimeError("not reached")
+        stmt_in = f.body
+        stmt = tvm.tir.ir_pass.IRTransform(
+            stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
+        assert len(lift_stmt) == 1
+        return f.with_body(_merge_block(lift_stmt[0], stmt))
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
+
+
+def InjectSkipCopy():
+    """Pass to inject skip copy stmt, used for debug purpose.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _do_fold(stmt):
+        if _match_pragma(stmt, "skip_dma_copy"):
+            return tvm.tir.Evaluate(0)
+        return None
+
+    def _ftransform(f, mod, ctx):
+        return f.with_body(tvm.tir.ir_pass.IRTransform(
+            f.body, _do_fold, None, ["AttrStmt"]))
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
+
+
+def InjectCoProcSync():
+    """Pass inject coproc sync
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _ftransform(f, *_):
+        success = [False]
+        def _do_fold(stmt):
+            if _match_pragma(stmt, "coproc_sync"):
+                success[0] = True
+                sync = tvm.tir.Call(
+                    "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
+                return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
+            if _match_pragma(stmt, "trim_loop"):
+                op = stmt.body
+                assert isinstance(op, tvm.tir.For)
+                return tvm.tir.For(
+                    op.loop_var, op.min, 2, op.for_type,
+                    op.device_api, op.body)
+            return None
+        return f.with_body(tvm.tir.ir_pass.IRTransform(
+            f.body, None, _do_fold, ["AttrStmt"]))
+    return tvm.transform.Sequential(
+        [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
+         tvm.tir.transform.CoProcSync()],
+        opt_level=0, name="tir.vta.InjectCoProcSync")
+
+
+def InjectDMAIntrin():
+    """Pass to inject DMA copy intrinsics.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    idxd = tvm.tir.indexdiv
+    idxm = tvm.tir.indexmod
+
+    def _check_compact(buf):
+        ndim = len(buf.shape)
+        size = tvm.tir.const(1, buf.shape[0].dtype)
+        for i in reversed(range(ndim)):
+            if not util.equal_const_int(size - buf.strides[i], 0):
+                raise RuntimeError(
+                    "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
+            size = size * buf.shape[i]
+
+    def _fold_buffer_dim(buf, scope, elem_block):
+        ndim = len(buf.shape)
+        x_size = 1
+        base = 0
+        for i in range(1, ndim + 1):
+            if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
+                raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
+            x_size = x_size * buf.shape[ndim - i]
+            if util.equal_const_int(x_size - elem_block, 0):
+                base = i + 1
+                break
+        if base == 0:
+            raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
+                scope, elem_block, buf.shape))
+        shape = [elem_block]
+        strides = [1]
+
+        if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block):
+            shape.append(1)
+            strides.append(elem_block)
+
+        analyzer = tvm.arith.Analyzer()
+        while base < ndim + 1:
+            x_size = 1
+            x_stride = buf.strides[ndim - base]
+            next_base = base
+            if not util.equal_const_int(idxm(x_stride, elem_block), 0):
+                raise RuntimeError(
+                    "scope %s need to have block=%d, shape=%s, strides=%s" % (
+                        scope, elem_block, buf.shape, buf.strides))
+            for i in range(base, ndim + 1):
+                k = ndim - i
+                if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
+                    break
+                x_size = x_size * buf.shape[k]
+                next_base = i + 1
+            shape.append(analyzer.simplify(x_size))
+            strides.append(x_stride)
+            assert next_base != base
+            base = next_base
+
+        strides = list(reversed(strides))
+        shape = list(reversed(shape))
+        return shape, strides
+
+    def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
+        elem_block = elem_bytes * 8 // elem_width
+        if buf.dtype != dtype:
+            raise RuntimeError("Expect buffer type to be %s instead of %s" %
+                               (dtype, buf.dtype))
+        shape, strides = buf.shape, buf.strides
+        if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
+            raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
+        if allow_fold:
+            shape, strides = _fold_buffer_dim(buf, scope, elem_block)
+        else:
+            shape = list(x for x in shape)
+            strides = list(x for x in strides)
+
+        def raise_error():
+            """Internal function to raise error """
+            raise RuntimeError(
+                ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
+                 " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))
+
+        ndim = len(shape)
+
+        # Check if the inner-tensor is already flat
+        flat = util.equal_const_int(shape[-1], elem_block)
+
+        if flat:
+            if not util.equal_const_int(strides[-1], 1):
+                raise_error()
+
+            if ndim == 1:
+                x_size = 1
+                x_stride = 1
+                y_size = 1
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+            if not util.equal_const_int(strides[-2] - elem_block, 0):
+                raise_error()
+
+            if ndim == 2:
+                x_size = shape[-2]
+                x_stride = shape[-2]
+                y_size = 1
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+            if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
+                raise_error()
+
+            if ndim == 3:
+                x_size = shape[-2]
+                x_stride = idxd(strides[-3], elem_block)
+                y_size = shape[-3]
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+
+        else:
+            if not util.equal_const_int(strides[-1], 1):
+                raise_error()
+            if not util.equal_const_int(strides[-2] - shape[-1], 0):
+                raise_error()
+            if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
+                raise_error()
+
+            if ndim == 2:
+                x_size = 1
+                x_stride = 1
+                y_size = 1
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+            if not util.equal_const_int(strides[-3], elem_block):
+                raise_error()
+
+            if ndim == 3:
+                x_size = shape[-3]
+                x_stride = shape[-3]
+                y_size = 1
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+            if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
+                raise_error()
+
+            if ndim == 4:
+                x_size = shape[-3]
+                x_stride = idxd(strides[-4], elem_block)
+                y_size = shape[-4]
+                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+
+        raise_error()
+
+
+    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
+        # FIXME: pad_value is ignored...
+        env = get_env()
+        _ = pad_value
+        if dst.scope == "global":
+            # Store
+            if pad_before or pad_after:
+                raise RuntimeError("Do not support copy into DRAM with pad")
+            if src.scope == env.acc_scope:
+                elem_width = env.OUT_WIDTH
+                elem_bytes = env.OUT_ELEM_BYTES
+                mem_type = env.dev.MEM_ID_OUT
+                data_type = "int%d" % env.OUT_WIDTH
+                task_qid = env.dev.QID_STORE_OUT
+            else:
+                raise RuntimeError("Do not support copy %s->dram" % (src.scope))
+            _check_compact(src)
+            x_size, y_size, x_stride, offset = _get_2d_pattern(
+                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
+            irb = tvm.tir.ir_builder.create()
+            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
+                           env.dev.get_task_qid(task_qid))
+            irb.emit(tvm.tir.call_extern(
+                "int32", "VTAStoreBuffer2D",
+                env.dev.command_handle,
+                src.access_ptr("r", "int32"),
+                mem_type, dst.data, offset, x_size, y_size, x_stride))
+            return irb.get()
+        elif src.scope == "global":
+            if dst.scope == env.acc_scope:
+                elem_width = env.ACC_WIDTH
+                elem_bytes = env.ACC_ELEM_BYTES
+                mem_type = env.dev.MEM_ID_ACC
+                data_type = "int%d" % env.ACC_WIDTH
+                task_qid = env.dev.QID_LOAD_OUT
+            elif dst.scope == env.inp_scope:
+                elem_width = env.INP_WIDTH
+                elem_bytes = env.INP_ELEM_BYTES
+                mem_type = env.dev.MEM_ID_INP
+                data_type = "int%d" % env.INP_WIDTH
+                task_qid = env.dev.QID_LOAD_INP
+            elif dst.scope == env.wgt_scope:
+                elem_width = env.WGT_WIDTH
+                elem_bytes = env.WGT_ELEM_BYTES
+                mem_type = env.dev.MEM_ID_WGT
+                data_type = "int%d" % env.WGT_WIDTH
+                task_qid = env.dev.QID_LOAD_WGT
+            else:
+                raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
+            # collect pad statistics
+            if pad_before:
+                assert pad_after
+                ndim = len(pad_before)
+                if ndim <= 2 or ndim > 5:
+                    raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
+                if ndim == 5:
+                    # This case occurs when batch size N > 1
+                    y_pad_before = pad_before[1]
+                    x_pad_before = pad_before[2]
+                    y_pad_after = pad_after[1]
+                    x_pad_after = pad_after[2]
+                    for dim in range(3, ndim):
+                        if not util.equal_const_int(pad_before[dim], 0):
+                            raise ValueError("Do not support pad on the innermost block")
+                        if not util.equal_const_int(pad_after[dim], 0):
+                            raise ValueError("Do not support pad on the innermost block")
+                else:
+                    y_pad_before = pad_before[0]
+                    x_pad_before = pad_before[1]
+                    y_pad_after = pad_after[0]
+                    x_pad_after = pad_after[1]
+                    for dim in range(2, ndim):
+                        if not util.equal_const_int(pad_before[dim], 0):
+                            raise ValueError("Do not support pad on the innermost block")
+                        if not util.equal_const_int(pad_after[dim], 0):
+                            raise ValueError("Do not support pad on the innermost block")
+                allow_fold = False
+            else:
+                x_pad_before = 0
+                y_pad_before = 0
+                x_pad_after = 0
+                y_pad_after = 0
+                allow_fold = True
+
+            _check_compact(dst)
+            x_size, y_size, x_stride, offset = _get_2d_pattern(
+                src, elem_width, elem_bytes, data_type,
+                dst.scope, allow_fold=allow_fold)
+
+            irb = tvm.tir.ir_builder.create()
+            irb.scope_attr(env.dev.vta_axis, "coproc_scope",
+                           env.dev.get_task_qid(task_qid))
+
+            irb.emit(tvm.tir.call_extern(
+                "int32", "VTALoadBuffer2D",
+                env.dev.command_handle,
+                src.data, offset, x_size, y_size, x_stride,
+                x_pad_before, y_pad_before,
+                x_pad_after, y_pad_after,
+                dst.access_ptr("r", "int32"), mem_type))
+            return irb.get()
+
+        else:
+            raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
+
+    return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)
+
+
+def _get_gemm_intrin_buffer():
+    env = get_env()
+    wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
+    assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
+    wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
+    assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
+    inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
+    assert inp_lanes == env.BATCH * env.BLOCK_IN
+    inp_shape = (env.BATCH, env.BLOCK_IN)
+    assert inp_shape[0] * inp_shape[1] == inp_lanes
+    out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
+    assert out_lanes == env.BATCH * env.BLOCK_OUT
+    out_shape = (env.BATCH, env.BLOCK_OUT)
+    assert out_shape[0] * out_shape[1] == out_lanes
+    wgt = te.placeholder((wgt_shape[0], wgt_shape[1]),
+                         dtype="int%d" % env.WGT_WIDTH,
+                         name=env.wgt_scope)
+    inp = te.placeholder((inp_shape[0], inp_shape[1]),
+                         dtype="int%d" % env.INP_WIDTH,
+                         name=env.inp_scope)
+    k = te.reduce_axis((0, wgt_shape[1]), name="k")
+    out_dtype = "int%d" % env.ACC_WIDTH
+    out = te.compute((out_shape[0], out_shape[1]),
+                     lambda i, j: te.sum(inp[i, k].astype(out_dtype) *
+                                         wgt[j, k].astype(out_dtype),
+                                         axis=[k]),
+                     name="out")
+    wgt_layout = tvm.tir.decl_buffer(
+        wgt.shape, wgt.dtype, env.wgt_scope,
+        scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
+    inp_layout = tvm.tir.decl_buffer(
+        inp.shape, inp.dtype, env.inp_scope,
+        scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
+    out_layout = tvm.tir.decl_buffer(
+        out.shape, out.dtype, env.acc_scope,
+        scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
+
+    return wgt_layout, inp_layout, out_layout
+
+
+def InjectConv2DTransposeSkip():
+    """Pass to skip 0-weights in conv2d transpose with stride > 1.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _ftransform(func, mod, ctx):
+        env = get_env()
+        dwgt, dinp, dout = _get_gemm_intrin_buffer()
+
+        calls = []
+        selects = []
+
+        def _find_basics(op):
+            if isinstance(op, tvm.tir.BufferLoad):
+                calls.append(op)
+            elif isinstance(op, tvm.tir.Select):
+                selects.append(op)
+
+        def _do_fold(op):
+            if _match_pragma(op, "conv2d_transpose_gemm"):
+                is_init = ".init" in str(op)
+                tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
+
+                if is_init:
+                    # create inner most block
+                    irb = tvm.tir.ir_builder.create()
+                    dev = env.dev
+                    irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
+                    irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
+                    irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
+                                                 0, 1,
+                                                 dout.access_ptr("rw", "int32"),
+                                                 0, 0,
+                                                 0, 0, 0))
+                    inner = irb.get()
+                    # TODO(@tmoreau89): This is only a temporary fix, please take a look.
+                    body = op.body.body
+                    while isinstance(body, tvm.tir.IfThenElse):
+                        body = body.then_case
+                    args = body.indices
+                    res_buffer = body.buffer
+                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
+                    inner = tvm.tir.AttrStmt(
+                        [dout, res_buffer], 'buffer_bind_scope',
+                        tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+                    return inner
+                else:
+                    conv_call, data_call, kernel_call = calls[-3:]
+                    pad_data_tensor = data_call.buffer
+                    kernel_tensor = kernel_call.buffer
+                    res_tensor = conv_call.buffer
+
+                    if selects:
+                        condition = selects[0].condition
+                    else:
+                        condition = tvm.tir.const(1, 'int')
+
+                    # create inner most block
+                    irb = tvm.tir.ir_builder.create()
+                    with irb.if_scope(condition):
+                        dev = env.dev
+                        irb.scope_attr(
+                            dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
+                        irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
+                        irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
+                                                     0, 0,
+                                                     dout.access_ptr("rw", "int32"),
+                                                     dinp.access_ptr("r", "int32"),
+                                                     dwgt.access_ptr("r", "int32"),
+                                                     0, 0, 0))
+                    inner = irb.get()
+
+                    args = conv_call.indices
+                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
+                           1, 0, 1, 0, env.BLOCK_OUT)
+                    inner = tvm.tir.AttrStmt(
+                        [dout, res_tensor], 'buffer_bind_scope',
+                        tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+                    args = kernel_call.indices
+                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
+                           1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
+                    inner = tvm.tir.AttrStmt(
+                        [dwgt, kernel_tensor], 'buffer_bind_scope',
+                        tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+                    args = data_call.indices
+                    tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
+                           1, 0, 1, 0, env.BLOCK_IN)
+                    inner = tvm.tir.AttrStmt(
+                        [dinp, pad_data_tensor], 'buffer_bind_scope',
+                        tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+                    return inner
+            return None
+
+        return func.with_body(tvm.tir.ir_pass.IRTransform(
+            func.body, _do_fold, None, ["AttrStmt"]))
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
+
+
+def AnnotateALUCoProcScope():
+    """Pass to insert ALU instruction.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _ftransform(func, mod, ctx):
+        env = get_env()
+        def _do_fold(stmt):
+            if _match_pragma(stmt, "alu"):
+                irb = tvm.tir.ir_builder.create()
+                irb.scope_attr(env.dev.vta_axis, "coproc_scope",
+                               env.dev.get_task_qid(env.dev.QID_COMPUTE))
+                irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
+                               tvm.tir.StringImm("VTAPushALUOp"))
+                irb.emit(stmt)
+                return irb.get()
+            if _match_pragma(stmt, "skip_alu"):
+                return tvm.tir.Evaluate(0)
+            return stmt
+
+        return func.with_body(tvm.tir.ir_pass.IRTransform(
+            func.body, None, _do_fold, ["AttrStmt"]))
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
+
+
+def InjectALUIntrin():
+    """Pass to inject ALU micro-ops.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    def _ftransform(func, mod, ctx):
+        env = get_env()
+        idxm = tvm.tir.indexmod
+        analyzer = tvm.arith.Analyzer()
+
+        def _do_fold(stmt):
+            def _equal(x, y):
+                return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
+
+            def _flatten_loop(src_coeff, dst_coeff, extents):
+                src_coeff = list(src_coeff)
+                dst_coeff = list(dst_coeff)
+                extents = list(extents)
+                rev_src_coeff = [src_coeff.pop()]
+                rev_dst_coeff = [dst_coeff.pop()]
+                rev_extents = []
+                assert src_coeff
+                vsrc = src_coeff.pop()
+                vdst = dst_coeff.pop()
+                vext = extents.pop()
+                while src_coeff:
+                    next_src = src_coeff.pop()
+                    next_dst = dst_coeff.pop()
+                    next_ext = extents.pop()
+
+                    if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
+                        vext = analyzer.simplify(vext * next_ext)
+                    else:
+                        rev_src_coeff.append(vsrc)
+                        rev_dst_coeff.append(vdst)
+                        rev_extents.append(vext)
+                        vsrc = next_src
+                        vdst = next_dst
+                        vext = next_ext
+                rev_src_coeff.append(vsrc)
+                rev_dst_coeff.append(vdst)
+                rev_extents.append(vext)
+                rev_src_coeff.reverse()
+                rev_dst_coeff.reverse()
+                rev_extents.reverse()
+
+                return rev_src_coeff, rev_dst_coeff, rev_extents
+
+            if _match_pragma(stmt, "alu"):
+                # Get to the innermost loop body
+                loop_body = stmt.body
+                nest_size = 0
+                while isinstance(loop_body, tvm.tir.For):
+                    loop_body = loop_body.body
+                    nest_size += 1
+                # Get the src/dst arguments
+                dst_var = loop_body.buffer_var
+                dst_idx = loop_body.index
+                # Derive loop variables and extents
+                tmp_body = stmt.body
+                indices = []
+                extents = []
+                for _ in range(nest_size):
+                    indices.append(tmp_body.loop_var)
+                    extents.append(tmp_body.extent)
+                    tmp_body = tmp_body.body
+                # Derive opcode
+                if isinstance(loop_body.value, tvm.tir.Add):
+                    alu_opcode = env.dev.ALU_OPCODE_ADD
+                    lhs = loop_body.value.a
+                    rhs = loop_body.value.b
+                elif isinstance(loop_body.value, tvm.tir.Sub):
+                    alu_opcode = env.dev.ALU_OPCODE_SUB
+                    lhs = loop_body.value.a
+                    rhs = loop_body.value.b
+                elif isinstance(loop_body.value, tvm.tir.Mul):
+                    alu_opcode = env.dev.ALU_OPCODE_MUL
+                    lhs = loop_body.value.a
+                    rhs = loop_body.value.b
+                elif isinstance(loop_body.value, tvm.tir.Min):
+                    alu_opcode = env.dev.ALU_OPCODE_MIN
+                    lhs = loop_body.value.a
+                    rhs = loop_body.value.b
+                elif isinstance(loop_body.value, tvm.tir.Max):
+                    alu_opcode = env.dev.ALU_OPCODE_MAX
+                    lhs = loop_body.value.a
+                    rhs = loop_body.value.b
+                elif isinstance(loop_body.value, tvm.tir.Call):
+                    if loop_body.value.name == 'shift_left':
+                        alu_opcode = env.dev.ALU_OPCODE_SHR
+                        lhs = loop_body.value.args[0]
+                        rhs = analyzer.simplify(-loop_body.value.args[1])
+                    elif loop_body.value.name == 'shift_right':
+                        alu_opcode = env.dev.ALU_OPCODE_SHR
+                        lhs = loop_body.value.args[0]
+                        rhs = loop_body.value.args[1]
+                    else:
+                        raise RuntimeError(
+                            "Function call not recognized %s" % (loop_body.value.name))
+                elif isinstance(loop_body.value, tvm.tir.Load):
+                    alu_opcode = env.dev.ALU_OPCODE_SHR
+                    lhs = loop_body.value
+                    rhs = tvm.tir.const(0, "int32")
+                else:
+                    raise RuntimeError(
+                        "Expression not recognized %s, %s, %s" % (
+                            type(loop_body.value), str(loop_body.value), str(stmt)))
+
+                # Derive array index coefficients
+                dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
+                # Check if lhs/rhs is immediate
+                use_imm = False
+                imm_val = None
+                if isinstance(rhs, tvm.tir.IntImm):
+                    assert lhs.buffer_var.same_as(dst_var)
+                    src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
+                    use_imm = True
+                    imm_val = rhs
+                if isinstance(lhs, tvm.tir.IntImm):
+                    assert rhs.buffer_var.same_as(dst_var)
+                    src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
+                    use_imm = True
+                    imm_val = lhs
+                if imm_val is None:
+                    imm_val = 0
+                    assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
+                    src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
+                    src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
+                    # Determine which side has the same coefficients
+                    lhs_equal = True
+                    rhs_equal = True
+                    for i, coef in enumerate(dst_coeff):
+                        if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
+                            lhs_equal = False
+                        if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
+                            rhs_equal = False
+                    # Make sure at least one of the source is identical to the
+                    # destination (in-place computation)
+                    assert lhs_equal or rhs_equal
+                    # Assign the source coefficients
+                    if lhs_equal:
+                        src_coeff = src_rhs_coeff
+                    else:
+                        src_coeff = src_lhs_coeff
+
+                # Ensure that we have the proper tensor dimensions in the
+                # innermost loop (pattern match)
+                src_coeff = list(src_coeff)
+                dst_coeff = list(dst_coeff)
+                extents = list(extents)
+                assert len(src_coeff) > 1
+                assert len(dst_coeff) > 1
+                assert len(extents) != 0
+                assert tvm.ir.structural_equal(
+                    analyzer.simplify(
+                        idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
+                assert tvm.ir.structural_equal(
+                    analyzer.simplify(
+                        idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
+                assert tvm.ir.structural_equal(src_coeff[-2], 1)
+                assert tvm.ir.structural_equal(dst_coeff[-2], 1)
+                if env.BATCH > 1:
+                    assert len(src_coeff) > 2
+                    assert len(dst_coeff) > 2
+                    assert len(extents) > 1
+                    assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
+                    assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
+
+                # Apply tensorization of the loop coefficients
+                src_offset = src_coeff[-1]
+                dst_offset = dst_coeff[-1]
+                if env.BATCH == 1:
+                    src_coeff = src_coeff[:-2]
+                    dst_coeff = dst_coeff[:-2]
+                    extents = extents[:-1]
+                else:
+                    src_coeff = src_coeff[:-3]
+                    dst_coeff = dst_coeff[:-3]
+                    extents = extents[:-2]
+                src_coeff.append(src_offset)
+                dst_coeff.append(dst_offset)
+                src_coeff = [
+                    analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
+                dst_coeff = [
+                    analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
+
+                # Flatten the outer loops
+                if extents:
+                    src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
+
+                # Insert ALU micro-ops
+                irb = tvm.tir.ir_builder.create()
+                for idx, extent in enumerate(extents):
+                    irb.emit(tvm.tir.call_extern(
+                        "int32", "VTAUopLoopBegin",
+                        extent, dst_coeff[idx], src_coeff[idx], 0))
+                use_imm = int(use_imm)
+                irb.emit(tvm.tir.call_extern(
+                    "int32", "VTAUopPush",
+                    1, 0,
+                    dst_coeff[len(dst_coeff)-1],
+                    src_coeff[len(src_coeff)-1],
+                    0,
+                    alu_opcode, use_imm, imm_val))
+                for extent in extents:
+                    irb.emit(tvm.tir.call_extern(
+                        "int32", "VTAUopLoopEnd"))
+                return irb.get()
+            return stmt
+
+        return func.with_body(tvm.tir.ir_pass.IRTransform(
+            func.body, None, _do_fold, ["AttrStmt"]))
+
+    return tvm.tir.transform.prim_func_pass(
+        _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")


Mime
View raw message