tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mas...@apache.org
Subject [incubator-tvm] branch master updated: [BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)
Date Sat, 11 Apr 2020 03:29:30 GMT
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 3616ebe  [BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)
3616ebe is described below

commit 3616ebee6890950c1324b6a0733866f404af7041
Author: masahi <masahi129@gmail.com>
AuthorDate: Sat Apr 11 12:29:20 2020 +0900

    [BYOC] Add example of Composite + Annotate for DNNL fused op (#5272)
    
    * merge change from dev branch
    
    * fix string issue
    
    * bring comanic's change back
---
 python/tvm/relay/op/contrib/dnnl.py             |   9 +-
 src/relay/backend/contrib/codegen_c/codegen.cc  |   5 +-
 src/relay/backend/contrib/codegen_c/codegen_c.h |  35 ---
 src/relay/backend/contrib/dnnl/codegen.cc       | 361 ++++++++++++++++--------
 src/relay/backend/utils.h                       |  72 ++++-
 src/relay/backend/vm/compiler.cc                |  13 +-
 src/runtime/contrib/dnnl/dnnl.cc                |  70 +++--
 src/runtime/contrib/dnnl/dnnl_kernel.h          |  15 +-
 tests/python/relay/test_pass_partition_graph.py | 130 ++++++++-
 9 files changed, 502 insertions(+), 208 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 1aa7192..45a8c83 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -56,17 +56,10 @@ def _register_external_op_helper(op_name, supported=True):
     return _func_wrapper
 
 
+_register_external_op_helper("nn.batch_norm")
 _register_external_op_helper("nn.conv2d")
 _register_external_op_helper("nn.dense")
 _register_external_op_helper("nn.relu")
 _register_external_op_helper("add")
 _register_external_op_helper("subtract")
 _register_external_op_helper("multiply")
-
-
-@reg.register("nn.batch_norm", "target.dnnl")
-def batch_norm(attrs, args):
-    """Check if the external DNNL codegen should be used.
-    FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
-    """
-    return False
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc
index 97231df..500e0dc 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -19,19 +19,22 @@
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/relay/type.h>
-#include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/object.h>
 
 #include <fstream>
 #include <sstream>
 
+#include "../../utils.h"
 #include "codegen_c.h"
 
 namespace tvm {
 namespace relay {
 namespace contrib {
 
+using namespace backend;
+
 /*!
  * \brief An example codegen that is only used for quick prototyping and testing
  * purpose. Only several binary options are covered. Users
diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h
index 1db3f20..1b953f3 100644
--- a/src/relay/backend/contrib/codegen_c/codegen_c.h
+++ b/src/relay/backend/contrib/codegen_c/codegen_c.h
@@ -170,41 +170,6 @@ class CodegenCBase {
   virtual std::string JIT() = 0;
 
   /*!
-   * \brief Extract the shape from a Relay tensor type.
-   *
-   * \param type The provided type.
-   *
-   * \return The extracted shape in a list.
-   */
-  std::vector<int> GetShape(const Type& type) const {
-    const auto* ttype = type.as<TensorTypeNode>();
-    CHECK(ttype) << "Expect TensorTypeNode";
-    std::vector<int> shape;
-    for (size_t i = 0; i < ttype->shape.size(); ++i) {
-      auto* val = ttype->shape[i].as<IntImmNode>();
-      CHECK(val);
-      shape.push_back(val->value);
-    }
-    return shape;
-  }
-
-  /*!
-   * \brief Check if a call has the provided name.
-   *
-   * \param call A Relay call node.
-   * \param op_name The name of the expected call.
-   *
-   * \return true if the call's name is equivalent to the given name. Otherwise,
-   * false.
-   */
-  bool IsOp(const CallNode* call, const std::string& op_name) const {
-    const auto* op_node = call->op.as<OpNode>();
-    CHECK(op_node) << "Expects a single op.";
-    Op op = GetRef<Op>(op_node);
-    return op == Op::Get(op_name);
-  }
-
-  /*!
    * \brief A common interface that is used by various external runtime to
    * generate the wrapper to invoke external kernels.
    *
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 7371174..7f3aabf 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -30,14 +30,102 @@
 #include <tvm/runtime/registry.h>
 
 #include <fstream>
+#include <numeric>
 #include <sstream>
 
+#include "../../utils.h"
 #include "../codegen_c/codegen_c.h"
 
 namespace tvm {
 namespace relay {
 namespace contrib {
 
+using namespace backend;
+
+inline size_t GetShape1DSize(const Type& type) {
+  const auto shape = GetShape(type);
+  return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
+}
+
+std::vector<std::string> Conv2d(const CallNode* call) {
+  std::vector<std::string> args;
+  const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
+  CHECK(conv2d_attr);
+
+  auto ishape = GetShape(call->args[0]->checked_type());
+  auto wshape = GetShape(call->args[1]->checked_type());
+
+  // Args: N, C, H, W
+  for (auto s : ishape) {
+    args.push_back(std::to_string(s));
+  }
+
+  // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
+  args.push_back(std::to_string(wshape[0]));
+  args.push_back(std::to_string(conv2d_attr->groups));
+  args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
+  args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
+  args.push_back(std::to_string(wshape[2]));
+  args.push_back(std::to_string(wshape[3]));
+  args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
+  args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
+
+  return args;
+}
+
+std::vector<std::string> Dense(const CallNode* call) {
+  std::vector<std::string> args;
+  auto ishape = GetShape(call->args[0]->checked_type());
+  auto wshape = GetShape(call->args[1]->checked_type());
+
+  // Args: N, C, O
+  args.push_back(std::to_string(ishape[0]));
+  args.push_back(std::to_string(ishape[1]));
+  args.push_back(std::to_string(wshape[0]));
+
+  return args;
+}
+
+std::vector<std::string> Relu(const CallNode* call) {
+  std::vector<std::string> args;
+  auto ishape = GetShape(call->args[0]->checked_type());
+
+  // Args: N, C, H, W
+  for (auto s : ishape) {
+    args.push_back(std::to_string(s));
+  }
+
+  return args;
+}
+
+std::vector<std::string> BatchNorm(const CallNode* call) {
+  std::vector<std::string> args;
+  const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
+  auto ishape = GetShape(call->args[0]->checked_type());
+
+  // Args: N, C, H, W
+  for (auto s : ishape) {
+    args.push_back(std::to_string(s));
+  }
+
+  // Args: epsilon
+  args.push_back(std::to_string(bn_attr->epsilon));
+
+  return args;
+}
+
+std::vector<std::string> Add(const CallNode* call) {
+  std::vector<std::string> args;
+  auto ishape = GetShape(call->args[0]->checked_type());
+
+  // Args: H, W
+  for (auto s : ishape) {
+    args.push_back(std::to_string(s));
+  }
+
+  return args;
+}
+
 // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
 // all utilities and make a base class for users to implement.
 class CodegenDNNL : public ExprVisitor, public CodegenCBase {
@@ -53,79 +141,64 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
   }
 
   void VisitExpr_(const TupleGetItemNode* op) final {
-    // Do nothing
+    VisitExpr(op->tuple);
+    CHECK(out_.size() > static_cast<size_t>(op->index));
+
+    // Only keep the item we want for the child node.
+    // FIXME(@comaniac): The other items should still be requried for the primary outputs.
+    auto item = out_[op->index];
+    out_.clear();
+    out_.push_back(item);
   }
 
-  void VisitExpr_(const CallNode* call) final {
-    std::ostringstream decl_stream;
-    std::ostringstream buf_stream;
-    // Args: ID
-    std::vector<std::string> args;
-
-    // Get the arguments for various DNNL kernels.
-    if (IsOp(call, "nn.conv2d")) {
-      decl_stream << "dnnl_conv2d";
-      args = Conv2d(call);
-    } else if (IsOp(call, "nn.dense")) {
-      decl_stream << "dnnl_dense";
-      args = Dense(call);
-    } else if (IsOp(call, "nn.relu")) {
-      decl_stream << "dnnl_relu";
-      args = Relu(call);
-    } else if (IsOp(call, "nn.batch_norm")) {
-      decl_stream << "dnnl_bn";
-      args = BatchNorm(call);
-    } else if (IsOp(call, "add")) {
-      decl_stream << "dnnl_add";
-      args = Add(call);
-    } else {
-      LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
+  void VisitExpr_(const ConstantNode* cn) final {
+    Constant constant = GetRef<Constant>(cn);
+    if (visited_.count(constant)) {
+      out_.push_back(visited_[constant]);
+      return;
     }
 
-    // Make function call with input buffers when visiting arguments
-    bool first = true;
-    decl_stream << "(";
-    for (size_t i = 0; i < call->args.size(); ++i) {
-      VisitExpr(call->args[i]);
-      for (auto out : out_) {
-        if (!first) {
-          decl_stream << ", ";
-        }
-        first = false;
-        decl_stream << out.name;
-      }
-    }
+    out_.clear();
+    Output output;
+    output.name = "const_" + std::to_string(const_idx_++);
+    output.dtype = "float";
+    out_.push_back(output);
+    visited_[constant] = output;
+
+    runtime::NDArray array = cn->data;
 
-    // Analyze the output buffer
-    auto type_node = call->checked_type().as<TensorTypeNode>();
+    // Get the number of elements.
+    int64_t num_elems = 1;
+    for (auto i : array.Shape()) num_elems *= i;
+
+    const auto* type_node = cn->checked_type().as<TensorTypeNode>();
     CHECK(type_node);
-    const auto& dtype = GetDtypeString(type_node);
-    std::string out = "buf_" + std::to_string(buf_idx_++);
-    auto out_shape = GetShape(call->checked_type());
-    int out_size = 1;
-    for (size_t i = 0; i < out_shape.size(); ++i) {
-      out_size *= out_shape[i];
+    CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
+
+    std::ostringstream buf_stream;
+    buf_stream << "float* " << output.name << " = (float*)std::malloc(4
* " << num_elems << ");\n";
+    const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
+    for (int64_t i = 0; i < num_elems; i++) {
+      buf_stream << "  " << output.name << "[" << i << "] =
" << ptr[i] << ";\n";
     }
-    this->PrintIndents();
-    buf_stream << "float* " << out << " = (float*)std::malloc(4 * " <<
out_size << ");";
-    buf_decl_.push_back(buf_stream.str());
-    decl_stream << ", " << out;
 
-    // Attach attribute arguments
-    for (size_t i = 0; i < args.size(); ++i) {
-      decl_stream << ", " << args[i];
+    ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
+  }
+
+  void VisitExpr_(const CallNode* call) final {
+    GenerateBodyOutput ret;
+    if (const auto* func = call->op.as<FunctionNode>()) {
+      ret = GenerateCompositeFunctionCall(func, call);
+    } else {
+      ret = GenerateOpCall(call);
     }
-    decl_stream << ");";
-    ext_func_body.push_back(decl_stream.str());
 
-    // Update output buffer
     out_.clear();
-    Output output;
-    output.name = out;
-    output.dtype = dtype;
-    output.need_copy = true;
-    output.size = out_size;
-    out_.push_back(output);
+    for (size_t i = 0; i < ret.outputs.size(); ++i) {
+      buf_decl_.push_back(ret.buffers[i]);
+      out_.push_back(ret.outputs[i]);
+    }
+    ext_func_body.push_back(ret.decl);
   }
 
   std::string JIT(void) {
@@ -133,83 +206,121 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
   }
 
  private:
-  std::vector<std::string> Conv2d(const CallNode* call) {
-    std::vector<std::string> args;
-    const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
-    CHECK(conv2d_attr);
-
-    auto ishape = GetShape(call->args[0]->checked_type());
-    auto wshape = GetShape(call->args[1]->checked_type());
-
-    // Args: N, C, H, W
-    for (auto s : ishape) {
-      args.push_back(std::to_string(s));
+  struct GenerateBodyOutput {
+    std::string decl;
+    std::vector<std::string> buffers;
+    std::vector<Output> outputs;
+  };
+
+  std::vector<std::string> GetArgumentNames(const CallNode* call) {
+    std::vector<std::string> arg_names;
+    for (size_t i = 0; i < call->args.size(); ++i) {
+      VisitExpr(call->args[i]);
+      for (auto out : out_) {
+        arg_names.push_back(out.name);
+      }
     }
-
-    // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
-    args.push_back(std::to_string(wshape[0]));
-    args.push_back(std::to_string(conv2d_attr->groups));
-    args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImmNode>()->value));
-    args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImmNode>()->value));
-    args.push_back(std::to_string(wshape[2]));
-    args.push_back(std::to_string(wshape[3]));
-    args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImmNode>()->value));
-    args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImmNode>()->value));
-
-    return args;
+    return arg_names;
   }
 
-  std::vector<std::string> Dense(const CallNode* call) {
-    std::vector<std::string> args;
-    auto ishape = GetShape(call->args[0]->checked_type());
-    auto wshape = GetShape(call->args[1]->checked_type());
-
-    // Args: N, C, O
-    args.push_back(std::to_string(ishape[0]));
-    args.push_back(std::to_string(ishape[1]));
-    args.push_back(std::to_string(wshape[0]));
+  GenerateBodyOutput GenerateOpCall(const CallNode* call) {
+    const auto* op_node = call->op.as<OpNode>();
+    CHECK(op_node) << "Expect OpNode, but got " << call->op->GetTypeKey();
+
+    using ArgFunType = std::function<std::vector<std::string>(const CallNode*)>;
+    static const std::map<std::string, std::pair<std::string, ArgFunType>> op_map
= {
+        {"nn.conv2d", {"dnnl_conv2d", Conv2d}},
+        {"nn.dense", {"dnnl_dense", Dense}},
+        {"nn.relu", {"dnnl_relu", Relu}},
+        {"nn.batch_norm", {"dnnl_bn", BatchNorm}},
+        {"add", {"dnnl_add", Add}},
+    };
+
+    const auto op_name = GetRef<Op>(op_node)->name;
+    const auto iter = op_map.find(op_name);
+    if (iter != op_map.end()) {
+      return GenerateBody(call, iter->second.first, iter->second.second(call));
+    }
 
-    return args;
+    LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
+    return {};
   }
 
-  std::vector<std::string> Relu(const CallNode* call) {
-    std::vector<std::string> args;
-    auto ishape = GetShape(call->args[0]->checked_type());
-
-    // Args: N, C, H, W
-    for (auto s : ishape) {
-      args.push_back(std::to_string(s));
+  GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
+                                                   const CallNode* caller) {
+    const auto pattern_name = callee->GetAttr<runtime::String>(attr::kComposite);
+    CHECK(pattern_name.defined()) << "Only functions with composite attribute supported";
+
+    if (pattern_name == "dnnl.conv2d_bias_relu") {
+      const auto* conv_call =
+          GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
+      return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller),
+                          Conv2d(conv_call));
+    } else if (pattern_name == "dnnl.conv2d_relu") {
+      const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d",
"nn.relu"});
+      return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller),
+                          Conv2d(conv_call));
     }
 
-    return args;
+    LOG(FATAL) << "Unknown composite function:" << pattern_name;
+    return {};
   }
 
-  std::vector<std::string> BatchNorm(const CallNode* call) {
-    std::vector<std::string> args;
-    const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
-    auto ishape = GetShape(call->args[0]->checked_type());
+  GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
+                                  const std::vector<std::string>& attribute_args)
{
+    return GenerateBody(root_call, func_name, GetArgumentNames(root_call), attribute_args);
+  }
 
-    // Args: N, C, H, W
-    for (auto s : ishape) {
-      args.push_back(std::to_string(s));
+  GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name,
+                                  const std::vector<std::string>& func_args,
+                                  const std::vector<std::string>& attribute_args)
{
+    // Make function call with input buffers when visiting arguments
+    CHECK_GT(func_args.size(), 0);
+    std::ostringstream decl_stream;
+    decl_stream << "(" << func_args[0];
+    for (size_t i = 1; i < func_args.size(); ++i) {
+      decl_stream << ", " << func_args[i];
     }
 
-    // Args: epsilon
-    args.push_back(std::to_string(bn_attr->epsilon));
-
-    return args;
-  }
-
-  std::vector<std::string> Add(const CallNode* call) {
-    std::vector<std::string> args;
-    auto ishape = GetShape(call->args[0]->checked_type());
+    // Analyze the output buffers
+    std::vector<Type> out_types;
+    if (root_call->checked_type()->IsInstance<TupleTypeNode>()) {
+      auto type_node = root_call->checked_type().as<TupleTypeNode>();
+      for (auto field : type_node->fields) {
+        CHECK(field->IsInstance<TensorTypeNode>());
+        out_types.push_back(field);
+      }
+    } else if (root_call->checked_type()->IsInstance<TensorTypeNode>()) {
+      CHECK(root_call->checked_type()->IsInstance<TensorTypeNode>());
+      out_types.push_back(root_call->checked_type());
+    } else {
+      LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(),
false);
+    }
 
-    // Args: H, W
-    for (auto s : ishape) {
-      args.push_back(std::to_string(s));
+    GenerateBodyOutput ret;
+    for (const auto& out_type : out_types) {
+      this->PrintIndents();
+      const std::string out = "buf_" + std::to_string(buf_idx_++);
+      const auto out_size = GetShape1DSize(out_type);
+      decl_stream << ", " << out;
+
+      Output output;
+      output.name = out;
+      output.size = out_size;
+      output.dtype = GetDtypeString(out_type.as<TensorTypeNode>());
+      output.need_copy = true;
+      ret.buffers.push_back("float* " + out + " = (float*)std::malloc(4 * " +
+                            std::to_string(out_size) + ");");
+      ret.outputs.push_back(output);
     }
 
-    return args;
+    // Attach attribute arguments
+    for (size_t i = 0; i < attribute_args.size(); ++i) {
+      decl_stream << ", " << attribute_args[i];
+    }
+    decl_stream << ");";
+    ret.decl = func_name + decl_stream.str();
+    return ret;
   }
 
   /*! \brief The id of the external dnnl ext_func. */
@@ -219,6 +330,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
    * output to a buffer that may be consumed by other kernels.
    */
   int buf_idx_{0};
+  /*! \brief The index of global constants. */
+  int const_idx_ = 0;
   /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
   Array<Var> ext_func_args_;
   /*! \brief statement of the function that will be compiled using DNNL kernels. */
@@ -227,6 +340,8 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
   std::vector<std::string> buf_decl_;
   /*! \brief The name of the the outputs. */
   std::vector<Output> out_;
+  /*! \brief The cached expressions. */
+  std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
 };
 
 /*!
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index 7171589..a96ffe4 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -25,18 +25,19 @@
 #define TVM_RELAY_BACKEND_UTILS_H_
 
 #include <dmlc/json.h>
+#include <tvm/driver/driver_api.h>
 #include <tvm/relay/expr.h>
-#include <tvm/relay/type.h>
 #include <tvm/relay/transform.h>
-#include <tvm/driver/driver_api.h>
+#include <tvm/relay/type.h>
 #include <tvm/target/codegen.h>
-#include <tvm/tir/ir_pass.h>
 #include <tvm/te/operation.h>
+#include <tvm/tir/ir_pass.h>
 
-#include <typeinfo>
 #include <string>
+#include <typeinfo>
 #include <unordered_map>
 #include <unordered_set>
+#include <vector>
 
 namespace tvm {
 namespace relay {
@@ -59,7 +60,7 @@ inline const PackedFunc* GetPackedFunc(const std::string& func_name)
{
  */
 template <typename R, typename... Args>
 inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string&
func_name) {
-  auto *pf = GetPackedFunc(func_name);
+  auto* pf = GetPackedFunc(func_name);
   CHECK(pf != nullptr) << "can not find packed function";
   return runtime::TypedPackedFunc<R(Args...)>(*pf);
 }
@@ -90,9 +91,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
  * \param params params dict
  * \return relay::Function
  */
-inline relay::Function
-BindParamsByName(relay::Function func,
-                 const std::unordered_map<std::string, runtime::NDArray>& params)
{
+inline relay::Function BindParamsByName(
+    relay::Function func, const std::unordered_map<std::string, runtime::NDArray>&
params) {
   std::unordered_map<std::string, relay::Var> name_dict;
   std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
   for (auto arg : func->params) {
@@ -122,8 +122,64 @@ BindParamsByName(relay::Function func,
   return ret;
 }
 
+/*!
+ * \brief Extract the shape from a Relay tensor type.
+ * \param type The provided type.
+ * \return The extracted shape in a list.
+ */
+inline std::vector<int> GetShape(const Type& type) {
+  const auto* ttype = type.as<TensorTypeNode>();
+  CHECK(ttype) << "Expect TensorTypeNode";
+  std::vector<int> shape;
+  for (size_t i = 0; i < ttype->shape.size(); ++i) {
+    auto* val = ttype->shape[i].as<IntImmNode>();
+    CHECK(val);
+    shape.push_back(val->value);
+  }
+  return shape;
+}
+
+/*!
+ * \brief Check if a call has the provided name.
+ * \param call A Relay call node.
+ * \param op_name The name of the expected call.
+ * \return true if the call's name is equivalent to the given name. Otherwise,
+ * false.
+ */
+inline bool IsOp(const CallNode* call, const std::string& op_name) {
+  const auto* op_node = call->op.as<OpNode>();
+  CHECK(op_node) << "Expects a single op.";
+  Op op = GetRef<Op>(op_node);
+  return op == Op::Get(op_name);
+}
+
+/*!
+ * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
+ * \param call A Relay call node. Typically nn.relu when called the first time.
+ * \param depth The number of calls before the root op, counting from current_call.
+ * \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add",
+ * "nn.relu"}
+ * \return A CallNode corresponding to the root op, whose name is expected_op_names[0]
+ */
+
+inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
+                                   const std::vector<std::string>& expected_op_names)
{
+  CHECK(current_call && depth >= 0 && static_cast<size_t>(depth)
< expected_op_names.size() &&
+        IsOp(current_call, expected_op_names[depth]));
+
+  if (depth == 0) {
+    return current_call;
+  }
+
+  CHECK_GT(current_call->args.size(), 0);
+
+  const auto* next_call = current_call->args[0].as<CallNode>();
+  return GetRootCall(next_call, depth - 1, expected_op_names);
+}
+
 }  // namespace backend
 }  // namespace relay
 }  // namespace tvm
 
+
 #endif  // TVM_RELAY_BACKEND_UTILS_H_
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 3e020bb..e2b0fff 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -924,6 +924,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap&
targe
   pass_seqs.push_back(transform::LambdaLift());
   pass_seqs.push_back(transform::InlinePrimitives());
 
+  // Manifest the allocations.
+  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
+  // Compute away possibly introduced constant computation.
+  pass_seqs.push_back(transform::FoldConstant());
+  // Fuse the shape functions.
+  pass_seqs.push_back(transform::FuseOps());
+
   // Inline the functions that are lifted to the module scope. We perform this
   // pass after all other optimization passes but before the memory allocation
   // pass. This is because memory allocation pass will insert `invoke_tvm_op`
@@ -931,12 +938,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap&
targe
   // external codegen.
   pass_seqs.push_back(transform::Inline());
 
-  // Manifest the allocations.
-  pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
-  // Compute away possibly introduced constant computation.
-  pass_seqs.push_back(transform::FoldConstant());
-  // Fuse the shape functions.
-  pass_seqs.push_back(transform::FuseOps());
   // Manifest the allocations needed for the shape functions.
   pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
 
diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc
index cc430b2..0922ac1 100644
--- a/src/runtime/contrib/dnnl/dnnl.cc
+++ b/src/runtime/contrib/dnnl/dnnl.cc
@@ -52,10 +52,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem)
{
   std::copy(src, src + bytes, reinterpret_cast<uint8_t*>(handle));
 }
 
-extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_,
-                            int p_C_, int p_H_, int p_W_, int p_O_, int p_G_,
-                            int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
-                            int p_Sh_, int p_Sw_) {
+void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int
p_C_,
+                        int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int
p_Kh_,
+                        int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) {
   using tag = memory::format_tag;
   using dt = memory::data_type;
   engine eng(engine::kind::cpu, 0);
@@ -65,21 +64,15 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int
p_N_,
   memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_};
   if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_};
   memory::dims conv2d_bias_tz = {p_O_};
-  memory::dims conv2d_dst_tz = {p_N_, p_O_,
-                                (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
+  memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_,
                                 (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_};
   memory::dims conv2d_strides = {p_Sh_, p_Sw_};
   memory::dims conv2d_padding = {p_Ph_, p_Pw_};
 
-  std::vector<float> conv2d_bias(p_O_, 0);
-
-  auto user_src_memory =
-      memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
-  auto user_weights_memory = memory(
-      {{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
-      weights);
-  auto conv2d_user_bias_memory =
-      memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data());
+  auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data);
+  auto user_weights_memory =
+      memory({{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng,
weights);
+  auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
 
   auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
   auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
@@ -87,10 +80,9 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int
p_N_,
   auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw);
 
   auto conv2d_desc = convolution_forward::desc(
-      prop_kind::forward_inference, algorithm::convolution_direct,
-      conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md,
-      conv2d_strides, conv2d_padding, conv2d_padding);
-  auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng);
+      prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
+      conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding);
+  auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng);
 
   auto conv2d_src_memory = user_src_memory;
   auto conv2d_weights_memory = user_weights_memory;
