tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From z...@apache.org
Subject [incubator-tvm] branch master updated: [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)
Date Fri, 10 Apr 2020 21:33:05 GMT
This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new f506c8b  [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)
f506c8b is described below

commit f506c8b19ab3a7634ac56f63298dade6d40d2d1d
Author: Cody Yu <comaniac0422@gmail.com>
AuthorDate: Fri Apr 10 14:32:56 2020 -0700

    [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes (#5277)
    
    * add target to region
    
    * refactor annotate_target
    
    * Make all unit test working
    
    * quick fix
    
    * enable BN, unit test failed
    
    * Fix vm test, unit test. Refactor annotate_target a bit.
    
    * quick fix fusion
    
    * revert fusion change
    
    * style fix
    
    * Refactor merge region pass
    
    * format
    
    * minor fix
    
    * Skip e2e test
    
    * lint
    
    * support AnnotateTarget multiple runs
    
    * Add HasAttr and revert DNNL codegen
    
    * address comment
    
    Co-authored-by: Zhi Chen <chzhi@amazon.com>
---
 python/tvm/relay/op/contrib/dnnl.py                |   9 +-
 python/tvm/relay/transform/transform.py            |  10 +-
 src/relay/analysis/annotated_region_set.cc         |  60 ++--
 src/relay/analysis/annotated_region_set.h          |  37 ++-
 src/relay/backend/contrib/dnnl/codegen.cc          |  71 ++---
 src/relay/backend/vm/compiler.cc                   |  13 +-
 src/relay/transforms/annotate_target.cc            | 326 ++++++++++++--------
 src/relay/transforms/merge_compiler_regions.cc     | 339 +++++----------------
 src/relay/transforms/partition_graph.cc            |  37 ++-
 src/runtime/contrib/dnnl/dnnl.cc                   |   6 +-
 src/runtime/contrib/dnnl/dnnl_kernel.h             |   4 +-
 tests/python/relay/test_annotated_regions.py       |  18 +-
 ...tate_target.py => test_pass_annotate_target.py} | 123 +++++++-
 .../relay/test_pass_merge_compiler_regions.py      |  71 +++--
 tests/python/relay/test_pass_partition_graph.py    |  14 +-
 15 files changed, 609 insertions(+), 529 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 45a8c83..1aa7192 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -56,10 +56,17 @@ def _register_external_op_helper(op_name, supported=True):
     return _func_wrapper
 
 
-_register_external_op_helper("nn.batch_norm")
 _register_external_op_helper("nn.conv2d")
 _register_external_op_helper("nn.dense")
 _register_external_op_helper("nn.relu")
 _register_external_op_helper("add")
 _register_external_op_helper("subtract")
 _register_external_op_helper("multiply")
+
+
+@reg.register("nn.batch_norm", "target.dnnl")
+def batch_norm(attrs, args):
+    """Check if the external DNNL codegen should be used.
+    FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
+    """
+    return False
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index ce4ac79..918894f 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -587,14 +587,14 @@ def PartitionGraph():
 
 
 
-def AnnotateTarget(target):
+def AnnotateTarget(targets):
     """Annotate ops in an experession with a provied compiler/target and then
     use it for codegen.
 
     Parameters
     ----------
-    target : String
-        The target compiler used for codegen.
+    targets : str or List[str]
+        The list of target compilers used for codegen.
 
     Returns
     -------
@@ -602,7 +602,9 @@ def AnnotateTarget(target):
         The annotated pass that wrapps ops with subgraph_start and
         subgraph_end.
     """
-    return _ffi_api.AnnotateTarget(target)
+    if isinstance(targets, str):
+        targets = [targets]
+    return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])
 
 
 def Inline():
diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
index ad2b9e1..94c7621 100644
--- a/src/relay/analysis/annotated_region_set.cc
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -21,6 +21,7 @@
 
 #include <tvm/relay/expr.h>
 #include <tvm/ir/error.h>
+#include <tvm/runtime/container.h>
 
 #include <unordered_map>
 #include <vector>
@@ -31,7 +32,7 @@ namespace relay {
 
 AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
   for (auto candidate : regions_) {
-    if (candidate->nodes.find(expr) != candidate->nodes.end()) {
+    if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
       return candidate;
     }
   }
@@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
   }
 
   // Merge src to dest and erase src.
-  dest->nodes.insert(src->nodes.begin(), src->nodes.end());
-  for (const auto& input : src->ins) {
-    dest->ins.push_back(input);
+  dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
+  for (const auto& input : src->ins_) {
+    dest->ins_.push_back(input);
   }
-  for (const auto& output : src->outs) {
-    dest->outs.push_back(output);
+  for (const auto& output : src->outs_) {
+    dest->outs_.push_back(output);
   }
   // if any of the outputs of src are inputs of dest, they become internal nodes
   // so remove them from outs
   std::vector<Expr> ins_to_remove;
-  for (const auto& input : dest->ins) {
+  for (const auto& input : dest->ins_) {
     auto call = Downcast<Call>(input);
-    auto it = src->nodes.find(call->args[0]);
-    if (it != src->nodes.end()) {
-      dest->outs.remove(*it);
+    auto it = src->nodes_.find(call->args[0]);
+    if (it != src->nodes_.end()) {
+      dest->outs_.remove(*it);
       ins_to_remove.push_back(input);
     }
   }
   for (const auto& input : ins_to_remove) {
-    dest->ins.remove(input);
+    dest->ins_.remove(input);
   }
   regions_.erase(src);
 }
@@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr)
   if (src.defined()) {
     MergeRegions(src, dest);
   } else {
-    dest->nodes.insert(expr);
+    dest->nodes_.insert(expr);
   }
 }
 
-AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() {
+AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
   auto ret = regions_.emplace(AnnotatedRegion());
-  (*ret.first)->id = region_id_++;
+  (*ret.first)->id_ = region_id_++;
+  (*ret.first)->target_ = target;
   return *ret.first;
 }
 
 class AnnotatedRegionSet::Creator : public ExprVisitor {
  public:
-  Creator(const Op& region_begin_op, const Op& region_end_op) :
-    begin_op_(region_begin_op), end_op_(region_end_op) {}
-
-  AnnotatedRegionSet Create(const Expr& expr) {
-    VisitExpr(expr);
-    return std::move(region_set_);
-  }
+  Creator(const Op& region_begin_op, const Op& region_end_op)
+      : begin_op_(region_begin_op), end_op_(region_end_op) {}
 
   void VisitExpr_(const CallNode* call) {
     auto op_node = call->op.as<OpNode>();
@@ -115,24 +112,35 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
                       << "Cannot find the corresponding region for start annotation:\n"
                       << AsText(GetRef<Call>(call), false));
       }
-      region->ins.push_back(GetRef<Call>(call));
+      region->ins_.push_back(GetRef<Call>(call));
     } else {
       CHECK_EQ(call->op, end_op_);
       // The annotation node is inserted on edge so it must have only one argument.
       CHECK_EQ(call->args.size(), 1U);
+      std::string target = call->attrs.as<CompilerAttrs>()->compiler;
 
       // Check if the argument already belongs to a region
       auto region = region_set_->GetRegion(call->args[0]);
       if (!region.defined()) {
-        region = region_set_->MakeRegion();
-        region->nodes.insert(call->args[0]);
+        // Create a new region if the argument is not belonged to any regions yet.
+        region = region_set_->MakeRegion(target);
+        region->nodes_.insert(call->args[0]);
+      } else {
+        // If the argument is belonged to a region, it must have the same target.
+        // Otherwise we should see a region_begin op.
+        CHECK_EQ(region->GetTarget(), target);
       }
-      region->nodes.insert(GetRef<Call>(call));
-      region->outs.push_back(GetRef<Call>(call));
+      region->nodes_.insert(GetRef<Call>(call));
+      region->outs_.push_back(GetRef<Call>(call));
     }
     ExprVisitor::VisitExpr_(call);
   }
 
