tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zha...@apache.org
Subject [incubator-tvm] branch master updated: [RFC] Improve quantized convolution performance for armv8 architectures (#5754)
Date Tue, 23 Jun 2020 05:20:16 GMT
This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new b94e8b7  [RFC] Improve quantized convolution performance for armv8 architectures (#5754)
b94e8b7 is described below

commit b94e8b7290c5ced98728e730634ec73727c53c51
Author: Giuseppe Rossini <giuseppe.rossini@arm.com>
AuthorDate: Tue Jun 23 06:20:05 2020 +0100

    [RFC] Improve quantized convolution performance for armv8 architectures (#5754)
    
    * Improve quantized conv2d performance for armv8
    
    Signed-off-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
    Change-Id: I3a3d29f5332dd9b3354e8e0dfb24677a521f9c8f
    
    * Add ASF header to conv2d_gemm.py
    
    Change-Id: I33853279e39c849ae1b555a9c91d7557985a0a35
    
    * Run clang-format-10 on c++ files
    
    Change-Id: Ieee22f032e595dabfc1616ab33466fcbf8d94365
    
    * Fix pylint errors/warnings
    
    Change-Id: I435d4d7bca7500db99547f4401fdc0d0995a1ff4
    
    * Fix pylint errors/warnings in topi
    
    Change-Id: I2fc1ad8453e9020072ab967c849df5390c2967b5
    
    * Fix legalizations tests for aarch64
    
    Change-Id: I0a67a49a7849f52ef7d57b9292ce9125bbb7cb2c
    
    * Reintroduce conv2d_nhwc_spatial_pack.arm_cpu and int16 cast
    
    Change-Id: I91b67fabd475e90a9b75f2dd5ecfee851265e0bb
    
    * Switch type of legalization depending on the strategy used
    
    Change-Id: I9a03040a8c40a6cd2658ed14c3751e05a8e19f2b
    
    * Revert last commit
    
    Change-Id: Ice34101e358e3ce8ebfb12c58f73e910ba5de8e8
    
    * Fix the auto-tuner by registering the correct schedules
    
    Change-Id: Id9273688b2620e1ea849ab01b4c46af8fbf37fd0
    
    * Address review comments
    
    Change-Id: Ia1755a0af7b6d159072d9f0c93c932c481101e48
    
    * Improve usability and readability of conv2d_gemm_weight_transform
    
    Change-Id: I3333186bbc2fe4054b58ce15d910e3be7b315482
    
    * Change variable name to weight in Conv2DGemmWeightTransformRel
    
    Change-Id: Ifb5f1f33af7512fe67c6b049b20a42a0bb2d26c9
    
    * Fix clang-10 linting errors
    
    Change-Id: I25ccc844d9cee23766096e1daddb6180abc413a6
    
    * Trigger tests
    
    Change-Id: Id37706fb7cf77a87a3cc817ecf8046297d9ca95a
---
 include/tvm/relay/attrs/nn.h                |  11 +
 python/tvm/relay/op/nn/_nn.py               |  17 ++
 python/tvm/relay/op/nn/nn.py                |  91 ++++++++
 python/tvm/relay/op/strategy/arm_cpu.py     |  42 ++++
 python/tvm/relay/op/strategy/generic.py     |  13 ++
 python/tvm/relay/qnn/op/legalizations.py    |   8 +-
 src/relay/op/nn/convolution.cc              |  82 +++++++
 src/relay/op/nn/convolution.h               | 131 +++++++++++
 topi/python/topi/arm_cpu/conv2d_alter_op.py |  65 +++++-
 topi/python/topi/arm_cpu/conv2d_gemm.py     | 174 ++++++++++++++
 topi/python/topi/arm_cpu/conv2d_int8.py     |  38 +++-
 topi/python/topi/arm_cpu/tensor_intrin.py   | 339 ++++++++++++++++++++++++++++
 topi/python/topi/generic/nn.py              |  19 ++
 topi/python/topi/nn/conv2d.py               |  49 ++++
 14 files changed, 1065 insertions(+), 14 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index abe63e5..5f1ee2f 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -187,6 +187,17 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode<ConvWinogradWeig
   }
 };
 
+/*! \brief Attributes used in gemm weight transformation operators */
+struct ConvGemmWeightTransformAttrs : public tvm::AttrsNode<ConvGemmWeightTransformAttrs> {
+  int tile_rows;
+  int tile_cols;
+
+  TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") {
+    TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm.");
+    TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm.");
+  }
+};
+
 /*! \brief Attributes used in convolution operators with winograd algorithm */
 struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
   int tile_size;
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 1c76f57..564d6f7 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -446,6 +446,23 @@ reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
 reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
                      OpPattern.OUT_ELEMWISE_FUSABLE)
 
+# conv2d_gemm related operators
+reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform",
+                      strategy.conv2d_gemm_without_weight_transform_strategy)
+reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform",
+                     OpPattern.OUT_ELEMWISE_FUSABLE)
+
+@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
+def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
+    """Compute definition of contrib_conv2d_gemm_weight_transform"""
+    out = topi.nn.conv2d_gemm_weight_transform(
+        inputs[0], attrs.tile_rows, attrs.tile_cols)
+    return [out]
+
+reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform",
+                      strategy.schedule_conv2d_gemm_weight_transform)
+reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform",
+                     OpPattern.OUT_ELEMWISE_FUSABLE)
 
 @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
 def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 34d07dc..3c47cf7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2046,6 +2046,74 @@ def contrib_conv2d_winograd_without_weight_transform(data,
         kernel_layout, out_layout, out_dtype)
 
 
