tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From junrus...@apache.org
Subject [tvm] 01/01: [TensorIR][Minor] Allow Tuple/Array in TE lowering
Date Fri, 03 Sep 2021 04:47:50 GMT
This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch junrushao1994-patch-1
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b992246bdc0e53b642b98e9a17190eb73d0a624f
Author: Junru Shao <junrushao1994@gmail.com>
AuthorDate: Thu Sep 2 21:47:31 2021 -0700

    [TensorIR][Minor] Allow Tuple/Array in TE lowering
---
 python/tvm/te/operation.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 6af3429..a0b9b43 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -22,13 +22,13 @@ from typing import List
 import tvm._ffi
 import tvm.tir
 import tvm.tir._ffi_api
-
 from tvm._ffi.base import string_types
+from tvm.ir import Array
 from tvm.runtime import convert
 
+from . import _ffi_api
 from . import tag as _tag
 from . import tensor as _tensor
-from . import _ffi_api
 
 
 def placeholder(shape, dtype=None, name="placeholder"):
@@ -431,6 +431,7 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
 
 def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
     """Create a TensorIR PrimFunc from tensor expression
+
     Parameters
     ----------
     ops : List[Tensor]
@@ -473,6 +474,6 @@ def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
     func : tir.PrimFunc
         The created function.
     """
-    if not isinstance(ops, list):
+    if not isinstance(ops, (list, tuple, Array)):
         ops = [ops]
     return _ffi_api.CreatePrimFunc(ops)

Mime
View raw message