+  AnnotatedRegionSet Create(const Expr& expr) {
+    VisitExpr(expr);
+    return std::move(region_set_);
+  }
+
   void VisitExpr_(const TupleNode* op) {
     auto region = region_set_->GetRegion(GetRef<Tuple>(op));
     if (region.defined()) {
diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h
index 0b93011..3bd5693 100644
--- a/src/relay/analysis/annotated_region_set.h
+++ b/src/relay/analysis/annotated_region_set.h
@@ -32,6 +32,7 @@
 #include <tvm/relay/expr.h>
 #include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/runtime/container.h>
 #include <tvm/relay/transform.h>
 
 #include <string>
@@ -49,33 +50,39 @@ class AnnotatedRegionSet;
 class AnnotatedRegionNode : public Object {
  public:
   void VisitAttrs(AttrVisitor* v) {
-    v->Visit("id", &id);
-    Array<Expr> nodes_array(nodes.begin(), nodes.end());
+    v->Visit("id", &id_);
+    v->Visit("target", &target_);
+    Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
     v->Visit("nodes", &nodes_array);
-    Array<Expr> args_array(ins.begin(), ins.end());
+    Array<Expr> args_array(ins_.begin(), ins_.end());
     v->Visit("args", &args_array);
-    Array<Expr> rets_array(outs.begin(), outs.end());
+    Array<Expr> rets_array(outs_.begin(), outs_.end());
     v->Visit("rets", &rets_array);
   }
 
   /*! \brief Get the region ID. */
   int GetID() const {
-    return id;
+    return id_;
+  }
+
+  /*! \brief Get the region target. */
+  std::string GetTarget() const {
+    return target_;
   }
 
   /*! \brief Get the region's inputs. */
   std::list<Expr> GetInputs() const {
-    return ins;
+    return ins_;
   }
 
   /*! \brief Get the region's outputs. */
   std::list<Expr> GetOutputs() const {
-    return outs;
+    return outs_;
   }
 
   /*! \brief Get the region's nodes. */
   std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
-    return nodes;
+    return nodes_;
   }
 
   static constexpr const char* _type_key = "relay.AnnotatedRegion";
@@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object {
 
  protected:
   /*! \brief The region ID. */
-  int id{-1};
+  int id_{-1};
+  /*! \brief The target for this region. */
+  std::string target_ = "default";
   /*! \brief The inputs to this region. */
-  std::list<Expr> ins;
+  std::list<Expr> ins_;
   /*! \brief The outputs of this region */
-  std::list<Expr> outs;
+  std::list<Expr> outs_;
   /*! \brief Nodes in this region. */
-  std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes_;
 
   friend class AnnotatedRegionSet;
   friend class AnnotatedRegionSetNode;
@@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object {
   void AddToRegion(AnnotatedRegion dest, const Expr& expr);
 
   /*!
-   * \brief Make a new region.
+   * \brief Make a new region for a target.
    *
    * \return The new region.
    */
-  AnnotatedRegion MakeRegion();
+  AnnotatedRegion MakeRegion(const std::string& target);
 
   std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_;
   /*! \brief The next region ID to assign. */
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index cd6412c..7371174 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
   }
 
   void VisitExpr_(const TupleGetItemNode* op) final {
-    VisitExpr(op->tuple);
-    CHECK(out_.size() > static_cast<size_t>(op->index));
-
-    // Only keep the item we want for the child node.
-    // FIXME(@comaniac): The other items should still be requried for the primary outputs.
-    auto item = out_[op->index];
-    out_.clear();
-    out_.push_back(item);
+    // Do nothing
   }
 
   void VisitExpr_(const CallNode* call) final {
     std::ostringstream decl_stream;
-
+    std::ostringstream buf_stream;
     // Args: ID
     std::vector<std::string> args;
 
@@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
       }
     }
 
-    // Analyze the output buffers
-    std::vector<Type> out_types;
-    if (call->checked_type()->IsInstance<TupleTypeNode>()) {
-      auto type_node = call->checked_type().as<TupleTypeNode>();
-      for (auto field : type_node->fields) {
-        CHECK(field->IsInstance<TensorTypeNode>());
-        out_types.push_back(field);
-      }
-    } else if (call->checked_type()->IsInstance<TensorTypeNode>()) {
-      CHECK(call->checked_type()->IsInstance<TensorTypeNode>());
-      out_types.push_back(call->checked_type());
-    } else {
-      LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false);
-    }
-
-    out_.clear();
-    for (auto out_type : out_types) {
-      const auto& dtype = GetDtypeString(out_type.as<TensorTypeNode>());
-
-      std::string out = "buf_" + std::to_string(buf_idx_++);
-      auto out_shape = GetShape(out_type);
-      int out_size = 1;
-      for (size_t i = 0; i < out_shape.size(); ++i) {
-        out_size *= out_shape[i];
-      }
-      this->PrintIndents();
-      std::ostringstream buf_stream;
-      buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
-      buf_decl_.push_back(buf_stream.str());
-      decl_stream << ", " << out;
-
-      // Update output buffer
-      Output output;
-      output.name = out;
-      output.dtype = dtype;
-      output.need_copy = true;
-      output.size = out_size;
-      out_.push_back(output);
+    // Analyze the output buffer
+    auto type_node = call->checked_type().as<TensorTypeNode>();
+    CHECK(type_node);
+    const auto& dtype = GetDtypeString(type_node);
+    std::string out = "buf_" + std::to_string(buf_idx_++);
+    auto out_shape = GetShape(call->checked_type());
+    int out_size = 1;
+    for (size_t i = 0; i < out_shape.size(); ++i) {
+      out_size *= out_shape[i];
     }
+    this->PrintIndents();
+    buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");";
+    buf_decl_.push_back(buf_stream.str());
+    decl_stream << ", " << out;
 
     // Attach attribute arguments
     for (size_t i = 0; i < args.size(); ++i) {
@@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
     }
     decl_stream << ");";
     ext_func_body.push_back(decl_stream.str());
+
+    // Update output buffer
+    out_.clear();
+    Output output;
+    output.name = out;
+    output.dtype = dtype;
+    output.need_copy = true;
+    output.size = out_size;
+    out_.push_back(output);
   }
 
   std::string JIT(void) {
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index e2b0fff..3e020bb 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
   pass_seqs.push_back(transform::LambdaLift());
   pass_seqs.push_back(transform::InlinePrimitives());
 
-  // Manifest the allocations.
-  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
-  // Compute away possibly introduced constant computation.
-  pass_seqs.push_back(transform::FoldConstant());
-  // Fuse the shape functions.
-  pass_seqs.push_back(transform::FuseOps());
-
   // Inline the functions that are lifted to the module scope. We perform this
   // pass after all other optimization passes but before the memory allocation
   // pass. This is because memory allocation pass will insert `invoke_tvm_op`
@@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
   // external codegen.
   pass_seqs.push_back(transform::Inline());
 
+  // Manifest the allocations.
+  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
+  // Compute away possibly introduced constant computation.
+  pass_seqs.push_back(transform::FoldConstant());
+  // Fuse the shape functions.
+  pass_seqs.push_back(transform::FuseOps());
   // Manifest the allocations needed for the shape functions.
   pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
 
diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc
index 44ef35a..44d7b54 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -19,132 +19,203 @@
 
 /*!
  * \file src/relay/transforms/annotate_target.cc
- * \brief Wraps a call with compiler_begin and compiler_end to indicate that
- * the op of this call node will use external compiler.
+ * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
+ * this expr should be handled by the external compiler.
  */
 
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
 
 namespace tvm {
 namespace relay {
 namespace annotate_target {
 
-// Cache compiler_begin op for equivalence check.
 static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
+
+const PackedFunc* make_begin_op =
+    runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
 
 // A helper class to insert annotation boundaries for a program region that will
 // be handled by a specific compiler.
 class AnnotateTargetWrapper : public ExprMutator {
  public:
-  explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
-
-  Expr Annotate(const Expr& expr) {
-    return InsertEnd(Mutate(expr));
-  }
-
-  bool IsSupported(const Expr& expr) {
-    if (expr->IsInstance<CallNode>()) {
-      Call call = Downcast<Call>(expr);
-      auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
-      if (call->op->IsInstance<OpNode>()) {
-        Op op = Downcast<Op>(call->op);
-        CHECK(op.defined());
-        if (fannotate.count(op)) {
-          return fannotate[op](call->attrs, call->args);
-        }
-      } else if (call->op->IsInstance<FunctionNode>()) {
-        // handle composite functions
-        Function func = Downcast<Function>(call->op);
-        CHECK(func.defined());
-        auto comp_name = func->GetAttr<String>(attr::kComposite);
-        if (comp_name.defined()) {
-          std::string comp_name_str = comp_name;
-          size_t i = comp_name_str.find('.');
-          if (i != std::string::npos) {
-            std::string target = comp_name_str.substr(0, i);
-            if (target == target_) return true;
-          }
+  explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {}
+
+  /*!
+   * \brief This function annotates a compiler end and a compiler begin to all arguments.
+   *
+   *  The compiler end is based on the arg target while the compiler begin is based on the given
+   *  target. If target is not given and all arguments are going to the same target, then we will
+   *  use that target; otherwise we use default for this op. Note that all arg exprs must be
+   *  available in op_expr_to_target before calling this function.
+   *
+   * \param args An array of arguments of the given node.
+   * \param target The target of the current node.
+   * \return A pair of target and annotated argument expressions.
+   */
+  std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
+                                                   const std::string& target = "") {
+    std::string ref_target = "";
+    Array<Expr> compiler_ends;
+    for (auto arg : args) {
+      std::string arg_target = "defualt";
+      const CallNode* call = arg.as<CallNode>();
+
+      if (call && call->op == compiler_begin_op) {
+        // Argument is already compiler begin node meaning that this is not the first time
+        // running this pass, so we simply remove it and will add a new one later.
+        CHECK_EQ(call->args.size(), 1U);
+        const CallNode* end = call->args[0].as<CallNode>();
+        if (end->op == compiler_end_op) {
+          arg_target = end->attrs.as<CompilerAttrs>()->compiler;
         }
+        compiler_ends.push_back(call->args[0]);
+      } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
+        arg_target = op_expr_to_target_[arg];
+        compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
+      } else {
+        // Input vars.
+        compiler_ends.push_back(arg);
       }
-    }
-    if (expr->IsInstance<TupleGetItemNode>()) {
-      TupleGetItem get = Downcast<TupleGetItem>(expr);
-      if (get->tuple->IsInstance<CallNode>() &&
-          get->tuple.as<CallNode>()->op == compiler_begin_op) {
-        return true;
+
+      // Maintain reference target in case the target of the current node is unassigned.
+      if (ref_target == "") {
+        ref_target = arg_target;
+      } else if (ref_target != arg_target) {
+        ref_target = "default";
       }
     }
-    return false;
-  }
 
-  Expr InsertEnd(const Expr& arg) {
-    if (IsSupported(arg)) {
-      const auto *end_op =
-        runtime::Registry::Get("relay.op.annotation._make.compiler_end");
-      CHECK(end_op);
-      Expr end = (*end_op)(arg, target_);
-      return end;
+    // Determine compiler begin target.
+    std::string op_target = (target == "") ? ref_target : target;
+
+    Array<Expr> compiler_begins;
+    for (const auto& end : compiler_ends) {
+      compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
     }
-    return arg;
+
+    return {op_target, compiler_begins};
   }
 
-  Expr VisitExpr_(const CallNode* cn) {
-    auto new_e = ExprMutator::VisitExpr_(cn);
+  Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
+    Expr new_op = (*ann_op)(expr, target);
+    new_op->checked_type_ = expr->checked_type_;
+    return new_op;
+  }
 
-    Call call = Downcast<Call>(new_e);
+  Expr VisitExpr_(const CallNode* cn) final {
+    // Supported targets for this node. The order implies the priority.
+    std::vector<std::string> supported_targets;
+
+    auto op_node = cn->op.as<OpNode>();
+
+    // This graph has annotations, meaning that this is not the first time running this pass.
+    if (op_node && cn->op == compiler_begin_op) {
+      // Bypass compiler begin due to lack of target information. It will be processed
+      // when the following op handling arguments.
+      CHECK_EQ(cn->args.size(), 1U);
+      return VisitExpr(cn->args[0]);
+    } else if (op_node && cn->op == compiler_end_op) {
+      // Override compiler end with the new target.
+      CHECK_EQ(cn->args.size(), 1U);
+      auto input_expr = VisitExpr(cn->args[0]);
+      CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
+      return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
+    }
 
-    // add end annotations if the args are supported
-    Array<Expr> compiler_ends;
-    for (const auto& it : call->args) {
-      compiler_ends.push_back(InsertEnd(it));
+    // Peek the first argument. If it is compiler begin then this node had annotated by
+    // another target before, so we also consider that target as a supported target.
+    const CallNode* first_arg_call = cn->args[0].as<CallNode>();
+    if (first_arg_call && first_arg_call->op == compiler_begin_op) {
+      std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
+      if (arg_target != "default") {
+        supported_targets.push_back(arg_target);
+      }
     }
-    call = Call(call->op, compiler_ends, call->attrs);
-
-    // add begin annotations if the call node is supported
-    if (IsSupported(call)) {
-      tvm::Array<tvm::relay::Expr> compiler_begins;
-      const auto* begin_op =
-        runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
-      for (const auto& it : call->args) {
-        CHECK(begin_op);
-        Expr begin = (*begin_op)(it, target_);
-        compiler_begins.push_back(begin);
+
+    // Check which targets this op can be offloaded.
+    if (op_node) {
+      // TVM operators: Check target specific op checking function and add to supported_targets
+      // if it is supported.
+      Op op = Downcast<Op>(cn->op);
+      CHECK(op.defined());
+      for (const auto& target : this->targets_) {
+        if (!Op::HasAttr("target." + std::string(target))) {
+          continue;
+        }
+        auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
+        if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) {
+          supported_targets.push_back(target);
+        }
+      }
+    } else if (cn->op->IsInstance<FunctionNode>()) {
+      // Composite function: Add the target of a composite function to supported_targets
+      // if it is in the target list.
+      Function func = Downcast<Function>(cn->op);
+      CHECK(func.defined());
+      auto comp_name = func->GetAttr<String>(attr::kComposite);
+      if (comp_name.defined()) {
+        std::string comp_name_str = comp_name;
+        size_t i = comp_name_str.find('.');
+        if (i != std::string::npos) {
+          std::string comp_target = comp_name_str.substr(0, i);
+          for (const auto& target : this->targets_) {
+            if (std::string(target) == comp_target) {
+              supported_targets.push_back(comp_target);
+              break;
+            }
+          }
+        }
       }
-      call = Call(call->op, compiler_begins, call->attrs);
     }
+    supported_targets.push_back("default");  // Make default as the last option.
+
+    // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
+    // the highest priority, but we should preserve all supported targets so that
+    // we can make a better decision.
+    std::string target = supported_targets[0];
+
+    // Visit and mutate arguments after the target of this op has been determined.
+    auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn));
+
+    // Add annotations to each arg.
+    auto target_n_args = AnnotateArgs(new_call->args, target);
+    Array<Expr> compiler_begins = std::get<1>(target_n_args);
+    Call call = Call(new_call->op, compiler_begins, new_call->attrs);
+    call->checked_type_ = cn->checked_type_;
+
+    // Update the target map.
+    op_expr_to_target_[call] = target;
 
     return std::move(call);
   }
 
-  Expr VisitExpr_(const TupleNode* op) {
+  Expr VisitExpr_(const TupleNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
+    auto expr = Downcast<Tuple>(new_e);
 
-    auto tup = Downcast<Tuple>(new_e);
-    Array<Expr> new_fields;
-    for (auto field : tup->fields) {
-      new_fields.push_back(InsertEnd(field));
-    }
-    return Tuple(new_fields);
+    auto target_n_args = AnnotateArgs(expr->fields);
+    auto new_expr = Tuple(std::get<1>(target_n_args));
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const TupleGetItemNode* op) {
+  Expr VisitExpr_(const TupleGetItemNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
+    auto expr = Downcast<TupleGetItem>(new_e);
 
-    auto get = Downcast<TupleGetItem>(new_e);
-    if (IsSupported(get->tuple)) {
-      const auto* begin_op =
-        runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
-      CHECK(begin_op);
-      return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index);
-    } else {
-      return TupleGetItem(InsertEnd(get->tuple), get->index);
-    }
+    auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
+    auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const FunctionNode* fn) {
+  Expr VisitExpr_(const FunctionNode* fn) final {
     Function func;
     Expr new_body;
     // don't step into composite functions
@@ -154,84 +225,93 @@ class AnnotateTargetWrapper : public ExprMutator {
     } else {
       auto new_e = ExprMutator::VisitExpr_(fn);
       func = Downcast<Function>(new_e);
-      new_body = InsertEnd(func->body);
+      new_body = func->body;
+      if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
+        new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
+        op_expr_to_target_[new_body] = op_expr_to_target_[func->body];
+      }
     }
-
-    return Function(
-      func->params,
-      new_body,
-      func->ret_type,
-      func->type_params,
-      func->attrs);
+    return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
   }
 
-  Expr VisitExpr_(const LetNode* op) {
+  Expr VisitExpr_(const LetNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
-
     auto let = Downcast<Let>(new_e);
-    return Let(
-      let->var,
-      InsertEnd(let->value),
-      InsertEnd(let->body));
+
+    auto target_n_args = AnnotateArgs({let->value, let->body});
+    auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const IfNode* op) {
+  Expr VisitExpr_(const IfNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
-
-    auto iff = Downcast<If>(new_e);
-    return If(
-      InsertEnd(iff->cond),
-      InsertEnd(iff->true_branch),
-      InsertEnd(iff->false_branch));
+    auto expr = Downcast<If>(new_e);
+
+    auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
+    CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
+    auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1],
+                       std::get<1>(target_n_args)[2]);
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const RefCreateNode* op) {
+  Expr VisitExpr_(const RefCreateNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
+    auto expr = Downcast<RefCreate>(new_e);
 
-    auto create = Downcast<RefCreate>(new_e);
-    return RefCreate(InsertEnd(create->value));
+    auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
+    auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const RefReadNode* op) {
+  Expr VisitExpr_(const RefReadNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
+    auto expr = Downcast<RefRead>(new_e);
 
-    auto read = Downcast<RefRead>(new_e);
-    return RefRead(InsertEnd(read->ref));
+    auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
+    auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const RefWriteNode* op) {
+  Expr VisitExpr_(const RefWriteNode* op) final {
     auto new_e = ExprMutator::VisitExpr_(op);
+    auto expr = Downcast<RefWrite>(new_e);
 
-    auto write = Downcast<RefWrite>(new_e);
-    return RefWrite(
-      InsertEnd(write->ref),
-      InsertEnd(write->value));
+    auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
+    auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
+    op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
+    return std::move(new_expr);
   }
 
  private:
-  std::string target_;
+  /*! \brief The target backends for annotation. */
+  Array<runtime::String> targets_;
+  /*! \brief Maintain the decision of the target for each op expr. */
+  std::unordered_map<Expr, std::string, ObjectHash, ObjectEqual> op_expr_to_target_;
 };
 
-Expr AnnotateTarget(const Expr& expr, const std::string& target) {
-  return AnnotateTargetWrapper(target).Annotate(expr);
+Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
+  return AnnotateTargetWrapper(targets).Mutate(expr);
 }
 
 }  // namespace annotate_target
 
 namespace transform {
 
-Pass AnnotateTarget(const std::string& target) {
+Pass AnnotateTarget(const Array<runtime::String>& targets) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
+        return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, targets));
       };
   auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
                                       {"InferType"});
   return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
 }
 
-TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget")
-.set_body_typed(AnnotateTarget);
+TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget);
 
 }  // namespace transform
 
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index 5253010..601be0f 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -46,216 +46,13 @@
 
 namespace tvm {
 namespace relay {
-namespace partitioning {
+namespace merge_compiler_region {
 
 // Cache compiler_begin and compiler_end annotation ops for equivalence check to
 // reduce registry lookup overhead.
 static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
 static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
 
-/*! \brief This is a pre-requisite pass to merge-supported pass.
- *  The AnnotateRestDefault pass will put "default" Compiler Annotations to
- *  nodes that are not annotated already. This is there to ensure that the
- *  user will not leave un-annotated nodes MergeCompilerRegions pass is run.
- *  Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
- */
-class AnnotateRestDefault : public ExprMutator {
- public:
-  explicit AnnotateRestDefault(const Expr& expr) {
-    regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
-  }
-
-  Expr Annotate(const Expr& expr) {
-    // Its a function that is being passed on to annotate
-    func_ = Downcast<Function>(expr);
-
-    // Corner Case CC1 : If the last node does not belong
-    // to a region node to add a compiler_end
-    auto region = regions_->GetRegion(func_->body);
-    auto mutated_expr = this->VisitExpr(expr);
-    if (!region.defined()) {
-      func_ = Downcast<Function>(mutated_expr);
-      // CC1 : add that compiler end after mutation
-      auto body = InsertEnd(func_->body);
-      func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs());
-      return Downcast<Expr>(func_);
-    }
-    return mutated_expr;
-  }
-
-  /*! \brief This function adds compiler ends to nodes that
-   * don't belong to a region already (default).
-   * \param expr The expression to add a compiler end to.
-   * \return expr The expression with or without a compiler end added.
-   */
-  Expr InsertEnd(const Expr& expr) {
-    if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance<VarNode>() &&
-        !expr->IsInstance<ConstantNode>()) {
-      const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
-      CHECK(end_op);
-      Expr end = (*end_op)(expr, target_);
-      return end;
-    }
-    return expr;
-  }
-
-  /*! \brief This function adds compiler begins to nodes that
-   * don't belong to a region already (default).
-   * \param expr The expression to add a compiler begin to.
-   * \return expr The expression with or without a compiler begin added.
-   */
-  Expr InsertBegin(const Expr& expr) {
-    const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
-    CHECK(begin_op);
-    Expr begin = (*begin_op)(expr, target_);
-    annotated_nodes_.insert(begin);
-    return begin;
-  }
-
-  Expr VisitExpr_(const CallNode* cn) final {
-    auto region = regions_->GetRegion(GetRef<Call>(cn));
-    auto new_e = ExprMutator::VisitExpr_(cn);
-    Call call = Downcast<Call>(new_e);
-
-    // Add compiler ends if the parent isn't annotated
-    Array<Expr> args;
-    for (auto arg : call->args) {
-      args.push_back(InsertEnd(arg));
-    }
-
-    Expr updated_call = Call(call->op, args, call->attrs);
-    if (!region.defined()) {
-      // if the current node does not belong to annotated region
-      // annotate the all incoming edges (args)
-      // with "default" compiler_begin annotations.
-      Array<Expr> compiler_begins;
-      for (auto arg : args) {
-        compiler_begins.push_back(InsertBegin(arg));
-      }
-      updated_call = Call(call->op, compiler_begins, call->attrs);
-    } else {
-      annotated_nodes_.insert(updated_call);
-    }
-    return updated_call;
-  };
-
-  Expr VisitExpr_(const TupleNode* op) {
-    auto region = regions_->GetRegion(GetRef<Tuple>(op));
-    auto new_e = ExprMutator::VisitExpr_(op);
-    Tuple tup = Downcast<Tuple>(new_e);
-
-    Array<Expr> fields;
-    for (auto field : tup->fields) {
-      fields.push_back(InsertEnd(field));
-    }
-
-    Expr updated_tuple = Tuple(fields);
-    if (!region.defined()) {
-      Array<Expr> compiler_begins;
-      for (const auto& field : fields) {
-        compiler_begins.push_back(InsertBegin(field));
-      }
-      updated_tuple = Tuple(compiler_begins);
-    } else {
-      annotated_nodes_.insert(updated_tuple);
-    }
-    return updated_tuple;
-  }
-
-  Expr VisitExpr_(const TupleGetItemNode* op) {
-    auto region = regions_->GetRegion(GetRef<TupleGetItem>(op));
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto get = Downcast<TupleGetItem>(new_e);
-
-    auto updated_tuple = InsertEnd(get->tuple);
-    Expr updated_get = TupleGetItem(updated_tuple, get->index);
-    if (!region.defined()) {
-      updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index);
-    } else {
-      annotated_nodes_.insert(updated_get);
-    }
-    return updated_get;
-  }
-
-  Expr VisitExpr_(const IfNode* op) {
-    auto region = regions_->GetRegion(GetRef<If>(op));
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto iff = Downcast<If>(new_e);
-
-    if (!region.defined()) {
-      return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)),
-                InsertBegin(InsertEnd(iff->false_branch)));
-    } else {
-      Expr updated_iff =
-          If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch));
-      annotated_nodes_.insert(updated_iff);
-      return updated_iff;
-    }
-  }
-
-  Expr VisitExpr_(const LetNode* op) {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto let = Downcast<Let>(new_e);
-    return Let(let->var, InsertEnd(let->value), InsertEnd(let->body));
-  }
-
-  Expr VisitExpr_(const RefCreateNode* op) {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto create = Downcast<RefCreate>(new_e);
-    return RefCreate(InsertEnd(create->value));
-  }
-
-  Expr VisitExpr_(const RefReadNode* op) {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto read = Downcast<RefRead>(new_e);
-    return RefRead(InsertEnd(read->ref));
-  }
-
-  Expr VisitExpr_(const RefWriteNode* op) {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto write = Downcast<RefWrite>(new_e);
-    return RefWrite(InsertEnd(write->ref), InsertEnd(write->value));
-  }
-
- private:
-  AnnotatedRegionSet regions_;
-  const std::string target_ = "default";
-  Function func_;
-  std::unordered_set<Expr, ObjectHash, ObjectEqual> annotated_nodes_;
-};
-
-class MergeAnnotations : public ExprMutator {
- public:
-  explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
-
-  Expr VisitExpr_(const CallNode* call) final {
-    // remove 'default' annotations
-    auto attrs = call->attrs.as<CompilerAttrs>();
-    if (attrs != nullptr && attrs->compiler == "default") {
-      return VisitExpr(call->args[0]);
-    }
-    // Merge annotations which are now internal to a region.
-    // This happens if we see a compiler begin next to a
-    // compiler end and they're both in the same region.
-    if (call->op == compiler_begin_op) {
-      if (call->args[0]->IsInstance<CallNode>()) {
-        auto arg = Downcast<Call>(call->args[0]);
-        if (arg->op == compiler_end_op) {
-          auto region1 = regions_->GetRegion(GetRef<Call>(call));
-          auto region2 = regions_->GetRegion(arg);
-          if (region1 == region2) {
-            return VisitExpr(arg->args[0]);
-          }
-        }
-      }
-    }
-    return ExprMutator::VisitExpr_(call);
-  }
-
- private:
-  AnnotatedRegionSet regions_;
-};
-
 class RegionMerger : public ExprVisitor {
  public:
   explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
@@ -263,62 +60,74 @@ class RegionMerger : public ExprVisitor {
   void VisitExpr_(const CallNode* call) final {
     if (call->op == compiler_end_op) {
       auto region = regions_->GetRegion(GetRef<Call>(call));
-      if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return;
-      // set the region target
+
+      // Skip this region if it has been merged to the other region.
+      if (merged_regions_.find(region->GetID()) != merged_regions_.end()) {
+        return;
+      }
+
+      // Check the region target.
       auto compiler_attrs = call->attrs.as<CompilerAttrs>();
-      region_targets_[region->GetID()] = compiler_attrs->compiler;
-      // first look at the region args to determine the parent regions
+      CHECK_EQ(region->GetTarget(), compiler_attrs->compiler);
+
+      // Visit the unmerged parent regions.
       for (const auto& arg : region->GetInputs()) {
-        // all args should be begin annotations
+        // Region inputs must be begin annotation, and the region of
+        // the begin annotation's argument is the parent region.
         auto begin = Downcast<Call>(arg);
         CHECK_EQ(begin->op, compiler_begin_op);
-        // the arguments of the begin annotations will be in the parent regions
         auto parent_region = regions_->GetRegion(begin->args[0]);
-        // if there is no parent region, move on
-        if (!parent_region.defined()) continue;
-        // merge the parent region if it hasn't been done already
-        if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
+
+        // Skip this region if it has been merged.
+        if (!parent_region.defined()) {
+          continue;
+        } else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
           VisitExpr(begin->args[0]);
         }
       }
-      // get the mergeable regions now all the parents have been visited
+
+      // Collect unmerged parent regions.
       std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> mergeable_regions;
       for (const auto& arg : region->GetInputs()) {
         auto begin = Downcast<Call>(arg);
         CHECK_EQ(begin->op, compiler_begin_op);
         auto parent_region = regions_->GetRegion(begin->args[0]);
-        if (!parent_region.defined()) continue;
-        mergeable_regions.insert(parent_region);
+        if (parent_region.defined()) {
+          mergeable_regions.insert(parent_region);
+        }
       }
+
+      // Propogate all the parent restrictions to the current region.
       auto& region_restrictions = region_restrictions_[region->GetID()];
       for (const auto& parent_region : mergeable_regions) {
-        // add all the parent restrictions to the current region
         auto parent_restrictions = region_restrictions_[parent_region->GetID()];
         region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
       }
+
       for (const auto& parent_region : mergeable_regions) {
-        bool merged = false;
-        // check the parent region has the same target
-        if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) {
-          // check the parent region isn't in the restrictions
-          if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) {
-            // merge the parent region into the current region
-            regions_->MergeRegions(parent_region, region);
-            // update the restrictions of all other regions to reflect the
-            // change in id
-            for (const auto& r : regions_) {
-              auto& restrictions = region_restrictions_[r->GetID()];
-              if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
-                restrictions.erase(parent_region->GetID());
-                restrictions.insert(region->GetID());
-              }
-            }
-            merged = true;
+        // Skip the parent region with a different target.
+        if (parent_region->GetTarget() != compiler_attrs->compiler) {
+          region_restrictions.insert(parent_region->GetID());
+          continue;
+        }
+
+        // Skip the parent region if it is in the restriction set.
+        if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) {
+          continue;
+        }
+
+        // Merge the parent region to the current one.
+        regions_->MergeRegions(parent_region, region);
+
+        // Replace the parent region ID with the current region for all
+        // other regions' restriction sets.
+        for (const auto& r : regions_) {
+          auto& restrictions = region_restrictions_[r->GetID()];
+          if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
+            restrictions.erase(parent_region->GetID());
+            restrictions.insert(region->GetID());
           }
         }
-        // if the parent wasn't merged, add it as a restriction to the current
-        // region
-        if (!merged) region_restrictions.insert(parent_region->GetID());
       }
       merged_regions_.insert(region->GetID());
     }
@@ -328,42 +137,58 @@ class RegionMerger : public ExprVisitor {
  private:
   AnnotatedRegionSet regions_;
   std::unordered_set<int> merged_regions_;
-  std::map<int, std::unordered_set<int>> region_restrictions_;
-  std::map<int, std::string> region_targets_;
+  std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
 };
 
-Expr MergeCompilerRegions(const Expr& expr) {
-  // Annotate all the nodes that aren't annotated as 'default'.
-  AnnotateRestDefault anno_default(expr);
-  auto expr_all_annotated = anno_default.Annotate(expr);
+class MergeAnnotations : public ExprMutator {
+ public:
+  explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
+
+  Expr VisitExpr_(const CallNode* call) final {
+    // Merge annotations which are now internal to a region.
+    // This happens if we see a compiler begin next to a
+    // compiler end and they're both in the same region.
+    if (call->op == compiler_begin_op && call->args[0]->IsInstance<CallNode>()) {
+      auto arg = Downcast<Call>(call->args[0]);
+      if (arg->op == compiler_end_op) {
+        auto region1 = regions_->GetRegion(GetRef<Call>(call));
+        auto region2 = regions_->GetRegion(arg);
+        if (region1 == region2) {
+          return VisitExpr(arg->args[0]);
+        }
+      }
+    }
+    return ExprMutator::VisitExpr_(call);
+  }
+
+ private:
+  AnnotatedRegionSet regions_;
+};
 
+Expr MergeCompilerRegions(const Expr& expr) {
   // Create regions using the annotations.
-  AnnotatedRegionSet regions =
-      AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op);
+  AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op);
 
-  // By now, all the nodes have some sort of annotation.
-  // Region merger is an ExprVisitor that will update the
-  // AnnotatedRegionSet, merging all the regions that can be merged.
+  // Analyze the graph to explore the opportunities of merging regions.
   RegionMerger merger(regions);
-  merger.VisitExpr(expr_all_annotated);
+  merger.VisitExpr(expr);
 
-  // This updates the expression to remove annotations that are now
-  // 'internal' to a merged region.
+  // Remove annotations that are not in the region boundaries.
   MergeAnnotations merge_anno(regions);
-  return merge_anno.Mutate(expr_all_annotated);
+  return merge_anno.Mutate(expr);
 }
 
-}  // namespace partitioning
+}  // namespace merge_compiler_region
 
 namespace transform {
 
 Pass MergeCompilerRegions() {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(partitioning::MergeCompilerRegions(f));
+        return Downcast<Function>(merge_compiler_region::MergeCompilerRegions(f));
       };
-  auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
-  return Sequential({partitioned, InferType()});
+  auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
+  return Sequential({merged, InferType()});
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc
index 8eeac17..fa9c8c4 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -477,13 +477,48 @@ class Partitioner : public ExprMutator {
   IRModule module_;
 };
 
+class DefaultRemover : public ExprMutator {
+ public:
+  explicit DefaultRemover(const IRModule& module) : module_(module) {}
+
+  IRModule Remove() {
+    auto glob_funcs = module_->functions;
+    for (const auto& pair : glob_funcs) {
+      if (auto* fn = pair.second.as<FunctionNode>()) {
+        auto func = GetRef<Function>(fn);
+        func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
+                        func->attrs);
+        module_->Update(pair.first, func);
+      }
+    }
+    return module_;
+  }
+
+  Expr VisitExpr_(const CallNode* call) final {
+    auto attrs = call->attrs.as<CompilerAttrs>();
+    if (attrs != nullptr && attrs->compiler == "default") {
+      return VisitExpr(call->args[0]);
+    }
+    return ExprMutator::VisitExpr_(call);
+  }
+
+ private:
+  IRModule module_;
+};
+
 }  // namespace partitioning
 
 namespace transform {
 
 Pass PartitionGraph() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
-      [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
+      [=](IRModule m, PassContext pc) {
+        // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
+        // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
+        // all "default" annotations and should be deleted in the future.
+        auto new_m = partitioning::DefaultRemover(m).Remove();
+        return partitioning::Partitioner(new_m).Partition();
+  };
   auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
   return Sequential({partitioned, InferType()});
 }
diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc
index 4dc023f..cc430b2 100644
--- a/src/runtime/contrib/dnnl/dnnl.cc
+++ b/src/runtime/contrib/dnnl/dnnl.cc
@@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
   read_from_dnnl_memory(out, dst_memory);
 }
 
-extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
-                        float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
+extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
+                        float* variance, float* out, int p_N_, int p_C_,
                         int p_H_, int p_W_, int p_E_) {
-  // FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update
-  // the rest two because no one cares about them for now. Should update it in the future.
   using tag = memory::format_tag;
   using dt = memory::data_type;
 
diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h
index cf474f9..4d0b100 100644
--- a/src/runtime/contrib/dnnl/dnnl_kernel.h
+++ b/src/runtime/contrib/dnnl/dnnl_kernel.h
@@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p
 extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
 
 extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
-                                float* variance, float* out, float* new_mean, float* new_variance,
-                                int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
+                                float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_,
+                                int p_e_);
 
 extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
                                  int p_h_, int p_w_);
diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py
index a246398..f3c157d 100644
--- a/tests/python/relay/test_annotated_regions.py
+++ b/tests/python/relay/test_annotated_regions.py
@@ -15,13 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
+import tvm
 from tvm import relay
 from tvm.relay.op.annotation import compiler_begin, compiler_end
 
 
-def check_region(region_set, args, nodes, rets):
+def check_region(region_set, target, args, nodes, rets):
     region = region_set.get_region(args[0])
     assert region
+    assert target == region.target
     assert set(args) == set(region.args)
     assert set(nodes) == set(region.nodes)
     assert set(rets) == set(region.rets)
