tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] zhiics commented on a change in pull request #5272: [BYOC] Add example of Composite + Annotate for DNNL fused op
Date Wed, 08 Apr 2020 21:15:18 GMT
zhiics commented on a change in pull request #5272: [BYOC] Add example of Composite + Annotate
for DNNL fused op
URL: https://github.com/apache/incubator-tvm/pull/5272#discussion_r405813179
 
 

 ##########
 File path: src/relay/backend/contrib/dnnl/codegen.cc
 ##########
 @@ -133,83 +209,100 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
   }
 
  private:
-  std::vector<std::string> Conv2d(const CallNode* call) {
-    std::vector<std::string> args;
-    const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
-    CHECK(conv2d_attr);
-
-    auto ishape = GetShape(call->args[0]->checked_type());
-    auto wshape = GetShape(call->args[1]->checked_type());
+  struct GenerateBodyOutput {
+    std::string decl, buf;
+    int out_size = 1;
+    std::string out;
+  };
 
-    // Args: N, C, H, W
-    for (auto s : ishape) {
-      args.push_back(std::to_string(s));
+  std::vector<std::string> GetArgumentNames(const CallNode* call) {
+    std::vector<std::string> arg_names;
+    for (size_t i = 0; i < call->args.size(); ++i) {
+      VisitExpr(call->args[i]);
+      for (auto out : out_) {
+        arg_names.push_back(out.name);
+      }
     }
-
-    // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
-    args.push_back(std::to_string(wshape[0]));
-    args.push_back(std::to_string(conv2d_attr->groups));
-    args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
-    args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
-    args.push_back(std::to_string(wshape[2]));
-    args.push_back(std::to_string(wshape[3]));
-    args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
-    args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
-
-    return args;
+    return arg_names;
   }
 
-  std::vector<std::string> Dense(const CallNode* call) {
-    std::vector<std::string> args;
-    auto ishape = GetShape(call->args[0]->checked_type());
-    auto wshape = GetShape(call->args[1]->checked_type());
-
-    // Args: N, C, O
-    args.push_back(std::to_string(ishape[0]));
-    args.push_back(std::to_string(ishape[1]));
-    args.push_back(std::to_string(wshape[0]));
+  GenerateBodyOutput GenerateOpCall(const CallNode* call) {
+    const auto* op_node = call->op.as<OpNode>();
+    CHECK(op_node) << "OpNode expected, got something else";
 
 Review comment:
   Let's also print out what the op_node is
   ```
   CHECK(op_node) << "Expect OpNode, but got " << op_node->GetTypeKey();
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message