tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #4886: [RELAY] Non-recursive Graph Vistor and Rewriter
Date Fri, 03 Apr 2020 17:16:54 GMT
icemelon9 commented on a change in pull request #4886: [RELAY] Non-recursive Graph Vistor and
Rewriter
URL: https://github.com/apache/incubator-tvm/pull/4886#discussion_r403160592
 
 

 ##########
 File path: include/tvm/relay/expr_functor.h
 ##########
 @@ -232,6 +232,188 @@ class ExprMutator
   std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
 };
 
+/*!
+ * \brief A wrapper around ExprVisitor which traverses the Dataflow Normal AST.
+ *
+ * MixedModeVisitor treats Expr as dataflow graph, and visits in post-DFS order
+ *
+ * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow
regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ */
+class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
+ public:
+  /*! \brief The constructor of MixedModeVisitor
+   *  \param visit_limit The number of times to allow visitation to a note. Usually 1, ocassionally
+   * higher (i.e., 2 for dead code elimiation), limited to 10 as a sanity check.
+   */
+  MixedModeVisitor(int visit_limit = 1);
+
+  /*!
+   * \brief VisitExpr is finalized to preserve call expansion of dataflow regions
+   */
+  void VisitExpr(const Expr& expr) final;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const TupleNode* op) override;
+  void VisitExpr_(const TupleGetItemNode* op) override;
+
+ protected:
+  /*!
+   * \brief A function to apply when reaching a leaf of the graph non-recursively
+   */
+  virtual void VisitLeaf(const Expr& expr);
+  /*!
+   * \brief A function to determine if an expression has already been visited or needs to
be
+   * re-visited
+   */
+  virtual bool CheckVisited(const Expr& expr);
+  /*!
+   * \brief The max number of times to visit a node
+   */
+  size_t visit_limit_;
+};
+
+/*! \brief Non-recursive DFS Graph Traversal for Custom Rewriting Passes
+ *
+ * MixedModeMutator treats Expr as dataflow graph, and only Rewrites each Expr once.
+ * The mutated results are memoized in a map and reused so that
+ * local transformation on the dataflow preserves the graph structure.
+ *
+ * MixedModeMutator provides the same recursive API as ExprMutator, and uses
+ * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow
regions
+ * of the graph and processes them iteratatively to prevent stack overflows
+ *
+ * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive
behavior.
+ */
+class MixedModeMutator : public ::tvm::relay::ExprMutator {
+ public:
+  Expr VisitExpr(const Expr& expr) final;
+  virtual Expr DispatchVisitExpr(const Expr& expr);
+  Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
+  Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
+  Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
+  /*!
+   *  Users should override Rewrite_ methods to implement their pass. Rewrite_ functions
will be
+   * able to rewrite the op only with data about the original node `pre` and the same node
with
+   * modified inputs `post` and should not recurse.
+   *
+   * \param pre The expression node before rewriting.
+   * \param post The expression with rewritten inputs.
+   */
+  virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;}
+  virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
+  virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post;
}
+
+ protected:
+  /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post`
node with
+   * changed inputs.
+   */
+  template <typename T>
+  Expr Rewrite(const T* op) {
+    Expr post = ExprMutator::VisitExpr_(op);
+    return Rewrite_(op, post);
+  }
+
+  virtual void VisitLeaf(const Expr& expr);
+  virtual bool CheckVisited(const Expr& expr);
+};
+
+/*! \brief A non-iterating Expression Rewriter
+ *
+ *  ExprRewriter provides a Rewrite interface for modifying graphs in Post-DFS order.
+ *  The expectation is that ExprRewriter objects will be passed to PostOrderRewrite, which
will
+ * non-recursively unroll the graph and call Rewriting on inputs. It will then pass the original
+ * node, called `pre`, and a node recreated with any alterned inputs, called `post`, to the
+ * ExprRewriter. The ExprRewriter can then use the information in those two nodes to do more
complex
+ * graph rewriting.
+ */
 
 Review comment:
   move this document to directly on top of `class ExprRewriter`? also fix the indent starting
from line 325

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message