@@ -51,24 +53,28 @@ def test_region_set_creator_diamond():
     assert len(region_set) == 4
     check_region(
         region_set,
+        'test_target',
         [cb_1],
         [cb_1, O_1, ce_1, ce_2],
         [ce_1, ce_2],
     )
     check_region(
         region_set,
+        'test_target',
         [cb_2],
         [cb_2, O_2, ce_3],
         [ce_3],
     )
     check_region(
         region_set,
+        'default',
         [cb_d],
         [cb_d, X, ce_d],
         [ce_d],
     )
     check_region(
         region_set,
+        'test_target',
         [cb_3, cb_4],
         [cb_3, cb_4, O_3, ce_4],
         [ce_4],
@@ -88,7 +94,9 @@ def test_region_set_creator_merged():
     cb_3 = compiler_begin(ce_3, 'test_target')
     cb_4 = compiler_begin(ce_d, 'test_target')
     O_3 = relay.add(cb_3, cb_4)
-    ce_4 = compiler_end(O_3, 'test_target')
+    O_4 = relay.add(cb_3, cb_4)
+    O_5 = relay.Tuple([O_3, O_4])
+    ce_4 = compiler_end(O_5, 'test_target')
     merged = relay.Function([data], ce_4)
 
     region_set = relay.analysis.AnnotatedRegionSet(merged,
@@ -97,20 +105,23 @@ def test_region_set_creator_merged():
     assert len(region_set) == 3
     check_region(
         region_set,
+        'test_target',
         [cb_1],
         [cb_1, O_1, O_2, ce_2, ce_3],
         [ce_2, ce_3],
     )
     check_region(
         region_set,
+        'default',
         [cb_d],
         [cb_d, X, ce_d],
         [ce_d],
     )
     check_region(
         region_set,
+        'test_target',
         [cb_3, cb_4],
-        [cb_3, cb_4, O_3, ce_4],
+        [cb_3, cb_4, O_3, O_4, O_5, ce_4],
         [ce_4],
     )
 
@@ -118,4 +129,3 @@ def test_region_set_creator_merged():
 if __name__ == "__main__":
     test_region_set_creator_diamond()
     test_region_set_creator_merged()
-
diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py
similarity index 72%
rename from tests/python/relay/test_annotate_target.py
rename to tests/python/relay/test_pass_annotate_target.py
index dd00d7e..705a261 100644
--- a/tests/python/relay/test_annotate_target.py
+++ b/tests/python/relay/test_pass_annotate_target.py
@@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet():
                  (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
 
 
-@reg.register("nn.relu", "target.test")
-def relu(attrs, args):
-    return True
-
-
 def test_multiple_ends():
+    @reg.register("nn.relu", "target.test")
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
     def before():
         x = relay.var("x", shape=(10, 10))
         r = relay.nn.relu(x)
@@ -208,10 +207,17 @@ def test_multiple_ends():
         r = relay.nn.relu(cb_1)
         ce_1 = relay.annotation.compiler_end(r, "test")
         ce_2 = relay.annotation.compiler_end(r, "test")
-        a_1 = relay.abs(ce_1)
-        a_2 = relay.abs(ce_2)
-        out = relay.add(a_1, a_2)
-        f = relay.Function([x], out)
+        cb_2 = relay.annotation.compiler_begin(ce_1, "default")
+        cb_3 = relay.annotation.compiler_begin(ce_2, "default")
+        a_1 = relay.abs(cb_2)
+        a_2 = relay.abs(cb_3)
+        ce_3 = relay.annotation.compiler_end(a_1, "default")
+        ce_4 = relay.annotation.compiler_end(a_2, "default")
+        cb_4 = relay.annotation.compiler_begin(ce_3, "default")
+        cb_5 = relay.annotation.compiler_begin(ce_4, "default")
+        out = relay.add(cb_4, cb_5)
+        ce_6 = relay.annotation.compiler_end(out, "default")
+        f = relay.Function([x], ce_6)
         mod = tvm.IRModule.from_expr(f)
         return mod
 
@@ -220,6 +226,72 @@ def test_multiple_ends():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_type_propagation():
+    target = "test_type_propagation"
+
+    @reg.register("nn.relu", "target." + target)
+    def relu(attrs, args): # pylint: disable=unused-variable
+        return args[0].checked_type.dtype == "float32"
+
+    def before():
+        x = relay.var("x", shape=(10, 10))
+        r = relay.nn.relu(x)
+        out = relay.nn.relu(r)
+        f = relay.Function([x], out)
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    # If the type isn't propogated, then the relu checker function will fail to get the dtype.
+    assert transform.AnnotateTarget(target)(before())
+
+
+def test_tuple():
+    target = "test_tuple"
+
+    @reg.register("nn.relu", "target." + target)
+    def relu(attrs, args): # pylint: disable=unused-variable
+        return True
+
+    @reg.register("concatenate", "target." + target)
+    def concatenate(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    """Test that TupleNode is included in annotation when surrounded by supported nodes."""
+    def before():
+        x = relay.var("x", shape=(10, 5))
+        y = relay.var("y", shape=(10, 5))
+        a_1 = relay.nn.relu(x)
+        a_2 = relay.nn.relu(y)
+        out = relay.concatenate((a_1, a_2), axis=1)
+        f = relay.Function([x, y], out)
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    def after():
+        x = relay.var("x", shape=(10, 5))
+        y = relay.var("y", shape=(10, 5))
+        cb_1 = relay.annotation.compiler_begin(x, target)
+        cb_2 = relay.annotation.compiler_begin(y, target)
+        a_1 = relay.nn.relu(cb_1)
+        a_2 = relay.nn.relu(cb_2)
+        ce_1 = relay.annotation.compiler_end(a_1, target)
+        ce_2 = relay.annotation.compiler_end(a_2, target)
+        cb_3 = relay.annotation.compiler_begin(ce_1, target)
+        cb_4 = relay.annotation.compiler_begin(ce_2, target)
+        tup = relay.Tuple([cb_3, cb_4])
+        ce_3 = relay.annotation.compiler_end(tup, target)
+        cb_3 = relay.annotation.compiler_begin(ce_3, target)
+        out = relay.op._make.concatenate(cb_3, 1)
+        ce_4 = relay.annotation.compiler_end(out, target)
+        f = relay.Function([x, y], ce_4)
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    result = transform.AnnotateTarget(target)(before())
+    expected = transform.InferType()(after())
+    assert tvm.ir.structural_equal(expected, result)
+
+
 def test_composite_function():
     def before():
         a = relay.var('a', shape=(10, 10))
@@ -265,8 +337,37 @@ def test_composite_function():
     assert tvm.ir.structural_equal(expected, result)
 
 
+def test_multiple_runs():
+    @reg.register("nn.relu", "target.A")
+    def relu(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    @reg.register("add", "target.B")
+    def add(attrs, args):  # pylint: disable=unused-variable
+        return True
+
+    def before():
+        x = relay.var("x", shape=(10, 5))
+        a_1 = relay.nn.relu(x)
+        a_2 = relay.abs(a_1)
+        a_3 = relay.nn.relu(a_1)
+        out = relay.add(a_2, a_3)
+
+        f = relay.Function([x], out)
+        mod = tvm.IRModule.from_expr(f)
+        return mod
+
+    mod = transform.AnnotateTarget("A")(before())
+    mod = transform.AnnotateTarget("B")(mod)
+    expected = transform.AnnotateTarget(["A", "B"])(before())
+    assert tvm.ir.structural_equal(expected, mod)
+
+
 if __name__ == "__main__":
-    test_multiple_ends()
     test_extern_dnnl()
-    #test_extern_dnnl_mobilenet()
     test_composite_function()
+    #test_extern_dnnl_mobilenet()
+    test_multiple_ends()
+    test_type_propagation()
+    test_tuple()
+    test_multiple_runs()
diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py
index f316a41..7d7db35 100644
--- a/tests/python/relay/test_pass_merge_compiler_regions.py
+++ b/tests/python/relay/test_pass_merge_compiler_regions.py
@@ -30,9 +30,9 @@ def test_diamond_graph_fanouts():
     X = not supported by target
 
        O         O
-      / \       /               \
+      / \\      /               \\
      O   X --> O    +       +    X
-      \ /              \ /
+     \\ /             \\ /
        O                O
 
     Note that we can't just merge the three supported operators together,
@@ -45,17 +45,20 @@ def test_diamond_graph_fanouts():
         ce_1 = compiler_end(O_1, "test")
         ce_2 = compiler_end(O_1, "test")
         cb_2 = compiler_begin(ce_1, "test")
+        cb_3 = compiler_begin(ce_2, "default")
         O_2 = relay.nn.relu(cb_2)
         ce_3 = compiler_end(O_2, "test")
 
-        X = relay.tanh(ce_2)
 
-        cb_3 = compiler_begin(ce_3, "test")
-        cb_4 = compiler_begin(X, "test")
-        O_3 = relay.add(cb_3, cb_4)
-        ce_4 = compiler_end(O_3, "test")
+        X = relay.tanh(cb_3)
+        ce_4 = compiler_end(X, "default")
 
-        diamond = relay.Function([data], ce_4)
+        cb_4 = compiler_begin(ce_3, "test")
+        cb_5 = compiler_begin(ce_4, "test")
+        O_3 = relay.add(cb_4, cb_5)
+        ce_5 = compiler_end(O_3, "test")
+
+        diamond = relay.Function([data], ce_5)
         return diamond
 
     def expected():
@@ -66,14 +69,16 @@ def test_diamond_graph_fanouts():
         O_2 = relay.nn.relu(O_1)
         ce_3 = compiler_end(O_2, "test")
 
-        X = relay.tanh(ce_2)
+        cb_3 = compiler_begin(ce_2, "default")
+        X = relay.tanh(cb_3)
+        ce_4 = compiler_end(X, "default")
 
-        cb_3 = compiler_begin(ce_3, "test")
-        cb_4 = compiler_begin(X, "test")
-        O_3 = relay.add(cb_3, cb_4)
-        ce_4 = compiler_end(O_3, "test")
+        cb_4 = compiler_begin(ce_3, "test")
+        cb_5 = compiler_begin(ce_4, "test")
+        O_3 = relay.add(cb_4, cb_5)
+        ce_5 = compiler_end(O_3, "test")
 
-        func = relay.Function([data], ce_4)
+        func = relay.Function([data], ce_5)
         return func
 
     result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions())
@@ -85,7 +90,7 @@ def test_example_graph():
     """This tests the merging algorithm on the example used in the RFC.
 
     See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830
-    Blue nodes are adds, red nodes are subtracts.
+    Blue nodes are adds (target: test), red nodes are subtracts (target: default).
     """
     def annotated():
         in_1 = relay.var('in_1', shape=(10, 10), dtype='float32')
@@ -112,21 +117,30 @@ def test_example_graph():
         node2 = relay.add(begin4, begin5)
         end2 = compiler_end(node2, "test")
 
-        node3 = relay.subtract(in_5, in_6)
-        node4 = relay.subtract(in_7, node3)
+        dbegin0 = compiler_begin(in_5, "default")
+        dbegin1 = compiler_begin(in_6, "default")
+        node3 = relay.subtract(dbegin0, dbegin1)
+        dbegin2 = compiler_begin(in_7, "default")
+        dend1 = compiler_end(node3, "default")
+        dbegin3 = compiler_begin(dend1, "default")
+        node4 = relay.subtract(dbegin2, dbegin3)
+        dend2 = compiler_end(node4, "default")
 
         begin6 = compiler_begin(end2, "test")
-        begin7 = compiler_begin(node4, "test")
+        begin7 = compiler_begin(dend2, "test")
         node5 = relay.add(begin6, begin7)
         end3 = compiler_end(node5, "test")
         end4 = compiler_end(node5, "test")
-        node6 = relay.subtract(in_8, end3)
+        dbegin4 = compiler_begin(in_8, "default")
+        dbegin5 = compiler_begin(end3, "default")
+        node6 = relay.subtract(dbegin4, dbegin5)
         begin8 = compiler_begin(in_9, "test")
         begin9 = compiler_begin(end4, "test")
         node7 = relay.add(begin8, begin9)
         end5 = compiler_end(node7, "test")
 
-        begin10 = compiler_begin(node6, "test")
+        dend3 = compiler_end(node6, "default")
+        begin10 = compiler_begin(dend3, "test")
         begin11 = compiler_begin(end5, "test")
         node8 = relay.add(begin10, begin11)
         end6 = compiler_end(node8, "test")
@@ -159,20 +173,27 @@ def test_example_graph():
         node1 = relay.add(begin2, begin3)
         node2 = relay.add(node0, node1)
 
-        node3 = relay.subtract(in_5, in_6)
-        node4 = relay.subtract(in_7, node3)
+        dbegin0 = compiler_begin(in_5, "default")
+        dbegin1 = compiler_begin(in_6, "default")
+        dbegin2 = compiler_begin(in_7, "default")
+        node3 = relay.subtract(dbegin0, dbegin1)
+        node4 = relay.subtract(dbegin2, node3)
+        dend0 = compiler_end(node4, "default")
 
-        begin4 = compiler_begin(node4, "test")
+        begin4 = compiler_begin(dend0, "test")
         begin5 = compiler_begin(in_9, "test")
         node5 = relay.add(node2, begin4)
         end1 = compiler_end(node5, "test")
 
-        node6 = relay.subtract(in_8, end1)
+        dbegin4 = compiler_begin(end1, "default")
+        dbegin5 = compiler_begin(in_8, "default")
+        node6 = relay.subtract(dbegin5, dbegin4)
+        dend1 = compiler_end(node6, "default")
 
         node7 = relay.add(begin5, node5)
         end2 = compiler_end(node7, "test")
         begin6 = compiler_begin(end2, "test")
-        begin7 = compiler_begin(node6, "test")
+        begin7 = compiler_begin(dend1, "test")
 
         node8 = relay.add(begin7, begin6)
 
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 1968f34..fb21682 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -17,6 +17,7 @@
 """Unit tests for graph partitioning."""
 import os
 import sys
+
 import numpy as np
 import pytest
 
@@ -26,8 +27,12 @@ from tvm import relay
 from tvm import runtime
 from tvm.relay import transform
 from tvm.contrib import util
-from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.relay import transform
+from tvm.relay.backend import compile_engine
 from tvm.relay.expr_functor import ExprMutator
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.runtime import container
+
 
 # Leverage the pass manager to write a simple white list based annotator
 @transform.function_pass(opt_level=0)
@@ -188,6 +193,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         return lib
 
     def check_vm_result():
+        compile_engine.get().clear()
         with relay.build_config(opt_level=3):
             exe = relay.vm.compile(mod, target=target, params=params)
         code, lib = exe.save()
@@ -199,6 +205,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
 
     def check_graph_runtime_result():
+        compile_engine.get().clear()
         with relay.build_config(opt_level=3):
             json, lib, param = relay.build(mod, target=target, params=params)
         lib = update_lib(lib)
@@ -449,9 +456,9 @@ def test_extern_dnnl_mobilenet():
     mod, params = relay.testing.mobilenet.get_workload(
         batch_size=1, dtype='float32')
 
-    op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"]
     mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
-    mod = WhiteListAnnotator(op_list, "dnnl")(mod)
+    mod = transform.AnnotateTarget(["dnnl"])(mod)
+    mod = transform.MergeCompilerRegions()(mod)
     mod = transform.PartitionGraph()(mod)
     i_data = np.random.uniform(0, 1, ishape).astype(dtype)
 
@@ -851,6 +858,7 @@ if __name__ == "__main__":
     test_extern_ccompiler_default_ops()
     test_extern_ccompiler()
     test_extern_dnnl()
+    # TODO(@comaniac, @zhiics): Fix constant node and re-open this case.
     #test_extern_dnnl_mobilenet()
     test_function_lifting()
     test_function_lifting_inline()


Mime
View raw message