tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] zhiics commented on a change in pull request #5616: [RELAY][BYOC] Fix the creation of tuple of tuples in PartitionGraph
Date Tue, 19 May 2020 16:30:04 GMT

zhiics commented on a change in pull request #5616:
URL: https://github.com/apache/incubator-tvm/pull/5616#discussion_r427432193



##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -404,21 +404,96 @@ IRModule RemoveDefaultAnnotations(IRModule module) {
   return module;
 }
 
+/*! \brief There can be regions with multiple outputs where each output
+ *  could be a tuple output. Such tuple outputs needs to be flattened
+ *  otherwise the function would create tuples of tuples.
+ */
+
+// New annotations would be required to be added for each flattened output
+const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+
+IRModule FlattenTupleOutputs(IRModule module) {
+  class TupleOutFlattener : public ExprRewriter {
+   public:
+    TupleOutFlattener() = default;
+
+    Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc*
ann_op) {
+      Expr new_op = (*ann_op)(expr, target);
+      new_op->checked_type_ = expr->checked_type_;
+      return new_op;
+    }
+
+    Expr Rewrite_(const CallNode* call, const Expr& post) final {
+      if (call->op == compiler_end_op) {
+        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+        // Arguments of annotation ops should be 1
+        CHECK_EQ(call->args.size(), 1U);
+        auto annotated_op = Downcast<Call>(post)->args[0];
+        if (annotated_op->IsInstance<TupleNode>()) {

Review comment:
       you can have `if (const auto* tn = annotated_op.as<TupleNode>())` and remove
the next line

##########
File path: src/relay/transforms/partition_graph.cc
##########
@@ -404,21 +404,96 @@ IRModule RemoveDefaultAnnotations(IRModule module) {
   return module;
 }
 
+/*! \brief There can be regions with multiple outputs where each output
+ *  could be a tuple output. Such tuple outputs needs to be flattened
+ *  otherwise the function would create tuples of tuples.
+ */
+
+// New annotations would be required to be added for each flattened output
+const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+
+IRModule FlattenTupleOutputs(IRModule module) {
+  class TupleOutFlattener : public ExprRewriter {
+   public:
+    TupleOutFlattener() = default;
+
+    Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc*
ann_op) {
+      Expr new_op = (*ann_op)(expr, target);
+      new_op->checked_type_ = expr->checked_type_;
+      return new_op;
+    }
+
+    Expr Rewrite_(const CallNode* call, const Expr& post) final {
+      if (call->op == compiler_end_op) {
+        std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+        // Arguments of annotation ops should be 1
+        CHECK_EQ(call->args.size(), 1U);
+        auto annotated_op = Downcast<Call>(post)->args[0];
+        if (annotated_op->IsInstance<TupleNode>()) {
+          auto tn = annotated_op.as<TupleNode>();
+          Array<Expr> new_fields;
+
+          // Here each input of the tuple will be annotated with compiler_ends
+          for (auto& tn_arg : tn->fields) {
+            auto nf = InsertAnnotation(tn_arg, target, make_end_op);

Review comment:
       You probably don't need InsertAnnotation helper, but just do:
   
   ```c++
   new_fields.push_back((*make_end_op)(tn_arg, target));
   ```
   because it is always an `end` op and `Update` below will do type inference.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message