tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] windclarion commented on a change in pull request #5277: [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes
Date Thu, 09 Apr 2020 01:36:18 GMT
windclarion commented on a change in pull request #5277: [BYOC] Refine AnnotateTarget and MergeCompilerRegion
Passes
URL: https://github.com/apache/incubator-tvm/pull/5277#discussion_r405904292
 
 

 ##########
 File path: src/relay/transforms/annotate_target.cc
 ##########
 @@ -19,131 +19,155 @@
 
 /*!
  * \file src/relay/transforms/annotate_target.cc
- * \brief Wraps a call with compiler_begin and compiler_end to indicate that
- * the op of this call node will use external compiler.
+ * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
+ * this expr should be handled by the external compiler.
  */
 
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
 
 namespace tvm {
 namespace relay {
 namespace annotate_target {
 
-// Cache compiler_begin op for equivalence check.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
 
 // A helper class to insert annotation boundaries for a program region that will
 // be handled by a specific compiler.
 class AnnotateTargetWrapper : public ExprMutator {
  public:
-  explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
+  explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets))
{}
+
+  /*!
+   * \brief This function annotates a compiler end and a compiler begin to all arguments.
+   *
+   *  The compiler end is based on the arg target while the compiler begin is based on the
given
+   *  target. If target is not given and all arguments are going to the same target, then
we will
+   *  use that target; otherwise we use default for this op. Note that all arg exprs must
be
+   *  available in op_expr_to_target before calling this function.
+   *
+   * \param args An array of arguments of the given node.
+   * \param target The target of the current node.
+   * \return A pair of target and annotated argument expressions.
+   */
+  std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>&
args,
+                                                   const std::string& target = "") {
+    std::string ref_target = "";
+    Array<Expr> compiler_ends;
+    for (auto arg : args) {
+      if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
+        std::string arg_target = op_expr_to_target_[arg];
+        compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op));
+        if (ref_target == "") {
+          ref_target = arg_target;
+        } else if (ref_target != arg_target) {
+          ref_target = "default";
+        }
+      } else {
+        // Input vars.
+        compiler_ends.push_back(arg);
+      }
+    }
+
+    // Determine compiler begin target.
+    std::string op_target = (target == "") ? ref_target : target;
+
+    Array<Expr> compiler_begins;
+    for (const auto& end : compiler_ends) {
+      compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op));
+    }
 
-  Expr Annotate(const Expr& expr) {
-    return InsertEnd(Mutate(expr));
+    return {op_target, compiler_begins};
   }
 
-  bool IsSupported(const Expr& expr) {
-    if (expr->IsInstance<CallNode>()) {
-      Call call = Downcast<Call>(expr);
-      auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
-      if (call->op->IsInstance<OpNode>()) {
-        Op op = Downcast<Op>(call->op);
-        CHECK(op.defined());
-        if (fannotate.count(op)) {
-          return fannotate[op](call->attrs, call->args);
+  Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc*
ann_op) {
+    Expr new_op = (*ann_op)(expr, target);
+    new_op->checked_type_ = expr->checked_type_;
+    return new_op;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) final {
+    // Supported targets for this node. The order implies the priority.
+    std::vector<std::string> supported_targets;
+
+    // Check which targets this op can be offloaded.
+    if (cn->op->IsInstance<OpNode>()) {
+      // TVM operators: Check target specific op checking function and add to supported_targets
+      // if it is supported.
+      Op op = Downcast<Op>(cn->op);
+      CHECK(op.defined());
+      for (const auto& target : this->targets_) {
+        auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
 
 Review comment:
   I understand, but what I said maybe is not clear. I know composite function doesn't jump
into OpNode branch, but I mean I only use composite function, so I don't register any op FTVMAnnotateTarget
attr for target.xxxx, and for any OpNodeļ¼Œauto fannotate = Op::GetAttr("target." + std::string(target))
will report fail, because I didn't register any FTVMAnnotateTarget attr before. annotate mechanism
can handle op and composite function, and the two is independent for each other.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message