tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From z...@apache.org
Subject [incubator-tvm] branch master updated: [BYOC] Use Non-Recursive Visitor/Mutator (#5410)
Date Thu, 23 Apr 2020 20:56:54 GMT
This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new ba87604  [BYOC] Use Non-Recursive Visitor/Mutator (#5410)
ba87604 is described below

commit ba8760462bb56c3a5571bb00a2f64355d98d1b43
Author: Cody Yu <comaniac0422@gmail.com>
AuthorDate: Thu Apr 23 13:56:43 2020 -0700

    [BYOC] Use Non-Recursive Visitor/Mutator (#5410)
    
    * Non-Recursive AnnotatedTarget and MergeAnnotation
    
    * Non-Recursive AnnotatedRegionSet and RegionMerger
---
 src/relay/analysis/annotated_region_set.cc      | 133 ++++++++++++------------
 src/relay/transforms/annotate_target.cc         |  85 +++++++--------
 src/relay/transforms/merge_compiler_regions.cc  |  14 +--
 tests/python/relay/test_pass_partition_graph.py |  67 ++++++------
 4 files changed, 144 insertions(+), 155 deletions(-)

diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
index 94c7621..103ddcb 100644
--- a/src/relay/analysis/annotated_region_set.cc
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -86,32 +86,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string&
target) {
   return *ret.first;
 }
 
-class AnnotatedRegionSet::Creator : public ExprVisitor {
+class AnnotatedRegionSet::Creator : protected MixedModeVisitor {
  public:
   Creator(const Op& region_begin_op, const Op& region_end_op)
       : begin_op_(region_begin_op), end_op_(region_end_op) {}
 
+  AnnotatedRegionSet Create(const Expr& expr) {
+    VisitExpr(expr);
+    return std::move(region_set_);
+  }
+
+  void AddToArgRegion(Expr expr, Array<Expr> args) {
+    // Merge argument regions and add itself to the region.
+
+    // Find the first open region.
+    AnnotatedRegion region;
+    for (auto arg : args) {
+      const CallNode* end = arg.as<CallNode>();
+      if (end && end->op == end_op_) {  // Ignore closed regions.
+          continue;
+      }
+
+      region = region_set_->GetRegion(arg);
+      if (region.defined()) {
+          break;
+      }
+    }
+
+    // Try to merge open regions.
+    for (auto arg : args) {
+      const CallNode* end = arg.as<CallNode>();
+      if (end && end->op == end_op_) {  // Ignore closed regions.
+          continue;
+      }
+
+      auto arg_region = region_set_->GetRegion(arg);
+      CHECK_EQ(region.defined(), arg_region.defined())
+          << "Arg regions are inconsistent: " << AsText(expr);
+      if (region.defined() && region != arg_region) {
+        region_set_->MergeRegions(arg_region, region);
+      }
+    }
+    if (region.defined()) {
+      region_set_->AddToRegion(region, expr);
+    }
+  }
+
   void VisitExpr_(const CallNode* call) {
     auto op_node = call->op.as<OpNode>();
 
     if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
-      // Propagate region to arguments
-      auto region = region_set_->GetRegion(GetRef<Call>(call));
-      if (region.defined()) {
-        for (auto arg : call->args) {
-          region_set_->AddToRegion(region, arg);
-        }
-      }
+      AddToArgRegion(GetRef<Call>(call), call->args);
     } else if (call->op == begin_op_) {
       // The annotation node is inserted on edge so it must have only one argument.
       CHECK_EQ(call->args.size(), 1U);
+      std::string target = call->attrs.as<CompilerAttrs>()->compiler;
 
+      // Check if the argument already belongs to a region
       auto region = region_set_->GetRegion(GetRef<Call>(call));
-      if (!region.defined()) {
-        throw Error(ErrorBuilder()
-                      << "Cannot find the corresponding region for start annotation:\n"
-                      << AsText(GetRef<Call>(call), false));
-      }
+      CHECK(!region.defined());
+
+      // Create a new region.
+      region = region_set_->MakeRegion(target);
+      region->nodes_.insert(GetRef<Call>(call));
       region->ins_.push_back(GetRef<Call>(call));
     } else {
       CHECK_EQ(call->op, end_op_);
@@ -122,9 +159,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
       // Check if the argument already belongs to a region
       auto region = region_set_->GetRegion(call->args[0]);
       if (!region.defined()) {
-        // Create a new region if the argument is not belonged to any regions yet.
-        region = region_set_->MakeRegion(target);
-        region->nodes_.insert(call->args[0]);
+        throw Error(ErrorBuilder() << "Cannot find the corresponding region for end
annotation:\n"
+                                   << AsText(GetRef<Call>(call), false));
       } else {
         // If the argument is belonged to a region, it must have the same target.
         // Otherwise we should see a region_begin op.
@@ -133,83 +169,44 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
       region->nodes_.insert(GetRef<Call>(call));
       region->outs_.push_back(GetRef<Call>(call));
     }
-    ExprVisitor::VisitExpr_(call);
-  }
-
-  AnnotatedRegionSet Create(const Expr& expr) {
-    VisitExpr(expr);
-    return std::move(region_set_);
   }
 
   void VisitExpr_(const TupleNode* op) {
-    auto region = region_set_->GetRegion(GetRef<Tuple>(op));
-    if (region.defined()) {
-      for (auto field : op->fields) {
-        region_set_->AddToRegion(region, field);
-      }
-    }
-    ExprVisitor::VisitExpr_(op);
+    AddToArgRegion(GetRef<Tuple>(op), op->fields);
   }
 
   void VisitExpr_(const TupleGetItemNode* g) {
-    auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
-    if (region.defined()) {
-      region_set_->AddToRegion(region, g->tuple);
-    }
-    ExprVisitor::VisitExpr_(g);
-  }
-
-  void VisitExpr_(const FunctionNode* op) {
-    auto region = region_set_->GetRegion(GetRef<Function>(op));
-    if (region.defined()) {
-      for (auto param : op->params) {
-        region_set_->AddToRegion(region, param);
-      }
-    }
-    ExprVisitor::VisitExpr_(op);
+    Array<Expr> args = {g->tuple};
+    AddToArgRegion(GetRef<TupleGetItem>(g), args);
   }
 
   void VisitExpr_(const LetNode* op) {
-    auto region = region_set_->GetRegion(GetRef<Let>(op));
-    if (region.defined()) {
-      region_set_->AddToRegion(region, op->var);
-      region_set_->AddToRegion(region, op->value);
-      region_set_->AddToRegion(region, op->body);
-    }
+    Array<Expr> args = {op->var, op->value, op->body};
+    AddToArgRegion(GetRef<Let>(op), args);
     ExprVisitor::VisitExpr_(op);
   }
 
   void VisitExpr_(const IfNode* op) {
-    auto region = region_set_->GetRegion(GetRef<If>(op));
-    if (region.defined()) {
-      region_set_->AddToRegion(region, op->cond);
-      region_set_->AddToRegion(region, op->true_branch);
-      region_set_->AddToRegion(region, op->false_branch);
-    }
+    Array<Expr> args = {op->cond, op->true_branch, op->false_branch};
+    AddToArgRegion(GetRef<If>(op), args);
     ExprVisitor::VisitExpr_(op);
   }
 
   void VisitExpr_(const RefCreateNode* op) {
-    auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
-    if (region.defined()) {
-      region_set_->AddToRegion(region, op->value);
-    }
+    Array<Expr> args = {op->value};
+    AddToArgRegion(GetRef<RefCreate>(op), args);
     ExprVisitor::VisitExpr_(op);
   }
 
   void VisitExpr_(const RefReadNode* op) {
-    auto region = region_set_->GetRegion(GetRef<RefRead>(op));
-    if (region.defined()) {
-      region_set_->AddToRegion(region, op->ref);
-    }
+    Array<Expr> args = {op->ref};
+    AddToArgRegion(GetRef<RefRead>(op), args);
     ExprVisitor::VisitExpr_(op);
   }
 
   void VisitExpr_(const RefWriteNode* op) {
-    auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
-    if (region.defined()) {
-      region_set_->AddToRegion(region, op->ref);
-    }
+    Array<Expr> args = {op->ref};
+    AddToArgRegion(GetRef<RefWrite>(op), args);
     ExprVisitor::VisitExpr_(op);
   }
 
diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc
index bc6b4b9..4caac04 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -42,9 +42,9 @@ const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._mak
 
 // A helper class to insert annotation boundaries for a program region that will
 // be handled by a specific compiler.
-class AnnotateTargetWrapper : public ExprMutator {
+class AnnotateTargetRewriter : public ExprRewriter {
  public:
-  explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets))
{}
+  explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets))
{}
 
   /*!
    * \brief This function annotates a compiler end and a compiler begin to all arguments.
@@ -108,29 +108,29 @@ class AnnotateTargetWrapper : public ExprMutator {
     return new_op;
   }
 
-  Expr VisitExpr_(const CallNode* cn) final {
+  Expr Rewrite_(const CallNode* pre, const Expr& post) final {
     // Supported targets for this node. The order implies the priority.
     std::vector<std::string> supported_targets;
 
-    auto op_node = cn->op.as<OpNode>();
+    auto op_node = pre->op.as<OpNode>();
 
     // This graph has annotations, meaning that this is not the first time running this pass.
-    if (op_node && cn->op == compiler_begin_op) {
+    if (op_node && pre->op == compiler_begin_op) {
       // Bypass compiler begin due to lack of target information. It will be processed
       // when the following op handling arguments.
-      CHECK_EQ(cn->args.size(), 1U);
-      return VisitExpr(cn->args[0]);
-    } else if (op_node && cn->op == compiler_end_op) {
+      CHECK_EQ(pre->args.size(), 1U);
+      return post.as<CallNode>()->args[0];
+    } else if (op_node && pre->op == compiler_end_op) {
       // Override compiler end with the new target.
-      CHECK_EQ(cn->args.size(), 1U);
-      auto input_expr = VisitExpr(cn->args[0]);
+      CHECK_EQ(pre->args.size(), 1U);
+      auto input_expr = post.as<CallNode>()->args[0];
       CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
       return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
     }
 
     // Peek the first argument. If it is compiler begin then this node had annotated by
     // another target before, so we also consider that target as a supported target.
-    const CallNode* first_arg_call = cn->args[0].as<CallNode>();
+    const CallNode* first_arg_call = pre->args[0].as<CallNode>();
     if (first_arg_call && first_arg_call->op == compiler_begin_op) {
       std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
       if (arg_target != "default") {
@@ -142,21 +142,21 @@ class AnnotateTargetWrapper : public ExprMutator {
     if (op_node) {
       // TVM operators: Check target specific op checking function and add to supported_targets
       // if it is supported.
-      Op op = Downcast<Op>(cn->op);
+      Op op = Downcast<Op>(pre->op);
       CHECK(op.defined());
       for (const auto& target : this->targets_) {
         if (!Op::HasAttr("target." + std::string(target))) {
           continue;
         }
         auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
-        if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) {
+        if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) {
           supported_targets.push_back(target);
         }
       }
-    } else if (cn->op->IsInstance<FunctionNode>()) {
+    } else if (pre->op->IsInstance<FunctionNode>()) {
       // Composite function: Add the target of a composite function to supported_targets
       // if it is in the target list.
-      Function func = Downcast<Function>(cn->op);
+      Function func = Downcast<Function>(pre->op);
       CHECK(func.defined());
 
       if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
@@ -181,23 +181,22 @@ class AnnotateTargetWrapper : public ExprMutator {
     std::string target = supported_targets[0];
 
     // Visit and mutate arguments after the target of this op has been determined.
-    auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn));
+    Call post_call = Downcast<Call>(post);
 
     // Add annotations to each arg.
-    auto target_n_args = AnnotateArgs(new_call->args, target);
+    auto target_n_args = AnnotateArgs(post_call->args, target);
     Array<Expr> compiler_begins = std::get<1>(target_n_args);
-    Call call = Call(new_call->op, compiler_begins, new_call->attrs);
-    call->checked_type_ = cn->checked_type_;
+    Call new_call = Call(post_call->op, compiler_begins, post_call->attrs);
+    new_call->checked_type_ = pre->checked_type_;
 
     // Update the target map.
-    op_expr_to_target_[call] = target;
+    op_expr_to_target_[new_call] = target;
 
-    return std::move(call);
+    return std::move(new_call);
   }
 
-  Expr VisitExpr_(const TupleNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto expr = Downcast<Tuple>(new_e);
+  Expr Rewrite_(const TupleNode* op, const Expr& post) final {
+    auto expr = Downcast<Tuple>(post);
 
     auto target_n_args = AnnotateArgs(expr->fields);
     auto new_expr = Tuple(std::get<1>(target_n_args));
@@ -205,9 +204,8 @@ class AnnotateTargetWrapper : public ExprMutator {
     return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const TupleGetItemNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto expr = Downcast<TupleGetItem>(new_e);
+  Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
+    auto expr = Downcast<TupleGetItem>(post);
 
     auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
     auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
@@ -215,7 +213,7 @@ class AnnotateTargetWrapper : public ExprMutator {
     return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const FunctionNode* fn) final {
+  Expr Rewrite_(const FunctionNode* fn, const Expr& post) final {
     Function func;
     Expr new_body;
     // don't step into composite functions
@@ -223,8 +221,7 @@ class AnnotateTargetWrapper : public ExprMutator {
       func = GetRef<Function>(fn);
       new_body = func->body;
     } else {
-      auto new_e = ExprMutator::VisitExpr_(fn);
-      func = Downcast<Function>(new_e);
+      func = Downcast<Function>(post);
       new_body = func->body;
       if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
         new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
@@ -234,9 +231,8 @@ class AnnotateTargetWrapper : public ExprMutator {
     return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
   }
 
-  Expr VisitExpr_(const LetNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto let = Downcast<Let>(new_e);
+  Expr Rewrite_(const LetNode* op, const Expr& post) final {
+    auto let = Downcast<Let>(post);
 
     auto target_n_args = AnnotateArgs({let->value, let->body});
     auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
@@ -244,9 +240,8 @@ class AnnotateTargetWrapper : public ExprMutator {
     return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const IfNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto expr = Downcast<If>(new_e);
+  Expr Rewrite_(const IfNode* op, const Expr& post) final {
+    auto expr = Downcast<If>(post);
 
     auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
     CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
@@ -256,9 +251,8 @@ class AnnotateTargetWrapper : public ExprMutator {
     return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const RefCreateNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto expr = Downcast<RefCreate>(new_e);
+  Expr Rewrite_(const RefCreateNode* op, const Expr& post) final {
+    auto expr = Downcast<RefCreate>(post);
 
     auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
     auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
@@ -266,9 +260,8 @@ class AnnotateTargetWrapper : public ExprMutator {
     return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const RefReadNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto expr = Downcast<RefRead>(new_e);
+  Expr Rewrite_(const RefReadNode* op, const Expr& post) final {
+    auto expr = Downcast<RefRead>(post);
 
     auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
     auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
@@ -276,9 +269,8 @@ class AnnotateTargetWrapper : public ExprMutator {
     return std::move(new_expr);
   }
 
-  Expr VisitExpr_(const RefWriteNode* op) final {
-    auto new_e = ExprMutator::VisitExpr_(op);
-    auto expr = Downcast<RefWrite>(new_e);
+  Expr Rewrite_(const RefWriteNode* op, const Expr& post) final {
+    auto expr = Downcast<RefWrite>(post);
 
     auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
     auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
@@ -294,7 +286,8 @@ class AnnotateTargetWrapper : public ExprMutator {
 };
 
 Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets)
{
-  return AnnotateTargetWrapper(targets).Mutate(expr);
+  auto rewriter = AnnotateTargetRewriter(targets);
+  return PostOrderRewrite(expr, &rewriter);
 }
 
 }  // namespace annotate_target
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index 601be0f..6fbd0d5 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -53,7 +53,7 @@ namespace merge_compiler_region {
 static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
 static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
 
-class RegionMerger : public ExprVisitor {
+class RegionMerger : public MixedModeVisitor {
  public:
   explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
 
@@ -131,7 +131,6 @@ class RegionMerger : public ExprVisitor {
       }
       merged_regions_.insert(region->GetID());
     }
-    ExprVisitor::VisitExpr_(call);
   }
 
  private:
@@ -140,11 +139,11 @@ class RegionMerger : public ExprVisitor {
   std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
 };
 
-class MergeAnnotations : public ExprMutator {
+class MergeAnnotations : public ExprRewriter {
  public:
   explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
 
-  Expr VisitExpr_(const CallNode* call) final {
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
     // Merge annotations which are now internal to a region.
     // This happens if we see a compiler begin next to a
     // compiler end and they're both in the same region.
@@ -154,11 +153,12 @@ class MergeAnnotations : public ExprMutator {
         auto region1 = regions_->GetRegion(GetRef<Call>(call));
         auto region2 = regions_->GetRegion(arg);
         if (region1 == region2) {
-          return VisitExpr(arg->args[0]);
+          auto post_arg = post.as<CallNode>()->args[0];
+          return post_arg.as<CallNode>()->args[0];
         }
       }
     }
-    return ExprMutator::VisitExpr_(call);
+    return post;
   }
 
  private:
@@ -175,7 +175,7 @@ Expr MergeCompilerRegions(const Expr& expr) {
 
   // Remove annotations that are not in the region boundaries.
   MergeAnnotations merge_anno(regions);
-  return merge_anno.Mutate(expr);
+  return PostOrderRewrite(expr, &merge_anno);
 }
 
 }  // namespace merge_compiler_region
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 14d57a9..2a4fd31 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -522,8 +522,8 @@ def test_function_lifting():
         bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
         func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
                                bn.astuple())
-        func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
-        gv0 = relay.GlobalVar("test_compiler_0")
+        func0 = set_func_attr(func0, "test_compiler", "test_compiler_2")
+        gv0 = relay.GlobalVar("test_compiler_2")
         mod[gv0] = func0
 
         # function for conv2d
@@ -536,8 +536,8 @@ def test_function_lifting():
             channels=16,
             padding=(1, 1))
         func1 = relay.Function([data1, weight1], conv)
-        func1 = set_func_attr(func1, "test_compiler", "test_compiler_1")
-        gv1 = relay.GlobalVar("test_compiler_1")
+        func1 = set_func_attr(func1, "test_compiler", "test_compiler_0")
+        gv1 = relay.GlobalVar("test_compiler_0")
         mod[gv1] = func1
 
         # main function
@@ -630,7 +630,6 @@ def test_constant_propagation():
 
     def expected():
         mod = tvm.IRModule()
-        x = relay.const(ones)
         y = relay.var("y", shape=(8, 8))
         x0 = relay.const(ones)
         y0 = relay.var("y0", shape=(8, 8))
@@ -712,12 +711,12 @@ def test_multiple_outputs():
         mod = tvm.IRModule()
 
         # function 0
-        data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32"))
-        weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32"))
-        bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32"))
-        bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32"))
-        bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32"))
-        bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32"))
+        data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
+        weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
+        bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
+        bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
+        bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
+        bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))
 
         conv_o = relay.nn.conv2d(
             data=data,
@@ -730,12 +729,12 @@ def test_multiple_outputs():
                                    bn_var)
 
         relu_o = relay.nn.relu(bn_o[0])
-        tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o))
+        tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))
 
         func0 = relay.Function([data, weight, bn_gamma, bn_beta,
                                 bn_mean, bn_var], tuple_o)
-        func0 = set_func_attr(func0, "test_target", "test_target_2")
-        gv0 = relay.GlobalVar("test_target_2")
+        func0 = set_func_attr(func0, "test_target", "test_target_0")
+        gv0 = relay.GlobalVar("test_target_0")
         mod[gv0] = func0
 
         # body
@@ -747,9 +746,9 @@ def test_multiple_outputs():
         bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))
 
         f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
-        f0_relu_o = relay.TupleGetItem(f0_o, 2)
+        f0_relu_o = relay.TupleGetItem(f0_o, 0)
         f0_mean_o = relay.TupleGetItem(f0_o, 1)
-        f0_var_o = relay.TupleGetItem(f0_o, 0)
+        f0_var_o = relay.TupleGetItem(f0_o, 2)
 
         f0_mean_abs = relay.abs(f0_mean_o)
         f0_var_abs = relay.abs(f0_var_o)
@@ -791,22 +790,22 @@ def test_mixed_single_multiple_outputs():
         mod = tvm.IRModule()
 
         # function 1
-        f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10))
+        f1_cb1 = relay.var('test_target_0_i0', shape=(10, 10))
         f1_O_1 = relay.abs(f1_cb1)
         f1_O_2 = relay.nn.relu(f1_O_1)
         f1_out = relay.Tuple((f1_O_2, f1_O_1))
         func1 = relay.Function([f1_cb1], f1_out)
-        func1 = set_func_attr(func1, "test_target", "test_target_1")
-        gv1 = relay.GlobalVar("test_target_1")
+        func1 = set_func_attr(func1, "test_target", "test_target_0")
+        gv1 = relay.GlobalVar("test_target_0")
         mod[gv1] = func1
 
         # function 0
-        f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10))
-        f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10))
+        f2_cb3 = relay.var('test_target_1_i0', shape=(10, 10))
+        f2_cb4 = relay.var('test_target_1_i1', shape=(10, 10))
         f2_O_3 = relay.add(f2_cb3, f2_cb4)
         func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
-        func0 = set_func_attr(func0, "test_target", "test_target_0")
-        gv0 = relay.GlobalVar("test_target_0")
+        func0 = set_func_attr(func0, "test_target", "test_target_1")
+        gv0 = relay.GlobalVar("test_target_1")
         mod[gv0] = func0
 
         # body
@@ -1109,22 +1108,22 @@ def test_duplicate_merge_and_tuplegetitem():
         mod = tvm.IRModule()
 
         # function 0
-        f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
-        f0_i1 = relay.var(target+"_1_i1")
-        f0_i2 = relay.var(target+"_1_i2")
-        f0_i3 = relay.var(target+"_1_i3")
-        f0_i4 = relay.var(target+"_1_i4")
+        f0_i0 = relay.var(target + "_0_i0", shape=(10, 10))
+        f0_i1 = relay.var(target + "_0_i1")
+        f0_i2 = relay.var(target + "_0_i2")
+        f0_i3 = relay.var(target + "_0_i3")
+        f0_i4 = relay.var(target + "_0_i4")
         f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
         f0_n1 = f0_n0[1]
         f0_n2 = relay.nn.relu(f0_n0[0])
-        f0_o0 = relay.Tuple([f0_n1, f0_n2])
+        f0_o0 = relay.Tuple([f0_n2, f0_n1])
         func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
 
         func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Compiler", target)
-        func0 = func0.with_attr("global_symbol", target+"_1")
-        gv0 = relay.GlobalVar(target+"_1")
+        func0 = func0.with_attr("global_symbol", target + "_0")
+        gv0 = relay.GlobalVar(target + "_0")
         mod[gv0] = func0
 
         # body
@@ -1136,9 +1135,9 @@ def test_duplicate_merge_and_tuplegetitem():
         function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
         get_out0 = relay.TupleGetItem(function_out, 0)
         get_out1 = relay.TupleGetItem(function_out, 1)
-        out_2 = relay.tanh(get_out0)
-        out_3 = relay.log(get_out0)
-        out = relay.Tuple([get_out1, out_2, out_3])
+        out_2 = relay.tanh(get_out1)
+        out_3 = relay.log(get_out1)
+        out = relay.Tuple([get_out0, out_2, out_3])
         func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
         mod["main"] = func
         return mod


Mime
View raw message