tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] comaniac commented on a change in pull request #5320: [BYOC] Prevent duplicate outputs in subgraph Tuple
Date Tue, 14 Apr 2020 18:34:01 GMT
comaniac commented on a change in pull request #5320: [BYOC] Prevent duplicate outputs in subgraph
Tuple
URL: https://github.com/apache/incubator-tvm/pull/5320#discussion_r408348156
 
 

 ##########
 File path: src/relay/transforms/partition_graph.cc
 ##########
 @@ -456,18 +375,109 @@ class Partitioner : public ExprMutator {
   }
 
   /*!
-   * \brief Get the index of the return(output);
-   * this is to be used as tuplegetitem idx
+   * \brief This function is called first time that we encounter a compiler_end
+   * node to create the function for the subgraph.
    */
-  int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
-    int idx = 0;
-    for (auto arg_ : sg->GetOutputs()) {
-      if (arg == arg_) {
-        return idx;
+  void CreateFunction(AnnotatedRegion region, const CallNode* call) {
+    // Create fields which is a unique list of outputs. Also populate
+    // region_return_indices_ map which maps parent of compiler_end node to
+    // corresponding index in fields.
+    Array<Expr> fields;
+    int i = 0;
+    for (auto ret : region->GetOutputs()) {
+      auto ret_node = Downcast<Call>(ret)->args[0];
+      // Don't duplicate outputs.
+      if (!region_return_indices_.count(region) ||
+          !region_return_indices_[region].count(ret_node)) {
+        auto ret_expr = VisitExpr(ret_node);
+        fields.push_back(ret_expr);
+        region_return_indices_[region][ret_node] = i;
+        i++;
       }
-      idx++;
     }
-    return -1;
+
+    Array<Var> params;
+    Array<Expr> param_expr;
+    std::unordered_map<std::string, runtime::NDArray> params_bind;
+
+    for (auto pair : region_args[region]) {
+      params.push_back(pair.first);
+      if (const auto* cn = pair.second.as<ConstantNode>()) {
+        params_bind[pair.first->name_hint()] = cn->data;
+      } else {
+        param_expr.push_back(pair.second);
+      }
+    }
+
+    Function global_region_func;
+    if (fields.size() == 1) {
+      // If there are only a single output; no need to add a tuple
+      global_region_func =
+          Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
+    } else {
+      auto tuple = Tuple(fields);
+      global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
+    }
+
+    std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+    std::string name = target + "_" + std::to_string(region->GetID());
+
+    global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
+                                  runtime::String(name));
+    global_region_func =
+        WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
+    global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
+                                  tvm::runtime::String(target));
+    global_region_func =
+        WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
+
+    // Constant propagation
+    if (!params_bind.empty()) {
+      global_region_func = backend::BindParamsByName(global_region_func, params_bind);
+    }
+
+    std::string fname = name;
+    CHECK(!module_->ContainGlobalVar(fname))
+        << "Global function " << fname << " already exists";
+    // Create a global function and add it to the IRModule for the region.
+    // This way we lift the functions that should be handled by external
+    // codegen to the module scope and rely on the pass manager to prevent
+    // relay function level passes (i.e. simplify inference and fusion)
+    // optimizing it.
+    GlobalVar glob_func(fname);
+    module_->Add(glob_func, global_region_func);
+
+    // The return type of callnode is the same as the type of the
+    // compiler_end node.
+    auto ret = Call(glob_func, param_expr);
+    region_function_calls[region] = ret;
+  }
+
+  /*!
+   * \brief Get the return(output) of the function for compiler end node "arg".
 
 Review comment:
   Better to improve the description a bit. Saying that this function will return either a
Call (for single output) or a TupleGetItem (for multiple outputs) based on the given end op.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message