tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [Relay] enable blocking format in x86 conv2d and fold scale axis (#5357)
Date Tue, 12 May 2020 14:37:06 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 93a9afb  [Relay] enable blocking format in x86 conv2d and fold scale axis (#5357)
93a9afb is described below

commit 93a9afbed631399ae67434b50220a7f848b4341b
Author: Menooker <Menooker@users.noreply.github.com>
AuthorDate: Tue May 12 22:36:53 2020 +0800

    [Relay] enable blocking format in x86 conv2d and fold scale axis (#5357)
---
 python/tvm/relay/op/strategy/x86.py             |  10 +
 src/relay/op/tensor/transform.h                 |   2 +
 src/relay/transforms/fold_scale_axis.cc         | 151 ++++++--
 tests/python/relay/test_pass_fold_scale_axis.py | 484 ++++++++++++++++--------
 topi/python/topi/x86/conv2d_alter_op.py         | 107 +++---
 5 files changed, 507 insertions(+), 247 deletions(-)

diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py
index ba0b3d2..fbc2ed2 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -18,6 +18,7 @@
 # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
 import logging
 
+import re
 import topi
 from tvm.te import SpecializedCondition
 from .generic import *
@@ -25,6 +26,9 @@ from .. import op as _op
 
 logger = logging.getLogger('strategy')
 
+_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
+_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")
+
 @schedule_injective.register("cpu")
 def schedule_injective_cpu(attrs, outs, target):
     """schedule injective ops for x86"""
@@ -96,6 +100,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                     wrap_compute_conv2d(topi.x86.conv2d_nchw),
                     wrap_topi_schedule(topi.x86.schedule_conv2d_nchw),
                     name="conv2d_nchw.x86")
+        elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
+            assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
+            return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWIO"
             logger.warning("For x86 target, NCHW layout is recommended for conv2d.")
@@ -128,6 +135,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
                     wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
                     wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
                     name="depthwise_conv2d_nchw.generic")
+        elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc
+            assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio
+            return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target)
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.")
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 62433c2..1d1f9c0 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -38,6 +38,8 @@
 namespace tvm {
 namespace relay {
 
+extern Expr MakeReshape(Expr data, Array<Integer> newshape);
+
 template <typename AttrType>
 bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                     const TypeReporter& reporter) {
diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc
index 57e3d69..4c8025a 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -29,6 +29,7 @@
 #include <tvm/relay/transform.h>
 #include <tvm/tir/data_layout.h>
 
+#include "../op/tensor/transform.h"
 #include "pass_util.h"
 #include "pattern_util.h"
 
@@ -39,6 +40,7 @@ namespace relay {
  *
  * Use namespace to reduce potential naming conflict.
  */
+
 namespace fold_scale_axis {
 
 using runtime::TypedPackedFunc;
@@ -305,6 +307,41 @@ class ForwardPrep : private ExprVisitor {
   }
 };
 
+static bool IsIntInArray(const Array<Integer>& axis, int v) {
+  for (size_t i = 0; i < axis.size(); i++) {
+    if (axis[i] == v) return true;
+  }
+  return false;
+}
+
+static Expr ReshapeToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
+                               const Array<Integer>& axis) {
+  Array<Integer> arr;
+  for (size_t i = 0; i < shape.size(); i++) {
+    if (IsIntInArray(axis, i)) {
+      auto node = shape[i].as<IntImmNode>();
+      if (!node) {
+        // if the shape is not a constant, use normal transform
+        return Expr();
+      }
+      arr.push_back(node->value);
+    } else {
+      arr.push_back(1);
+    }
+  }
+  return MakeReshape(scale, std::move(arr));
+}
+
+// if only one axis, use expand dim. Else, use reshape
+static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array<PrimExpr>& shape,
+                                       const Array<Integer>& axis) {
+  if (axis.size() > 1) {
+    return ReshapeToMatchAxis(scale, shape, axis);
+  } else {
+    return ExpandBiasToMatchAxis(scale, shape.size(), axis);
+  }
+}
+
 //----------------------------------------------
 // Per operator defs for FScaleAxisForward
 //----------------------------------------------
@@ -365,7 +402,10 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
   if (slhs != nullptr) {
     CHECK(srhs == nullptr);
     CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes));
-    Expr scale = ExpandBiasToMatchAxis(slhs->scale, tlhs->shape.size(), slhs->axes);
+    Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes);
+    if (!scale.defined()) {
+      return Expr();
+    }
     Expr rhs = Divide(new_args[1], scale);
     rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args);
     rnode->scale = slhs->scale;
