tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] comaniac commented on a change in pull request #5320: [BYOC] Prevent duplicate outputs in subgraph Tuple
Date Tue, 14 Apr 2020 18:34:01 GMT
comaniac commented on a change in pull request #5320: [BYOC] Prevent duplicate outputs in subgraph
Tuple
URL: https://github.com/apache/incubator-tvm/pull/5320#discussion_r408345921
 
 

 ##########
 File path: src/relay/transforms/partition_graph.cc
 ##########
 @@ -206,97 +206,16 @@ class Partitioner : public ExprMutator {
       // (each annotated regions) --> created function
 
       if (region_function_calls.find(region) != region_function_calls.end()) {
-        // This section is executed only if there are multiple outputs in the
-        // region Thus, the function is always created and at the end there
-        // would be a tuple node Therefore, we insert a tuple get item node.
-
-        // Use the already created tuple node
-        auto sg_call = region_function_calls[region];
-        int index = GetRetIdx(region, GetRef<Call>(call));
-        CHECK_NE(index, -1);
-
-        auto tuple_get_item_ = TupleGetItem(sg_call, index);
-        tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
-        return std::move(tuple_get_item_);
+        // This section is executed if there are multiple outputs in the region
+        // or if the output of the function is being accessed multiple times by
+        // different nodes.
+        return GetFunctionOutput(region, GetRef<Call>(call));
       } else {
-        // First time this region is encountered in the traversal
-        // Creating the function
-
-        Array<Expr> fields;
-
-        for (auto ret : region->GetOutputs()) {
-          auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
-          fields.push_back(ret_expr);
-        }
-        int index = GetRetIdx(region, GetRef<Call>(call));
-        CHECK_NE(index, -1);
-
-        Array<Var> params;
-        Array<Expr> param_expr;
-        std::unordered_map<std::string, runtime::NDArray> params_bind;
-
-        for (auto pair : region_args[region]) {
-          params.push_back(pair.first);
-          if (const auto* cn = pair.second.as<ConstantNode>()) {
-            params_bind[pair.first->name_hint()] = cn->data;
-          } else {
-            param_expr.push_back(pair.second);
-          }
-        }
-
-        Function global_region_func;
-        if (region->GetOutputs().size() == 1) {
-          // If there are only a single output; no need to add a tuple
-          global_region_func =
-              Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
-        } else {
-          auto tuple = Tuple(fields);
-          global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
-        }
-
-        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
-        std::string name = target + "_" + std::to_string(region->GetID());
-
-        global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
-                                      runtime::String(name));
-        global_region_func =
-            WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
-        global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
-                                      tvm::runtime::String(target));
-        global_region_func =
-            WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
-
-        // Constant propagation
-        if (!params_bind.empty()) {
-          global_region_func = backend::BindParamsByName(global_region_func, params_bind);
-        }
-
-        std::string fname = name;
-        CHECK(!module_->ContainGlobalVar(fname))
-            << "Global function " << fname << " already exists";
-        // Create a global function and add it to the IRModule for the region.
-        // This way we lift the functions that should be handled by external
-        // codegen to the module scope and rely on the pass manager to prevent
-        // relay function level passes (i.e. simplify inference and fusion)
-        // optimizing it.
-        GlobalVar glob_func(fname);
-        module_->Add(glob_func, global_region_func);
-
-        // The return type of callnode is the same as the type of the
-        // compiler_end node.
-        auto ret = Call(glob_func, param_expr);
-        region_function_calls[region] = ret;
-
-        if (region->GetOutputs().size() == 1) {
-          // If there is only a single output; no need to add a tuplegetitem
-          // node
-          return std::move(ret);
-        } else {
-          // Add a tuplegetitem node to select this output out of many
-          auto tuple_get_item_ = TupleGetItem(ret, index);
-          tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
-          return std::move(tuple_get_item_);
-        }
+        // First time this region is encountered in the traversal.
+        // Creating the function.
+        CreateFunction(region, call);
+        // Retrieve particular output.
+        return GetFunctionOutput(region, GetRef<Call>(call));
 
 Review comment:
   Return could be moved out of this branch as it is a common statement.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message