@@ -105,6 +97,42 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int
p_N_,
   read_from_dnnl_memory(out, conv2d_dst_memory);
 }
 
+extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_,
int p_H_,
+                            int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_,
+                            int p_Kw_, int p_Sh_, int p_Sw_) {
+  primitive_attr attr;
+  std::vector<float> bias(p_O_, 0);
+  return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_,
p_G_,
+                            p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr);
+}
+
+primitive_attr create_attr_with_relu_post_op() {
+  post_ops ops;
+  ops.append_eltwise(1.f, algorithm::eltwise_relu, 0.f, 0.f);
+
+  primitive_attr attr;
+  attr.set_post_ops(ops);
+
+  return attr;
+}
+
+extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_,
int p_C_,
+                                       int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_,
int p_Pw_,
+                                       int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) {
+  std::vector<float> bias(p_O_, 0);
+  return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_,
p_G_,
+                            p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
+                            create_attr_with_relu_post_op());
+}
+
+extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float*
out,
+                                            int p_N_, int p_C_, int p_H_, int p_W_, int p_O_,
+                                            int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int
p_Kw_,
+                                            int p_Sh_, int p_Sw_) {
+  return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph_,
+                            p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op());
+}
+
 extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_,
                            int p_I_, int p_O_) {
   using tag = memory::format_tag;
@@ -169,8 +197,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_,
int p_H_,
   read_from_dnnl_memory(out, dst_memory);
 }
 
-extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
-                        float* variance, float* out, int p_N_, int p_C_,
+extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance,
+                        float* out, float* new_mean, float* new_variance, int p_N_, int p_C_,
                         int p_H_, int p_W_, int p_E_) {
   using tag = memory::format_tag;
   using dt = memory::data_type;
diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h
index 4d0b100..f92d767 100644
--- a/src/runtime/contrib/dnnl/dnnl_kernel.h
+++ b/src/runtime/contrib/dnnl/dnnl_kernel.h
@@ -38,14 +38,25 @@ extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float*
out, int
                                     int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int
p_Pw_,
                                     int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_);
 
+extern "C" TVM_DLL void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int
p_N_,
+                                               int p_C_, int p_H_, int p_W_, int p_O_, int
p_G_,
+                                               int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_,
+                                               int p_Sh_, int p_Sw_);
+
+extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias,
+                                                    float* out, int p_N_, int p_C_, int p_H_,
+                                                    int p_W_, int p_O_, int p_G_, int p_Ph_,
+                                                    int p_Pw_, int p_Kh_, int p_Kw_, int
p_Sh_,
+                                                    int p_Sw_);
+
 extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int
p_I_,
                                    int p_O_);
 
 extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_,
int p_W_);
 
 extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