+def contrib_conv2d_gemm_without_weight_transform(data,
+                                                 weight,
+                                                 strides=(1, 1),
+                                                 padding=(0, 0),
+                                                 dilation=(1, 1),
+                                                 groups=1,
+                                                 channels=None,
+                                                 kernel_size=None,
+                                                 data_layout="NCHW",
+                                                 kernel_layout="OIHW",
+                                                 out_layout="",
+                                                 out_dtype=""):
+    r"""2D convolution with gemm algorithm.
+
+    The basic parameters are the same as the ones in vanilla conv2d.
+    It assumes the weight is pre-transformed by nn.contrib_conv2d_gemm_weight_transform
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data to the operator.
+
+    weight : tvm.relay.Expr
+        The weight expressions.
+
+    strides : tuple of int, optional
+        The strides of convolution.
+
+    padding : tuple of int, optional
+        The padding of convolution on both sides of inputs before convolution.
+
+    dilation : tuple of int, optional
+        Specifies the dilation rate to be used for dilated convolution.
+
+    groups : int, optional
+        Number of groups for grouped convolution.
+
+    channels : int, optional
+        Number of output channels of this convolution.
+
+    kernel_size : tuple of int, optional
+        The spatial of the convolution kernel.
+
+    data_layout : str, optional
+        Layout of the input.
+
+    kernel_layout : str, optional
+        Layout of the weight.
+
+    out_layout : str, optional
+        Layout of the output, by default, out_layout is the same as data_layout
+
+    out_dtype : str, optional
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    # convert 2-way padding to 4-way padding
+    padding = get_pad_tuple2d(padding)
+    return _make.contrib_conv2d_gemm_without_weight_transform(
+        data, weight, strides, padding, dilation,
+        groups, channels, kernel_size, data_layout,
+        kernel_layout, out_layout, out_dtype)
+
+
 def contrib_conv2d_nchwc(data,
                          kernel,
                          strides=(1, 1),
@@ -2204,6 +2272,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
     return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
 
 
+def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
+    r"""Weight Transformation part for 2D convolution with gemm algorithm.
+
+    We separate this as a single op to enable pre-compute for inference.
+    Use this together with nn.contrib_conv2d_gemm_without_weight_transform
+
+    Parameters
+    ----------
+    weights : tvm.relay.Expr
+        The weight expressions.
+    tile_rows: int
+        Tile rows of the weight transformation for ConvGemm.
+    tile_cols: int
+       Tile columns of the weight transformation for ConvGemm.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The computed result.
+    """
+    return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols)
+
+
 def contrib_conv3d_winograd_weight_transform(weight,
                                              tile_size):
     r"""Weight Transformation part for 3D convolution with winograd algorithm.
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py
index 6bdec67..d682aad 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -112,6 +112,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd),
                     name='conv2d_direct_simd.micro_dev')
             elif kernel_layout == "HWIO":
+                is_aarch64 = "aarch64" in str(isa.target)
+
+                if is_aarch64 and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
+                        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
+                        name="conv2d_NHWC_quantized.arm_cpu")
+
                 strategy.add_implementation(
                     wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
                     wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
@@ -246,6 +254,40 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out
                            format(layout))
     return strategy
 
