From commits-return-11394-archive-asf-public=cust-asf.ponee.io@tvm.apache.org Tue Apr 14 18:34:05 2020 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [207.244.88.153]) by mx-eu-01.ponee.io (Postfix) with SMTP id 9A13E18066D for ; Tue, 14 Apr 2020 20:34:05 +0200 (CEST) Received: (qmail 37546 invoked by uid 500); 14 Apr 2020 18:34:02 -0000 Mailing-List: contact commits-help@tvm.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@tvm.apache.org Delivered-To: mailing list commits@tvm.apache.org Received: (qmail 37379 invoked by uid 99); 14 Apr 2020 18:34:01 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 14 Apr 2020 18:34:01 +0000 From: GitBox To: commits@tvm.apache.org Subject: [GitHub] [incubator-tvm] comaniac commented on a change in pull request #5320: [BYOC] Prevent duplicate outputs in subgraph Tuple Message-ID: <158688924173.3072.6945129617663346370.gitbox@gitbox.apache.org> References: In-Reply-To: Date: Tue, 14 Apr 2020 18:34:01 -0000 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit comaniac commented on a change in pull request #5320: [BYOC] Prevent duplicate outputs in subgraph Tuple URL: https://github.com/apache/incubator-tvm/pull/5320#discussion_r408345921 ########## File path: src/relay/transforms/partition_graph.cc ########## @@ -206,97 +206,16 @@ class Partitioner : public ExprMutator { // (each annotated regions) --> created function if (region_function_calls.find(region) != region_function_calls.end()) { - // This section is executed only if there are multiple outputs in the - // region Thus, the function is always created and at the end there - // would be a tuple node Therefore, we insert a tuple get item node. - - // Use the already created tuple node - auto sg_call = region_function_calls[region]; - int index = GetRetIdx(region, GetRef(call)); - CHECK_NE(index, -1); - - auto tuple_get_item_ = TupleGetItem(sg_call, index); - tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; - return std::move(tuple_get_item_); + // This section is executed if there are multiple outputs in the region + // or if the output of the function is being accessed multiple times by + // different nodes. + return GetFunctionOutput(region, GetRef(call)); } else { - // First time this region is encountered in the traversal - // Creating the function - - Array fields; - - for (auto ret : region->GetOutputs()) { - auto ret_expr = VisitExpr(Downcast(ret)->args[0]); - fields.push_back(ret_expr); - } - int index = GetRetIdx(region, GetRef(call)); - CHECK_NE(index, -1); - - Array params; - Array param_expr; - std::unordered_map params_bind; - - for (auto pair : region_args[region]) { - params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; - } else { - param_expr.push_back(pair.second); - } - } - - Function global_region_func; - if (region->GetOutputs().size() == 1) { - // If there are only a single output; no need to add a tuple - global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); - } else { - auto tuple = Tuple(fields); - global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); - } - - std::string target = call->attrs.as()->compiler; - std::string name = target + "_" + std::to_string(region->GetID()); - - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); - global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); - global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); - - // Constant propagation - if (!params_bind.empty()) { - global_region_func = backend::BindParamsByName(global_region_func, params_bind); - } - - std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; - // Create a global function and add it to the IRModule for the region. - // This way we lift the functions that should be handled by external - // codegen to the module scope and rely on the pass manager to prevent - // relay function level passes (i.e. simplify inference and fusion) - // optimizing it. - GlobalVar glob_func(fname); - module_->Add(glob_func, global_region_func); - - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; - - if (region->GetOutputs().size() == 1) { - // If there is only a single output; no need to add a tuplegetitem - // node - return std::move(ret); - } else { - // Add a tuplegetitem node to select this output out of many - auto tuple_get_item_ = TupleGetItem(ret, index); - tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; - return std::move(tuple_get_item_); - } + // First time this region is encountered in the traversal. + // Creating the function. + CreateFunction(region, call); + // Retrieve particular output. + return GetFunctionOutput(region, GetRef(call)); Review comment: Return could be moved out of this branch as it is a common statement. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services