-                                float* variance, float* out, int p_n_, int p_c_, int p_h_,
int p_w_,
-                                int p_e_);
+                                float* variance, float* out, float* new_mean, float* new_variance,
+                                int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);
 
 extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
                                  int p_h_, int p_w_);
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index fb21682..c7d9626 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -27,11 +27,10 @@ from tvm import relay
 from tvm import runtime
 from tvm.relay import transform
 from tvm.contrib import util
-from tvm.relay import transform
 from tvm.relay.backend import compile_engine
 from tvm.relay.expr_functor import ExprMutator
 from tvm.relay.op.annotation import compiler_begin, compiler_end
-from tvm.runtime import container
+from tvm.relay.build_module import bind_params_by_name
 
 
 # Leverage the pass manager to write a simple white list based annotator
@@ -456,7 +455,7 @@ def test_extern_dnnl_mobilenet():
     mod, params = relay.testing.mobilenet.get_workload(
         batch_size=1, dtype='float32')
 
-    mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
+    mod["main"] = bind_params_by_name(mod["main"], params)
     mod = transform.AnnotateTarget(["dnnl"])(mod)
     mod = transform.MergeCompilerRegions()(mod)
     mod = transform.PartitionGraph()(mod)
@@ -663,7 +662,7 @@ def test_constant_propagation():
     add = x + y
     log = relay.log(add)
     f = relay.Function([x, y], log)