+def wrap_compute_conv2d_gemm(topi_compute):
+    """wrap topi compute for conv2d_gemm"""
+
+    def _compute_conv2d_gemm(attrs, inputs, out_type):
+        padding = attrs.get_int_tuple("padding")
+        strides = attrs.get_int_tuple("strides")
+        dilation = attrs.get_int_tuple("dilation")
+        out_dtype = attrs.get_str("out_dtype")
+        channels = attrs['channels']
+        kernel_size = attrs['kernel_size']
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+        return [topi_compute(inputs[0], inputs[1], strides, padding,
+                             dilation, out_dtype, kernel_size, channels)]
+
+    return _compute_conv2d_gemm
+
+@conv2d_gemm_without_weight_transform_strategy.register("arm_cpu")
+def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target):
+    """conv2d_winograd_without_weight_transfrom arm cpu strategy"""
+    layout = attrs.data_layout
+    data = inputs[0]
+    strategy = _op.OpStrategy()
+
+    if layout == "NHWC" and data.dtype in ['int8', 'uint8']:
+        strategy.add_implementation(
+            wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
+            wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
+            name="conv2d_NHWC_quantized_without_transform.arm_cpu")
+    else:
+        raise RuntimeError(
+            "Unsupported conv2d_gemm_without_weight_transform layout {0} with datatype {1}".
+            format(layout, data.dtype))
+    return strategy
+
 @conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"])
 def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
     """conv2d_transpose arm cpu strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index b1fb421..a0dd6bf 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -266,6 +266,12 @@ def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, t
     """conv2d_winograd_without_weight_transfrom generic strategy"""
     raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")
 
+# conv2d_gemm_without_weight_transform
+@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
+def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
+    """conv2d_gemm_without_weight_transfrom generic strategy"""
+    raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")
+
 # conv2d_winograd_weight_transform
 @generic_func
 def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
@@ -280,6 +286,13 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
     with target:
         return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
 
+# conv2d_gemm_weight_transform
+@generic_func
+def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
+    """Schedule conv2d_gemm_weight_transform"""
+    with target:
+        return topi.generic.schedule_conv2d_gemm_weight_transform(outs)
+
 # deformable_conv2d
 def wrap_compute_deformable_conv2d(topi_compute):
     """wrap deformable_conv2d topi compute"""
diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py
index d3b0e44..7246214 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -237,6 +237,11 @@ def is_fast_int8_on_arm():
     target = tvm.target.Target.current(allow_none=False)
     return '+v8.2a,+dotprod' in ' '.join(target.options)
 
+def is_aarch64_arm():
+    """ Checks whether we are compiling for an AArch64 target. """
+    target = tvm.target.Target.current(allow_none=False)
+    return 'aarch64' in ' '.join(target.options)
+
 ########################
 # ARM CPU legalizations.
 ########################
@@ -244,10 +249,11 @@ def is_fast_int8_on_arm():
 @qnn_conv2d_legalize.register('arm_cpu')
 def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
     # ARM prefers the dtypes to be same.
-    if is_fast_int8_on_arm():
+    if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm():
         return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
     return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
 
+
 @qnn_dense_legalize.register('arm_cpu')
 def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
     # ARM prefers the dtypes to be same.
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 6c6eb1e..f63c489 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -77,6 +77,26 @@ Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> st
   return Call(op, {data, weight}, Attrs(attrs), {});
 }
 
+template <typename T>
+Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+                  Array<IndexExpr> dilation, int groups, IndexExpr channels,
+                  Array<IndexExpr> kernel_size, std::string data_layout, std::string kernel_layout,
+                  std::string out_layout, DataType out_dtype, std::string op_name) {
+  auto attrs = make_object<T>();
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->dilation = std::move(dilation);
+  attrs->groups = groups;
+  attrs->channels = std::move(channels);
+  attrs->kernel_size = std::move(kernel_size);
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  const Op& op = Op::Get(op_name);
+  return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
 Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) {
   auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
   attrs->tile_size = tile_size;
@@ -84,6 +104,14 @@ Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_
   return Call(op, {weight}, Attrs(attrs), {});
 }
 
+Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) {
+  auto attrs = make_object<ConvGemmWeightTransformAttrs>();
+  attrs->tile_rows = tile_rows;
+  attrs->tile_cols = tile_cols;
+  const Op& op = Op::Get(op_name);
+  return Call(op, {weight}, Attrs(attrs), {});
+}
+
 template <typename T>
 Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
                        Array<IndexExpr> dilation, int groups, IndexExpr channels,
@@ -504,6 +532,60 @@ weight transformation in advance.
     .set_support_level(10)
     .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
 
+// relay.nn.contrib_conv2d_gemm_without_weight_transform
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform")
+    .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+                       Array<IndexExpr> dilation, int groups, IndexExpr channels,
+                       Array<IndexExpr> kernel_size, std::string data_layout,
+                       std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+      return MakeConvGemm<Conv2DAttrs>(
+          data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
+          kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform");
+    });
+
+RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform")
+    .describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout.
+                 This operator assumes the weight tensor is already pre-transformed by
+                 nn.contrib_conv2d_gemm_weight_transform.
+
+- **data**: Input is 4D array of shape  (batch_size, height, width, in_channels)
+- **weight**: Any shape
+            We do not check the shape for this input tensor. Since different backend
+            has different layout strategy.
+
+- **out**:  Output is 4D array of shape (batch_size, channels, out_height, out_width)
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<Conv2DAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_support_level(10)
+    .add_type_rel("Conv2DGemm", Conv2DGemmRel<Conv2DAttrs>)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
+
+// relay.nn.contrib_conv2d_gemm_weight_transform
+
+TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform")
+    .set_body_typed([](Expr weights, int tile_rows, int tile_cols) {
+      return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols,
+                                         "nn.contrib_conv2d_gemm_weight_transform");
+    });
+
+RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform")
+    .describe(R"code(Weight transformation of GEMM convolution algorithm.
+
+Separate this into another operator in order to enable Precompute Pass to compute the
+weight transformation in advance.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<ConvGemmWeightTransformAttrs>()
+    .set_num_inputs(1)
+    .add_argument("weights", "Tensor", "The weights tensor.")
+    .set_support_level(10)
+    .add_type_rel("Conv2DGemmWeightTransform", Conv2DGemmWeightTransformRel);
+
 // Positional relay function to create conv2d NCHWc operator
 // used by frontend FFI.
 TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index 0c5b20a..f53f4e0 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -383,6 +383,65 @@ inline bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_i
   return true;
 }
 
+// Gemm convolution shape relations
+// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W.
+// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and
+// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols]
+// matrix that we call W_interleaved_t.
+//
+// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed
+// for tile_rows = 4 and tile_cols = 16
+//
+//              W[0,0,:,:]                        W_interleaved_t[0,0,:,:]
+//  +-------------------------------+     +----------------------------------- +
+//  |W[0,0]  W[0,1]  W[0,2]  W[0,3] |     |W[0,0]  W[1,0]  W[2,0]  ...  W[15,0]|
+//  |W[1,0]  W[1,1]  W[1,2]  W[1,3] | --\ |W[0,1]  W[1,1]  W[2,1]  ...  W[15,1]|
+//  |W[2,0]  W[2,1]  W[2,2]  W[2,3] | --/ |W[0,2]  W[1,2]  W[2,2]  ...  W[15,2]|
+//  |  ...     ...    ...      ...  |     |W[0,3]  W[1,3]  W[2,3]  ...  W[15,3]|
+//  |  ...     ...    ...      ...  |     +------------------------------------+
+//  |W[15,0] W[15,1] W[15,2] W[15,3]|
+//  +-------------------------------+
+//
+// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements
+// at the time, we should set tile_cols = k.
+// Tile rows is connected with the number of registers available for the given target.
+//
+inline bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs,
+                                         const Attrs& attrs, const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* weight = types[0].as<TensorTypeNode>();
+  if (weight == nullptr) return false;
+
+  const ConvGemmWeightTransformAttrs* param = attrs.as<ConvGemmWeightTransformAttrs>();
+  CHECK(param != nullptr);
+  int n = param->tile_rows;
+  int k = param->tile_cols;
+
+  CHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout";
+
+  const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2];
+  const auto N = weight->shape[3];
+
+  auto K_mod_k = indexmod(K, k);
+  auto N_mod_n = indexmod(N, n);
+
+  auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32)));
+  auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32)));
+
+  const auto N_padded = N + pad_N;
+  const auto K_padded = K + pad_K;
+
+  Array<IndexExpr> oshape{
+      indexdiv(N_padded, n),
+      indexdiv(K_padded, k),
+      n,
+      k,
+  };
+
+  reporter->Assign(types[1], TensorType(oshape, weight->dtype));
+  return true;
+}
+
 inline bool Conv3DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
                                              const Attrs& attrs, const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 2);
@@ -520,6 +579,78 @@ bool Conv2DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& at
 }
 
 template <typename AttrType>
+bool Conv2DGemmRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                   const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  static const Layout kNHWC("NHWC");
+  static const Layout kHWIO("HWIO");
+
+  const AttrType* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->data_layout);
+  const Layout kernel_layout(param->kernel_layout);
+
+  const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC);
+  CHECK(trans_in_layout.defined())
+      << "Conv only support input layouts that are convertible from NHWC."
+      << " But got " << in_layout;
+
+  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO);
+  CHECK(trans_kernel_layout.defined())
+      << "Conv only support kernel layouts that are convertible from HWIO."
+      << " But got " << kernel_layout;
+
+  Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+  const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC);
+  CHECK(trans_out_layout.defined())
+      << "Conv only support output layouts that are convertible from NHWC."
+      << " But got " << out_layout;
+
+  Array<IndexExpr> dshape_nhwc = trans_in_layout.ForwardShape(data->shape);
+
+  IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+  CHECK(param->kernel_size.defined() && param->channels.defined())
+      << "The kernel size and channels of a Conv must be set or inferred by previous pass";
+
+  CHECK_EQ(param->kernel_size.size(), 2);
+  CHECK_EQ(param->dilation.size(), 2);
+
+  channels = param->channels;
+  dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+  dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+
+  // NOTE: Do not check weight shape here!
+
+  // dilation
+  Array<IndexExpr> oshape({dshape_nhwc[0], 0, 0, channels});
+
+  IndexExpr pad_h, pad_w;
+  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+  if (!dshape_nhwc[2].as<tir::AnyNode>()) {
+    oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1);
+  } else {
+    oshape.Set(1, dshape_nhwc[1]);
+  }
+  if (!dshape_nhwc[3].as<tir::AnyNode>()) {
+    oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1);
+  } else {
+    oshape.Set(2, dshape_nhwc[2]);
+  }
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+  oshape = trans_out_layout.BackwardShape(oshape);
+  // assign output type
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return true;
+}
+
+template <typename AttrType>
 bool Conv3DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                        const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 3);
diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py
index 3206168..99fdf21 100644
--- a/topi/python/topi/arm_cpu/conv2d_alter_op.py
+++ b/topi/python/topi/arm_cpu/conv2d_alter_op.py
@@ -59,10 +59,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     data, kernel = tinfos
     out_dtype = out_type.dtype
 
-    # We only perform layout alteration for NCHW data layout.
-    if data_layout == "NHWC":
-        return None
-
     # Extract data types
     data_tensor, kernel_tensor = tinfos
     data_dtype = data_tensor.dtype
@@ -70,6 +66,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
 
     idxd = tvm.tir.indexdiv
 
+    # We don't perform layout alteration for NHWC layout with real data types
+    if data_layout == "NHWC" and data_dtype not in ['uint8', 'int8']:
+        return None
+
     if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
         N, CI, H, W = get_const_tuple(data.shape)
@@ -88,21 +88,27 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         return relay.nn.conv2d(*inputs, **new_attrs)
 
     if topi_tmpl == "conv2d_nhwc_spatial_pack.arm_cpu":
+        assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
+                data.dtype == 'uint8' and kernel.dtype == 'uint8')
+
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
-        N, H, W, CI = get_const_tuple(data.shape)
-        KH, KW, _, CO = get_const_tuple(kernel.shape)
-        VC = cfg['tile_co'].size[-1]
 
-        new_attrs['kernel_layout'] = 'OHWI%do' % VC
+        data_expr, kernel_expr = inputs
+
+        data_int16 = relay.cast(data_expr, dtype='int16')
+        kernel_int16 = relay.cast(kernel_expr, dtype='int16')
+
+        new_attrs = {k : attrs[k] for k in attrs.keys()}
+
+        new_data = te.placeholder(data.shape, 'int16')
+        new_kernel = te.placeholder(kernel.shape, 'int16')
 
-        new_data = data
-        new_kernel = te.placeholder((idxd(CO, VC), KH, KW, CI, VC), dtype=kernel.dtype)
         new_workload = autotvm.task.args_to_workload(
             [new_data, new_kernel, strides, padding, dilation, out_dtype],
-            "conv2d_nhwc_spatial_pack.arm_cpu")
+            'conv2d_nhwc_spatial_pack.arm_cpu')
         dispatch_ctx.update(target, new_workload, cfg)
 
-        return relay.nn.conv2d(*inputs, **new_attrs)
+        return relay.nn.conv2d(data_int16, kernel_int16, **new_attrs)
 
     if topi_tmpl == "conv2d_nchw_winograd.arm_cpu":
         assert data_layout == "NCHW" and kernel_layout == "OIHW"
@@ -235,5 +241,40 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
              new_attrs['out_layout'], out_dtype], topi_tmpl)
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
+    if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
+        assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
+                data.dtype == 'uint8' and kernel.dtype == 'uint8')
+        assert data_layout == "NHWC" and kernel_layout == "HWIO"
+        CO, IC, KH, KW = get_const_tuple(kernel.shape)
+        K = KH * KW * IC
+        N = CO
+
+        tile_rows = 4
+        tile_cols = 16
+        pad_K = 0
+        pad_N = 0
+
+        if N % tile_rows != 0:
+            pad_N = tile_rows - (N % tile_rows)
+        if K % tile_cols != 0:
+            pad_k = tile_cols - (K % tile_cols)
+
+        N_padded = N + pad_N
+        K_padded = K + pad_K
+        kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
+        new_kernel = te.placeholder((N_padded // tile_rows,
+                                     K_padded // tile_cols,
+                                     tile_rows,
+                                     tile_cols), kernel.dtype)
+
+        new_workload = autotvm.task.args_to_workload([data, new_kernel,
+                                                      strides, padding, dilation,
+                                                      out_dtype, (KH, KW), CO],
+                                                     "conv2d_NHWC_int8_without_tranform.arm_cpu")
+        dispatch_ctx.update(target, new_workload, cfg)
+
+        return relay.nn.contrib_conv2d_gemm_without_weight_transform(inputs[0],
+                                                                     kernel_expr,
+                                                                     **new_attrs)
 
     return None
diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py
new file mode 100644
index 0000000..2b61229
--- /dev/null
+++ b/topi/python/topi/arm_cpu/conv2d_gemm.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+# pylint: disable=unused-argument, redefined-builtin
+"""GEMM Convolution schedule on ARM"""
+import tvm
+from tvm import te
+from topi import nn
+from ..util import get_const_tuple
+from ..nn.util import get_pad_tuple
+from .tensor_intrin import gemv_quantized, gemv_quantized_impl
+
+
+# Compute function
+def compute_conv2d_gemm_without_weight_transform(cfg,
+                                                 data, B_interleaved_t, strides, padding, dilation,
+                                                 out_dtype, kernel_size, output_channels):
+    """Compute conv2d by transforming the input,
+    executing GEMM and transforming the output back"""
+    batches, IH, IW, IC = get_const_tuple(data.shape)
+
+    KH, KW = kernel_size
+    OC = output_channels
+
+    K_AREA = KH * KW
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = \
+        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    # --- Im2col
+    M = OH * OW
+    K = IC * K_AREA
+    N = OC
+
+    A_shape = (batches, M, K)
+    if K_AREA == 1:
+        A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y],
+                       name='data_flatten')
+    else:
+        A = te.compute(A_shape, lambda n, x, y:
+                       data_pad[n,
+                                HSTR * (x // OW) + dilation_h * (y // IC) // KW,
+                                WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC],
+                       name='data_im2col')
+    N_transformed = B_interleaved_t.shape[0]
+
+    # --- Pad if necessary
+    idxm = tvm.tir.indexmod
+
+    pad_m = 0
+    pad_k = 0
+
+    if M % 4 != 0:
+        pad_m = 4 - (M % 4)
+
+    if K % 16 != 0:
+        pad_k = 16 - (K % 16)
+
+    M_padded = M + pad_m
+    K_padded = K + pad_k
+
+    pad_before = (0, 0, 0)
+    pad_after = (0, pad_m, pad_k)
+
+    if pad_m != 0 or pad_k != 0:
+        A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
+
+    # --- GEMM: A*B'
+    k = te.reduce_axis((0, K_padded), "k")
+
+    A_interleaved = te.compute((batches, M_padded // 4, K_padded // 16, 4, 16),
+                               lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
+                               name='A_interleaved')
+
+    C_interleaved = te.compute((batches, M_padded // 4, N_transformed, 4, 4),
+                               lambda b, x, y, w, z:
+                               te.sum(A_interleaved[b, x, k//16, w, idxm(k, 16)].astype(out_dtype)*
+                                      B_interleaved_t[y, k//16, z, idxm(k, 16)].astype(out_dtype),
+                                      axis=k),
+                               name='C_interleaved')
+
+    # --- Unpack C
+    C = te.compute((batches, M, N),
+                   lambda b, x, y:
+                   C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
+                   name="C", tag='injective')
+
+    # --- Produce the conv output
+    out_shape = (batches, OH, OW, OC)
+    out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
+                     name='conv2d_gemm_output')
+
+    return out
+
+# Schedules
+def schedule_conv2d_gemm(cfg, s, out):
+    """Create schedule for tensors"""
+    C = out.op.input_tensors[0]
+    C_interleaved = C.op.input_tensors[0]
+    A_interleaved = C_interleaved.op.input_tensors[0]
+
+    # Input transform
+    A_interleaved_input = A_interleaved.op.input_tensors[0]
+    if A_interleaved_input.op.name == "A_padded":
+        s[A_interleaved_input].compute_at(s[A_interleaved], A_interleaved.op.axis[3])
+        s[A_interleaved_input].vectorize(A_interleaved_input.op.axis[2])
+        s[A_interleaved_input].compute_inline()
+        data_im2col = A_interleaved_input.op.input_tensors[0]
+    else:
+        data_im2col = A_interleaved_input
+
+    b, m, n = data_im2col.op.axis
+    if data_im2col.op.name == "data_im2col":
+        n_outer, n_inner = s[data_im2col].split(n, 16)
+        s[data_im2col].unroll(n_outer)
+        s[data_im2col].vectorize(n_inner)
+    else:
+        s[data_im2col].compute_inline()
+
+    # Computation(through tensorize)
+    b, xo, yo, xi, yi = C_interleaved.op.axis
+    s[C_interleaved].reorder(xo, yo, yi, xi)
+    s[C_interleaved].parallel(xo)
+    s[A_interleaved].compute_at(s[C_interleaved], xo)
+    s[A_interleaved].vectorize(A_interleaved.op.axis[4])
+
+    in_type = A_interleaved.dtype
+    out_type = C.dtype
+    if out_type == 'int32':
+        K = A_interleaved_input.shape[2]
+        _, M, N = C.shape
+        assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"
+
+        gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type)
+        s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type))
+        s[C_interleaved].tensorize(yi, gem_v_dotprod)
+
+    # Output transform
+    N, OH, OW, OC = out.shape
+    s[C].split(C.op.axis[1], OW)
+    s[C].compute_at(s[out], out.op.axis[3])
+
+    return s
diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py
index 06412b6..5a895c0 100644
--- a/topi/python/topi/arm_cpu/conv2d_int8.py
+++ b/topi/python/topi/arm_cpu/conv2d_int8.py
@@ -19,11 +19,12 @@
 from tvm import te
 from tvm import autotvm
 from .. import tag
-from ..util import get_const_tuple
+from ..util import traverse_inline, get_const_tuple
 from ..generic import conv2d as conv2d_generic
 from .. import nn
 from ..nn.conv2d import _get_workload as _get_conv2d_workload
 from .tensor_intrin import dot_int8_int8_int32
+from .conv2d_gemm import compute_conv2d_gemm_without_weight_transform, schedule_conv2d_gemm
 
 
 def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
@@ -109,3 +110,38 @@ def schedule_conv2d_NCHWc_int8(cfg, outs):
 
     traverse(outs[0].op)
     return s
+
+
+@autotvm.register_topi_compute("conv2d_NHWC_quantized.arm_cpu")
+def compute_conv2d_NHWC_quantized(cfg, data, kernel, strides, padding, dilation, out_dtype):
+    N, IH, IW, IC = get_const_tuple(data.shape)
+    KH, KW, _, OC = get_const_tuple(kernel.shape)
+    tile_rows = 4
+    tile_cols = 16
+    kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols)
+    return  compute_conv2d_gemm_without_weight_transform(cfg,
+                                                         data, kernel, strides, padding,
+                                                         dilation, out_dtype, (KH, KW), OC)
+
+
+@autotvm.register_topi_compute("conv2d_NHWC_quantized_without_transform.arm_cpu")
+def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, padding,
+                                                    dilation, out_dtype, kernel_size=None,
+                                                    output_channels=None):
+    return  compute_conv2d_gemm_without_weight_transform(cfg, data, B, strides, padding,
+                                                         dilation, out_dtype, kernel_size,
+                                                         output_channels)
+
+
+@autotvm.register_topi_schedule("conv2d_NHWC_quantized.arm_cpu")
+def schedule_conv2d_NHWC_quantized(cfg, outs):
+    """Create schedule for tensors"""
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        """Traverse operators from computation graph"""
+        if op.name == "conv2d_gemm_output":
+            schedule_conv2d_gemm(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py
index da9c71a..6ef2548 100644
--- a/topi/python/topi/arm_cpu/tensor_intrin.py
+++ b/topi/python/topi/arm_cpu/tensor_intrin.py
@@ -19,6 +19,345 @@
 
 import tvm
 from tvm import te
+from tvm.contrib import util, clang
+
+def gemv_quantized_impl(M, N, data_type='uint8'):
+    """ Assembly implementation of a blocked gemv. Given
+    a block a of shape (4, k) and a block b' of shape (4, k)
+    produces the output block c = a*b of shape (4,4) """
+
+    stepA = min(4, M)
+    stepB = min(4, N)
+    assert data_type in ['uint8', 'int8'], 'Only uint8/int8 supported for this implementation'
+
+    cc_code = """
+          extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer,
+                                                    unsigned char *a_buffer,
+                                                    unsigned char *b_buffer,
+                                                    int K, int m, int n)
+              """.format(data_type, stepA, stepB)
+
+    cc_code += """
+    {
+            unsigned char * a_ptr = a_buffer;
+            unsigned char * b_ptr = b_buffer;
+            int * c_ptr = c_buffer;
+
+            int k = K / 16;
+
+            __asm__  __volatile__ (
+                "movi v16.4s, #0\\n"
+                "movi v17.4s, #0\\n"
+                "movi v18.4s, #0\\n"
+                "movi v19.4s, #0\\n"
+                "movi v20.4s, #0\\n"
+                "movi v21.4s, #0\\n"
+                "movi v22.4s, #0\\n"
+                "movi v23.4s, #0\\n"
+                "movi v24.4s, #0\\n"
+                "movi v25.4s, #0\\n"
+                "movi v26.4s, #0\\n"
+                "movi v27.4s, #0\\n"
+                "movi v28.4s, #0\\n"
+                "movi v29.4s, #0\\n"
+                "movi v30.4s, #0\\n"
+                "movi v31.4s, #0\\n"
+            "1:"
+    """
+
+    cc_code += ' "ldr q0, [%[a_ptr]]\\n" '
+
+    if M > 1:
+        cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" '
+    else:
+        cc_code += ' "movi v1.4s, #0\\n" '
+
+    if M > 2:
+        cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" '
+    else:
+        cc_code += ' "movi v2.4s, #0\\n" '
+
+    if M > 3:
+        cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" '
+    else:
+        cc_code += ' "movi v3.4s, #0\\n" '
+
+    cc_code += ' "ldr q4, [%[b_ptr]]\\n" '
+
+    if N > 1:
+        cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" '
+
+    if N > 2:
+        cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" '
+
+    if N > 3:
+        cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" '
+
+    cc_code += """
+                // First half
+                // Higher part of a0 * {b0,b1,b2,b3}
+                "umull v8.8h, v0.8b, v4.8b\\n"
+                "umull v9.8h, v0.8b, v5.8b\\n"
+                "umull v10.8h, v0.8b, v6.8b\\n"
+                "umull v11.8h, v0.8b, v7.8b\\n"
+
+                // Higher part of a1 * {b0,b1,b2,b3}
+                "umull v12.8h, v1.8b, v4.8b\\n"
+                "umull v13.8h, v1.8b, v5.8b\\n"
+                "umull v14.8h, v1.8b, v6.8b\\n"
+                "umull v15.8h, v1.8b, v7.8b\\n"
+
+                // Accumulate
+                "uadalp v16.4s, v8.8h\\n"
+                "uadalp v17.4s, v9.8h\\n"
+                "uadalp v18.4s, v10.8h\\n"
+                "uadalp v19.4s, v11.8h\\n"
+                "uadalp v20.4s, v12.8h\\n"
+                "uadalp v21.4s, v13.8h\\n"
+                "uadalp v22.4s, v14.8h\\n"
+                "uadalp v23.4s, v15.8h\\n"
+
+                // Lower part of a0 * {b0,b1,b2,b3}
+                "umull2 v8.8h, v0.16b, v4.16b\\n"
+                "umull2 v9.8h, v0.16b, v5.16b\\n"
+                "umull2 v10.8h, v0.16b, v6.16b\\n"
+                "umull2 v11.8h, v0.16b, v7.16b\\n"
+
+                // Lower part of a1 * {b0,b1,b2,b3}
+                "umull2 v12.8h, v1.16b, v4.16b\\n"
+                "umull2 v13.8h, v1.16b, v5.16b\\n"
+                "umull2 v14.8h, v1.16b, v6.16b\\n"
+                "umull2 v15.8h, v1.16b, v7.16b\\n"
+
+                 // Accumulate again
+                "uadalp v16.4s, v8.8h\\n"
+                "uadalp v17.4s, v9.8h\\n"
+                "uadalp v18.4s, v10.8h\\n"
+                "uadalp v19.4s, v11.8h\\n"
+                "uadalp v20.4s, v12.8h\\n"
+                "uadalp v21.4s, v13.8h\\n"
+                "uadalp v22.4s, v14.8h\\n"
+                "uadalp v23.4s, v15.8h\\n"
+
+                // Second half
+
+                // Lower part of a2 * {b0,b1,b2,b3}
+                "umull v8.8h, v2.8b, v4.8b\\n"
+                "umull v9.8h, v2.8b, v5.8b\\n"
+                "umull v10.8h, v2.8b, v6.8b\\n"
+                "umull v11.8h, v2.8b, v7.8b\\n"
+
+                // Lower part of a3 * {b0,b1,b2,b3}
+                "umull v12.8h, v3.8b, v4.8b\\n"
+                "umull v13.8h, v3.8b, v5.8b\\n"
+                "umull v14.8h, v3.8b, v6.8b\\n"
+                "umull v15.8h, v3.8b, v7.8b\\n"
+
+                // Accumulate
+                "uadalp v24.4s, v8.8h\\n"
+                "uadalp v25.4s, v9.8h\\n"
+                "uadalp v26.4s, v10.8h\\n"
+                "uadalp v27.4s, v11.8h\\n"
+                "uadalp v28.4s, v12.8h\\n"
+                "uadalp v29.4s, v13.8h\\n"
+                "uadalp v30.4s, v14.8h\\n"
+                "uadalp v31.4s, v15.8h\\n"
+
+                // Higher part of a2 * {b0,b1,b2,b3}
+                "umull2 v8.8h, v2.16b, v4.16b\\n"
+                "umull2 v9.8h, v2.16b, v5.16b\\n"
+                "umull2 v10.8h, v2.16b, v6.16b\\n"
+                "umull2 v11.8h, v2.16b, v7.16b\\n"
+
+                // Higher part of a3 * {b0,b1,b2,b3}
+                "umull2 v12.8h, v3.16b, v4.16b\\n"
+                "umull2 v13.8h, v3.16b, v5.16b\\n"
+                "umull2 v14.8h, v3.16b, v6.16b\\n"
+                "umull2 v15.8h, v3.16b, v7.16b\\n"
+
+                // Accumulate again
+                "uadalp v24.4s, v8.8h\\n"
+                "uadalp v25.4s, v9.8h\\n"
+                "uadalp v26.4s, v10.8h\\n"
+                "uadalp v27.4s, v11.8h\\n"
+                "uadalp v28.4s, v12.8h\\n"
+                "uadalp v29.4s, v13.8h\\n"
+                "uadalp v30.4s, v14.8h\\n"
+                "uadalp v31.4s, v15.8h\\n"
+    """
+    blockA = min(64, M * 16)
+    blockB = min(64, N * 16)
+
+    cc_code += """
+                // Increment pointers and decrement k
+                "add %[a_ptr], %[a_ptr], #{0}\\n"
+                "add %[b_ptr], %[b_ptr], #{1}\\n"
+                "subs %w[k], %w[k], #1\\n"
+    """.format(blockA, blockB)
+
+    stepC = min(4, N)
+
+    cc_code += """
+                "cbnz %w[k], 1b\\n"
+
+                // Final additions
+
+                // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+                // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+                // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+                // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h)
+                "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p)
+                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+                // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+                // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+                // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+                // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h)
+                "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p)
+                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+                // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+                // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+                // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+                // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b, c+d, e+f, g+h)
+                "addp v25.4s, v26.4s, v27.4s\\n"  // v25 = (i+j, k+l, m+n, o+p)
+                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+                // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+                // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+                // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+                // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h)
+                "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p)
+                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+                "str q16, [%[c_ptr]]\\n"
+            """
+
+    if M > 1:
+        cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4)
+
+    if M > 2:
+        cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8)
+
+    if M > 3:
+        cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12)
+
+    cc_code += """
+             : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k)
+             :
+             : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+                    "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+                    "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+                    "v27", "v28", "v29", "v30", "v31"
+             );
+        return 0;
+        }
+    """
+
+    if data_type == 'int8':
+        cc_code = cc_code.replace('unsigned char', 'char')
+        cc_code = cc_code.replace('umull', 'smull')
+        cc_code = cc_code.replace('uadalp', 'sadalp')
+
+    temp = util.tempdir()
+    ll_path = temp.relpath("temp.ll")
+    # Create LLVM ir from c source code
+    ll_code = clang.create_llvm(cc_code,
+                                options=["--target=aarch64-linux-gnu -mattr=+neon"],
+                                output=ll_path)
+    return ll_code
+
+
+def gemv_quantized(M, N, K, in_type, out_type):
+    """
+    Use integer ARM v8 instructions in order to produce a block c of 4x4 elements
+    given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final
+    result is c = a*b (where '*' indicates the matrix product)
+
+    Every row of the matrix c is obtained (for uint8) by a sequence of
+
+          umull -> uadalp -> umull2 -> uadalp
+
+    The block size is constrained by the number of registers available in arvm8. This
+    function returns a TensorIntrin that can be used to tensorize
+    a schedule.
+
+    Parameters
+    ----------
+    M: int
+        rows of the matrix A
+    N: int
+        columns of the matrix B
+    K: int
+        columns of matrix A
+    in_type: str, {'uint8', 'int8'}
+    out_type: str, {'uint32', 'int32'}
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name='A')
+    B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name='B')
+
+    idxm = tvm.tir.indexmod
+
+    k = te.reduce_axis((0, K), "k")
+
+    C = te.compute((te.var("m"), te.var("n")),
+                   lambda x, y: te.sum(A[k // 16, x, idxm(k, 16)].astype(out_type) *
+                                       B[k // 16, y, idxm(k, 16)].astype(out_type),
+                                       axis=k), name="C")
+
+    a_buffer = tvm.tir.decl_buffer(A.shape, dtype=in_type, name="a_buffer",
+                                   offset_factor=1, strides=[te.var('sa_1'), te.var('sa_2'), 1])
+
+    b_buffer = tvm.tir.decl_buffer(B.shape, dtype=in_type, name="b_buffer",
+                                   offset_factor=1, strides=[te.var('sb_1'), te.var('sb_2'), 1])
+
+    c_buffer = tvm.tir.decl_buffer(C.shape, dtype=out_type, name="c_buffer",
+                                   offset_factor=1, strides=[te.var('sc'), 1])
+
+    def _intrin_func(ins, outs):
+
+        def _instr():
+            ib = tvm.tir.ir_builder.create()
+            aa, bb = ins
+            cc = outs[0]
+            stepA = min(4, M)
+            stepB = min(4, N)
+
+            if in_type == 'int8':
+                ib.emit(tvm.tir.call_extern("int32",
+                                            "gemv_int8_int8_int32_{0}_{1}".format(stepA, stepB),
+                                            outs[0].access_ptr("w"),
+                                            a_buffer.access_ptr("r"),
+                                            b_buffer.access_ptr("r"),
+                                            K))
+            else:
+                ib.emit(tvm.tir.call_extern("int32",
+                                            "gemv_uint8_uint8_int32_{0}_{1}".format(stepA, stepB),
+                                            c_buffer.access_ptr("w"),
+                                            a_buffer.access_ptr("r"),
+                                            b_buffer.access_ptr("r"),
+                                            K,
+                                            C.shape[0],  # m, very useful for debug
+                                            C.shape[1]))  # n, very useful for debug
+            return ib.get()
+
+        # body, reset, update
+        return _instr()
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(C.op, _intrin_func,
+                                 binds={A:a_buffer, B:b_buffer, C:c_buffer},
+                                 default_buffer_params=buffer_params)
+
 
 def dot_int8_int8_int32(int32_lanes, dtype='uint'):
     """
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 767087b..7645588 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -187,6 +187,25 @@ def schedule_conv2d_winograd_weight_transform(outs):
     return s
 
 
+def schedule_conv2d_gemm_weight_transform(outs):
+    """Schedule for weight transformation of gemm
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of this operator
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    # Typically this is computed in PreCompute pass
+    s = te.create_schedule([x.op for x in outs])
+    return s
+
+
 def schedule_conv3d_winograd_weight_transform(outs):
     """Schedule for weight transformation of 3D winograd
 
diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index 4c7941b..5928889 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -590,6 +590,55 @@ def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layou
                       name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
 
 
+def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
+    """Weight transformation for winograd
+
+    Parameters
+    ----------
+    kernel: Tensor
+        The raw kernel tensor with layout "NHWC".
+    tile_rows: int
+        Tile rows of the weight transformation for ConvGemm.
+    tile_cols: int
+        Tile columns of the weight transformation for ConvGemm.
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        2-D with shape [CI*KH*KW,CO]
+    """
+    KH, KW, IC, OC = get_const_tuple(kernel.shape)
+    K = KH * KW * IC
+    N = OC
+
+    kernel_flat = te.compute((K, N), lambda x, y:
+                             kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y],
+                             'weight_flatten')
+
+    pad_K = 0
+    pad_N = 0
+
+    if N % tile_rows != 0:
+        pad_N = tile_rows - (N % tile_rows)
+
+    if K % tile_cols != 0:
+        pad_k = tile_cols - (K % tile_cols)
+
+    N_padded = N + pad_N
+    K_padded = K + pad_K
+
+    if pad_K != 0 or pad_N != 0:
+        kernel_flat = pad(kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N),
+                          name='weight_padding')
+
+    return te.compute((N_padded // tile_rows,
+                       K_padded // tile_cols,
+                       tile_rows,
+                       tile_cols), lambda x, y, z, w:
+                      kernel_flat[w + tile_cols * y, z + tile_rows * x],
+                      name='weight_block_reshape')
+
+
 def conv2d_winograd_weight_transform(kernel, tile_size):
     """Weight transformation for winograd
 


Mime
View raw message