@@ -373,7 +413,10 @@ Expr AddSubForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
   } else {
     CHECK(srhs != nullptr);
     CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes));
-    Expr scale = ExpandBiasToMatchAxis(srhs->scale, trhs->shape.size(), srhs->axes);
+    Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes);
+    if (!scale.defined()) {
+      return Expr();
+    }
     Expr lhs = Divide(new_args[0], scale);
     rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args);
     rnode->scale = srhs->scale;
@@ -445,7 +488,6 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
 
   CHECK_GE(c_big_axis, 0);
   Message none = NullValue<Message>();
-  AxesSet data_axes = NullValue<AxesSet>();
   // For now, we only support simple pattern (no folded weight/data)
   // More general layout can be supported under the current framework.
   // By using a unified layout transformation.
@@ -454,12 +496,17 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
   // only handle depthwise or full conv2d.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
   bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
-      (param->groups == 1 || is_depthwise_conv2d)) {
-    data_axes = {c_big_axis};
-  }
-  if (data_axes.defined()) {
-    return {Message(data_axes, false), none};
+  if (param->groups == 1 || is_depthwise_conv2d) {
+    auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
+    auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
+    if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) ||     // simple layout
+        (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) {  // blocked layout
+      Array<Integer> arr{c_big_axis};
+      if (c_small_axis >= 0) {
+        arr.push_back(c_small_axis);
+      }
+      return {Message(arr, false), none};
+    }
   }
   return {none, none};
 }
@@ -478,12 +525,14 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
   Layout kernel_layout(param->kernel_layout);
   int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C'));
   CHECK_GE(c_big_axis, 0);
-  // For now, we only support simple pattern (no folded weight/data)
-  // TODO(tvm-team) support general data layout
-  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
-  CHECK(sdata->axes.size() == 1 && c_big_axis == sdata->axes[0]->value);
-  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
-  int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
+  int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
+  int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
+  int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
+  int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
+
+  bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
+  bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
+  CHECK(is_simple || is_blocking);
 
   // Check it must be depthwise or full conv2d.
   bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout);
@@ -493,11 +542,26 @@ Expr Conv2DForwardRewrite(const Call& ref_call, const Array<Expr>& new_args,
 
   // match the ic_axis
   if (is_depthwise_conv2d) {
-    Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_oc_axis});
-    weight = Multiply(weight, scale);
+    if (is_simple) {
+      Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis});
+      weight = Multiply(weight, scale);
+    } else {
+      weight = Multiply(weight,
+                        ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
+                                           {big_ko_axis, small_ko_axis}));
+      if (!weight.defined()) return Expr();
+    }
+
   } else {
-    Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ic_axis});
-    weight = Multiply(weight, scale);
+    if (is_simple) {
+      Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis});
+      weight = Multiply(weight, scale);
+    } else {
+      weight = Multiply(weight,
+                        ReshapeToMatchAxis(sdata->scale, weight->type_as<TensorTypeNode>()->shape,
+                                           {big_ki_axis, small_ki_axis}));
+      if (!weight.defined()) return Expr();
+    }
   }
   // return transformed conv2d
   return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
@@ -752,14 +816,20 @@ Expr AddSubBackwardTransform(const Call& call, const Message& message, const Exp
     CHECK(equal(message->axes, lhs_message->axes));
     Expr lhs = transformer->Transform(call->args[0], message, scale);
     Expr rhs = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
-    Expr rhs_scale = ExpandBiasToMatchAxis(scale, tlhs->shape.size(), message->axes);
+    Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes);
+    if (!rhs_scale.defined()) {
+      return transformer->NormalCallTransform(call.operator->());
+    }
     rhs = Multiply(rhs, rhs_scale);
     return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
   } else if (rhs_message.defined()) {
     CHECK(equal(message->axes, rhs_message->axes));
     Expr lhs = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
     Expr rhs = transformer->Transform(call->args[1], message, scale);
-    Expr lhs_scale = ExpandBiasToMatchAxis(scale, trhs->shape.size(), message->axes);
+    Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes);
+    if (!lhs_scale.defined()) {
+      return transformer->NormalCallTransform(call.operator->());
+    }
     lhs = Multiply(lhs, lhs_scale);
     return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
   } else {
@@ -829,13 +899,19 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
   // only handle depthwise or full conv2d.
   // TODO(tvm-team) handle grouped conv by reshape + bcast
   bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
-  if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 &&
-      kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && c_small_axis < 0 &&
-      (param->groups == 1 || is_depthwise_conv2d)) {
-    return Message({c_big_axis}, false);
-  } else {
-    return NullValue<Message>();
+  if (param->groups == 1 || is_depthwise_conv2d) {
+    auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
+    auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
+    if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) ||     // simple layout
+        (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) {  // blocked layout
+      Array<Integer> arr{c_big_axis};
+      if (c_small_axis >= 0) {
+        arr.push_back(c_small_axis);
+      }
+      return Message(arr, false);
+    }
   }
