tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [REFACTOR][TIR] Introduce PrimFuncPass. (#5139)
Date Tue, 24 Mar 2020 15:16:03 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a0e58b  [REFACTOR][TIR] Introduce PrimFuncPass. (#5139)
0a0e58b is described below

commit 0a0e58bfa4c87b2cbff0be2b401da0b3a08fcfe8
Author: Tianqi Chen <tqchen@users.noreply.github.com>
AuthorDate: Tue Mar 24 08:15:54 2020 -0700

    [REFACTOR][TIR] Introduce PrimFuncPass. (#5139)
    
    * [REFACTOR][TIR] Introduce PrimFuncPass.
    
    - Introduce PrimFuncPass
    - Convert one pass to the unified Pass API.
    
    * Address comments
    
    * Fix comments
---
 docs/api/python/tir.rst                            |   9 ++
 include/tvm/ir/expr.h                              |   2 +-
 include/tvm/ir/type_functor.h                      |   4 +
 include/tvm/tir/transform.h                        |  72 ++++++++++
 python/tvm/tir/__init__.py                         |   1 +
 python/tvm/tir/transform/__init__.py               |  21 +++
 python/tvm/tir/transform/_ffi_api.py               |  21 +++
 python/tvm/tir/transform/function_pass.py          | 149 +++++++++++++++++++++
 python/tvm/tir/transform/transform.py              |  31 +++++
 src/ir/module.cc                                   |   4 +-
 src/ir/type_functor.cc                             |  14 ++
 src/relay/analysis/alpha_equal.cc                  |  16 +++
 src/target/codegen.cc                              |   3 +
 src/tir/ir/function.cc                             |   3 +
 src/tir/ir/transform.cc                            | 145 ++++++++++++++++++++
 .../{pass => transforms}/combine_context_call.cc   |  19 +++
 ... => test_tir_transform_combine_context_call.py} |  12 +-
 .../unittest/test_tir_transform_prim_func_pass.py  |  50 +++++++
 18 files changed, 570 insertions(+), 6 deletions(-)

diff --git a/docs/api/python/tir.rst b/docs/api/python/tir.rst
index ea1ac66..dd08758 100644
--- a/docs/api/python/tir.rst
+++ b/docs/api/python/tir.rst
@@ -22,3 +22,12 @@ tvm.tir
    :imported-members:
    :exclude-members: PrimExpr, const
    :autosummary:
+
+
+
+tvm.tir.transform
+-----------------
+.. automodule:: tvm.tir.transform
+   :members:
+   :imported-members:
+   :autosummary:
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 85b3937..44244df 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -150,7 +150,7 @@ class RelayExprNode : public BaseExprNode {
   /*!
    * \return The checked_type
    */
-  const Type& checked_type() const;
+  inline const Type& checked_type() const;
   /*!
    * \brief Check if the inferred(checked) type of the Expr
    *  is backed by a TTypeNode and return it.
diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h
index 476538c..5507191 100644
--- a/include/tvm/ir/type_functor.h
+++ b/include/tvm/ir/type_functor.h
@@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
   virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
+  virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
   virtual R VisitTypeDefault_(const Object* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
     throw;  // unreachable, written to stop compiler warning
@@ -115,6 +116,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
     TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
     TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
     TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
+    TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
     return vtable;
   }
 };
@@ -138,6 +140,7 @@ class TVM_DLL TypeVisitor :
   void VisitType_(const TypeCallNode* op) override;
   void VisitType_(const TypeDataNode* op) override;
   void VisitType_(const PrimTypeNode* op) override;
+  void VisitType_(const PointerTypeNode* op) override;
 };
 
 /*!
@@ -158,6 +161,7 @@ class TVM_DLL TypeMutator :
   Type VisitType_(const TypeCallNode* op) override;
   Type VisitType_(const TypeDataNode* op) override;
   Type VisitType_(const PrimTypeNode* op) override;
+  Type VisitType_(const PointerTypeNode* op) override;
 
  private:
   Array<Type> MutateArray(Array<Type> arr);
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
new file mode 100644
index 0000000..5149677
--- /dev/null
+++ b/include/tvm/tir/transform.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/transform.h
+ * \brief TIR specific transformation passes.
+ */
+#ifndef TVM_TIR_TRANSFORM_H_
+#define TVM_TIR_TRANSFORM_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+
+#include <string>
+
+namespace tvm {
+namespace tir {
+namespace transform {
+
+using tvm::transform::Pass;
+using tvm::transform::PassNode;
+using tvm::transform::PassInfo;
+using tvm::transform::PassInfoNode;
+using tvm::transform::PassContext;
+using tvm::transform::PassContextNode;
+using tvm::transform::Sequential;
+
+/*
+ * \brief Create a function pass that optimizes PrimFuncs.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the function pass.
+ * \param name The name of the function pass.
+ * \param required The list of the passes that the function pass is dependent on.
+ *
+ * \return The created function pass.
+ */
+TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
+                                PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
+                                int opt_level,
+                                const std::string& name,
+                                const tvm::Array<tvm::PrimExpr>& required);
+
+/*!
+ * \brief Create PrimFuncPass to combine context calls in the host function.
+ *
+ * \return The pass.
+ */
+Pass CombineContextCall();
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_TRANSFORM_H_
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index d4d389a..f0d4d93 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -45,3 +45,4 @@ from .op import comm_reducer, min, max, sum
 
 from . import ir_builder
 from . import ir_pass
+from . import transform
diff --git a/python/tvm/tir/transform/__init__.py b/python/tvm/tir/transform/__init__.py
new file mode 100644
index 0000000..5947f41
--- /dev/null
+++ b/python/tvm/tir/transform/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Namespace of all TIR transformations"""
+# pylint: disable=wildcard-import, invalid-name
+
+from .function_pass import prim_func_pass, PrimFuncPass
+from .transform import *
diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tir/transform/_ffi_api.py
new file mode 100644
index 0000000..86f7bdf
--- /dev/null
+++ b/python/tvm/tir/transform/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.tir.transform"""
+import tvm._ffi
+
+
+tvm._ffi._init_api("tir.transform", __name__)
diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py
new file mode 100644
index 0000000..93bb996
--- /dev/null
+++ b/python/tvm/tir/transform/function_pass.py
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""TIR specific function pass support."""
+import inspect
+import functools
+
+import tvm._ffi
+from tvm.ir.transform import Pass, PassInfo
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("tir.PrimFuncPass")
+class PrimFuncPass(Pass):
+    """A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function
+    pass class should be created through py:func:`tvm.tir.transform.function_pass`.
+    """
+
+
+def _wrap_class_function_pass(pass_cls, pass_info):
+    """Wrap a python class as function pass"""
+    class PyFunctionPass(PrimFuncPass):
+        """Internal wrapper class to create a class instance."""
+        def __init__(self, *args, **kwargs):
+            # initialize handle in cass pass_cls creation failed.fg
+            self.handle = None
+            inst = pass_cls(*args, **kwargs)
+            # it is important not to capture self to
+            # avoid a cyclic dependency
+            def _pass_func(func, mod, ctx):
+                return inst.transform_function(func, mod, ctx)
+            self.__init_handle_by_constructor__(
+                _ffi_api.CreatePrimFuncPass, _pass_func, pass_info)
+            self._inst = inst
+
+        def __getattr__(self, name):
+            # fall back to instance attribute if there is not any
+            return self._inst.__getattribute__(name)
+
+    functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
+    PyFunctionPass.__name__ = pass_cls.__name__
+    PyFunctionPass.__doc__ = pass_cls.__doc__
+    PyFunctionPass.__module__ = pass_cls.__module__
+    return PyFunctionPass
+
+
+def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
+    """Decorate a function pass.
+
+    This function returns a callback when pass_func
+    is provided. Otherwise, it returns the created function pass using the
+    given optimization function.
+
+    Parameters
+    ----------
+    pass_func : Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]]
+        The transformation function or class.
+
+    opt_level : int
+        The optimization level of this module pass.
+
+    name : Optional[str]
+        The name of the function pass. The name could be empty. In this case, the
+        name of the optimization function will be used as the pass name.
+
+    required : Optional[List[str]]
+        The list of passes that the function pass is dependent on.
+
+    Returns
+    -------
+    create_function_pass : Union[Callable, FunctionPass]
+
+        A decorator will be returned if pass_func is not provided,
+        otherwise return the decorated result.
+        The returned decorator has two behaviors depending on the input:
+        A new FunctionPass will be returned when we decorate a pass function.
+        A new FunctionPass class will be returned when we decorate a class type.
+
+    Examples
+    --------
+    The following code block decorates a function pass class.
+
+    .. code-block:: python
+
+        @tvm.tir.transform.prim_func_pass(opt_level=1)
+        class TestReplaceFunc:
+            def __init__(self, new_func):
+                self.new_func = new_func
+
+            def transform_function(self, func, mod, ctx):
+                # just for demo purposes
+                # transform func to new_func
+                return self.new_func
+
+    The following code creates a function pass by decorating
+    a user defined transform function.
+
+    .. code-block:: python
+
+        @tvm.tir.transform.prim_func_pass(opt_level=2)
+        def transform(func, mod, ctx):
+            # my transformations here.
+            return func
+
+        function_pass = transform
+        assert isinstance(function_pass, transform.FunctionPass)
+        assert function_pass.info.opt_level == 2
+
+        # Given a module m, the optimization could be invoked as the follwoing:
+        updated_mod = function_pass(m)
+        # Now constant folding should have been applied to every function in
+        # the provided module m. And the updated module will be returned.
+    """
+
+    if opt_level is None:
+        raise ValueError("Please provide opt_level for the funtion pass.")
+
+    required = required if required else []
+    if not isinstance(required, (list, tuple)):
+        raise TypeError("Required is expected to be the type of " +
+                        "list/tuple.")
+
+    def create_function_pass(pass_arg):
+        """Internal function that creates a function pass"""
+        fname = name if name else pass_arg.__name__
+        info = PassInfo(opt_level, fname, required)
+        if inspect.isclass(pass_arg):
+            return _wrap_class_function_pass(pass_arg, info)
+        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+            raise TypeError("pass_func must be a callable for Module pass")
+        return _ffi_api.MakeFunctionPass(pass_arg, info)
+
+    if pass_func:
+        return create_function_pass(pass_func)
+    return create_function_pass
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
new file mode 100644
index 0000000..1eec94e
--- /dev/null
+++ b/python/tvm/tir/transform/transform.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Wrapping existing transformations."""
+# pylint: disable=invalid-name
+
+from . import _ffi_api
+
+
+def CombineContextCall():
+    """Combine context calls in the host function.
+
+    Returns
+    -------
+    fpass : tvm.ir.transform.Pass
+        The result pass
+    """
+    return _ffi_api.CombineContextCall()
diff --git a/src/ir/module.cc b/src/ir/module.cc
index a78a752..4ac769b 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -170,13 +170,13 @@ void IRModuleNode::Add(const GlobalVar& var,
                                 GetRef<relay::Function>(ptr));
   }
 
-  auto type = checked_func->checked_type();
+  Type type = checked_func->checked_type();
   CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
 
   if (functions.find(var) != functions.end()) {
     CHECK(update)
         << "Already have definition for " << var->name_hint;
-    auto old_type = functions[var].as<relay::FunctionNode>()->checked_type();
+    auto old_type = functions[var]->checked_type();
     CHECK(relay::AlphaEqual(type, old_type))
         << "Module#update changes type, not possible in this mode.";
   }
diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc
index cbd3538..9d9167f 100644
--- a/src/ir/type_functor.cc
+++ b/src/ir/type_functor.cc
@@ -93,6 +93,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
 void TypeVisitor::VisitType_(const PrimTypeNode* op) {
 }
 
+void TypeVisitor::VisitType_(const PointerTypeNode* op) {
+  this->VisitType(op->element_type);
+}
+
 Type TypeMutator::VisitType(const Type& t) {
   return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
 }
@@ -209,6 +213,16 @@ Type TypeMutator::VisitType_(const PrimTypeNode* op) {
   return GetRef<Type>(op);
 }
 
+Type TypeMutator::VisitType_(const PointerTypeNode* op) {
+  Type element_type = VisitType(op->element_type);
+
+  if (element_type.same_as(op->element_type)) {
+    return GetRef<Type>(op);
+  } else {
+    return PointerType(element_type);
+  }
+}
+
 // Implements bind.
 class TypeBinder : public TypeMutator {
  public:
diff --git a/src/relay/analysis/alpha_equal.cc b/src/relay/analysis/alpha_equal.cc
index 5402848..28c7681 100644
--- a/src/relay/analysis/alpha_equal.cc
+++ b/src/relay/analysis/alpha_equal.cc
@@ -202,6 +202,22 @@ class AlphaEqualHandler:
     return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
   }
 
+  bool VisitType_(const PrimTypeNode* lhs, const Type& other) final {
+    if (const PrimTypeNode* rhs = other.as<PrimTypeNode>()) {
+      return lhs->dtype == rhs->dtype;
+    } else {
+      return false;
+    }
+  }
+
+  bool VisitType_(const PointerTypeNode* lhs, const Type& other) final {
+    if (const PointerTypeNode* rhs = other.as<PointerTypeNode>()) {
+      return TypeEqual(lhs->element_type, rhs->element_type);
+    } else {
+      return false;
+    }
+  }
+
   bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
     if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
       if (lhs->kind != rhs->kind) return false;
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index 7dc23b6..e9ff234 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -310,6 +310,9 @@ TVM_REGISTER_GLOBAL("target.Build")
     }
   });
 
+TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
+.set_body_typed(ToIRModule);
+
 // Export two auxiliary function to the runtime namespace.
 TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
 .set_body_typed(PackImportsToC);
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 7464e3a..0891c47 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -23,10 +23,12 @@
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
 
 namespace tvm {
 namespace tir {
 
+// Get the function type of a PrimFunc
 PrimFunc::PrimFunc(Array<tir::Var> params,
                    Stmt body,
                    Type ret_type,
@@ -43,6 +45,7 @@ PrimFunc::PrimFunc(Array<tir::Var> params,
   n->ret_type = std::move(ret_type);
   n->buffer_map = std::move(buffer_map);
   n->attrs = std::move(attrs);
+  n->checked_type_ = n->func_type_annotation();
   data_ = std::move(n);
 }
 
diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc
new file mode 100644
index 0000000..f991e90
--- /dev/null
+++ b/src/tir/ir/transform.cc
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/transform.cc
+ * \brief TIR specific transformation passes.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/node/repr_printer.h>
+#include <tvm/tir/transform.h>
+
+
+namespace tvm {
+namespace tir {
+namespace transform {
+
+
+/*!
+ * \brief Function level pass that applies transformations to all
+ *        TIR functions within the module.
+ */
+class PrimFuncPassNode : public PassNode {
+ public:
+  /* \brief The pass meta data.*/
+  PassInfo pass_info;
+
+  /*! \brief The pass function called on each. */
+  runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pass_info", &pass_info);
+  }
+
+  /*!
+   * \brief Run a function pass on given pass context.
+   *
+   * \param mod The module that an optimization pass is applied on.
+   * \param pass_ctx The context that an optimization pass executes on.
+   *
+   * \return Return the updated module.
+   */
+  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
+
+  /*!
+   * \brief Get the pass information/meta data.
+   */
+  PassInfo Info() const override { return pass_info; }
+
+  static constexpr const char* _type_key = "tir.PrimFuncPass";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode);
+};
+
+class PrimFuncPass : public Pass {
+ public:
+  /*!
+   * \brief The constructor
+   * \param pass_func The packed function which implements a pass.
+   * \param pass_info The pass info.
+   */
+  TVM_DLL PrimFuncPass(
+      runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
+      PassInfo pass_info);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
+};
+
+PrimFuncPass::PrimFuncPass(
+    runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
+    PassInfo pass_info) {
+  auto n = make_object<PrimFuncPassNode>();
+  n->pass_func = std::move(pass_func);
+  n->pass_info = std::move(pass_info);
+  data_ = std::move(n);
+}
+
+// Perform Module -> Module optimizations at the PrimFunc level.
+IRModule PrimFuncPassNode::operator()(const IRModule& mod,
+                                      const PassContext& pass_ctx) const {
+  const PassInfo& pass_info = Info();
+  CHECK(mod.defined());
+  pass_ctx.Trace(mod, pass_info, true);
+  // Execute the pass function and return a new module.
+  IRModule updated_mod = IRModule(
+      mod->functions, mod->type_definitions, mod->Imports());
+  std::vector<std::pair<GlobalVar, PrimFunc> > updates;
+  for (const auto& it : updated_mod->functions) {
+    // only picks up relay::PrimFunc
+    if (auto* n = it.second.as<PrimFuncNode>()) {
+      PrimFunc func = GetRef<PrimFunc>(n);
+      auto updated_func =
+          pass_func(func, updated_mod, pass_ctx);
+      updates.push_back({it.first, updated_func});
+    }
+  }
+  for (const auto& pair : updates) {
+    updated_mod->Add(pair.first, pair.second, true);
+  }
+  pass_ctx.Trace(updated_mod, pass_info, false);
+  return updated_mod;
+}
+
+Pass CreatePrimFuncPass(
+    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>&
pass_func,
+    int opt_level,
+    const std::string& name,
+    const tvm::Array<tvm::PrimExpr>& required) {
+  PassInfo pass_info = PassInfo(opt_level, name, required);
+  return PrimFuncPass(pass_func, pass_info);
+}
+
+TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
+
+TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
+.set_body_typed([](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>
pass_func,
+    PassInfo pass_info) {
+  return PrimFuncPass(pass_func, pass_info);
+});
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
+  auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
+  const PassInfo info = node->Info();
+  p->stream << "PrimFuncPass(" << info->name
+            << ", opt_level=" << info->opt_level << ")";
+});
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/pass/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc
similarity index 88%
rename from src/tir/pass/combine_context_call.cc
rename to src/tir/transforms/combine_context_call.cc
index 5f043bc..ed352c1 100644
--- a/src/tir/pass/combine_context_call.cc
+++ b/src/tir/transforms/combine_context_call.cc
@@ -25,7 +25,11 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/runtime/registry.h>
+
 #include <tvm/tir/ir_pass.h>
+
 #include <map>
 
 namespace tvm {
@@ -114,5 +118,20 @@ LoweredFunc CombineContextCall(LoweredFunc f) {
   return LoweredFunc(n);
 }
 
+namespace transform {
+
+Pass CombineContextCall() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = ContextCallCombiner().Combine(n->body);
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
+.set_body_typed(CombineContextCall);
+
+}  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_pass_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py
similarity index 84%
rename from tests/python/unittest/test_tir_pass_combine_context_call.py
rename to tests/python/unittest/test_tir_transform_combine_context_call.py
index e51d4d8..e76fb33 100644
--- a/tests/python/unittest/test_tir_pass_combine_context_call.py
+++ b/tests/python/unittest/test_tir_transform_combine_context_call.py
@@ -37,9 +37,15 @@ def test_for():
                     ("int32", "fadd", device_context(0), A))
     body = ib.get()
     f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
-    f = tvm.tir.ir_pass.CombineContextCall(f)
-    assert f.body.value.dtype == "handle"
-    assert f.body.body.value.dtype == "handle"
+
+    # temp adapter to convert loweredFunc to IRModule
+    # to test passes in the new style.
+    mod = tvm.testing.LoweredFuncsToIRModule([f])
+
+    mod = tvm.tir.transform.CombineContextCall()(mod)
+
+    assert mod["func"].body.value.dtype == "handle"
+    assert mod["func"].body.body.value.dtype == "handle"
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_transform_prim_func_pass.py b/tests/python/unittest/test_tir_transform_prim_func_pass.py
new file mode 100644
index 0000000..87aecd1
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_prim_func_pass.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import te
+
+
+def test_prim_func_pass():
+    @tvm.tir.transform.prim_func_pass(opt_level=1)
+    class TestReplaceFunc:
+        """Simple test function to replace one argument to another."""
+        def __init__(self, new_func):
+            self.new_func = new_func
+
+        def transform_function(self, func, mod, ctx):
+            return self.new_func
+
+    x = te.var('x')
+    y = te.var('y')
+    b = tvm.tir.decl_buffer((x,), "float32")
+    stmt = tvm.tir.LetStmt(
+        x, 10, tvm.tir.Evaluate(x + 1));
+
+    func = tvm.tir.PrimFunc(
+        [x, y, b], stmt)
+
+    new_func = tvm.tir.PrimFunc(
+        [x, y, b], tvm.tir.Evaluate(0))
+
+    mod = tvm.IRModule({"main": func})
+    mod = TestReplaceFunc(new_func)(mod)
+
+    assert tvm.tir.ir_pass.Equal(mod["main"].body, new_func.body)
+
+
+if __name__ == "__main__":
+    test_prim_func_pass()


Mime
View raw message