-    f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)})
+    f = bind_params_by_name(f, {"x": tvm.nd.array(ones)})
     mod = tvm.IRModule()
     mod["main"] = f
     mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
@@ -852,6 +851,128 @@ def test_mixed_single_multiple_outputs():
     partitioned = transform.PartitionGraph()(mod)
     assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
+
+def test_dnnl_fuse():
+    def make_pattern(with_bias=True):
+        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+        weight = relay.var("weight")
+        bias = relay.var("bias")
+        conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
+                               channels=8, padding=(1, 1))
+        if with_bias:
+            conv_out = relay.add(conv, bias)
+        else:
+            conv_out = conv
+        return relay.nn.relu(conv_out)
+
+    conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
+    conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
+    dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
+
+    def get_blocks(prefix, data, in_channel, out_channel,
+                   include_bn=True, include_sigmoid=False):
+        weight = relay.var(prefix + "weight")
+        bn_gamma = relay.var(prefix + "bn_gamma")
+        bn_beta = relay.var(prefix + "bn_beta")
+        bn_mmean = relay.var(prefix + "bn_mean")
+        bn_mvar = relay.var(prefix + "bn_var")
+
+        layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
+                                channels=out_channel, padding=(1, 1))
+        if include_bn:
+            bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
+                                            bn_mmean, bn_mvar)
+            layer = bn_output[0]
+        if include_sigmoid:
+            # dummy layer to prevent pattern detection
+            layer = relay.sigmoid(layer)
+        layer = relay.nn.relu(layer)
+        return layer
+
+    def get_net(include_bn=True, include_sigmoid=False):
+        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+        block1 = get_blocks("block1_", data, 3, 8, include_bn, include_sigmoid)
+        # The second block is always conv + relu, to make it more interesting
+        block2 = get_blocks("block2_", block1, 8, 8, False, include_sigmoid)
+        return relay.Function(relay.analysis.free_vars(block2), block2)
+
+    def get_partitoned_mod(mod, params, pattern_table):
+        # This is required for constant folding
+        mod["main"] = bind_params_by_name(mod["main"], params)
+
+        remove_bn_pass = transform.Sequential([
+            transform.InferType(),
+            transform.SimplifyInference(),
+            transform.FoldConstant(),
+            transform.FoldScaleAxis(),
+        ])
+        composite_partition = transform.Sequential([
+            remove_bn_pass,
+            transform.MergeComposite(pattern_table),
+            transform.AnnotateTarget("dnnl"),
+            transform.PartitionGraph()
+        ])
+
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            return composite_partition(mod)
+
+    def test_detect_pattern(pattern_table, include_bn, include_sigmoid,
+                            num_expected_partition):
+        net = get_net(include_bn, include_sigmoid)
+        mod, params = tvm.relay.testing.create_workload(net)
+        mod = get_partitoned_mod(mod, params, pattern_table)
+        assert(len(mod.functions) - 1 == num_expected_partition)  # -1 for main
+
+    def test_partition():
+        # conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
+        test_detect_pattern([conv2d_bias_relu_pat], True, False, 3)
+        # conv + bn + relu, conv + relu -> conv, bias, relu, and fused conv_relu
+        test_detect_pattern([conv2d_relu_pat], True, False, 4)
+        # conv + bn + relu, conv + relu -> fused conv_bias_relu, and fused conv_relu
+        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, False, 2)
+        # conv + relu, conv + relu -> two fused conv_relu
+        test_detect_pattern([conv2d_relu_pat], False, False, 2)
+        # conv + relu, conv + relu -> no fusion, 4 partition each with a single op
+        test_detect_pattern([conv2d_bias_relu_pat], False, False, 4)
+        # conv + bn + sigmoid + relu, conv + sigmoid + relu -> no fusion
+        test_detect_pattern([conv2d_bias_relu_pat, conv2d_relu_pat], True, True, 5)
+
+    def test_partition_mobilenet():
+        mod, params = relay.testing.mobilenet.get_workload()
+        mod = get_partitoned_mod(mod, params, dnnl_patterns)
+        # 27 fused conv + bn + relu and one dense
+        assert(len(mod.functions) - 1 == 28)  # -1 for main
+
+    def test_exec(mod, params, ref_mod, ref_params, out_shape):
+        ishape = (1, 3, 224, 224)
+        i_data = np.random.randn(*ishape).astype(np.float32)
+        ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
+        ref_res = ref_ex.evaluate()(i_data, **ref_params)
+        compile_engine.get().clear()
+
+        mod = get_partitoned_mod(mod, params, dnnl_patterns)
+
+        check_result(mod, {"data": i_data},
+                     out_shape, ref_res.asnumpy(), tol=1e-5, params=params)
+
+    test_partition()
+    test_partition_mobilenet()
+
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    net = get_net()
+    mod, params = tvm.relay.testing.create_workload(net)
+    ref_mod, ref_params = tvm.relay.testing.create_workload(net)
+    test_exec(mod, params, ref_mod, ref_params, (1, 8, 224, 224))
+
+    # exec test on mobilenet is not possible due to manually inlined constants
+    # mod, params = relay.testing.mobilenet.get_workload()
+    # ref_mod, ref_params = relay.testing.mobilenet.get_workload()
+    # test_exec(mod, params, ref_mod, ref_params, (1, 1000))
+
+
 if __name__ == "__main__":
     test_multi_node_compiler()
     test_extern_ccompiler_single_op()
@@ -865,3 +986,4 @@ if __name__ == "__main__":
     test_constant_propagation()
     test_multiple_outputs()
     test_mixed_single_multiple_outputs()
+    test_dnnl_fuse()


Mime
View raw message