+  return NullValue<Message>();
 }
 
 // Conv2D consumes the scale axis during transformation.
@@ -852,19 +928,28 @@ Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Exp
   CHECK_GE(c_big_axis, 0);
   // For now, we only support simple pattern (no folded weight/data)
   // TODO(tvm-team) support general data layout
-  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1);
-  CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1);
-  CHECK(message->axes.size() == 1 && c_big_axis == message->axes[0]->value);
-
-  int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
+  int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o'));
+  int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i'));
+  int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I'));
+  int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O'));
   // Check it must be depthwise or full conv2d.
   bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout);
   CHECK(param->groups == 1 || is_depthwise_conv2d);
+  bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0);
+  bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0);
+  CHECK(is_simple || is_blocking);
 
   Expr data = transformer->Transform(call->args[0], NullValue<Message>(), NullValue<Expr>());
   Expr weight = transformer->Transform(call->args[1], NullValue<Message>(), NullValue<Expr>());
   // scale on input for deptwise.
-  Expr wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_oc_axis});
+  Expr wscale;
+  if (is_simple) {
+    wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis});
+  } else {
+    wscale = ReshapeToMatchAxis(scale, weight->type_as<TensorTypeNode>()->shape,
+                                {big_ko_axis, small_ko_axis});
+    if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->());
+  }
   weight = Multiply(weight, wscale);
   return Call(call->op, {data, weight}, call->attrs, call->type_args);
 }
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index d7c437a..8aecf3f 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -35,58 +35,75 @@ def run_opt_pass(expr, opt_pass):
 
 def test_fold_fwd_simple():
     """Simple testcase."""
-    def before(x, conv_weight, in_bias, in_scale, channels):
+    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         args = [x, conv_weight, in_bias]
-        in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
         x = relay.multiply(x, in_scale)
         x = relay.nn.relu(x)
         x = relay.add(x, in_bias)
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW")
 
         return relay.Function(args, y)
 
-    def expected(x, conv_weight, in_bias, in_scale, channels):
+    def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, in_bias]
-        in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2)
-        squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
-        x = relay.nn.relu(x)
-        in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
-        x = relay.add(x, in_bias)
-        conv_weight = relay.multiply(
-            conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+        if blocking:
+            squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3])
+            x = relay.nn.relu(x)
+            in_bias = relay.divide(in_bias, 
+                relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0]))) #NCHWc
+            x = relay.add(x, in_bias)
+            conv_weight = relay.multiply(conv_weight,
+                relay.reshape(squeezed_scale, (1, in_channels//2, 1, 1, 2, 1))) #OIHWio
+        else:
+            squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
+            x = relay.nn.relu(x)
+            in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+            x = relay.add(x, in_bias)
+            conv_weight = relay.multiply(
+                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW")
         return relay.Function(args, y)
 
-    def check(shape, channels):
+    def check(shape, channels, blocking):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[1]
         weight = relay.var("weight")
-        in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
-        y1 = before(x, weight, in_bias, in_scale, channels)
+        if blocking:
+            in_channels = shape[1] * shape[4]
+            in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0]))
+            in_scale = relay.const(_get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0])))
+        else:
+            in_channels = shape[1]
+            in_bias = relay.var("in_bias", shape=(in_channels, 1, 1))
+            in_scale = relay.const(_get_positive_scale((in_channels, 1, 1)))
+        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
-        y1_expected = expected(x, weight, in_bias, in_scale, channels)
+        y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking)
 
         y1_folded = run_opt_pass(y1_folded, transform.InferType())
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 10), 2)
-
+    check((2, 4, 10, 10), 2, None)
+    check((2, 2, 10, 10, 2), 8, (2, 4))
 
 def test_fold_fwd_dual_path():
     """scale axis being consumed by two consumers"""
-    def before(x, conv_weight, in_bias, in_scale, channels):
+    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         args = [x, conv_weight, in_bias]
         x = relay.multiply(in_scale, x)
         x = relay.nn.relu(x)
