tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: create function.py (#5087)
Date Wed, 18 Mar 2020 15:57:07 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ca3212  create function.py (#5087)
7ca3212 is described below

commit 7ca3212f06a56eb95420060aa822a56860d114fd
Author: Zhi <5145158+zhiics@users.noreply.github.com>
AuthorDate: Wed Mar 18 08:56:55 2020 -0700

    create function.py (#5087)
---
 docs/api/python/relay/expr.rst                     |  3 -
 docs/langref/relay_expr.rst                        |  2 +-
 python/tvm/autotvm/graph_tuner/base_graph_tuner.py |  4 +-
 .../autotvm/graph_tuner/utils/traverse_graph.py    |  3 +-
 python/tvm/autotvm/task/relay_integration.py       |  6 +-
 python/tvm/relay/__init__.py                       |  3 +-
 python/tvm/relay/_parser.py                        |  7 +-
 python/tvm/relay/analysis/analysis.py              |  2 +-
 python/tvm/relay/backend/compile_engine.py         |  4 +-
 python/tvm/relay/backend/interpreter.py            |  3 +-
 python/tvm/relay/build_module.py                   | 13 ++--
 python/tvm/relay/expr.py                           | 66 +----------------
 python/tvm/relay/expr_functor.py                   |  3 +-
 python/tvm/relay/frontend/caffe2.py                |  7 +-
 python/tvm/relay/frontend/common.py                |  5 +-
 python/tvm/relay/frontend/coreml.py                |  3 +-
 python/tvm/relay/frontend/darknet.py               |  3 +-
 python/tvm/relay/frontend/keras.py                 |  3 +-
 python/tvm/relay/frontend/mxnet.py                 |  5 +-
 python/tvm/relay/frontend/onnx.py                  |  7 +-
 python/tvm/relay/frontend/tensorflow.py            |  3 +-
 python/tvm/relay/frontend/tflite.py                |  3 +-
 python/tvm/relay/function.py                       | 86 ++++++++++++++++++++++
 python/tvm/relay/loops.py                          |  3 +-
 python/tvm/relay/prelude.py                        |  3 +-
 python/tvm/relay/testing/nat.py                    |  3 +-
 python/tvm/relay/testing/py_converter.py           |  3 +-
 src/relay/ir/function.cc                           |  9 +--
 28 files changed, 152 insertions(+), 113 deletions(-)

diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst
index 57a4a25..cfb6df0 100644
--- a/docs/api/python/relay/expr.rst
+++ b/docs/api/python/relay/expr.rst
@@ -35,9 +35,6 @@ tvm.relay.expr
 .. autoclass:: tvm.relay.expr.Tuple
     :members:
 
-.. autoclass:: tvm.relay.expr.Function
-    :members:
-
 .. autoclass:: tvm.relay.expr.Call
     :members:
 
diff --git a/docs/langref/relay_expr.rst b/docs/langref/relay_expr.rst
index 66bfe43..3b93360 100644
--- a/docs/langref/relay_expr.rst
+++ b/docs/langref/relay_expr.rst
@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a
function c
 function or returned by a function, as function expressions evaluate to closures (see the
`Closures`_ subsection),
 which are values like tensors and tuples.
 
-See :py:class:`~tvm.relay.expr.Function` for the definition and documentation of function
nodes.
+See :py:class:`~tvm.relay.function.Function` for the definition and documentation of function
nodes.
 
 Syntax
 ~~~~~~
diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
index f1a0756..e7b4694 100644
--- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
         target_op in the input graph and layout transformation benchmark need to be
         executed before initialization.
 
-        graph : tvm.relay.Expr.Function
+        graph : tvm.relay.function.Function
             Input graph
 
         input_shapes : dict of str to tuple.
@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
         if isinstance(graph, tvm.IRModule):
             graph = graph["main"]
 
-        if isinstance(graph, relay.expr.Function):
+        if isinstance(graph, relay.function.Function):
             node_dict = {}
             graph = bind_inputs(graph, input_shapes, dtype)
             expr2graph(graph, self._target_ops, node_dict, self._node_list)
diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
index f1dd404..8470fb6 100644
--- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
+++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
@@ -21,7 +21,8 @@ import threading
 import tvm
 from tvm import relay, autotvm
 from tvm.relay import transform
-from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
+from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple
+from tvm.relay.function import Function
 from tvm.relay.ty import TupleType, TensorType
 from tvm.autotvm.task import TaskExtractEnv
 
diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py
index cd8d32f..a7cbef7 100644
--- a/python/tvm/autotvm/task/relay_integration.py
+++ b/python/tvm/autotvm/task/relay_integration.py
@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
 
     Parameters
     ----------
-    mod: tvm.IRModule or relay.expr.Function
+    mod: tvm.IRModule or relay.function.Function
         The module or function to tune
     params: dict of str to numpy array
         The associated parameters of the program
@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None,
ops=No
 
     Parameters
     ----------
-    mods: List[tvm.IRModule] or List[relay.expr.Function]
+    mods: List[tvm.IRModule] or List[relay.function.Function]
         The list of modules or functions to tune
     params: List of dict of str to numpy array
         The associated parameters of the programs
@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None,
ops=No
         logger.disabled = True
 
         for mod, param in zip(mods, params):
-            if isinstance(mod, relay.expr.Function):
+            if isinstance(mod, relay.function.Function):
                 mod = tvm.IRModule.from_expr(mod)
             assert isinstance(mod, tvm.IRModule), \
                 "only support relay Module or Function to be tuned"
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index b1aac3e..95545c8 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -22,6 +22,7 @@ from sys import setrecursionlimit
 from . import base
 from . import ty
 from . import expr
+from . import function
 from . import type_functor
 from . import expr_functor
 from . import adt
@@ -87,7 +88,7 @@ Constant = expr.Constant
 Tuple = expr.Tuple
 Var = expr.Var
 GlobalVar = expr.GlobalVar
-Function = expr.Function
+Function = function.Function
 Call = expr.Call
 Let = expr.Let
 If = expr.If
diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py
index 49bdbb3..4a73e57 100644
--- a/python/tvm/relay/_parser.py
+++ b/python/tvm/relay/_parser.py
@@ -43,6 +43,7 @@ from tvm.ir import IRModule
 from .base import Span, SourceName
 from . import adt
 from . import expr
+from . import function
 from . import ty
 from . import op
 
@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
     def mk_func(
             self,
             ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
-            -> expr.Function:
+            -> function.Function:
         """Construct a function from either a Func or Defn."""
         # Enter var scope early to put params in scope.
         self.enter_var_scope()
@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
         self.exit_var_scope()
 
         attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else
None
-        return expr.Function(var_list, body, ret_type, type_params, attrs)
+        return function.Function(var_list, body, ret_type, type_params, attrs)
 
     @spanify
-    def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
+    def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function:
         return self.mk_func(ctx)
 
     # TODO: how to set spans for definitions?
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index beb3c65..722f3b0 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
 
     Returns
     -------
-    ret : Dict[int, tvm.relay.ir.expr.Function]
+    ret : Dict[int, tvm.relay.function.Function]
         A module containing only fused primitive functions
     """
     ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py
index 03d91d5..3e35bd2 100644
--- a/python/tvm/relay/backend/compile_engine.py
+++ b/python/tvm/relay/backend/compile_engine.py
@@ -25,7 +25,7 @@ from tvm import te
 from tvm.runtime import Object
 from ... import target as _target
 from ... import autotvm
-from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from .. import ty as _ty
 from . import _backend
@@ -65,7 +65,7 @@ class CCacheValue(Object):
 
 
 def _get_cache_key(source_func, target):
-    if isinstance(source_func, _expr.Function):
+    if isinstance(source_func, _function.Function):
         if isinstance(target, str):
             target = _target.create(target)
             if not target:
diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py
index ab39f7c..9c4be29 100644
--- a/python/tvm/relay/backend/interpreter.py
+++ b/python/tvm/relay/backend/interpreter.py
@@ -27,7 +27,8 @@ from tvm.ir import IRModule
 from . import _backend
 from .. import _make, analysis, transform
 from ... import nd
-from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
+from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
+from ..function import Function
 from ..scope_builder import ScopeBuilder
 
 
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index d1add27..30c5971 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
 from . import _build_module
 from . import ty as _ty
 from . import expr as _expr
+from . import function as _function
 from .backend import interpreter as _interpreter
 from .backend.vm import VMExecutor
 
@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
     params : dict
         The parameters of the final graph.
     """
-    if not isinstance(mod, (IRModule, _expr.Function)):
+    if not isinstance(mod, (IRModule, _function.Function)):
         raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
-    if isinstance(mod, _expr.Function):
+    if isinstance(mod, _function.Function):
         if params:
             mod = bind_params_by_name(mod, params)
         mod = IRModule.from_expr(mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
-            "instead of deprecated parameter mod (tvm.relay.expr.Function)",
+            "instead of deprecated parameter mod (tvm.relay.function.Function)",
             DeprecationWarning)
 
     target = _update_target(target)
@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
     params : dict
         The parameters of the final graph.
     """
-    if not isinstance(mod, (IRModule, _expr.Function)):
+    if not isinstance(mod, (IRModule, _function.Function)):
         raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
-    if isinstance(mod, _expr.Function):
+    if isinstance(mod, _function.Function):
         if params:
             mod = bind_params_by_name(mod, params)
         mod = IRModule.from_expr(mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
-            "instead of deprecated parameter func (tvm.relay.expr.Function)",
+            "instead of deprecated parameter func (tvm.relay.function.Function)",
             DeprecationWarning)
 
     target = _update_target(target)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 380cdf7..ff13683 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -22,8 +22,8 @@ from numbers import Number as _Number
 import numpy as _np
 import tvm._ffi
 from tvm._ffi import base as _base
-from tvm.runtime import NDArray, convert, ndarray as _nd
-from tvm.ir import RelayExpr, GlobalVar, BaseFunc
+from tvm.runtime import NDArray, ndarray as _nd
+from tvm.ir import RelayExpr, GlobalVar
 
 from .base import RelayNode
 from . import _ffi_api
@@ -225,68 +225,6 @@ class Var(ExprWithOp):
         return name
 
 
-@tvm._ffi.register_object("relay.Function")
-class Function(BaseFunc):
-    """A function declaration expression.
-
-    Parameters
-    ----------
-    params: List[tvm.relay.Var]
-        List of input parameters to the function.
-
-    body: tvm.relay.Expr
-        The body of the function.
-
-    ret_type: Optional[tvm.relay.Type]
-        The return type annotation of the function.
-
-    type_params: Optional[List[tvm.relay.TypeParam]]
-        The additional type parameters, this is only
-        used in advanced usecase of template functions.
-    """
-    def __init__(self,
-                 params,
-                 body,
-                 ret_type=None,
-                 type_params=None,
-                 attrs=None):
-        if type_params is None:
-            type_params = convert([])
-
-        self.__init_handle_by_constructor__(
-            _ffi_api.Function, params, body, ret_type, type_params, attrs)
-
-    def __call__(self, *args):
-        """Invoke the global function.
-
-        Parameters
-        ----------
-        args: List[relay.Expr]
-            Arguments.
-        """
-        return Call(self, args, None, None)
-
-    def with_attr(self, attr_key, attr_value):
-        """Create a new copy of the function and update the attribute
-
-        Parameters
-        ----------
-        attr_key : str
-            The attribute key to use.
-
-        attr_value : Object
-            The new attribute value.
-
-        Returns
-        -------
-        func : Function
-            A new copy of the function
-        """
-        return _ffi_api.FunctionWithAttr(
-            self, attr_key, convert(attr_value))
-
-
-
 @tvm._ffi.register_object("relay.Call")
 class Call(ExprWithOp):
     """Function call node in Relay.
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index 8d69239..874a3a7 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -17,7 +17,8 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """The expression functor of Relay."""
 
-from .expr import Function, Call, Let, Var, GlobalVar
+from .function import Function
+from .expr import Call, Let, Var, GlobalVar
 from .expr import If, Tuple, TupleGetItem, Constant
 from .expr import RefCreate, RefRead, RefWrite
 from .adt import Constructor, Match, Clause
diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py
index da0cc64..f4fcd92 100644
--- a/python/tvm/relay/frontend/caffe2.py
+++ b/python/tvm/relay/frontend/caffe2.py
@@ -21,6 +21,7 @@ from tvm.ir import IRModule
 
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from ... import nd as _nd
 from .common import AttrCvt, Renamer
@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
         else:
             outputs = out[0]
 
-        func = _expr.Function(analysis.free_vars(outputs), outputs)
+        func = _function.Function(analysis.free_vars(outputs), outputs)
         self._mod["main"] = func
 
         return self._mod, self._params
@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
         ----------
         op_type : str
             Operator name, such as Convolution, FullyConnected
-        inputs : list of tvm.relay.expr.Function
+        inputs : list of tvm.relay.function.Function
             List of input inputs.
         args : dict
             Dict of operator attributes
@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
 
         Returns
         -------
-        func : tvm.relay.expr.Function
+        func : tvm.relay.function.Function
             Converted relay function
         """
         identity_list = identity_list if identity_list else _identity_list
diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py
index d427fe9..6185121 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
 from topi.util import get_const_tuple
 
 from .. import expr as _expr
+from .. import function as _function
 from .. import transform as _transform
 from .. import op as _op
 from .. import analysis
@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
         new_mod.update(mod)
     new_mod = _transform.InferType()(new_mod)
     entry = new_mod["main"]
-    return entry if isinstance(node, _expr.Function) else entry.body
+    return entry if isinstance(node, _function.Function) else entry.body
 
 def infer_shape(inputs, mod=None):
     """A method to get the output type of an intermediate node in the graph."""
@@ -491,7 +492,7 @@ def infer_value(input_val, params):
     # Check that all free variables have associated parameters.
     assert all(var.name_hint in params.keys() for var in analysis.free_vars(
         input_val)), "All inputs to infer must be available in params."
-    func = _expr.Function(analysis.free_vars(input_val), input_val)
+    func = _function.Function(analysis.free_vars(input_val), input_val)
     with tvm.relay.build_config(opt_level=0):
         graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
     ctx = tvm.cpu(0)
diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py
index 0e5b64c..6658803 100644
--- a/python/tvm/relay/frontend/coreml.py
+++ b/python/tvm/relay/frontend/coreml.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
 
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from ... import nd as _nd
 from ..._ffi import base as _base
@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
                for o in spec.description.output]
     # for now return first output
     outexpr = outexpr[0]
-    func = _expr.Function(analysis.free_vars(outexpr), outexpr)
+    func = _function.Function(analysis.free_vars(outexpr), outexpr)
     params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
     return IRModule.from_expr(func), params
diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py
index 0dae645..936d7c0 100644
--- a/python/tvm/relay/frontend/darknet.py
+++ b/python/tvm/relay/frontend/darknet.py
@@ -26,6 +26,7 @@ from tvm.ir import IRModule
 
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .common import get_relay_op, new_var
 
 __all__ = ['from_darknet']
@@ -821,7 +822,7 @@ class GraphProto(object):
 
         outputs = _as_list(sym) + self._outs
         outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
-        sym = _expr.Function(analysis.free_vars(outputs), outputs)
+        sym = _function.Function(analysis.free_vars(outputs), outputs)
         return IRModule.from_expr(sym), self._tvmparams
 
 def from_darknet(net,
diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index adb28c4..090bd4c 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -23,6 +23,7 @@ from tvm.ir import IRModule
 
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from ... import nd as _nd
 from .common import ExprTable, new_var
@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
     outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
                for oc in model._output_coordinates]
     outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
-    func = _expr.Function(analysis.free_vars(outexpr), outexpr)
+    func = _function.Function(analysis.free_vars(outexpr), outexpr)
     params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
     return IRModule.from_expr(func), params
diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py
index ba93bb2..17be368 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -25,6 +25,7 @@ from tvm import relay
 from topi.util import get_const_tuple
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from .. import scope_builder as _scope_builder
 from ... import nd as _nd
@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
         else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
         else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
         sb.ret(_expr.Call(else_func, else_args))
-    func = _expr.Function(input_args, sb.get())
+    func = _function.Function(input_args, sb.get())
     ret = _expr.Call(func, inputs)
     if num_outputs > 1:
         ret = _expr.TupleWrapper(ret, num_outputs)
@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
 
     outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
     outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
-    func = _expr.Function(analysis.free_vars(outputs), outputs)
+    func = _function.Function(analysis.free_vars(outputs), outputs)
     return func
 
 
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 7f417d3..e1b0a7f 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
 from ... import nd as _nd
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from .common import AttrCvt, Renamer
 from .common import get_relay_op, new_var, infer_shape, infer_channels
@@ -1708,7 +1709,7 @@ class GraphProto(object):
         # now return the outputs
         outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
         outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
-        func = _expr.Function(analysis.free_vars(outputs), outputs)
+        func = _function.Function(analysis.free_vars(outputs), outputs)
         return IRModule.from_expr(func), self._params
 
     def _parse_value_proto(self, value_proto):
@@ -1774,7 +1775,7 @@ class GraphProto(object):
         ----------
         op_name : str
             Operator name, such as Convolution, FullyConnected
-        inputs : list of tvm.relay.expr.Function
+        inputs : list of tvm.relay.function.Function
             List of inputs.
         attrs : dict
             Dict of operator attributes
@@ -1783,7 +1784,7 @@ class GraphProto(object):
 
         Returns
         -------
-        sym : tvm.relay.expr.Function
+        sym : tvm.relay.function.Function
             Converted relay function
         """
         convert_map = _get_convert_map(opset)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 3dca365..e0da863 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
 
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from ..expr_functor import ExprMutator
 from .common import AttrCvt, get_relay_op
@@ -2461,7 +2462,7 @@ class GraphProto(object):
                 out.append(out_rnn)
 
         out = out[0] if len(out) == 1 else _expr.Tuple(out)
-        func = _expr.Function(analysis.free_vars(out), out)
+        func = _function.Function(analysis.free_vars(out), out)
         self._mod["main"] = func
         return self._mod, self._params
 
diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 95f7579..aa51570 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
 from tvm import relay
 from .. import analysis
 from .. import expr as _expr
+from .. import function as _function
 from .. import op as _op
 from .. import qnn as _qnn
 from ... import nd as _nd
@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
     params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
     outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
     outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
-    func = _expr.Function(analysis.free_vars(outputs), outputs)
+    func = _function.Function(analysis.free_vars(outputs), outputs)
     mod = IRModule.from_expr(func)
     return mod, params
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
new file mode 100644
index 0000000..786a7f4
--- /dev/null
+++ b/python/tvm/relay/function.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, invalid-name, unused-import
+"""The expression nodes of Relay."""
+from __future__ import absolute_import
+
+import tvm._ffi
+from tvm.runtime import convert
+from tvm.ir import BaseFunc
+
+from .expr import Call
+from . import _ffi_api
+
+@tvm._ffi.register_object("relay.Function")
+class Function(BaseFunc):
+    """A function declaration expression.
+
+    Parameters
+    ----------
+    params: List[tvm.relay.Var]
+        List of input parameters to the function.
+
+    body: tvm.relay.Expr
+        The body of the function.
+
+    ret_type: Optional[tvm.relay.Type]
+        The return type annotation of the function.
+
+    type_params: Optional[List[tvm.relay.TypeParam]]
+        The additional type parameters, this is only
+        used in advanced usecase of template functions.
+    """
+    def __init__(self,
+                 params,
+                 body,
+                 ret_type=None,
+                 type_params=None,
+                 attrs=None):
+        if type_params is None:
+            type_params = convert([])
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.Function, params, body, ret_type, type_params, attrs)
+
+    def __call__(self, *args):
+        """Invoke the global function.
+
+        Parameters
+        ----------
+        args: List[relay.Expr]
+            Arguments.
+        """
+        return Call(self, args, None, None)
+
+    def with_attr(self, attr_key, attr_value):
+        """Create a new copy of the function and update the attribute
+
+        Parameters
+        ----------
+        attr_key : str
+            The attribute key to use.
+
+        attr_value : Object
+            The new attribute value.
+
+        Returns
+        -------
+        func : Function
+            A new copy of the function
+        """
+        return _ffi_api.FunctionWithAttr(
+            self, attr_key, convert(attr_value))
diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py
index 8e066ab..9af6811 100644
--- a/python/tvm/relay/loops.py
+++ b/python/tvm/relay/loops.py
@@ -20,6 +20,7 @@ Utilities for building Relay loops.
 """
 from .scope_builder import ScopeBuilder
 from . import expr as _expr
+from . import function as _function
 
 def while_loop(cond, loop_vars, loop_bodies):
     """
@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
     with sb.else_scope():
         sb.ret(_expr.Tuple(fresh_vars))
 
-    func = _expr.Function(fresh_vars, sb.get())
+    func = _function.Function(fresh_vars, sb.get())
     let = _expr.Let(loop, func, loop)
     return let
diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py
index 5288a2e..0e64a2f 100644
--- a/python/tvm/relay/prelude.py
+++ b/python/tvm/relay/prelude.py
@@ -19,7 +19,8 @@
 from tvm.ir import IRModule
 
 from .ty import GlobalTypeVar, TensorType, Any, scalar_type
-from .expr import Var, Function, GlobalVar, If, const
+from .expr import Var, GlobalVar, If, const
+from .function import Function
 from .op.tensor import add, subtract, equal
 from .adt import Constructor, TypeData, Clause, Match
 from .adt import PatternConstructor, PatternVar, PatternWildcard
diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py
index eb71120..4906eef 100644
--- a/python/tvm/relay/testing/nat.py
+++ b/python/tvm/relay/testing/nat.py
@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
 
 from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar
 from tvm.relay.backend.interpreter import ConstructorValue
-from tvm.relay.expr import Var, Function, GlobalVar
+from tvm.relay.expr import Var, GlobalVar
+from tvm.relay.function import Function
 from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
 
 def define_nat_adt(prelude):
diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py
index eacfe37..e850000 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -23,7 +23,8 @@ import tvm
 from tvm import relay
 from tvm.relay.adt import Pattern
 from tvm.relay.backend import compile_engine
-from tvm.relay.expr import Expr, Function, GlobalVar, Var
+from tvm.relay.expr import Expr, GlobalVar, Var
+from tvm.relay.function import Function
 from tvm.relay.expr_functor import ExprFunctor
 
 OUTPUT_VAR_NAME = '_py_out'
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index b251645..48cb4d8 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -27,10 +27,10 @@ namespace tvm {
 namespace relay {
 
 Function::Function(tvm::Array<Var> params,
-                  Expr body,
-                  Type ret_type,
-                  tvm::Array<TypeVar> type_params,
-                  DictAttrs attrs) {
+                   Expr body,
+                   Type ret_type,
+                   tvm::Array<TypeVar> type_params,
+                   DictAttrs attrs) {
   ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
   CHECK(params.defined());
   CHECK(type_params.defined());
@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
   return Function(params, body, ret_type, ty_params, attrs);
 });
 
-
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
   auto* node = static_cast<const FunctionNode*>(ref.get());


Mime
View raw message