tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] comaniac commented on a change in pull request #5752: Make batch matrix multiplication on GPU tunable
Date Tue, 09 Jun 2020 17:53:50 GMT

comaniac commented on a change in pull request #5752:
URL: https://github.com/apache/incubator-tvm/pull/5752#discussion_r437597643



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -463,9 +463,9 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
     """batch_matmul cuda strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.nn.batch_matmul),
+        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
         wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
-        name="batch_matmul.cuda",
+        name="batch_matmul.gpu",

Review comment:
       Why change to `gpu`?

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:
+            # llvm-based backends cannot do non-explicit unrolling
+            cfg.define_knob("unroll_explicit", [1])
+        else:
+            cfg.define_knob("unroll_explicit", [0, 1])
+
+        if cfg.is_fallback:

Review comment:
       Missing `unroll_explicit` in fallback schedule. Based on your comment, this might result
in a problem at LLVM-based backends.

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -14,13 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,too-many-locals,unused-variable
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
 """cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
 from tvm import te
 from tvm.contrib import cublas
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-def schedule_batch_matmul(outs):
+@autotvm.register_topi_compute("batch_matmul.gpu")
+def batch_matmul(cfg, x, y):
+    """Compute conv2d with NCHW layout"""
+    return nn.batch_matmul(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul.gpu")

Review comment:
       ditto.

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -51,55 +62,73 @@ def _schedule(op):
             C = s.outputs[0].output(0)
 
         b, y, x = s[C].op.axis
-        y_bn = get_max_power2_factor(M, 64)
-        x_bn = get_max_power2_factor(N, 64)
-        by, y = s[C].split(y, y_bn)
-        bx, x = s[C].split(x, x_bn)
-        y_nthreads = min(y_bn, 8)
-        x_nthreads = min(x_bn, 8)
-        ty, yi = s[C].split(y, nparts=y_nthreads)
-        tx, xi = s[C].split(x, nparts=x_nthreads)
-        thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
-        thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
+
+        cfg.define_split("tile_y", y, num_outputs=3)
+        cfg.define_split("tile_x", x, num_outputs=3)
+        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
+        target = tvm.target.Target.current()
+        if target.target_name in ['nvptx', 'rocm']:

Review comment:
       It reminds me that you may also need to add batch_matmul op strategy to ROCM; otherwise
ROCM will never use this schedule.

##########
File path: topi/python/topi/cuda/batch_matmul.py
##########
@@ -14,13 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,too-many-locals,unused-variable
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
 """cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
 from tvm import te
 from tvm.contrib import cublas
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+from .. import nn
 from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
 
-def schedule_batch_matmul(outs):
+@autotvm.register_topi_compute("batch_matmul.gpu")

Review comment:
       We use `XXX.cuda` in all other places so I'd suggest changing to `batch_matmul.cuda`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message