@@ -94,363 +111,474 @@ def test_fold_fwd_dual_path():
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NHWC",
-                             kernel_layout="HWIO",
+                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
                              groups=channels,
                              padding=(1, 1))
         y2 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NHWC",
-                             kernel_layout="HWIO",
+                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
                              groups=channels,
                              padding=(1, 1))
         z = relay.add(y1, y2)
         return relay.Function(args, z)
 
-    def expected(x, conv_weight, in_bias, in_scale, channels):
+    def expected(x, conv_weight, in_bias, in_scale, channels, blocking):
         args = [x, conv_weight, in_bias]
         x = relay.nn.relu(x)
-        in_bias = relay.divide(in_bias, in_scale)
+        if blocking:
+            _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], blocking[0])) #NHWCc
+        else:
+            _in_scale = in_scale
+        in_bias = relay.divide(in_bias, _in_scale)
         x = relay.subtract(x, in_bias)
+        if blocking:
+            _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio
         y1 = relay.nn.conv2d(x,
-                             relay.multiply(conv_weight, in_scale),
+                             relay.multiply(conv_weight, _in_scale),
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NHWC",
-                             kernel_layout="HWIO",
+                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
                              groups=channels,
                              padding=(1, 1))
+        if blocking:
+            _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio
         y2 = relay.nn.conv2d(x,
-                             relay.multiply(conv_weight, in_scale),
+                             relay.multiply(conv_weight, _in_scale),
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NHWC",
-                             kernel_layout="HWIO",
+                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
                              groups=channels,
                              padding=(1, 1))
         z = relay.add(y1, y2)
         return relay.Function(args, z)
 
-    def check(dshape, channels):
+    def check(dshape, channels, blocking):
         x =  relay.var("x", shape=dshape)
-        in_channels = dshape[-1]
+        if blocking:
+            in_channels = dshape[3] * dshape[4]
+            wshape = (3, 3, 1, channels//blocking[1], 1, blocking[1]) # HWIOio
+            weight = relay.var("weight", shape=wshape)
+            in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0]))
+            in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0])))
+        else:
+            in_channels = dshape[-1]
+            wshape = (3, 3, 1, channels) # HWIO
+            weight = relay.var("weight", shape=wshape)
+            in_bias = relay.var("in_bias", shape=(in_channels,))
+            in_scale = relay.const(_get_positive_scale(in_channels,))
+        
         # test depthwise
         assert in_channels == channels
-        wshape = (3, 3, 1, channels) # HWIO
-        weight = relay.var("weight", shape=wshape)
-        in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.const(_get_positive_scale(in_channels,))
-        y1 = before(x, weight, in_bias, in_scale, channels)
+
+        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
-        y1_expected = expected(x, weight, in_bias, in_scale, channels)
+        y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 3), 3)
-
+    check((2, 4, 10, 3), 3, None)
+    check((2, 4, 10, 2, 2), 4, (2, 2))
 
 def test_fold_fwd_fail():
     """testcase where we canont fold"""
-    def before(x, conv_weight, in_bias, in_scale, channels):
+    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         x = relay.multiply(x, in_scale)
         xx = relay.nn.leaky_relu(x, alpha=0.1)
         y1 = relay.nn.conv2d(xx, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NHWC",
+                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
                              padding=(1, 1))
         z = relay.add(y1, x)
         return relay.Function(relay.analysis.free_vars(z), z)
 
-    def check(shape, channels):
+    def check(shape, channels, blocking):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[-1]
+        if blocking:
+            in_channels = shape[3] * shape[4]
+            in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0]))
+            in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0])))
+        else:
+            in_channels = shape[-1]
+            in_bias = relay.var("in_bias", shape=(in_channels,))
+            in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
         # test depthwise
         assert in_channels == channels
         weight = relay.var("weight")
-        in_bias = relay.var("in_bias", shape=(in_channels,))
-        in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
-        y1 = before(x, weight, in_bias, in_scale, channels)
+        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
         assert tvm.ir.structural_equal(y1, y1_folded)
 
-    check((2, 11, 10, 4), 4)
-
+    check((2, 11, 10, 4), 4, None)
+    check((2, 11, 10, 2, 2), 4, (2,2))
 
 def test_fold_fwd_relu_fail():
     """testcase where we canont fold because scale can not pass relu"""
-    def before(x, conv_weight, in_bias, in_scale, channels):
+    def before(x, conv_weight, in_bias, in_scale, channels, blocking):
         x = relay.multiply(x, in_scale)
         xx = relay.nn.relu(x)
         y1 = relay.nn.conv2d(xx, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NHWC",
+                             data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC",
+                             kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO",
                              padding=(1, 1))
         z = relay.add(y1, x)
         return relay.Function(relay.analysis.free_vars(z), z)
 
