From commits-return-8849-archive-asf-public=cust-asf.ponee.io@tvm.apache.org Wed Mar 18 15:57:09 2020 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [207.244.88.153]) by mx-eu-01.ponee.io (Postfix) with SMTP id 8A63F18025F for ; Wed, 18 Mar 2020 16:57:08 +0100 (CET) Received: (qmail 42298 invoked by uid 500); 18 Mar 2020 15:57:07 -0000 Mailing-List: contact commits-help@tvm.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@tvm.apache.org Delivered-To: mailing list commits@tvm.apache.org Received: (qmail 42289 invoked by uid 99); 18 Mar 2020 15:57:07 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 18 Mar 2020 15:57:07 +0000 Received: by gitbox.apache.org (ASF Mail Server at gitbox.apache.org, from userid 33) id 722728DACA; Wed, 18 Mar 2020 15:57:07 +0000 (UTC) Date: Wed, 18 Mar 2020 15:57:07 +0000 To: "commits@tvm.apache.org" Subject: [incubator-tvm] branch master updated: create function.py (#5087) MIME-Version: 1.0 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit Message-ID: <158454702702.2684.6351300044486831136@gitbox.apache.org> From: tqchen@apache.org X-Git-Host: gitbox.apache.org X-Git-Repo: incubator-tvm X-Git-Refname: refs/heads/master X-Git-Reftype: branch X-Git-Oldrev: 06bbc7c9e4941713b4012d344b5424d4a32a9228 X-Git-Newrev: 7ca3212f06a56eb95420060aa822a56860d114fd X-Git-Rev: 7ca3212f06a56eb95420060aa822a56860d114fd X-Git-NotificationType: ref_changed_plus_diff X-Git-Multimail-Version: 1.5.dev Auto-Submitted: auto-generated This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git The following commit(s) were added to refs/heads/master by this push: new 7ca3212 create function.py (#5087) 7ca3212 is described below commit 7ca3212f06a56eb95420060aa822a56860d114fd Author: Zhi <5145158+zhiics@users.noreply.github.com> AuthorDate: Wed Mar 18 08:56:55 2020 -0700 create function.py (#5087) --- docs/api/python/relay/expr.rst | 3 - docs/langref/relay_expr.rst | 2 +- python/tvm/autotvm/graph_tuner/base_graph_tuner.py | 4 +- .../autotvm/graph_tuner/utils/traverse_graph.py | 3 +- python/tvm/autotvm/task/relay_integration.py | 6 +- python/tvm/relay/__init__.py | 3 +- python/tvm/relay/_parser.py | 7 +- python/tvm/relay/analysis/analysis.py | 2 +- python/tvm/relay/backend/compile_engine.py | 4 +- python/tvm/relay/backend/interpreter.py | 3 +- python/tvm/relay/build_module.py | 13 ++-- python/tvm/relay/expr.py | 66 +---------------- python/tvm/relay/expr_functor.py | 3 +- python/tvm/relay/frontend/caffe2.py | 7 +- python/tvm/relay/frontend/common.py | 5 +- python/tvm/relay/frontend/coreml.py | 3 +- python/tvm/relay/frontend/darknet.py | 3 +- python/tvm/relay/frontend/keras.py | 3 +- python/tvm/relay/frontend/mxnet.py | 5 +- python/tvm/relay/frontend/onnx.py | 7 +- python/tvm/relay/frontend/tensorflow.py | 3 +- python/tvm/relay/frontend/tflite.py | 3 +- python/tvm/relay/function.py | 86 ++++++++++++++++++++++ python/tvm/relay/loops.py | 3 +- python/tvm/relay/prelude.py | 3 +- python/tvm/relay/testing/nat.py | 3 +- python/tvm/relay/testing/py_converter.py | 3 +- src/relay/ir/function.cc | 9 +-- 28 files changed, 152 insertions(+), 113 deletions(-) diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst index 57a4a25..cfb6df0 100644 --- a/docs/api/python/relay/expr.rst +++ b/docs/api/python/relay/expr.rst @@ -35,9 +35,6 @@ tvm.relay.expr .. autoclass:: tvm.relay.expr.Tuple :members: -.. autoclass:: tvm.relay.expr.Function - :members: - .. autoclass:: tvm.relay.expr.Call :members: diff --git a/docs/langref/relay_expr.rst b/docs/langref/relay_expr.rst index 66bfe43..3b93360 100644 --- a/docs/langref/relay_expr.rst +++ b/docs/langref/relay_expr.rst @@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection), which are values like tensors and tuples. -See :py:class:`~tvm.relay.expr.Function` for the definition and documentation of function nodes. +See :py:class:`~tvm.relay.function.Function` for the definition and documentation of function nodes. Syntax ~~~~~~ diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index f1a0756..e7b4694 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -69,7 +69,7 @@ class BaseGraphTuner(object): target_op in the input graph and layout transformation benchmark need to be executed before initialization. - graph : tvm.relay.Expr.Function + graph : tvm.relay.function.Function Input graph input_shapes : dict of str to tuple. @@ -143,7 +143,7 @@ class BaseGraphTuner(object): if isinstance(graph, tvm.IRModule): graph = graph["main"] - if isinstance(graph, relay.expr.Function): + if isinstance(graph, relay.function.Function): node_dict = {} graph = bind_inputs(graph, input_shapes, dtype) expr2graph(graph, self._target_ops, node_dict, self._node_list) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index f1dd404..8470fb6 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,7 +21,8 @@ import threading import tvm from tvm import relay, autotvm from tvm.relay import transform -from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple +from tvm.relay.function import Function from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index cd8d32f..a7cbef7 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): Parameters ---------- - mod: tvm.IRModule or relay.expr.Function + mod: tvm.IRModule or relay.function.Function The module or function to tune params: dict of str to numpy array The associated parameters of the program @@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No Parameters ---------- - mods: List[tvm.IRModule] or List[relay.expr.Function] + mods: List[tvm.IRModule] or List[relay.function.Function] The list of modules or functions to tune params: List of dict of str to numpy array The associated parameters of the programs @@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No logger.disabled = True for mod, param in zip(mods, params): - if isinstance(mod, relay.expr.Function): + if isinstance(mod, relay.function.Function): mod = tvm.IRModule.from_expr(mod) assert isinstance(mod, tvm.IRModule), \ "only support relay Module or Function to be tuned" diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index b1aac3e..95545c8 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -22,6 +22,7 @@ from sys import setrecursionlimit from . import base from . import ty from . import expr +from . import function from . import type_functor from . import expr_functor from . import adt @@ -87,7 +88,7 @@ Constant = expr.Constant Tuple = expr.Tuple Var = expr.Var GlobalVar = expr.GlobalVar -Function = expr.Function +Function = function.Function Call = expr.Call Let = expr.Let If = expr.If diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 49bdbb3..4a73e57 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -43,6 +43,7 @@ from tvm.ir import IRModule from .base import Span, SourceName from . import adt from . import expr +from . import function from . import ty from . import op @@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor): def mk_func( self, ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \ - -> expr.Function: + -> function.Function: """Construct a function from either a Func or Defn.""" # Enter var scope early to put params in scope. self.enter_var_scope() @@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor): self.exit_var_scope() attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None - return expr.Function(var_list, body, ret_type, type_params, attrs) + return function.Function(var_list, body, ret_type, type_params, attrs) @spanify - def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function: + def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function: return self.mk_func(ctx) # TODO: how to set spans for definitions? diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index beb3c65..722f3b0 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -421,7 +421,7 @@ def extract_fused_functions(mod): Returns ------- - ret : Dict[int, tvm.relay.ir.expr.Function] + ret : Dict[int, tvm.relay.function.Function] A module containing only fused primitive functions """ ret_mod = _ffi_api.ExtractFusedFunctions()(mod) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 03d91d5..3e35bd2 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -25,7 +25,7 @@ from tvm import te from tvm.runtime import Object from ... import target as _target from ... import autotvm -from .. import expr as _expr +from .. import function as _function from .. import op as _op from .. import ty as _ty from . import _backend @@ -65,7 +65,7 @@ class CCacheValue(Object): def _get_cache_key(source_func, target): - if isinstance(source_func, _expr.Function): + if isinstance(source_func, _function.Function): if isinstance(target, str): target = _target.create(target) if not target: diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ab39f7c..9c4be29 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -27,7 +27,8 @@ from tvm.ir import IRModule from . import _backend from .. import _make, analysis, transform from ... import nd -from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const +from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const +from ..function import Function from ..scope_builder import ScopeBuilder diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d1add27..30c5971 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt from . import _build_module from . import ty as _ty from . import expr as _expr +from . import function as _function from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - if not isinstance(mod, (IRModule, _expr.Function)): + if not isinstance(mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") - if isinstance(mod, _expr.Function): + if isinstance(mod, _function.Function): if params: mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " - "instead of deprecated parameter mod (tvm.relay.expr.Function)", + "instead of deprecated parameter mod (tvm.relay.function.Function)", DeprecationWarning) target = _update_target(target) @@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None): params : dict The parameters of the final graph. """ - if not isinstance(mod, (IRModule, _expr.Function)): + if not isinstance(mod, (IRModule, _function.Function)): raise ValueError("Type of input parameter mod must be tvm.IRModule") - if isinstance(mod, _expr.Function): + if isinstance(mod, _function.Function): if params: mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " - "instead of deprecated parameter func (tvm.relay.expr.Function)", + "instead of deprecated parameter func (tvm.relay.function.Function)", DeprecationWarning) target = _update_target(target) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 380cdf7..ff13683 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -22,8 +22,8 @@ from numbers import Number as _Number import numpy as _np import tvm._ffi from tvm._ffi import base as _base -from tvm.runtime import NDArray, convert, ndarray as _nd -from tvm.ir import RelayExpr, GlobalVar, BaseFunc +from tvm.runtime import NDArray, ndarray as _nd +from tvm.ir import RelayExpr, GlobalVar from .base import RelayNode from . import _ffi_api @@ -225,68 +225,6 @@ class Var(ExprWithOp): return name -@tvm._ffi.register_object("relay.Function") -class Function(BaseFunc): - """A function declaration expression. - - Parameters - ---------- - params: List[tvm.relay.Var] - List of input parameters to the function. - - body: tvm.relay.Expr - The body of the function. - - ret_type: Optional[tvm.relay.Type] - The return type annotation of the function. - - type_params: Optional[List[tvm.relay.TypeParam]] - The additional type parameters, this is only - used in advanced usecase of template functions. - """ - def __init__(self, - params, - body, - ret_type=None, - type_params=None, - attrs=None): - if type_params is None: - type_params = convert([]) - - self.__init_handle_by_constructor__( - _ffi_api.Function, params, body, ret_type, type_params, attrs) - - def __call__(self, *args): - """Invoke the global function. - - Parameters - ---------- - args: List[relay.Expr] - Arguments. - """ - return Call(self, args, None, None) - - def with_attr(self, attr_key, attr_value): - """Create a new copy of the function and update the attribute - - Parameters - ---------- - attr_key : str - The attribute key to use. - - attr_value : Object - The new attribute value. - - Returns - ------- - func : Function - A new copy of the function - """ - return _ffi_api.FunctionWithAttr( - self, attr_key, convert(attr_value)) - - - @tvm._ffi.register_object("relay.Call") class Call(ExprWithOp): """Function call node in Relay. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 8d69239..874a3a7 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -17,7 +17,8 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression functor of Relay.""" -from .expr import Function, Call, Let, Var, GlobalVar +from .function import Function +from .expr import Call, Let, Var, GlobalVar from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index da0cc64..f4fcd92 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -21,6 +21,7 @@ from tvm.ir import IRModule from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ... import nd as _nd from .common import AttrCvt, Renamer @@ -451,7 +452,7 @@ class Caffe2NetDef(object): else: outputs = out[0] - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) self._mod["main"] = func return self._mod, self._params @@ -517,7 +518,7 @@ class Caffe2NetDef(object): ---------- op_type : str Operator name, such as Convolution, FullyConnected - inputs : list of tvm.relay.expr.Function + inputs : list of tvm.relay.function.Function List of input inputs. args : dict Dict of operator attributes @@ -530,7 +531,7 @@ class Caffe2NetDef(object): Returns ------- - func : tvm.relay.expr.Function + func : tvm.relay.function.Function Converted relay function """ identity_list = identity_list if identity_list else _identity_list diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index d427fe9..6185121 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from topi.util import get_const_tuple from .. import expr as _expr +from .. import function as _function from .. import transform as _transform from .. import op as _op from .. import analysis @@ -459,7 +460,7 @@ def infer_type(node, mod=None): new_mod.update(mod) new_mod = _transform.InferType()(new_mod) entry = new_mod["main"] - return entry if isinstance(node, _expr.Function) else entry.body + return entry if isinstance(node, _function.Function) else entry.body def infer_shape(inputs, mod=None): """A method to get the output type of an intermediate node in the graph.""" @@ -491,7 +492,7 @@ def infer_value(input_val, params): # Check that all free variables have associated parameters. assert all(var.name_hint in params.keys() for var in analysis.free_vars( input_val)), "All inputs to infer must be available in params." - func = _expr.Function(analysis.free_vars(input_val), input_val) + func = _function.Function(analysis.free_vars(input_val), input_val) with tvm.relay.build_config(opt_level=0): graph, lib, params = tvm.relay.build(func, target="llvm", params=params) ctx = tvm.cpu(0) diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 0e5b64c..6658803 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ... import nd as _nd from ..._ffi import base as _base @@ -503,6 +504,6 @@ def from_coreml(model, shape=None): for o in spec.description.output] # for now return first output outexpr = outexpr[0] - func = _expr.Function(analysis.free_vars(outexpr), outexpr) + func = _function.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return IRModule.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 0dae645..936d7c0 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -26,6 +26,7 @@ from tvm.ir import IRModule from .. import analysis from .. import expr as _expr +from .. import function as _function from .common import get_relay_op, new_var __all__ = ['from_darknet'] @@ -821,7 +822,7 @@ class GraphProto(object): outputs = _as_list(sym) + self._outs outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - sym = _expr.Function(analysis.free_vars(outputs), outputs) + sym = _function.Function(analysis.free_vars(outputs), outputs) return IRModule.from_expr(sym), self._tvmparams def from_darknet(net, diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index adb28c4..090bd4c 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -23,6 +23,7 @@ from tvm.ir import IRModule from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ... import nd as _nd from .common import ExprTable, new_var @@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'): outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \ for oc in model._output_coordinates] outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) - func = _expr.Function(analysis.free_vars(outexpr), outexpr) + func = _function.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return IRModule.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index ba93bb2..17be368 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -25,6 +25,7 @@ from tvm import relay from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from .. import scope_builder as _scope_builder from ... import nd as _nd @@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs): else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args] else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info) sb.ret(_expr.Call(else_func, else_args)) - func = _expr.Function(input_args, sb.get()) + func = _function.Function(input_args, sb.get()) ret = _expr.Call(func, inputs) if num_outputs > 1: ret = _expr.TupleWrapper(ret, num_outputs) @@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) return func diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7f417d3..e1b0a7f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from ... import nd as _nd from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels @@ -1708,7 +1709,7 @@ class GraphProto(object): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) return IRModule.from_expr(func), self._params def _parse_value_proto(self, value_proto): @@ -1774,7 +1775,7 @@ class GraphProto(object): ---------- op_name : str Operator name, such as Convolution, FullyConnected - inputs : list of tvm.relay.expr.Function + inputs : list of tvm.relay.function.Function List of inputs. attrs : dict Dict of operator attributes @@ -1783,7 +1784,7 @@ class GraphProto(object): Returns ------- - sym : tvm.relay.expr.Function + sym : tvm.relay.function.Function Converted relay function """ convert_map = _get_convert_map(opset) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 3dca365..e0da863 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from ..expr_functor import ExprMutator from .common import AttrCvt, get_relay_op @@ -2461,7 +2462,7 @@ class GraphProto(object): out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) - func = _expr.Function(analysis.free_vars(out), out) + func = _function.Function(analysis.free_vars(out), out) self._mod["main"] = func return self._mod, self._params diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 95f7579..aa51570 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from tvm import relay from .. import analysis from .. import expr as _expr +from .. import function as _function from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd @@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict): params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()} outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(analysis.free_vars(outputs), outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) mod = IRModule.from_expr(func) return mod, params diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py new file mode 100644 index 0000000..786a7f4 --- /dev/null +++ b/python/tvm/relay/function.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, invalid-name, unused-import +"""The expression nodes of Relay.""" +from __future__ import absolute_import + +import tvm._ffi +from tvm.runtime import convert +from tvm.ir import BaseFunc + +from .expr import Call +from . import _ffi_api + +@tvm._ffi.register_object("relay.Function") +class Function(BaseFunc): + """A function declaration expression. + + Parameters + ---------- + params: List[tvm.relay.Var] + List of input parameters to the function. + + body: tvm.relay.Expr + The body of the function. + + ret_type: Optional[tvm.relay.Type] + The return type annotation of the function. + + type_params: Optional[List[tvm.relay.TypeParam]] + The additional type parameters, this is only + used in advanced usecase of template functions. + """ + def __init__(self, + params, + body, + ret_type=None, + type_params=None, + attrs=None): + if type_params is None: + type_params = convert([]) + + self.__init_handle_by_constructor__( + _ffi_api.Function, params, body, ret_type, type_params, attrs) + + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relay.Expr] + Arguments. + """ + return Call(self, args, None, None) + + def with_attr(self, attr_key, attr_value): + """Create a new copy of the function and update the attribute + + Parameters + ---------- + attr_key : str + The attribute key to use. + + attr_value : Object + The new attribute value. + + Returns + ------- + func : Function + A new copy of the function + """ + return _ffi_api.FunctionWithAttr( + self, attr_key, convert(attr_value)) diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py index 8e066ab..9af6811 100644 --- a/python/tvm/relay/loops.py +++ b/python/tvm/relay/loops.py @@ -20,6 +20,7 @@ Utilities for building Relay loops. """ from .scope_builder import ScopeBuilder from . import expr as _expr +from . import function as _function def while_loop(cond, loop_vars, loop_bodies): """ @@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies): with sb.else_scope(): sb.ret(_expr.Tuple(fresh_vars)) - func = _expr.Function(fresh_vars, sb.get()) + func = _function.Function(fresh_vars, sb.get()) let = _expr.Let(loop, func, loop) return let diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 5288a2e..0e64a2f 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -19,7 +19,8 @@ from tvm.ir import IRModule from .ty import GlobalTypeVar, TensorType, Any, scalar_type -from .expr import Var, Function, GlobalVar, If, const +from .expr import Var, GlobalVar, If, const +from .function import Function from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index eb71120..4906eef 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -21,7 +21,8 @@ test cases for recursion and pattern matching.""" from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.backend.interpreter import ConstructorValue -from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.expr import Var, GlobalVar +from tvm.relay.function import Function from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType def define_nat_adt(prelude): diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index eacfe37..e850000 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -23,7 +23,8 @@ import tvm from tvm import relay from tvm.relay.adt import Pattern from tvm.relay.backend import compile_engine -from tvm.relay.expr import Expr, Function, GlobalVar, Var +from tvm.relay.expr import Expr, GlobalVar, Var +from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor OUTPUT_VAR_NAME = '_py_out' diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index b251645..48cb4d8 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -27,10 +27,10 @@ namespace tvm { namespace relay { Function::Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array type_params, - DictAttrs attrs) { + Expr body, + Type ret_type, + tvm::Array type_params, + DictAttrs attrs) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") return Function(params, body, ret_type, ty_params, attrs); }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get());