tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [RELAY] Non-recursive Graph Vistor and Rewriter (#4886)
Date Fri, 03 Apr 2020 21:36:12 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7de8a53  [RELAY] Non-recursive Graph Vistor and Rewriter (#4886)
7de8a53 is described below

commit 7de8a539b1e73627309308b49c6c69625efc4d5a
Author: Matthew Brookhart <matthewbrookhart@gmail.com>
AuthorDate: Fri Apr 3 14:35:55 2020 -0700

    [RELAY] Non-recursive Graph Vistor and Rewriter (#4886)
    
    * First pass a defining a non-recursive Graph Vistor and Rewriter
    
    autoformat
    
    remove a currently empty test until testing is solidfied
    
    * Make CalcDep from Dead Code Elimination non-recursive
    
    * Partially working, not passing all tests yet
    
    passes tests when disabling GetExprRefCount, I think I have a bug in visit counting
    
    fix GetExprRefCount
    
    Fix a subtle bug with nested recursive/non-recursive scopes
    
    * Refactor
    
    * improve comments
    
    * respond to review comments on comments
    
    * Fix a problem with default recursion for dataflow nodes
    
    mark DataflowVisitor methods as override
    
    * implement ScopeMutator
    
    * convert forward_rewrite to ScopeMutator, remove DataflowMutator
    
    * rewrite ExprRewriter and convert fast_math to use it
    
    * switch BiasAddSimplifier to ExprRewriter
    
    fix a clang warning
    
    fix cpp lint
    
    fix doc param error
    
    * respond to review comments
    
    * fix a typo in the iterative looping
    
    * add a regression test for GetExprRefCount issue
    
    * Normalize naming
    
    * fix lint
    
    * First pass a defining a non-recursive Graph Vistor and Rewriter
    
    autoformat
    
    remove a currently empty test until testing is solidfied
    
    * Make CalcDep from Dead Code Elimination non-recursive
    
    * Partially working, not passing all tests yet
    
    passes tests when disabling GetExprRefCount, I think I have a bug in visit counting
    
    fix GetExprRefCount
    
    Fix a subtle bug with nested recursive/non-recursive scopes
    
    * Refactor
    
    * improve comments
    
    * respond to review comments on comments
    
    * Fix a problem with default recursion for dataflow nodes
    
    mark DataflowVisitor methods as override
    
    * implement ScopeMutator
    
    * convert forward_rewrite to ScopeMutator, remove DataflowMutator
    
    * rewrite ExprRewriter and convert fast_math to use it
    
    * switch BiasAddSimplifier to ExprRewriter
    
    fix a clang warning
    
    fix cpp lint
    
    fix doc param error
    
    * respond to review comments
    
    * fix a typo in the iterative looping
    
    * add a regression test for GetExprRefCount issue
    
    * Normalize naming
    
    * fix lint
    
    * respond to review comments
---
 include/tvm/relay/analysis.h             |  11 ++
 include/tvm/relay/expr_functor.h         | 183 +++++++++++++++++++++++++++++++
 src/relay/analysis/util.cc               |   2 +-
 src/relay/ir/expr_functor.cc             | 158 +++++++++++++++++++++++++-
 src/relay/transforms/canonicalize_ops.cc |   9 +-
 src/relay/transforms/dead_code.cc        |   9 +-
 src/relay/transforms/fast_math.cc        |  18 +--
 src/relay/transforms/forward_rewrite.cc  |  66 +++++------
 src/relay/transforms/pass_util.h         |   8 --
 tests/cpp/relay_build_module_test.cc     |  17 +++
 10 files changed, 416 insertions(+), 65 deletions(-)

diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h
index e04b4e6..a2c0c75 100644
--- a/include/tvm/relay/analysis.h
+++ b/include/tvm/relay/analysis.h
@@ -30,6 +30,7 @@
 #include <tvm/ir/module.h>
 #include <tvm/relay/type.h>
 #include <string>
+#include <unordered_map>
 
 namespace tvm {
 namespace relay {
@@ -225,6 +226,16 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr&
expr);
  */
 TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
 
+/*!
+ * \brief Get reference counter of each internal ExprNode in body.
+ *
+ * \param body The body expression.
+ *
+ * \return The reference count mapping.
+ */
+TVM_DLL std::unordered_map<const Object*, size_t>
+GetExprRefCount(const Expr& body);
+
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index d1c5ca1..6f8ac69 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -233,6 +233,189 @@ class ExprMutator
 };
 
 /*!
+ * \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
+ *
+ * MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
+ *
+ * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow
regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ */
+class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
+ public:
+  /*! \brief The constructor of MixedModeVisitor
+   *  \param visit_limit The number of times to allow visitation to a node. Usually 1, ocassionally
+   * higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity check.
+   */
+  explicit MixedModeVisitor(int visit_limit = 1);
+
+  /*!
+   * \brief VisitExpr is finalized to preserve call expansion of dataflow regions
+   */
+  void VisitExpr(const Expr& expr) final;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const TupleNode* op) override;
+  void VisitExpr_(const TupleGetItemNode* op) override;
+
+ protected:
+  /*!
+   * \brief A function to apply when reaching a leaf of the graph non-recursively
+   */
+  virtual void VisitLeaf(const Expr& expr);
+  /*!
+   * \brief A function to determine if an expression has already been visited or needs to
be
+   * re-visited
+   */
+  virtual bool CheckVisited(const Expr& expr);
+  /*!
+   * \brief The max number of times to visit a node
+   */
+  size_t visit_limit_;
+};
+
+/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
+ *
+ * MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr once.
+ * The mutated results are memoized in a map and reused so that
+ * local transformation on the dataflow preserves the graph structure.
+ *
+ * MixedModeMutator provides the same recursive API as ExprMutator, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow
regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ *
+ * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive
behavior.
+ */
+class MixedModeMutator : public ::tvm::relay::ExprMutator {
+ public:
+  Expr VisitExpr(const Expr& expr) final;
+  virtual Expr DispatchVisitExpr(const Expr& expr);
+  Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
+  Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
+  Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
+  /*!
+   *  \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions
will be
+   * able to rewrite the op only with data about the original node `pre` and the same node
with
+   * modified inputs `post` and should not recurse.
+   *
+   * \param pre The expression node before rewriting.
+   * \param post The expression with rewritten inputs.
+   */
+  virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
+  virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
+  virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post;
}
+
+ protected:
+  /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post`
node with
+   * changed inputs.
+   */
+  template <typename T>
+  Expr Rewrite(const T* op) {
+    Expr post = ExprMutator::VisitExpr_(op);
+    return Rewrite_(op, post);
+  }
+
+  virtual void VisitLeaf(const Expr& expr);
+  virtual bool CheckVisited(const Expr& expr);
+};
+
+#define RELAY_EXPR_REWRITER_DISPATCH(OP)                                                
  \
+  vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, const Expr&
post) { \
+    return self->Rewrite_(static_cast<const OP*>(n.get()), post);              
           \
+  });
+
+#define EXPR_REWRITER_REWRITE_DEFAULT \
+  { return post; }
+
+/*! \brief A non-iterating Expression Rewriter
+ *
+ *  ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
+ *
+ *  The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which
will
+ * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
+ * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
+ * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more
complex
+ * graph rewriting.
+ */
+class ExprRewriter {
+ private:
+  using TSelf = ExprRewriter;
+  using FType = tvm::NodeFunctor<Expr(const ObjectRef& n, TSelf* self, const Expr&
post)>;
+
+ public:
+  /*! \brief virtual destructor */
+  virtual ~ExprRewriter() {}
+  /*!
+   * \brief Same as call.
+   * \param pre The expression node before rewriting.
+   * \param post The expression node with rewritten inputs.
+   * \return The result of the call
+   */
+  Expr operator()(const Expr& pre, const Expr& post) {
+    return Rewrite(pre, post);
+  }
+  /*!
+   * \brief The functor call.
+   * \param pre The expression node before rewriting.
+   * \param post The expression node with rewritten inputs.
+   * \return The result of the call
+   */
+  virtual Expr Rewrite(const Expr& pre, const Expr& post) {
+    CHECK(pre.defined());
+    static FType vtable = InitVTable();
+    return vtable(pre, this, post);
+  }
+  // Functions that can be overriden by subclass, should not recurse
+  virtual Expr Rewrite_(const VarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const GlobalVarNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const ConstantNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const FunctionNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const CallNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const LetNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const IfNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const OpNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const TupleGetItemNode* pre,
+                        const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const RefCreateNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const RefReadNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const RefWriteNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const ConstructorNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+  virtual Expr Rewrite_(const MatchNode* pre, const Expr& post) EXPR_REWRITER_REWRITE_DEFAULT;
+
+ private:
+  // initialize the vtable.
+  static FType InitVTable() {
+    FType vtable;
+    // Set dispatch
+    RELAY_EXPR_REWRITER_DISPATCH(ConstantNode);
+    RELAY_EXPR_REWRITER_DISPATCH(TupleNode);
+    RELAY_EXPR_REWRITER_DISPATCH(VarNode);
+    RELAY_EXPR_REWRITER_DISPATCH(GlobalVarNode);
+    RELAY_EXPR_REWRITER_DISPATCH(FunctionNode);
+    RELAY_EXPR_REWRITER_DISPATCH(CallNode);
+    RELAY_EXPR_REWRITER_DISPATCH(LetNode);
+    RELAY_EXPR_REWRITER_DISPATCH(IfNode);
+    RELAY_EXPR_REWRITER_DISPATCH(OpNode);
+    RELAY_EXPR_REWRITER_DISPATCH(TupleGetItemNode);
+    RELAY_EXPR_REWRITER_DISPATCH(RefCreateNode);
+    RELAY_EXPR_REWRITER_DISPATCH(RefReadNode);
+    RELAY_EXPR_REWRITER_DISPATCH(RefWriteNode);
+    RELAY_EXPR_REWRITER_DISPATCH(ConstructorNode);
+    RELAY_EXPR_REWRITER_DISPATCH(MatchNode);
+    return vtable;
+  }
+};
+
+/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
+ *
+ *  PostOrderRewrite does a non-recursive traversal of the graph in Post-DFS order and calls
the
+ * ExprRewriter's Rewrite functions on nodes once their inputs are rewritten. At each rewrite
call,
+ * PostOrderRewrite provides the original node and the node with altered inputs for use by
the
+ * ExprRewriter.
+ */
+Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
+
+/*!
  * \brief recursively visit the ir in post DFS order node, apply fvisit
  * Each node is guaranteed to be visited only once.
  * \param node The ir to be visited.
diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc
index 6132532..a86faeb 100644
--- a/src/relay/analysis/util.cc
+++ b/src/relay/analysis/util.cc
@@ -330,7 +330,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
  */
 std::unordered_map<const Object*, size_t>
 GetExprRefCount(const Expr& body) {
-  class ExprRefCounter : private ExprVisitor {
+  class ExprRefCounter : private MixedModeVisitor {
    public:
     std::unordered_map<const Object*, size_t>
     Get(const Expr& body) {
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index 11e85d5..cb5d06f 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -29,8 +29,162 @@
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 
+#include <stack>
+
 namespace tvm {
 namespace relay {
+/*!
+ * \brief A function to iteratively traverse dataflow regions of a graph
+ *
+ * ExpandDataflow manually manages a stack and performs DFS to determine the processing
+ * order of nodes in an input graph.
+ *
+ * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to
that node
+ * need to be processed via fcheck_visited. If so, the function pushes those arguments to
the stack
+ * and continues iteratively to process the top of the stack. When it finds a node that doesn't
+ * match the dataflow types, or a node who's inputs have all been processed, it visits the
current
+ * leaf via fvisit_leaf.
+ *
+ * This function should be used internally to other classes to implement mixed-mode traversals.
The
+ * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal
if it
+ * hits a non-dataflow node.
+ *
+ * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
+ */
+template <typename FCheckVisited, typename FVisitLeaf>
+void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
+  std::stack<std::pair<Expr, bool>> stack;
+  auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
+    // The second state of the stack indicate whether the child has been
+    // expanded in the pre-order.
+    // NOTE: function will be inlined.
+    if (!fcheck_visited(expr)) {
+      stack.push({expr, false});
+    }
+  };
+  fpush_to_stack(expr);
+  while (stack.size() > 0) {
+    auto node = stack.top().first;
+    if (fcheck_visited(node)) {
+      // if this node was visited through another path
+      // after being added to the stack ignore it.
+      stack.pop();
+    } else if (stack.top().second) {
+      // all the children have already been expanded.
+      // we can just run post order visit on it.
+      fvisit_leaf(node);
+      stack.pop();
+    } else if (const CallNode* op = node.as<CallNode>()) {
+      // mark expanded = true
+      stack.top().second = true;
+      // push the children to the stack in reverse order
+      // to match recursive processing order
+      for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
+        fpush_to_stack(*it);
+      }
+      fpush_to_stack(op->op);
+    } else if (const TupleNode* op = node.as<TupleNode>()) {
+      stack.top().second = true;
+      // push the children to the stack in reverse order
+      // to match recursive processing order
+      for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
+        fpush_to_stack(*it);
+      }
+    } else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
+      stack.top().second = true;
+      fpush_to_stack(op->tuple);
+    } else {
+      // No need to expand the children directly run visit.
+      fvisit_leaf(node);
+      stack.pop();
+    }
+  }
+}
+
+MixedModeVisitor::MixedModeVisitor(int visit_limit) {
+  CHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
+  CHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
+  visit_limit_ = visit_limit;
+}
+
+void MixedModeVisitor::VisitLeaf(const Expr& expr) {
+  if (visit_counter_[expr.get()] < visit_limit_) {
+    ExprFunctor::VisitExpr(expr);
+  }
+  visit_counter_[expr.get()]++;
+}
+
+bool MixedModeVisitor::CheckVisited(const Expr& expr) {
+  if (visit_counter_[expr.get()] < visit_limit_) {
+    return false;
+  } else {
+    visit_counter_[expr.get()]++;
+    return true;
+  }
+}
+
+void MixedModeVisitor::VisitExpr(const Expr& expr) {
+  auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr);
};
+  auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
+  if (visit_counter_[expr.get()] < visit_limit_) {
+    ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
+  }
+}
+
+// Overwrite the VisitExpr so we don't recurse for dataflow nodes
+void MixedModeVisitor::VisitExpr_(const CallNode* op) {}
+
+// Overwrite the VisitExpr so we don't recurse for dataflow nodes
+void MixedModeVisitor::VisitExpr_(const TupleNode* op) {}
+
+// Overwrite the VisitExpr so we don't recurse for dataflow nodes
+void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {}
+
+void MixedModeMutator::VisitLeaf(const Expr& expr) {
+  if (!memo_.count(expr)) {
+    this->DispatchVisitExpr(expr);
+  }
+}
+
+bool MixedModeMutator::CheckVisited(const Expr& expr) {
+  if (memo_.count(expr)) {
+    return true;
+  } else {
+    return false;
+  }
+}
+
+Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) {
+  return ExprMutator::VisitExpr(expr);
+}
+
+Expr MixedModeMutator::VisitExpr(const Expr& expr) {
+  auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr);
};
+  auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
+  if (memo_.count(expr)) {
+    return memo_[expr];
+  } else {
+    ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
+    Expr ret = this->DispatchVisitExpr(expr);
+    memo_[expr] = ret;
+    return ret;
+  }
+}
+
+class PostOrderRewriter : public MixedModeMutator {
+ public:
+  explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
+  Expr DispatchVisitExpr(const Expr& expr) final {
+    auto post = ExprFunctor::VisitExpr(expr);
+    return rewriter_->Rewrite(expr, post);
+  }
+ protected:
+  ExprRewriter* rewriter_;
+};
+
+Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter) {
+  return PostOrderRewriter(rewriter).VisitExpr(expr);
+}
 
 Expr ExprMutator::VisitExpr(const Expr& expr) {
   auto it = this->memo_.find(expr);
@@ -211,12 +365,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
   for (const Clause& p : m->clauses) {
     clauses.push_back(VisitClause(p));
   }
-  return Match(VisitExpr(m->data), clauses, m->complete);
+  return Match(Mutate(m->data), clauses, m->complete);
 }
 
 Clause ExprMutator::VisitClause(const Clause& c) {
   Pattern p = VisitPattern(c->lhs);
-  return Clause(p, VisitExpr(c->rhs));
+  return Clause(p, Mutate(c->rhs));
 }
 
 Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc
index bcb7f9d..97a128d 100644
--- a/src/relay/transforms/canonicalize_ops.cc
+++ b/src/relay/transforms/canonicalize_ops.cc
@@ -32,12 +32,12 @@
 namespace tvm {
 namespace relay {
 
-class BiasAddSimplifier : public ExprMutator {
+class BiasAddSimplifier : public ExprRewriter {
  public:
   BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
 
-  Expr VisitExpr_(const CallNode* n) {
-    auto new_n = ExprMutator::VisitExpr_(n);
+  Expr Rewrite_(const CallNode* n, const Expr& post) override {
+    auto new_n = post;
     if (n->op == bias_add_op_) {
       Call call = Downcast<Call>(new_n);
       CHECK_EQ(call->args.size(), 2);
@@ -63,7 +63,8 @@ class BiasAddSimplifier : public ExprMutator {
 };
 
 Expr CanonicalizeOps(const Expr& e) {
-  return BiasAddSimplifier().Mutate(e);
+  auto rewriter = BiasAddSimplifier();
+  return PostOrderRewrite(e, &rewriter);
 }
 
 namespace transform {
diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc
index f4058b2..a0d093f 100644
--- a/src/relay/transforms/dead_code.cc
+++ b/src/relay/transforms/dead_code.cc
@@ -92,7 +92,7 @@ class Eliminator : private ExprMutator {
 };
 
 // calculate the dependency graph from expression
-class CalcDep : private ExprVisitor {
+class CalcDep : protected MixedModeVisitor {
  public:
   static Expr Eliminate(const Expr& e, bool inline_once) {
     FindDef fd;
@@ -104,11 +104,14 @@ class CalcDep : private ExprVisitor {
   }
 
  private:
-  explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
+  explicit CalcDep(const VarMap<Expr>& expr_map)
+      : MixedModeVisitor(2), expr_map_(expr_map) {}
   VarMap<Expr> expr_map_;
   VarMap<size_t> use_map_;
 
-  void VisitExpr(const Expr& e) final {
+  using MixedModeVisitor::VisitExpr_;
+
+  void VisitLeaf(const Expr& e) final {
     visit_counter_[e.get()]++;
     // The dce code seprate variable into three parts:
     // used 0 times (remove)
diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc
index 898f760..861566f 100644
--- a/src/relay/transforms/fast_math.cc
+++ b/src/relay/transforms/fast_math.cc
@@ -31,20 +31,19 @@
 namespace tvm {
 namespace relay {
 
-class FastMathMutator : public ExprMutator {
+class FastMathMutator : public ExprRewriter {
  public:
   FastMathMutator()
       : exp_op_(Op::Get("exp")),
         tanh_op_(Op::Get("tanh")) {}
 
-  Expr VisitExpr_(const CallNode* n) {
-    auto new_n = ExprMutator::VisitExpr_(n);
-    if (n->op == exp_op_) {
-      return FastExp(new_n.as<CallNode>()->args[0]);
-    } else if (n->op == tanh_op_) {
-      return FastTanh(new_n.as<CallNode>()->args[0]);
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    if (pre->op == exp_op_) {
+      return FastExp(post.as<CallNode>()->args[0]);
+    } else if (pre->op == tanh_op_) {
+      return FastTanh(post.as<CallNode>()->args[0]);
     }
-    return new_n;
+    return post;
   }
 
  private:
@@ -56,7 +55,8 @@ class FastMathMutator : public ExprMutator {
 };
 
 Expr FastMath(const Expr& e) {
-  return FastMathMutator().Mutate(e);
+  auto rewriter = FastMathMutator();
+  return PostOrderRewrite(e, &rewriter);
 }
 
 namespace transform {
diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc
index 1d9d2b6..f01c4fa 100644
--- a/src/relay/transforms/forward_rewrite.cc
+++ b/src/relay/transforms/forward_rewrite.cc
@@ -22,6 +22,7 @@
  * \file forward_rewrite.cc
  * \brief Apply rewriting rules in a forward fashion.
  */
+#include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
@@ -33,32 +34,25 @@ namespace relay {
 // Realizer class that realizes the expression
 // Note that we can take benefit of its internal memo
 // so that calling realize repeatively won't hurt perf.
-class TempRealizer : private ExprMutator {
+class TempRealizer : private MixedModeMutator {
  public:
   Expr Realize(Expr expr) {
-    return VisitExpr(expr);
+    return Mutate(expr);
   }
 
  private:
-  Expr VisitExpr(const Expr& expr) final {
-    auto it = memo_.find(expr);
-    if (it != memo_.end()) {
-      return it->second;
+  Expr DispatchVisitExpr(const Expr& expr) final {
+    Expr res;
+    if (const auto* temp = expr.as<TempExprNode>()) {
+      res = temp->Realize();
     } else {
-      Expr res;
-      if (const auto* temp = expr.as<TempExprNode>()) {
-        res = temp->Realize();
-
-      } else {
-        res = ExprFunctor::VisitExpr(expr);
-      }
-      memo_[res] = res;
-      return res;
+      res = MixedModeMutator::DispatchVisitExpr(expr);
     }
+    return res;
   }
 };
 
-class ForwardRewriter : private ExprMutator {
+class ForwardRewriter : private MixedModeMutator {
  public:
   ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
                   std::function<ObjectRef(const Call&)> fcontext,
@@ -76,11 +70,11 @@ class ForwardRewriter : private ExprMutator {
 
 
   // Transform expression.
-  Expr Rewrite(Expr expr) {
+  Expr Rewrite(const Expr& expr) {
     if (fmulti_ref_trigger_ != nullptr) {
       ref_counter_ = GetExprRefCount(expr);
     }
-    return this->VisitExpr(expr);
+    return realizer_.Realize(this->VisitExpr(expr));
   }
 
  private:
@@ -96,15 +90,10 @@ class ForwardRewriter : private ExprMutator {
   // internal realizer
   TempRealizer realizer_;
 
-  Expr VisitExpr(const Expr& expr) final {
-    // by default always realize.
-    return realizer_.Realize(ExprMutator::VisitExpr(expr));
-  }
-
   // Visit and allow non-realized version.
-  Expr GetTempExpr(const Expr& expr)  {
+  Expr GetTempExpr(const Expr& expr, const Expr& post)  {
     if (fmulti_ref_trigger_ != nullptr) {
-      Expr ret = ExprMutator::VisitExpr(expr);
+      Expr ret = post;
       auto it = ref_counter_.find(expr.get());
       CHECK(it != ref_counter_.end());
       if (it->second > 1) {
@@ -112,13 +101,13 @@ class ForwardRewriter : private ExprMutator {
       }
       return ret;
     } else {
-      return ExprMutator::VisitExpr(expr);
+      return post;
     }
   }
 
   // Automatic fold TupleGetItem.
-  Expr VisitExpr_(const TupleGetItemNode* op) final {
-    Expr tuple = this->GetTempExpr(op->tuple);
+  Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
+    Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple);
     if (const auto* ptuple = tuple.as<TupleNode>()) {
       return ptuple->fields[op->index];
     } else {
@@ -130,13 +119,14 @@ class ForwardRewriter : private ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const TupleNode* op) final {
+  Expr Rewrite_(const TupleNode* op, const Expr& post) final {
     tvm::Array<Expr> fields;
     bool all_fields_unchanged = true;
-    for (auto field : op->fields) {
-      auto new_field = this->GetTempExpr(field);
+    const auto* post_node = post.as<TupleNode>();
+    for (size_t i = 0; i < op->fields.size(); ++i) {
+      auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
       fields.push_back(new_field);
-      all_fields_unchanged &= new_field.same_as(field);
+      all_fields_unchanged &= new_field.same_as(op->fields[i]);
     }
 
     if (all_fields_unchanged) {
@@ -146,7 +136,7 @@ class ForwardRewriter : private ExprMutator {
     }
   }
 
-  Expr VisitExpr_(const CallNode* call_node) final {
+  Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
     const Call& ref_call = GetRef<Call>(call_node);
     PackedFunc frewrite;
     if (rewrite_func_) {
@@ -155,17 +145,17 @@ class ForwardRewriter : private ExprMutator {
       CHECK(rewrite_map_);
       frewrite = rewrite_map_->get(call_node->op, nullptr);
     }
-
-    auto new_op = this->Mutate(call_node->op);
+    const auto* post_node = post.as<CallNode>();
+    auto new_op = post_node->op;
     bool unchanged = call_node->op.same_as(new_op);
 
     Array<Expr> call_args;
-    for (auto arg : call_node->args) {
-      Expr new_arg = this->GetTempExpr(arg);
+    for (size_t i = 0; i < call_node->args.size(); ++i) {
+      Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]);
       if (frewrite == nullptr) {
         new_arg = realizer_.Realize(new_arg);
       }
-      unchanged &= new_arg.same_as(arg);
+      unchanged &= new_arg.same_as(call_node->args[i]);
       call_args.push_back(new_arg);
     }
     // try to rewrite.
diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h
index 6a69cf9..56b0645 100644
--- a/src/relay/transforms/pass_util.h
+++ b/src/relay/transforms/pass_util.h
@@ -35,14 +35,6 @@ namespace tvm {
 namespace relay {
 
 /*!
- * \brief Get reference counter of each internal ExprNode in body.
- * \param body The body expression.
- * \return The reference count mapping.
- */
-std::unordered_map<const Object*, size_t>
-GetExprRefCount(const Expr& body);
-
-/*!
  * \brief Check if expr is positive constant.
  * \param expr The expression to be checked.
  * \return Whether all elements of expr is positive constant.
diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc
index fa94271..f5658fb 100644
--- a/tests/cpp/relay_build_module_test.cc
+++ b/tests/cpp/relay_build_module_test.cc
@@ -161,6 +161,23 @@ TEST(Relay, BuildModule) {
   }
 }
 
+TEST(Relay, GetExprRefCount) {
+  auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
+  auto a = relay::Var("a", tensor_type);
+  auto add_op = relay::Op::Get("add");
+  auto relu_op = relay::Op::Get("nn.relu");
+  auto x = relay::Call(relu_op, {a}, tvm::Attrs(), {});
+  auto y = relay::Call(relu_op, {x}, tvm::Attrs(), {});
+  auto z = relay::Call(add_op, {y, x}, tvm::Attrs(), {});
+  auto ref_count = GetExprRefCount(z);
+  CHECK(ref_count[a.get()] == 1);
+  CHECK(ref_count[relu_op.get()] == 2);
+  CHECK(ref_count[add_op.get()] == 1);
+  CHECK(ref_count[x.get()] == 2);
+  CHECK(ref_count[y.get()] == 1);
+  CHECK(ref_count[z.get()] == 1);
+}
+
 int main(int argc, char ** argv) {
   testing::InitGoogleTest(&argc, argv);
   testing::FLAGS_gtest_death_test_style = "threadsafe";


Mime
View raw message