-    def check(shape, channels, in_scale):
+    def check(shape, channels, blocking, in_scale):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[-1]
-        # test depthwise
-        assert in_channels == channels
         weight = relay.var("weight")
-        in_bias = relay.var("in_bias", shape=(in_channels,))
-        y1 = before(x, weight, in_bias, in_scale, channels)
+        if blocking:
+            in_channels = shape[3] * shape[4]
+            in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0]))
+        else:
+            in_channels = shape[-1]
+            in_bias = relay.var("in_bias", shape=(in_channels,))
+
+        assert in_channels == channels
+        y1 = before(x, weight, in_bias, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
         assert tvm.ir.structural_equal(y1, y1_folded)
 
     in_scale = relay.var("in_scale", shape=(4,))
-    check((2, 11, 10, 4), 4, in_scale)
+    check((2, 11, 10, 4), 4, None, in_scale)
     in_scale = relay.const(-_get_positive_scale((4,)))
-    check((2, 11, 10, 4), 4, in_scale)
+    check((2, 11, 10, 4), 4, None, in_scale)
+
+    in_scale = relay.var("in_scale", shape=(1,1,1,2,2))
+    check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)
+    in_scale = relay.const(-_get_positive_scale((1,1,1,2,2)))
+    check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)
+
+
 
 
 def test_fold_fwd_negative_scale():
     """Testcase of folding negative scale"""
-    def before(x, conv_weight, in_scale, channels):
+    def before(x, conv_weight, in_scale, channels, blocking):
         args = [x, conv_weight]
         x = relay.multiply(x, in_scale)
         y = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW")
         return relay.Function(args, y)
 
-    def expected(x, conv_weight, in_scale, channels):
+    def expected(x, conv_weight, in_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight]
-        squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
-        conv_weight = relay.multiply(
-            conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
+        if blocking:
+            squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3])
+            conv_weight = relay.multiply(
+                conv_weight , relay.reshape(squeezed_scale, (1, in_channels//4, 1, 1, 4, 1)))
+            #blocking by "i" in OIHWio
+        else:
+            squeezed_scale = relay.squeeze(in_scale, axis=[1,2])
+            conv_weight = relay.multiply(
+                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
         y = relay.nn.conv2d(x,
                              conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW")
         return relay.Function(args, y)
 
-    def check(shape, channels):
+    def check(shape, channels, blocking):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[1]
-        in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1)))
+        if blocking:
+            in_channels = shape[1] * shape[4]
+            in_scale = relay.const(-_get_positive_scale((1, shape[1], 1, 1, shape[4])))
+        else:
+            in_channels = shape[1]
+            in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1)))
         weight = relay.var("weight")
-        y1 = before(x, weight, in_scale, channels)
+        y1 = before(x, weight, in_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
-        y1_expected = expected(x, weight, in_scale, channels)
+        y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 10), 4)
-
+    check((2, 4, 10, 10), 4, None)
+    check((2, 2, 10, 10, 2), 8, (2, 2))
 
 def test_fold_bwd_simple():
     """Simple testcase."""
-    def before(x, conv_weight, out_bias, out_scale, channels):
+    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        if blocking:
+            out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1]))
+        else:
+            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y = relay.add(y, out_bias)
         y = relay.nn.relu(y)
+        if blocking:
+            out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1]))
         y = relay.multiply(y, out_scale)
         return relay.Function(args, y)
 
-    def expected(x, conv_weight, out_bias, out_scale, channels):
+    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, out_bias]
-        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
-        squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
-        conv_weight = relay.multiply(
-            conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
+        if blocking:
+            out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1]))
+            out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1]))
+            squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3])
+            conv_weight = relay.multiply(
+                conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+        else:
+            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+            squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+            conv_weight = relay.multiply(
+                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
 
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
-        out_bias = relay.multiply(out_bias,
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
+        if blocking:
+            out_bias = relay.multiply(out_bias,
+                                  relay.reshape(squeezed_scale, (1, channels//blocking[1], 1, 1, blocking[1])))
+        else:
+            out_bias = relay.multiply(out_bias,
                                   relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2))
         y = relay.add(y, out_bias)
         y = relay.nn.relu(y)
         return relay.Function(args, y)
 
-    def check(shape, channels):
+    def check(shape, in_channels, channels, blocking):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[1]
         weight = relay.var("weight")
         out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
-
-        y1 = before(x, weight, out_bias, out_scale, channels)
+        if blocking:
+            out_scale = relay.const(_get_positive_scale((channels,)))
+        else:
+            out_scale = relay.const(_get_positive_scale((channels,1, 1)))
+        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
-        y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 10), 8)
+    check((2, 4, 10, 10), 4, 8, None)
+    check((2, 2, 10, 10, 16), 32, 64, (16, 16))
 
 
 def test_fold_bwd_dual_path():
     """Dual path testcase."""
-    def before(x, conv_weight, out_bias, out_scale, channels):
+    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y1 = relay.nn.relu(y1)
         y2 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y2 = relay.nn.relu(y2)
         y = relay.add(y1, y2)
         y = relay.multiply(y, out_scale)
         return relay.Function(args, y)
 
-    def expected(x, conv_weight, out_bias, out_scale, channels):
+    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, out_bias]
-        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
+        if not blocking:
+            out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
         def fold_conv_weight():
-            return  relay.multiply(
-                conv_weight ,
-                relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
+            if blocking:
+                return relay.multiply(
+                    conv_weight ,
+                    relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+            else:
+                return relay.multiply(
+                    conv_weight ,
+                    relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
         y1 = relay.nn.conv2d(x, fold_conv_weight(),
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y1 = relay.nn.relu(y1)
         y2 = relay.nn.conv2d(x, fold_conv_weight(),
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y2 = relay.nn.relu(y2)
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
-    def check(shape, channels):
+    def check(shape, in_channels, channels, blocking):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[1]
         weight = relay.var("weight")
-        out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
-
-        y1 = before(x, weight, out_bias, out_scale, channels)
+        if blocking:
+            out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
+            out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1])))
+        else:
+            out_bias = relay.var("out_bias", shape=(channels,))
+            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
+
+        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
-        y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 10), 8)
-
+    check((2, 4, 10, 10), 4, 8, None)
+    check((2, 2, 10, 10, 2), 4, 8, (2, 2))
 
 def test_fold_bwd_dual_consumer():
-    def before(x, conv_weight, out_bias, out_scale, channels):
+    def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
         y0 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y0 = relay.multiply(y0, out_scale)
         y0 = relay.nn.relu(y0)
 
         y1 = relay.nn.conv2d(y0, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y1 = relay.multiply(y1, out_scale)
         y1 = relay.nn.relu(y1)
 
         y2 = relay.nn.conv2d(y0, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y2 = relay.multiply(y2, out_scale)
         y2 = relay.nn.relu(y2)
 
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
-    def expected(x, conv_weight, out_bias, out_scale, channels):
+    def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight, out_bias]
         def fold_conv_weight():
             squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
-            return  relay.multiply(
-                conv_weight ,
-                relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
+            if blocking:
+                return relay.multiply(
+                    conv_weight ,
+                    relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+            else:
+                return relay.multiply(
+                    conv_weight ,
+                    relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
         y0 = relay.nn.conv2d(x, fold_conv_weight(),
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y0 = relay.nn.relu(y0)
         y1 = relay.nn.conv2d(y0, fold_conv_weight(),
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y1 = relay.nn.relu(y1)
         y2 = relay.nn.conv2d(y0, fold_conv_weight(),
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y2 = relay.nn.relu(y2)
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
-    def check(shape, channels):
+    def check(shape, in_channels, channels, blocking):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[1]
         weight = relay.var("weight")
-        out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.const(_get_positive_scale((channels,1, 1)))
-
-        y1 = before(x, weight, out_bias, out_scale, channels)
+        if blocking:
+            out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
+            out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1])))
+        else:
+            out_bias = relay.var("out_bias", shape=(channels,))
+            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
+
+        y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
-        y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 10), 4)
-
+    check((2, 4, 10, 10), 4, 4, None)
+    check((2, 2, 10, 10, 2), 4, 4, (2, 2))
 
 def test_fold_bwd_fail():
     """Dual path testcase."""
-    def fail1(x, conv_weight, out_bias, out_scale, channels):
+    def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y1 = relay.nn.relu(y1)
         y2 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
                              padding=(1, 1),
-                             out_layout="CNHW")
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW",
+                             out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW")
         # fold will fail because the axis from two path
         # differs from each other.
         y2 = relay.nn.relu(y2)
@@ -458,99 +586,123 @@ def test_fold_bwd_fail():
         y = relay.multiply(y, out_scale)
         return relay.Function(args, y)
 
-    def fail2(x, conv_weight, out_bias, out_scale, channels):
+    def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
         args = [x, conv_weight, out_bias]
-        out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
         y1 = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y2 = relay.nn.relu(y1)
         # fold will fail because y1 is referred also by y2
         y1 = relay.multiply(y1, out_scale)
         y = relay.add(y1, y2)
         return relay.Function(args, y)
 
-    def check(shape, channels, fbefore):
+    def check(shape, in_channels, channels, blocking, fbefore):
         x =  relay.var("x", shape=shape)
-        in_channels = shape[1]
         weight = relay.var("weight")
-        out_bias = relay.var("out_bias", shape=(channels,))
-        out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
-        y1 = fbefore(x, weight, out_bias, out_scale, channels)
+        if blocking:
+            out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1]))
+            out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1])))
+        else:
+            out_bias = relay.var("out_bias", shape=(channels, 1, 1))
+            out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
+        y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         assert tvm.ir.structural_equal(y1_folded, y1)
 
-    check((4, 4, 10, 10), 4, fail1)
-    check((4, 4, 10, 10), 4, fail2)
+    check((4, 4, 10, 10), 4, 4, None, fail1)
+    check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1)
+    check((4, 4, 10, 10), 4, 4, None, fail2)
+    check((4, 2, 10, 10, 2), 4, 4, (2, 2), fail2)
 
 
 def test_fold_bwd_relu_fail():
     """testcase where we canont fold because scale can not pass relu"""
-    def before(x, conv_weight, out_scale, channels):
+    def before(x, conv_weight, out_scale, channels, blocking):
         y = relay.nn.conv2d(x, conv_weight,
                              channels=channels,
                              kernel_size=(3, 3),
-                             data_layout="NCHW",
-                             padding=(1, 1))
+                             padding=(1, 1),
+                             data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                             kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y = relay.nn.relu(y)
         y = relay.multiply(x, out_scale)
         return relay.Function(relay.analysis.free_vars(y), y)
 
-    def check(shape, channels, out_scale):
+    def check(shape, channels, blocking, out_scale):
         x =  relay.var("x", shape=shape)
         in_channels = shape[1]
         weight = relay.var("weight")
-        y1 = before(x, weight, out_scale, channels)
+        y1 = before(x, weight, out_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         assert tvm.ir.structural_equal(y1, y1_folded)
 
     out_scale = relay.var("in_scale", shape=(4, 1, 1))
-    check((4, 4, 10, 10), 4, out_scale)
+    check((4, 4, 10, 10), 4, None, out_scale)
     out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32")
-    check((4, 4, 10, 10), 4, out_scale)
+    check((4, 4, 10, 10), 4, None, out_scale)
+
+    out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2))
+    check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
+    out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype("float32")
+    check((4, 2, 10, 10, 2), 4, (2, 2), out_scale)
 
 
 def test_fold_bwd_negative_scale():
     """Testcase of folding negative scale"""
-    def before(x, conv_weight, out_scale, channels):
+    def before(x, conv_weight, out_scale, channels, blocking):
         args = [x, conv_weight]
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         y = relay.multiply(y, out_scale)
         return relay.Function(args, y)
 
-    def expected(x, conv_weight, out_scale, channels):
+    def expected(x, conv_weight, out_scale, channels, blocking):
         # use a fixed order of args so alpha equal check can pass
         args = [x, conv_weight]
-        squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
-        conv_weight = relay.multiply(
-            conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
+        if blocking:
+            squeezed_scale = relay.squeeze(out_scale, axis=[0,2,3])
+            conv_weight = relay.multiply(
+                conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1])))
+        else:
+            squeezed_scale = relay.squeeze(out_scale, axis=[1,2])
+            conv_weight = relay.multiply(
+                conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3))
         y = relay.nn.conv2d(x, conv_weight,
                             channels=channels,
                             kernel_size=(3, 3),
-                            padding=(1, 1))
+                            padding=(1, 1),
+                            data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW",
+                            kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW")
         return relay.Function(args, y)
 
-    def check(shape, channels):
+    def check(shape, channels, blocking):
         x =  relay.var("x", shape=shape)
         weight = relay.var("weight")
-        out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
-        y1 = before(x, weight, out_scale, channels)
+        if blocking:
+            out_scale = relay.const(-_get_positive_scale((1,channels//blocking[1], 1, 1, blocking[1])))
+        else:
+            out_scale = relay.const(-_get_positive_scale((channels, 1, 1)))
+        y1 = before(x, weight, out_scale, channels, blocking)
         y1 = run_opt_pass(y1, transform.InferType())
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
-        y1_expected = expected(x, weight, out_scale, channels)
+        y1_expected = expected(x, weight, out_scale, channels, blocking)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
         assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
-    check((2, 4, 10, 10), 8)
-
+    check((2, 4, 10, 10), 8, None)
+    check((2, 2, 10, 10, 2), 8, (2, 2))
 
 if __name__ == "__main__":
     test_fold_fwd_simple()
diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py
index 5ee691b..d1c607f 100644
--- a/topi/python/topi/x86/conv2d_alter_op.py
+++ b/topi/python/topi/x86/conv2d_alter_op.py
@@ -19,6 +19,7 @@
 
 import logging
 
+import re
 import tvm
 from tvm import te
 from tvm import relay
@@ -31,6 +32,9 @@ from ..nn.util import get_pad_tuple
 
 logger = logging.getLogger('topi')
 
+_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
+_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$")
+
 @conv2d_alter_layout.register("cpu")
 def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     target = tvm.target.Target.current(allow_none=False)
@@ -64,30 +68,33 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
 
     if topi_tmpl == "conv2d_NCHWc.x86":
         # we only convert conv2d_NCHW to conv2d_NCHWc for x86
-        assert data_layout == "NCHW" and kernel_layout == "OIHW"
-        if cfg.is_fallback:
-            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                out_dtype, False, data_layout)
-        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
-        out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
-        ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-
-        # update new attrs
-        new_attrs['channels'] = out_channel
-        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
-        # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
-        new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
-        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
-
-        # Store altered operator's config
-        new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                  dtype=data_dtype)
-        new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn,
-                                     kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
-        new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
-             new_attrs["out_layout"], out_dtype], topi_tmpl)
-        dispatch_ctx.update(target, new_workload, cfg)
+        if data_layout == "NCHW" and kernel_layout == "OIHW":
+            if cfg.is_fallback:
+                _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
+                                    out_dtype, False, data_layout)
+            batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+            out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
+            ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+
+            # update new attrs
+            new_attrs['channels'] = out_channel
+            new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+            # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
+            new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
+            new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+
+            # Store altered operator's config
+            new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
+                                      dtype=data_dtype)
+            new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn,
+                                         kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
+            new_workload = autotvm.task.args_to_workload(
+                [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"],
+                 new_attrs["out_layout"], out_dtype], topi_tmpl)
+            dispatch_ctx.update(target, new_workload, cfg)
+        else:
+            assert _NCHWc_matcher.match(data_layout)
+            assert _OIHWio_matcher.match(kernel_layout)
         return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
 
     if topi_tmpl == "conv2d_NCHWc_int8.x86":
@@ -136,30 +143,34 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs)
 
     if topi_tmpl == "depthwise_conv2d_NCHWc.x86":
-        assert data_layout == "NCHW" and kernel_layout == "OIHW"
-        if cfg.is_fallback:
-            _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
-                                out_dtype, True, data_layout)
-
-        batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
-        out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
-        ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
-        assert channel_multiplier == 1
-
-        # update new attrs
-        new_attrs['channels'] = out_channel
-        new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
-        new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
-        new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
-
-        # Store altered operator's config.
-        new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
-                                  dtype=data_dtype)
-        new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype)
-        new_workload = autotvm.task.args_to_workload(
-            [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
-             new_attrs['out_layout'], out_dtype], topi_tmpl)
-        dispatch_ctx.update(target, new_workload, cfg)
+        if data_layout == "NCHW" and kernel_layout == "OIHW":
+            if cfg.is_fallback:
+                _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding,
+                                    out_dtype, True, data_layout)
+
+            batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+            out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
+            ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+            assert channel_multiplier == 1
+
+            # update new attrs
+            new_attrs['channels'] = out_channel
+            new_attrs['data_layout'] = 'NCHW%dc' % ic_bn
+            new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn
+            new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
+
+            # Store altered operator's config.
+            new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
+                                      dtype=data_dtype)
+            new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn),
+                                        dtype=kernel_dtype)
+            new_workload = autotvm.task.args_to_workload(
+                [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'],
+                 new_attrs['out_layout'], out_dtype], topi_tmpl)
+            dispatch_ctx.update(target, new_workload, cfg)
+        else:
+            assert _NCHWc_matcher.match(data_layout)
+            assert _OIHWio_matcher.match(kernel_layout)
         return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
 
     return None


Mime
View raw message