tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)
Date Sat, 28 Mar 2020 05:21:06 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 997a14e  [NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)
997a14e is described below

commit 997a14eda9aec3b343e742e55c3018f9dc23d8c3
Author: Tianqi Chen <tqchen@users.noreply.github.com>
AuthorDate: Fri Mar 27 22:21:00 2020 -0700

    [NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)
    
    * [NODE][IR] Introduce StructuralEqual Infra for the Unified IR.
    
    This PR introduces a new way to handle structural equality
    for both TIR and relay nodes in an extensive way.
    
    - Each object can now register an optional SEqualReduce function, which
      describes how to reduce its structural equality to another instance
      into equality of the children.
    - Optionally, the object can choose to allow remapping of vars(e.g. function parameters)
      by calling DefEqual
    - We implemented a non-recursive structural equality checker that
      recursively traverses the objects and does the structural equality checking.
    
    This PR also fixes a few potential problems in previous relay's AlphaEqual.
    
    - In particular, the new structural equality relation will be communicative.
    - It is can be dangerous to use same_as relation to quickly check equality,
      demonstrated by the following case. (%x, %y) are shared vars between two functions.
    
    - function0: fn (%x, %y) { %x + %y }
    - function1: fn (%y, %x) { %x + %y }
    
    The new structural equal is intented to supersede AlphaEqual and AttrsEqual.
    
    Follow-up PRs should be performed to redirect the existing usages, and removes
    the corresponding implementation.
    
    * Update the rule to distinguish between graph node and non-graph nodes.
    
    * Refactor the test cases to use structural equal.
    
    * address comments
    
    * Mark more relay::Expr as graph node, fix a testcase issue(was bug that was not caught by previous alpha equal)
    
    * Remove unrelated comment
    
    * Fix file comment
    
    * Address review comment
    
    * Relax condition to fit flaky case
---
 include/tvm/arith/analyzer.h                       |   8 +
 include/tvm/arith/int_set.h                        |   1 +
 include/tvm/ir/adt.h                               |  15 ++
 include/tvm/ir/attrs.h                             |  41 +++
 include/tvm/ir/env_func.h                          |   5 +
 include/tvm/ir/expr.h                              |  21 ++
 include/tvm/ir/module.h                            |   3 +
 include/tvm/ir/op.h                                |   5 +
 include/tvm/ir/span.h                              |  11 +
 include/tvm/ir/tensor_type.h                       |   6 +
 include/tvm/ir/transform.h                         |   2 +
 include/tvm/ir/type.h                              |  43 ++++
 include/tvm/ir/type_relation.h                     |  14 ++
 include/tvm/node/container.h                       |  20 +-
 include/tvm/node/node.h                            |   2 +
 include/tvm/node/reflection.h                      | 148 +++++++++--
 include/tvm/node/structural_equal.h                | 225 +++++++++++++++++
 include/tvm/relay/adt.h                            |  32 +++
 include/tvm/relay/expr.h                           |  71 ++++++
 include/tvm/relay/function.h                       |  11 +
 include/tvm/runtime/ndarray.h                      |  22 ++
 include/tvm/runtime/object.h                       |   4 +
 include/tvm/tir/buffer.h                           |  15 ++
 include/tvm/tir/expr.h                             | 153 ++++++++++-
 include/tvm/tir/function.h                         |  11 +
 include/tvm/tir/stmt.h                             | 103 ++++++++
 python/tvm/ir/__init__.py                          |   1 +
 python/tvm/ir/base.py                              |  73 ++++++
 src/ir/attr_functor.h                              |   4 +-
 src/ir/expr.cc                                     |   8 +-
 src/ir/module.cc                                   |  19 +-
 src/node/container.cc                              | 140 +++++++++++
 src/node/reflection.cc                             |   2 +-
 src/node/structural_equal.cc                       | 241 ++++++++++++++++++
 src/tir/ir/expr.cc                                 |  18 +-
 tests/python/frontend/tensorflow/test_forward.py   |   2 +-
 tests/python/relay/test_ir_parser.py               | 109 ++++----
 ..._alpha_equal.py => test_ir_structural_equal.py} | 280 +++++++++++----------
 .../relay/test_pass_dead_code_elimination.py       |  14 +-
 tests/python/relay/test_pass_partial_eval.py       |  26 +-
 tests/python/relay/test_pass_qnn_legalize.py       |   8 +-
 tests/python/relay/test_pass_to_a_normal_form.py   |   4 +-
 tests/python/relay/test_pass_to_cps.py             |   2 +-
 tests/python/relay/test_type_infer.py              |   3 +-
 tests/python/unittest/test_node_reflection.py      |   4 +-
 tests/python/unittest/test_tir_structural_equal.py | 102 ++++++++
 46 files changed, 1781 insertions(+), 271 deletions(-)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 31f2216..e7f5ede 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object {
     v->Visit("max_value", &max_value);
   }
 
+  bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
+    return equal(min_value, other->min_value) && equal(max_value, other->max_value);
+  }
+
   /*! \brief Number to represent +inf */
   static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
   /*!
@@ -170,6 +174,10 @@ class ModularSetNode : public Object {
     v->Visit("base", &base);
   }
 
+  bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
+    return equal(coeff, other->coeff) && equal(base, other->base);
+  }
+
   static constexpr const char* _type_key = "arith.ModularSet";
   TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
 };
diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h
index 8b73f87..86ef906 100644
--- a/include/tvm/arith/int_set.h
+++ b/include/tvm/arith/int_set.h
@@ -59,6 +59,7 @@ enum SignType {
 class IntSetNode : public Object {
  public:
   static constexpr const char* _type_key = "IntSet";
+  static constexpr bool _type_has_method_sequal_reduce = false;
   TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
 };
 
diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h
index 67cfb8d..2601614 100644
--- a/include/tvm/ir/adt.h
+++ b/include/tvm/ir/adt.h
@@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
+    // Use namehint for now to be consistent with the legacy relay impl
+    // TODO(tvm-team) revisit, need to check the type var.
+    return
+        equal(name_hint, other->name_hint) &&
+        equal(inputs, other->inputs);
+  }
+
   static constexpr const char* _type_key = "relay.Constructor";
   TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
 };
@@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
+    return
+        equal.DefEqual(header, other->header) &&
+        equal.DefEqual(type_vars, other->type_vars) &&
+        equal(constructors, other->constructors);
+  }
+
   static constexpr const char* _type_key = "relay.TypeData";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
 };
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 4413fc3..c3b5831 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object {
     v->Visit("type_info", &type_info);
     v->Visit("description", &description);
   }
+
   static constexpr const char* _type_key = "AttrFieldInfo";
+  static constexpr bool _type_has_method_sequal_reduce = false;
   TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
 };
 
@@ -278,6 +280,7 @@ class BaseAttrsNode : public Object {
    */
   TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
 
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const char* _type_key = "Attrs";
   TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
 };
@@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode {
   /*! \brief internal attrs map */
   Map<std::string, ObjectRef> dict;
 
+  bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
+    return equal(dict, other->dict);
+  }
+
   // implementations
   void VisitAttrs(AttrVisitor* v) final;
   void VisitNonDefaultAttrs(AttrVisitor* v) final;
@@ -401,6 +408,33 @@ class AttrsEqualVisitor {
   const AttrsEqual& equal_;
 };
 
+class AttrsSEqualVisitor {
+ public:
+  bool result_{true};
+  // constructor
+  AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
+      : lhs_(lhs), rhs_(rhs), equal_(equal) {
+  }
+  template<typename T>
+  AttrNopEntry operator()(const char* key, T* lhs_value) {
+    if (!result_) return AttrNopEntry();
+    const T* rhs_value =
+        reinterpret_cast<const T*>(
+            reinterpret_cast<const char*>(rhs_) +
+            (reinterpret_cast<const char*>(lhs_value) -
+             reinterpret_cast<const char*>(lhs_)));
+    if (!equal_(*lhs_value, *rhs_value)) {
+      result_ = false;
+    }
+    return AttrNopEntry();
+  }
+
+ private:
+  const Object* lhs_;
+  const Object* rhs_;
+  const SEqualReducer& equal_;
+};
+
 class AttrsHashVisitor {
  public:
   explicit AttrsHashVisitor(const AttrsHash& hasher)
@@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode {
     }
   }
 
+  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
+    DerivedType* pself = self();
+    ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
+    self()->__VisitAttrs__(visitor);
+    return visitor.result_;
+  }
+
   Array<AttrFieldInfo> ListFieldInfo() const final {
     ::tvm::detail::AttrDocVisitor visitor;
     self()->__VisitAttrs__(visitor);
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index f5b17bb..1064fd1 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -51,7 +51,12 @@ class EnvFuncNode : public Object {
     v->Visit("name", &name);
   }
 
+  bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
+    return this == other;
+  }
+
   static constexpr const char* _type_key = "EnvFunc";
+  static constexpr bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
 };
 
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 44244df..fc63da0 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -43,6 +43,7 @@ namespace tvm {
 class BaseExprNode : public Object {
  public:
   static constexpr const char* _type_key = "Expr";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
 };
 
@@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
+    // name matters for global var.
+    return
+        equal(name_hint, other->name_hint) &&
+        equal.FreeVarEqualImpl(this, other);
+  }
+
   static constexpr const char* _type_key = "GlobalVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
 };
@@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode {
     v->Visit("value", &value);
   }
 
+  bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
+    return equal(dtype, other->dtype) && equal(value, other->value);
+  }
+
   static constexpr const char* _type_key = "IntImm";
   TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
 };
@@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode {
     v->Visit("value", &value);
   }
 
+  bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
+    return equal(dtype, other->dtype) && equal(value, other->value);
+  }
+
   static constexpr const char* _type_key = "FloatImm";
   TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
 };
@@ -353,7 +369,12 @@ class RangeNode : public Object {
     v->Visit("extent", &extent);
   }
 
+  bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
+    return equal(min, other->min) && equal(extent, other->extent);
+  }
+
   static constexpr const char* _type_key = "Range";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
 };
 
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 4613bec..38e583d 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -62,6 +62,8 @@ class IRModuleNode : public Object {
     v->Visit("global_type_var_map_", &global_type_var_map_);
   }
 
+  TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
+
   /*!
    * \brief Add a function to the global environment.
    * \param var The var of the global function.
@@ -235,6 +237,7 @@ class IRModuleNode : public Object {
   TVM_DLL std::unordered_set<std::string> Imports() const;
 
   static constexpr const char* _type_key = "IRModule";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
 
  private:
diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h
index 8a6ab77..f023e87 100644
--- a/include/tvm/ir/op.h
+++ b/include/tvm/ir/op.h
@@ -101,6 +101,11 @@ class OpNode : public RelayExprNode {
     v->Visit("support_level", &support_level);
   }
 
+  bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
+    // pointer equality is fine as there is only one op with the same name.
+    return this == other;
+  }
+
   /*!
    * \brief Check that if current op is a "primtive operator".
    * That is the arguments are all type variables, and there is a single
diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h
index 4720dfe..7194e90 100644
--- a/include/tvm/ir/span.h
+++ b/include/tvm/ir/span.h
@@ -44,6 +44,10 @@ class SourceNameNode : public Object {
   // override attr visitor
   void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
 
+  bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const {
+    return equal(name, other->name);
+  }
+
   static constexpr const char* _type_key = "SourceName";
   TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
 };
@@ -87,6 +91,13 @@ class SpanNode : public Object {
     v->Visit("col_offset", &col_offset);
   }
 
+  bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
+    return
+        equal(source, other->source) &&
+        equal(lineno, other->lineno) &&
+        equal(col_offset, other->col_offset);
+  }
+
   TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
 
   static constexpr const char* _type_key = "Span";
diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h
index 70a2df1..05c7a95 100644
--- a/include/tvm/ir/tensor_type.h
+++ b/include/tvm/ir/tensor_type.h
@@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
+    return
+        equal(shape, other->shape) &&
+        equal(dtype, other->dtype);
+  }
+
   /*! \brief Return product of elements in the shape.
    *  \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if shape size is zero.
    */
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 1b6ea25..ecd234a 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -111,6 +111,7 @@ class PassContextNode : public Object {
   }
 
   static constexpr const char* _type_key = "transform.PassContext";
+  static constexpr bool _type_has_method_sequal_reduce = false;
   TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
 };
 
@@ -207,6 +208,7 @@ class PassInfoNode : public Object {
   }
 
   static constexpr const char* _type_key = "transform.PassInfo";
+  static constexpr bool _type_has_method_sequal_reduce = false;
   TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
 };
 
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index c23626e..dd70029 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -79,6 +79,7 @@ class TypeNode : public Object {
   mutable Span span;
 
   static constexpr const char* _type_key = "Type";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
 };
 
@@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode {
     v->Visit("dtype", &dtype);
   }
 
+  bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
+    return equal(dtype, other->dtype);
+  }
+
   static constexpr const char* _type_key = "PrimType";
   TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
 };
@@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode {
     v->Visit("element_type", &element_type);
   }
 
+  bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
+    return equal(element_type, other->element_type);
+  }
+
   static constexpr const char* _type_key = "PointerType";
   TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
 };
@@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
+    return
+        equal(kind, other->kind) &&
+        equal.FreeVarEqualImpl(this, other);
+  }
+
   static constexpr const char* _type_key = "TypeVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
 };
@@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode {
     v->Visit("kind", &kind);
   }
 
+  bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const {
+    // name matters for now in global type var.
+    return
+        equal(name_hint, other->name_hint) &&
+        equal.FreeVarEqualImpl(this, other);
+  }
+
   static constexpr const char* _type_key = "GlobalTypeVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
 };
@@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
+    return equal(fields, other->fields);
+  }
+
   static constexpr const char* _type_key = "TupleType";
   TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
 };
@@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
+    // type params first as they defines type vars.
+    return
+        equal.DefEqual(type_params, other->type_params) &&
+        equal(arg_types, other->arg_types) &&
+        equal(ret_type, other->ret_type) &&
+        equal(type_constraints, other->type_constraints);
+  }
+
   static constexpr const char* _type_key = "FuncType";
   TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
 };
@@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
+    return equal(kind, other->kind);
+  }
+
   static constexpr const char* _type_key = "IncompleteType";
   TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
 };
@@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
+    return equal(value, other->value);
+  }
+
   // Keep the relay prefix in the type as this type is specific
   // to the relay itself.
   static constexpr const char* _type_key = "relay.RefType";
diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h
index f7bfb68..592bf25 100644
--- a/include/tvm/ir/type_relation.h
+++ b/include/tvm/ir/type_relation.h
@@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
+    return
+        equal(func, other->func) &&
+        equal(args, other->args);
+  }
+
   static constexpr const char* _type_key = "TypeCall";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
 };
@@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
+    return
+        equal(func, other->func) &&
+        equal(args, other->args) &&
+        equal(num_inputs, other->num_inputs) &&
+        equal(attrs, other->attrs);
+  }
+
   static constexpr const char* _type_key = "TypeRelation";
   TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
 };
diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h
index a541385..461fa11 100644
--- a/include/tvm/node/container.h
+++ b/include/tvm/node/container.h
@@ -23,7 +23,9 @@
 #ifndef TVM_NODE_CONTAINER_H_
 #define TVM_NODE_CONTAINER_H_
 
-#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/packed_func.h>
 
 #include <type_traits>
 #include <vector>
@@ -34,15 +36,19 @@
 
 namespace tvm {
 
+using runtime::Object;
+using runtime::ObjectPtr;
+using runtime::ObjectRef;
+using runtime::make_object;
+using runtime::ObjectHash;
+using runtime::ObjectEqual;
+
 /*! \brief array node content in array */
 class ArrayNode : public Object {
  public:
   /*! \brief the data content */
   std::vector<ObjectRef> data;
 
-  void VisitAttrs(AttrVisitor* visitor) {
-  }
-
   static constexpr const char* _type_key = "Array";
   TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
 };
@@ -50,9 +56,6 @@ class ArrayNode : public Object {
 /*! \brief map node content */
 class MapNode : public Object {
  public:
-  void VisitAttrs(AttrVisitor* visitor) {
-  }
-
   /*! \brief The corresponding conatiner type */
   using ContainerType = std::unordered_map<
     ObjectRef,
@@ -73,9 +76,6 @@ class StrMapNode : public Object {
   /*! \brief The corresponding conatiner type */
   using ContainerType = std::unordered_map<std::string, ObjectRef>;
 
-  void VisitAttrs(AttrVisitor* visitor) {
-  }
-
   /*! \brief the data content */
   ContainerType data;
 
diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
index 3ea3d76..76e574b 100644
--- a/include/tvm/node/node.h
+++ b/include/tvm/node/node.h
@@ -39,6 +39,8 @@
 #include <tvm/runtime/memory.h>
 #include <tvm/node/reflection.h>
 #include <tvm/node/repr_printer.h>
+#include <tvm/node/container.h>
+#include <tvm/node/structural_equal.h>
 
 #include <string>
 #include <vector>
diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
index daffeb8..d0a9304 100644
--- a/include/tvm/node/reflection.h
+++ b/include/tvm/node/reflection.h
@@ -29,13 +29,14 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/data_type.h>
+#include <tvm/node/structural_equal.h>
 
 #include <vector>
 #include <string>
+#include <type_traits>
 
 namespace tvm {
 
-// forward declaration
 using runtime::Object;
 using runtime::ObjectPtr;
 using runtime::ObjectRef;
@@ -87,6 +88,13 @@ class ReflectionVTable {
    */
   typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
   /*!
+   * \brief Equality comparison function.
+   * \note We use function pointer, instead of std::function
+   *       to reduce the dispatch overhead as field visit
+   *       does not need as much customization.
+   */
+  typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
+  /*!
    * \brief creator function.
    * \param global_key Key that identifies a global single object.
    *        If this is not empty then FGlobalKey must be defined for the object.
@@ -112,6 +120,14 @@ class ReflectionVTable {
    */
   inline std::string GetGlobalKey(Object* self) const;
   /*!
+   * \brief Dispatch the SEqualReduce function.
+   * \param self The pointer to the object.
+   * \param other The pointer to another object to be compared.
+   * \param equal The equality comparator.
+   * \return the result.
+   */
+  bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
+  /*!
    * \brief Create an initial object using default constructor
    *        by type_key and global key.
    *
@@ -139,12 +155,14 @@ class ReflectionVTable {
   TVM_DLL static ReflectionVTable* Global();
 
   class Registry;
-  template<typename T>
+  template<typename T, typename TraitName>
   inline Registry Register();
 
  private:
   /*! \brief Attribute visitor. */
   std::vector<FVisitAttrs> fvisit_attrs_;
+  /*! \brief Structural equal function. */
+  std::vector<FSEqualReduce> fsequal_;
   /*! \brief Creation function. */
   std::vector<FCreate> fcreate_;
   /*! \brief Global key function. */
@@ -182,6 +200,44 @@ class ReflectionVTable::Registry {
   uint32_t type_index_;
 };
 
+
+#define TVM_REFLECTION_REG_VAR_DEF                                     \
+  static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry        \
+  __make_reflectiion
+
+/*!
+ * \brief Directly register reflection VTable.
+ * \param TypeName The name of the type.
+ * \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce.
+ *
+ * \code
+ *
+ *  // Example SEQualReduce traits for runtime StringObj.
+ *
+ *  struct StringObjTrait {
+ *     static constexpr const std::nullptr_t VisitAttrs = nullptr;
+ *
+ *    static bool SEqualReduce(const runtime::StringObj* lhs,
+ *                             const runtime::StringObj* rhs,
+ *                             SEqualReducer equal) {
+ *      if (lhs == rhs) return true;
+ *      if (lhs->size != rhs->size) return false;
+ *      if (lhs->data != rhs->data) return true;
+ *      return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
+ *    }
+ *  };
+ *
+ *  TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
+ *
+ * \endcode
+ *
+ * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE.
+ *       And can be used to register the related reflection functions for runtime objects.
+ */
+#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName)             \
+  TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) =             \
+      ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() \
+
 /*!
  * \brief Register a node type to object registry and reflection registry.
  * \param TypeName The name of the type.
@@ -189,15 +245,79 @@ class ReflectionVTable::Registry {
  */
 #define TVM_REGISTER_NODE_TYPE(TypeName)                                \
   TVM_REGISTER_OBJECT_TYPE(TypeName);                                   \
-  static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry &      \
-  __make_Node ## _ ## TypeName ## __ =                                  \
-      ::tvm::ReflectionVTable::Global()->Register<TypeName>()           \
-      .set_creator([](const std::string&) -> ObjectPtr<Object> {        \
-          return ::tvm::runtime::make_object<TypeName>();               \
-        })
+  TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
+  .set_creator([](const std::string&) -> ObjectPtr<Object> {            \
+      return ::tvm::runtime::make_object<TypeName>();                   \
+    })
+
 
 // Implementation details
+namespace detail {
+
+template<typename T,
+         bool = T::_type_has_method_visit_attrs>
+struct ImplVisitAttrs {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+};
+
+template<typename T>
+struct ImplVisitAttrs<T, true> {
+  static void VisitAttrs(T* self, AttrVisitor* v) {
+    self->VisitAttrs(v);
+  }
+};
+
+template<typename T,
+         bool = T::_type_has_method_sequal_reduce>
+struct ImplSEqualReduce {
+  static constexpr const std::nullptr_t SEqualReduce = nullptr;
+};
+
+template<typename T>
+struct ImplSEqualReduce<T, true> {
+  static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
+    return self->SEqualReduce(other, equal);
+  }
+};
+
 template<typename T>
+struct ReflectionTrait :
+      public ImplVisitAttrs<T>,
+      public ImplSEqualReduce<T> {
+};
+
+template<typename T, typename TraitName,
+         bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
+struct SelectVisitAttrs {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+};
+
+template<typename T, typename TraitName>
+struct SelectVisitAttrs<T, TraitName, false> {
+  static void VisitAttrs(Object* self, AttrVisitor* v) {
+    TraitName::VisitAttrs(static_cast<T*>(self), v);
+  }
+};
+
+template<typename T, typename TraitName,
+         bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
+struct SelectSEqualReduce {
+  static constexpr const std::nullptr_t SEqualReduce = nullptr;
+};
+
+template<typename T, typename TraitName>
+struct SelectSEqualReduce<T, TraitName, false> {
+  static bool SEqualReduce(const Object* self,
+                           const Object* other,
+                           SEqualReducer equal) {
+    return TraitName::SEqualReduce(static_cast<const T*>(self),
+                                   static_cast<const T*>(other),
+                                   equal);
+  }
+};
+}  // namespace detail
+
+template<typename T, typename TraitName>
 inline ReflectionVTable::Registry
 ReflectionVTable::Register() {
   uint32_t tindex = T::RuntimeTypeIndex();
@@ -205,15 +325,15 @@ ReflectionVTable::Register() {
     fvisit_attrs_.resize(tindex + 1, nullptr);
     fcreate_.resize(tindex + 1, nullptr);
     fglobal_key_.resize(tindex + 1, nullptr);
+    fsequal_.resize(tindex + 1, nullptr);
   }
   // functor that implemnts the redirection.
-  struct Functor {
-    static void VisitAttrs(Object* self, AttrVisitor* v) {
-      static_cast<T*>(self)->VisitAttrs(v);
-     }
-  };
+  fvisit_attrs_[tindex] =
+      ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
+
+  fsequal_[tindex] =
+      ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
 
-  fvisit_attrs_[tindex] = Functor::VisitAttrs;
   return Registry(this, tindex);
 }
 
diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h
new file mode 100644
index 0000000..f719e24
--- /dev/null
+++ b/include/tvm/node/structural_equal.h
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/structural_equal.h
+ * \brief Structural equality comparison.
+ */
+#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
+#define TVM_NODE_STRUCTURAL_EQUAL_H_
+
+#include <tvm/runtime/data_type.h>
+#include <tvm/node/functor.h>
+#include <tvm/node/container.h>
+#include <string>
+
+namespace tvm {
+
+/*!
+ * \brief Equality definition of base value class.
+ */
+class BaseValueEqual {
+ public:
+  bool operator()(const double& lhs, const double& rhs) const {
+    // fuzzy float pt comparison
+    constexpr double atol = 1e-9;
+    if (lhs == rhs) return true;
+    double diff = lhs - rhs;
+    return diff > -atol && diff < atol;
+  }
+
+  bool operator()(const int64_t& lhs, const int64_t& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const int& lhs, const int& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const bool& lhs, const bool& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const std::string& lhs, const std::string& rhs) const {
+    return lhs == rhs;
+  }
+  bool operator()(const DataType& lhs, const DataType& rhs) const {
+    return lhs == rhs;
+  }
+  template<typename ENum,
+           typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+  bool operator()(const ENum& lhs, const ENum& rhs) const {
+    return lhs == rhs;
+  }
+};
+
+/*!
+ * \brief Content-aware structural equality comparator for objects.
+ *
+ *  The structural equality is recursively defined in the DAG of IR nodes via SEqual.
+ *  There are two kinds of nodes:
+ *
+ *  - Graph node: a graph node in lhs can only be mapped as equal to
+ *    one and only one graph node in rhs.
+ *  - Normal node: equality is recursively defined without the restriction
+ *    of graph nodes.
+ *
+ *  Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
+ *  For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
+ *  to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
+ *
+ *  A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
+ *  with the same type if one of the following condition holds:
+ *
+ *  - They appear in a same definition point(e.g. function argument).
+ *  - They points to the same VarNode via the same_as relation.
+ *  - They appear in a same usage point, and map_free_vars is set to be True.
+ */
+class StructuralEqual : public BaseValueEqual {
+ public:
+  // inheritate operator()
+  using BaseValueEqual::operator();
+  /*!
+   * \brief Compare objects via strutural equal.
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return The comparison result.
+   */
+  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
+};
+
+/*!
+ * \brief A Reducer class to reduce the structural equality result of two objects.
+ *
+ * The reducer will call the SEqualReduce function of each objects recursively.
+ * Importantly, the reducer may not directly use recursive calls to resolve the
+ * equality checking. Instead, it can store the necessary equality conditions
+ * and check later via an internally managed stack.
+ */
+class SEqualReducer : public BaseValueEqual {
+ public:
+  /*! \brief Internal handler that defines custom behaviors.. */
+  class Handler {
+   public:
+    /*!
+     * \brief Reduce condition to equality of lhs and rhs.
+     *
+     * \param lhs The left operand.
+     * \param rhs The right operand.
+     * \param map_free_vars Whether do we allow remap variables if possible.
+     *
+     * \return false if there is an immediate failure, true otherwise.
+     * \note This function may save the equality condition of (lhs == rhs) in an internal
+     *       stack and try to resolve later.
+     */
+    virtual bool SEqualReduce(const ObjectRef& lhs,
+                              const ObjectRef& rhs,
+                              bool map_free_vars) = 0;
+    /*!
+     * \brief Lookup the graph node equal map for vars that are already mapped.
+     *
+     *  This is an auxiliary method to check the Map<Var, Value> equality.
+     * \param lhs an lhs value.
+     *
+     * \return The corresponding rhs value if any, nullptr if not available.
+     */
+    virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
+    /*!
+     * \brief Mark current comparison as graph node equal comparison.
+     */
+    virtual void MarkGraphNode() = 0;
+  };
+
+  using BaseValueEqual::operator();
+
+  /*! \brief default constructor */
+  SEqualReducer() = default;
+  /*!
+   * \brief Constructor with a specific handler.
+   * \param handler The equal handler for objects.
+   * \param map_free_vars Whether or not to map free variables.
+   */
+  explicit SEqualReducer(Handler* handler, bool map_free_vars)
+      : handler_(handler), map_free_vars_(map_free_vars) {}
+  /*!
+   * \brief Reduce condition to comparison of two objects.
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return the immediate check result.
+   */
+  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
+    return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
+  }
+  /*!
+   * \brief Reduce condition to comparison of two definitions,
+   *        where free vars can be mapped.
+   *
+   *  Call this function to compare definition points such as function params
+   *  and var in a let-binding.
+   *
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return the immediate check result.
+   */
+  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
+    return handler_->SEqualReduce(lhs, rhs, true);
+  }
+  /*!
+   * \brief Reduce condition to comparison of two arrays.
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return the immediate check result.
+   */
+  template<typename T>
+  bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
+    // quick specialization for Array to reduce amount of recursion
+    // depth as array comparison is pretty common.
+    if (lhs.size() != rhs.size()) return false;
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      if (!(operator()(lhs[i], rhs[i]))) return false;
+    }
+    return true;
+  }
+  /*!
+   * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return the result.
+   */
+  bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object* rhs) const {
+    // var need to be remapped, so it belongs to graph node.
+    handler_->MarkGraphNode();
+    // We only map free vars if they corresponds to the same address
+    // or map free_var option is set to be true.
+    return lhs == rhs || map_free_vars_;
+  }
+
+  /*! \return Get the internal handler. */
+  Handler* operator->() const {
+    return handler_;
+  }
+
+ private:
+  /*! \brief Internal class pointer. */
+  Handler* handler_;
+  /*! \brief Whether or not to map free vars. */
+  bool map_free_vars_;
+};
+
+}  // namespace tvm
+#endif  // TVM_NODE_STRUCTURAL_EQUAL_H_
diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h
index 8189b21..ea13e25 100644
--- a/include/tvm/relay/adt.h
+++ b/include/tvm/relay/adt.h
@@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode;
 class PatternNode : public RelayNode {
  public:
   static constexpr const char* _type_key = "relay.Pattern";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
 };
 
@@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const {
+    return true;
+  }
+
   static constexpr const char* _type_key = "relay.PatternWildcard";
   TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
 };
@@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
+    return equal.DefEqual(var, other->var);
+  }
+
   static constexpr const char* _type_key = "relay.PatternVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
 };
@@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
+    return
+        equal(constructor, other->constructor) &&
+        equal(patterns, other->patterns);
+  }
+
   static constexpr const char* _type_key = "relay.PatternConstructor";
   TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
 };
@@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode {
     v->Visit("span", &span);
   }
 
+  bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
+    return equal(patterns, other->patterns);
+  }
+
   static constexpr const char* _type_key = "relay.PatternTuple";
   TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
 };
@@ -208,7 +227,12 @@ class ClauseNode : public Object {
     v->Visit("rhs", &rhs);
   }
 
+  bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
+    return equal(lhs, other->lhs) && equal(rhs, other->rhs);
+  }
+
   static constexpr const char* _type_key = "relay.Clause";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
 };
 
@@ -248,6 +272,14 @@ class MatchNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return
+        equal(data, other->data) &&
+        equal(clauses, other->clauses) &&
+        equal(complete, other->complete);
+  }
+
   static constexpr const char* _type_key = "relay.Match";
   TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
 };
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 3acb5dd..731046e 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -26,6 +26,7 @@
 
 #include <tvm/ir/attrs.h>
 #include <tvm/ir/expr.h>
+#include <tvm/ir/op.h>
 #include <tvm/ir/module.h>
 #include <string>
 #include <functional>
@@ -72,6 +73,10 @@ class ConstantNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
+    return equal(data, other->data);
+  }
+
   static constexpr const char* _type_key = "relay.Constant";
   TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
 };
@@ -101,6 +106,16 @@ class TupleNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
+    // specially handle empty tuple as a constant is not a graph node.
+    if (fields.size() == other->fields.size() && fields.size() == 0) {
+      return true;
+    } else {
+      equal->MarkGraphNode();
+      return equal(fields, other->fields);
+    }
+  }
+
   static constexpr const char* _type_key = "relay.Tuple";
   TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
 };
@@ -157,6 +172,12 @@ class VarNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
+    return
+        equal(type_annotation, other->type_annotation) &&
+        equal.FreeVarEqualImpl(this, other);
+  }
+
   TVM_DLL static Var make(std::string name_hint,
                           Type type_annotation);
 
@@ -238,6 +259,16 @@ class CallNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
+    // skip type_args check for primitive ops.
+    equal->MarkGraphNode();
+    return
+        equal(op, other->op) &&
+        equal(args, other->args) &&
+        equal(attrs, other->attrs) &&
+        (IsPrimitiveOp(op) || equal(type_args, other->type_args));
+  }
+
   static constexpr const char* _type_key = "relay.Call";
   TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
 };
@@ -289,6 +320,14 @@ class LetNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return
+        equal.DefEqual(var, other->var) &&
+        equal(value, other->value) &&
+        equal(body, other->body);
+  }
+
   static constexpr const char* _type_key = "relay.Let";
   TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
 };
@@ -336,6 +375,14 @@ class IfNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return
+        equal(cond, other->cond) &&
+        equal(true_branch, other->true_branch) &&
+        equal(false_branch, other->false_branch);
+  }
+
   static constexpr const char* _type_key = "relay.If";
   TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
 };
@@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
+    return
+        equal(tuple, other->tuple) &&
+        equal(index, other->index);
+  }
+
   static constexpr const char* _type_key = "relay.TupleGetItem";
   TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
 };
@@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return equal(value, other->value);
+  }
+
   static constexpr const char* _type_key = "relay.RefCreate";
   TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
 };
@@ -426,6 +484,11 @@ class RefReadNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return equal(ref, other->ref);
+  }
+
   static constexpr const char* _type_key = "relay.RefRead";
   TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
 };
@@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return
+        equal(ref, other->ref) &&
+        equal(value, other->value);
+  }
+
   TVM_DLL static RefWrite make(Expr ref, Expr value);
 
   static constexpr const char* _type_key = "relay.RefWrite";
@@ -497,6 +567,7 @@ class TempExprNode : public ExprNode {
   virtual Expr Realize() const = 0;
 
   static constexpr const char* _type_key = "relay.TempExpr";
+  static constexpr const bool _type_has_method_sequal_reduce = false;
   TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
 };
 
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index 5c5bd26..ed39caa 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
+    // Important to make def equal first.
+    equal->MarkGraphNode();
+    return
+        equal.DefEqual(params, other->params) &&
+        equal.DefEqual(type_params, other->type_params) &&
+        equal(ret_type, other->ret_type) &&
+        equal(attrs, other->attrs) &&
+        equal(body, other->body);
+  }
+
   /*!
    * \brief Return the derived function annotation of this expression.
    *
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 2441ab6..17f81a2 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -65,6 +65,8 @@ class NDArray : public ObjectRef {
   inline int use_count() const;
   /*! \return Pointer to content of DLTensor */
   inline const DLTensor* operator->() const;
+  /*! \return Whether the tensor is contiguous */
+  inline bool IsContiguous() const;
   /*!
    * \brief Copy data content from another array.
    * \param other The source array to be copied from.
@@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) {
   return size;
 }
 
+/*!
+ * \brief check if a DLTensor is contiguous.
+ * \param arr The input DLTensor.
+ * \return The check result.
+ */
+inline bool IsContiguous(const DLTensor& arr) {
+  if (arr.strides == nullptr) return true;
+  int64_t expected_stride = 1;
+  for (int32_t i = arr.ndim; i != 0; --i) {
+    int32_t k = i - 1;
+    if (arr.strides[k] != expected_stride) return false;
+    expected_stride *= arr.shape[k];
+  }
+  return true;
+}
+
+inline bool NDArray::IsContiguous() const {
+  return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor);
+}
+
 inline void NDArray::CopyFrom(const DLTensor* other) {
   CHECK(data_ != nullptr);
   CopyFromTo(other, &(get_mutable()->dl_tensor));
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index fe5e30b..80b479d 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -211,11 +211,15 @@ class Object {
   static constexpr bool _type_final = false;
   static constexpr uint32_t _type_child_slots = 0;
   static constexpr bool _type_child_slots_can_overflow = true;
+  // member information
+  static constexpr bool _type_has_method_visit_attrs = true;
+  static constexpr bool _type_has_method_sequal_reduce = false;
   // NOTE: the following field is not type index of Object
   // but was intended to be used by sub-classes as default value.
   // The type index of Object is TypeIndex::kRoot
   static constexpr uint32_t _type_index = TypeIndex::kDynamic;
 
+
   // Default constructor and copy constructor
   Object() {}
   // Override the copy and assign constructors to do nothing.
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index c172316..60dd455 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -150,6 +150,20 @@ class BufferNode : public Object {
     v->Visit("buffer_type", &buffer_type);
   }
 
+  bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
+    // Use DefEqual as buffer can define variables
+    // in its semantics, skip name as name is not important.
+    return
+        equal.DefEqual(data, other->data) &&
+        equal(dtype, other->dtype) &&
+        equal.DefEqual(shape, other->shape) &&
+        equal.DefEqual(strides, other->strides) &&
+        equal.DefEqual(elem_offset, other->elem_offset) &&
+        equal(scope, other->scope) &&
+        equal(data_alignment, other->data_alignment) &&
+        equal(buffer_type, other->buffer_type);
+  }
+
   /*! \return preferred index type for this buffer node */
   DataType DefaultIndexType() const {
     return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
@@ -169,6 +183,7 @@ class BufferNode : public Object {
                              BufferType buffer_type);
 
   static constexpr const char* _type_key = "Buffer";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
 };
 
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 90fef87..28e6186 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -75,6 +75,12 @@ class VarNode : public PrimExprNode {
     v->Visit("type_annotation", &type_annotation);
   }
 
+  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
+    if (!equal(dtype, other->dtype)) return false;
+    if (!equal(type_annotation, other->type_annotation)) return false;
+    return equal.FreeVarEqualImpl(this, other);
+  }
+
   static constexpr const char* _type_key = "tir.Var";
   TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
 };
@@ -288,11 +294,20 @@ class IterVarNode : public Object {
     v->Visit("thread_tag", &thread_tag);
   }
 
+  bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
+    return
+        equal(dom, other->dom) &&
+        equal.DefEqual(var, other->var) &&
+        equal(iter_type, other->iter_type) &&
+        equal(thread_tag, other->thread_tag);
+  }
+
   TVM_DLL static IterVar make(Range dom, Var var,
                               IterVarType iter_type,
                               std::string thread_tag = "");
 
   static constexpr const char* _type_key = "IterVar";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
 };
 
@@ -334,6 +349,10 @@ class StringImmNode : public PrimExprNode {
     v->Visit("value", &value);
   }
 
+  bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
+    return equal(value, other->value);
+  }
+
   TVM_DLL PrimExpr static make(std::string value);
 
   static constexpr const char* _type_key = "StringImm";
@@ -359,6 +378,10 @@ class CastNode : public PrimExprNode {
     v->Visit("value", &value);
   }
 
+  bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
+    return equal(dtype, other->dtype) && equal(value, other->value);
+  }
+
   TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
 
   static constexpr const char* _type_key = "Cast";
@@ -383,6 +406,13 @@ class BinaryOpNode : public PrimExprNode {
     v->Visit("b", &b);
   }
 
+  bool SEqualReduce(const T* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(a, other->a) &&
+        equal(b, other->b);
+  }
+
   static PrimExpr make(PrimExpr a, PrimExpr b) {
     CHECK(a.defined()) << "ValueError: a is undefined\n";
     CHECK(b.defined()) << "ValueError: b is undefined\n";
@@ -475,6 +505,13 @@ class CmpOpNode : public PrimExprNode {
     v->Visit("b", &b);
   }
 
+  bool SEqualReduce(const T* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(a, other->a) &&
+        equal(b, other->b);
+  }
+
   static PrimExpr make(PrimExpr a, PrimExpr b) {
     CHECK(a.defined()) << "ValueError: a is undefined\n";
     CHECK(b.defined()) << "ValueError: b is undefined\n";
@@ -539,6 +576,13 @@ class AndNode : public PrimExprNode {
     v->Visit("b", &b);
   }
 
+  bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(a, other->a) &&
+        equal(b, other->b);
+  }
+
   TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
 
   static constexpr const char* _type_key = "And";
@@ -559,6 +603,13 @@ class OrNode : public PrimExprNode {
     v->Visit("b", &b);
   }
 
+  bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(a, other->a) &&
+        equal(b, other->b);
+  }
+
   TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
 
   static constexpr const char* _type_key = "Or";
@@ -576,6 +627,10 @@ class NotNode : public PrimExprNode {
     v->Visit("a", &a);
   }
 
+  bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
+    return equal(dtype, other->dtype) && equal(a, other->a);
+  }
+
   TVM_DLL static PrimExpr make(PrimExpr a);
 
   static constexpr const char* _type_key = "Not";
@@ -605,6 +660,14 @@ class SelectNode : public PrimExprNode {
     v->Visit("false_value", &false_value);
   }
 
+  bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(condition, other->condition) &&
+        equal(true_value, other->true_value) &&
+        equal(false_value, other->false_value);
+  }
+
   TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
 
   static constexpr const char* _type_key = "Select";
@@ -642,6 +705,14 @@ class LoadNode : public PrimExprNode {
     v->Visit("predicate", &predicate);
   }
 
+  bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(buffer_var, other->buffer_var) &&
+        equal(index, other->index) &&
+        equal(predicate, other->predicate);
+  }
+
   TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
 
   static constexpr const char* _type_key = "Load";
@@ -673,6 +744,14 @@ class RampNode : public PrimExprNode {
     v->Visit("lanes", &lanes);
   }
 
+  bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(base, other->base) &&
+        equal(stride, other->stride) &&
+        equal(lanes, other->lanes);
+  }
+
   TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
 
   static constexpr const char* _type_key = "Ramp";
@@ -693,6 +772,13 @@ class BroadcastNode : public PrimExprNode {
     v->Visit("lanes", &lanes);
   }
 
+  bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(value, other->value) &&
+        equal(lanes, other->lanes);
+  }
+
   TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
 
   static constexpr const char* _type_key = "Broadcast";
@@ -718,6 +804,14 @@ class LetNode : public PrimExprNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal.DefEqual(var, other->var) &&
+        equal(value, other->value) &&
+        equal(body, other->body);
+  }
+
   TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
 
   static constexpr const char* _type_key = "Let";
@@ -788,12 +882,22 @@ class CallNode : public PrimExprNode {
     v->Visit("value_index", &value_index);
   }
 
+  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(name, other->name) &&
+        equal(args, other->args) &&
+        equal(call_type, other->call_type) &&
+        equal(func, other->func) &&
+        equal(value_index, other->value_index);
+  }
+
   TVM_DLL static PrimExpr make(DataType dtype,
-                           std::string name,
-                           Array<PrimExpr> args,
-                           CallType call_type,
-                           FunctionRef func = FunctionRef(),
-                           int value_index = 0);
+                               std::string name,
+                               Array<PrimExpr> args,
+                               CallType call_type,
+                               FunctionRef func = FunctionRef(),
+                               int value_index = 0);
 
   /*! \return Whether call node is pure. */
   bool is_pure() const {
@@ -856,6 +960,13 @@ class ShuffleNode : public PrimExprNode {
     v->Visit("indices", &indices);
   }
 
+  bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
+    return
+        equal(dtype, other->dtype) &&
+        equal(vectors, other->vectors) &&
+        equal(indices, other->indices);
+  }
+
   TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr> indices);
   TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
   TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
@@ -918,7 +1029,16 @@ class CommReducerNode : public Object {
     v->Visit("identity_element", &identity_element);
   }
 
+  bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
+    return
+        equal.DefEqual(lhs, other->lhs) &&
+        equal.DefEqual(rhs, other->rhs) &&
+        equal(result, other->result) &&
+        equal(identity_element, other->identity_element);
+  }
+
   static constexpr const char* _type_key = "CommReducer";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
 };
 
@@ -948,10 +1068,10 @@ class ReduceNode : public PrimExprNode {
 
   /*! \brief construct expr from op and rdom */
   TVM_DLL static PrimExpr make(CommReducer combiner,
-                           Array<PrimExpr> src,
-                           Array<IterVar> rdom,
-                           PrimExpr condition,
-                           int value_index);
+                               Array<PrimExpr> src,
+                               Array<IterVar> rdom,
+                               PrimExpr condition,
+                               int value_index);
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("dtype", &dtype);
@@ -962,6 +1082,16 @@ class ReduceNode : public PrimExprNode {
     v->Visit("value_index", &value_index);
   }
 
+  bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
+    // check axis first so IterVars can define the necessary variables.
+    return
+        equal(dtype, other->dtype) &&
+        equal(axis, other->axis) &&
+        equal(combiner, other->combiner) &&
+        equal(source, other->source) &&
+        equal(condition, other->condition) &&
+        equal(value_index, other->value_index);
+  }
   static constexpr const char* _type_key = "Reduce";
   TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
 };
@@ -970,6 +1100,11 @@ class ReduceNode : public PrimExprNode {
 class AnyNode : public PrimExprNode {
  public:
   void VisitAttrs(AttrVisitor* v) {}
+
+  bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
+    return true;
+  }
+
   /*! \brief Convert to var. */
   Var ToVar() const {
     return Var("any_dim", DataType::Int(32));
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 63a8630..26b643a 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode {
     v->Visit("_checked_type_", &checked_type_);
   }
 
+  bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
+    // visit params and buffer_map first as they contains defs.
+    return
+        equal.DefEqual(params, other->params) &&
+        equal(buffer_map, other->buffer_map) &&
+        equal(ret_type, other->ret_type) &&
+        equal(body, other->body) &&
+        equal(attrs, other->attrs);
+  }
+
   /*!
    * \brief Return the derived function annotation of this function.
    *
@@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode {
   TVM_DLL FuncType func_type_annotation() const;
 
   static constexpr const char* _type_key = "tir.PrimFunc";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
 };
 
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index a543737..d4b144d 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -38,6 +38,7 @@ namespace tir {
 class StmtNode : public Object {
  public:
   static constexpr const char* _type_key = "Stmt";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
   TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
 };
 
@@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
+    return
+        equal.DefEqual(var, other->var) &&
+        equal(value, other->value) &&
+        equal(body, other->body);
+  }
+
   TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
 
   static constexpr const char* _type_key = "LetStmt";
@@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
+    return
+        equal(node, other->node) &&
+        equal(attr_key, other->attr_key) &&
+        equal(value, other->value) &&
+        equal(body, other->body);
+  }
+
   TVM_DLL static Stmt make(ObjectRef node,
                            std::string type_key,
                            PrimExpr value,
@@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
+    return
+        equal(condition, other->condition) &&
+        equal(message, other->message) &&
+        equal(body, other->body);
+  }
+
   TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
 
   static constexpr const char* _type_key = "AssertStmt";
@@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const {
+    return
+        equal(func, other->func) &&
+        equal(is_producer, other->is_producer) &&
+        equal(body, other->body);
+  }
+
   TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
 
   static constexpr const char* _type_key = "ProducerConsumer";
@@ -194,6 +224,14 @@ class StoreNode : public StmtNode {
     v->Visit("predicate", &predicate);
   }
 
+  bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
+    return
+        equal(buffer_var, other->buffer_var) &&
+        equal(value, other->value) &&
+        equal(index, other->index) &&
+        equal(predicate, other->predicate);
+  }
+
   TVM_DLL static Stmt make(Var buffer_var,
                            PrimExpr value,
                            PrimExpr index,
@@ -224,6 +262,14 @@ class ProvideNode : public StmtNode {
     v->Visit("args", &args);
   }
 
+  bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
+    return
+        equal(func, other->func) &&
+        equal(value_index, other->value_index) &&
+        equal(value, other->value) &&
+        equal(args, other->args);
+  }
+
   TVM_DLL static Stmt make(FunctionRef func,
                            int value_index,
                            PrimExpr value,
@@ -261,6 +307,15 @@ class AllocateNode : public StmtNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
+    return
+        equal.DefEqual(buffer_var, other->buffer_var) &&
+        equal(dtype, other->dtype) &&
+        equal(extents, other->extents) &&
+        equal(condition, other->condition) &&
+        equal(body, other->body);
+  }
+
   TVM_DLL static Stmt make(Var buffer_var,
                            DataType dtype,
                            Array<PrimExpr> extents,
@@ -300,6 +355,11 @@ class FreeNode : public StmtNode {
     v->Visit("buffer_var", &buffer_var);
   }
 
+  bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
+    return
+        equal(buffer_var, other->buffer_var);
+  }
+
   TVM_DLL static Stmt make(Var buffer_var);
 
   static constexpr const char* _type_key = "Free";
@@ -341,6 +401,16 @@ class RealizeNode : public StmtNode {
                            PrimExpr condition,
                            Stmt body);
 
+  bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
+    return
+        equal(func, other->func) &&
+        equal(value_index, other->value_index) &&
+        equal(dtype, other->dtype) &&
+        equal(bounds, other->bounds) &&
+        equal(condition, other->condition) &&
+        equal(body, other->body);
+  }
+
   static constexpr const char* _type_key = "Realize";
   TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
 };
@@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode {
     v->Visit("seq", &seq);
   }
 
+  bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
+    return equal(seq, other->seq);
+  }
+
   static constexpr const char* _type_key = "SeqStmt";
   TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
 };
@@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode {
     v->Visit("else_case", &else_case);
   }
 
+  bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
+    return
+        equal(condition, other->condition) &&
+        equal(then_case, other->then_case) &&
+        equal(else_case, other->else_case);
+  }
+
   TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());
 
   static constexpr const char* _type_key = "IfThenElse";
@@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode {
     v->Visit("value", &value);
   }
 
+  bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
+    return equal(value, other->value);
+  }
+
   TVM_DLL static Stmt make(PrimExpr v);
 
   static constexpr const char* _type_key = "Evaluate";
@@ -562,6 +647,16 @@ class ForNode : public StmtNode {
     v->Visit("body", &body);
   }
 
+  bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
+    return
+        equal.DefEqual(loop_var, other->loop_var) &&
+        equal(min, other->min) &&
+        equal(extent, other->extent) &&
+        equal(for_type, other->for_type) &&
+        equal(device_api, other->device_api) &&
+        equal(body, other->body);
+  }
+
   static constexpr const char* _type_key = "For";
   TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
 };
@@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode {
     v->Visit("bounds", &bounds);
   }
 
+  bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
+    return
+        equal(func, other->func) &&
+        equal(value_index, other->value_index) &&
+        equal(dtype, other->dtype) &&
+        equal(bounds, other->bounds);
+  }
+
   TVM_DLL static Stmt make(FunctionRef func,
                            int value_index,
                            DataType dtype,
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 1e11446..88af05c 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -17,6 +17,7 @@
 # pylint: disable=unused-import
 """Common data structures across all IR variants."""
 from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
+from .base import structural_equal, assert_structural_equal
 from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
 from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
 from .tensor_type import TensorType
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 810d78f..df69a2c 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -149,3 +149,76 @@ def save_json(node):
         Saved json string.
     """
     return tvm.runtime._ffi_node_api.SaveJSON(node)
+
+
+def structural_equal(lhs, rhs, map_free_vars=False):
+    """Check structural equality of lhs and rhs.
+
+    The structural equality is recursively defined in the DAG of IRNodes.
+    There are two kinds of nodes:
+
+    - Graph node: a graph node in lhs can only be mapped as equal to
+      one and only one graph node in rhs.
+    - Normal node: equality is recursively defined without the restriction
+      of graph nodes.
+
+    Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph nodes.
+    For example, it means that `%1 = %x + %y; %1 + %1` is not structurally equal
+    to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
+
+    A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another var
+    with the same type if one of the following condition holds:
+
+    - They appear in a same definition point(e.g. function argument).
+    - They points to the same VarNode via the same_as relation.
+    - They appear in a same usage point, and map_free_vars is set to be True.
+
+    The rules for var are used to remap variables occurs in function
+    arguments and let-bindings.
+
+    Parameters
+    ----------
+    lhs : Object
+        The left operand.
+
+    rhs : Object
+        The left operand.
+
+    map_free_vars : bool
+        Whether or not shall we map free vars that does
+        not bound to any definitions as equal to each other.
+
+    Return
+    ------
+    result : bool
+        The comparison result.
+    """
+    return tvm.runtime._ffi_node_api.StructuralEqual(
+        lhs, rhs, False, map_free_vars)
+
+
+def assert_structural_equal(lhs, rhs, map_free_vars=False):
+    """Assert lhs and rhs are structurally equal to each other.
+
+    Parameters
+    ----------
+    lhs : Object
+        The left operand.
+
+    rhs : Object
+        The left operand.
+
+    map_free_vars : bool
+        Whether or not shall we map free vars that does
+        not bound to any definitions as equal to each other.
+
+    Raises
+    ------
+    ValueError : if assertion does not hold.
+
+    See Also
+    --------
+    structural_equal
+    """
+    tvm.runtime._ffi_node_api.StructuralEqual(
+        lhs, rhs, True, map_free_vars)
diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h
index babd08a..9acc465 100644
--- a/src/ir/attr_functor.h
+++ b/src/ir/attr_functor.h
@@ -45,8 +45,8 @@ class AttrFunctor;
 
 #define ATTR_FUNCTOR_DISPATCH(OP)                                       \
   vtable.template set_dispatch<OP>(                                     \
-      [](const ObjectRef& n, TSelf* self, Args... args) {                 \
-        return self->VisitAttr_(static_cast<const OP*>(n.get()),  \
+      [](const ObjectRef& n, TSelf* self, Args... args) {               \
+        return self->VisitAttr_(static_cast<const OP*>(n.get()),        \
                                 std::forward<Args>(args)...);           \
       });                                                               \
 
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 9731a51..b07f04a 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm")
 
 TVM_REGISTER_NODE_TYPE(FloatImmNode);
 
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const FloatImmNode*>(node.get());
@@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range")
   *ret = Range(args[0], args[1]);
   });
 
+TVM_REGISTER_NODE_TYPE(RangeNode);
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const RangeNode*>(node.get());
     p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
   });
 
-TVM_REGISTER_NODE_TYPE(ArrayNode);
-TVM_REGISTER_NODE_TYPE(MapNode);
-TVM_REGISTER_NODE_TYPE(StrMapNode);
-TVM_REGISTER_NODE_TYPE(RangeNode);
-
 
 GlobalVar::GlobalVar(std::string name_hint) {
   ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 4ac769b..ca85cb8 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
   data_ = std::move(n);
 }
 
+
+bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
+  if (functions.size() != other->functions.size()) return false;
+  for (const auto& kv : this->functions) {
+    if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
+    if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
+  }
+  if (type_definitions.size() != other->type_definitions.size()) return false;
+  for (const auto& kv : this->type_definitions) {
+    if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
+    if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
+  }
+  return true;
+}
+
 bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
   return global_var_map_.find(name) != global_var_map_.end();
 }
@@ -305,8 +320,8 @@ IRModule IRModule::FromExpr(
   const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
   auto mod = IRModule(global_funcs, type_definitions);
   BaseFunc func;
-  if (auto* func_node = expr.as<relay::FunctionNode>()) {
-    func = GetRef<relay::Function>(func_node);
+  if (auto* func_node = expr.as<BaseFuncNode>()) {
+    func = GetRef<BaseFunc>(func_node);
   } else {
     func = relay::Function(
         relay::FreeVars(expr), expr, Type(),
diff --git a/src/node/container.cc b/src/node/container.cc
index 25bfe9d..fc5c62a 100644
--- a/src/node/container.cc
+++ b/src/node/container.cc
@@ -21,11 +21,98 @@
  * \file src/node/container.cc
  */
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
 #include <tvm/node/container.h>
 #include <tvm/tir/expr.h>
+#include <cstring>
 
 namespace tvm {
 
+// SEQualReduce traits for runtime containers.
+struct StringObjTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static bool SEqualReduce(const runtime::StringObj* lhs,
+                           const runtime::StringObj* rhs,
+                           SEqualReducer equal) {
+    if (lhs == rhs) return true;
+    if (lhs->size != rhs->size) return false;
+    if (lhs->data != rhs->data) return true;
+    return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
+  }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
+
+struct ADTObjTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static bool SEqualReduce(const runtime::ADTObj* lhs,
+                           const runtime::ADTObj* rhs,
+                           SEqualReducer equal) {
+    if (lhs == rhs) return true;
+    if (lhs->tag != rhs->tag) return false;
+    if (lhs->size != rhs->size) return false;
+
+    for (uint32_t i = 0; i < lhs->size; ++i) {
+      if (!equal((*lhs)[i], (*rhs)[i])) return false;
+    }
+    return true;
+  }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
+
+
+struct NDArrayContainerTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static bool SEqualReduce(const runtime::NDArray::Container* lhs,
+                           const runtime::NDArray::Container* rhs,
+                           SEqualReducer equal) {
+    if (lhs == rhs) return true;
+
+    auto ldt = lhs->dl_tensor.dtype;
+    auto rdt = rhs->dl_tensor.dtype;
+    CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
+    CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor";
+    CHECK(runtime::IsContiguous(lhs->dl_tensor))
+        << "Can only compare contiguous tensor";
+    CHECK(runtime::IsContiguous(rhs->dl_tensor))
+        << "Can only compare contiguous tensor";
+    if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
+      size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
+      return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
+    } else {
+      return false;
+    }
+  }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait);
+
+
+struct ArrayNodeTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static bool SEqualReduce(const ArrayNode* lhs,
+                           const ArrayNode* rhs,
+                           SEqualReducer equal) {
+    if (lhs->data.size() != rhs->data.size()) return false;
+    for (size_t i = 0; i < lhs->data.size(); ++i) {
+      if (!equal(lhs->data[i], rhs->data[i])) return false;
+    }
+    return true;
+  }
+};
+
+TVM_REGISTER_OBJECT_TYPE(ArrayNode);
+TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
+.set_creator([](const std::string&) -> ObjectPtr<Object> {
+    return ::tvm::runtime::make_object<ArrayNode>();
+  });
+
+
 TVM_REGISTER_GLOBAL("node.Array")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     std::vector<ObjectRef> data;
@@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize")
         static_cast<const ArrayNode*>(ptr)->data.size());
   });
 
+
+struct MapNodeTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static bool SEqualReduce(const MapNode* lhs,
+                           const MapNode* rhs,
+                           SEqualReducer equal) {
+    if (rhs->data.size() != lhs->data.size()) return false;
+    for (const auto& kv : lhs->data) {
+      // Only allow equal checking if the keys are already mapped
+      // This resolves common use cases where we want to store
+      // Map<Var, Value> where Var is defined in the function
+      // parameters.
+      ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
+      if (!rhs_key.defined()) return false;
+      auto it = rhs->data.find(rhs_key);
+      if (it == rhs->data.end()) return false;
+      if (!equal(kv.second, it->second)) return false;
+    }
+    return true;
+  }
+};
+
+TVM_REGISTER_OBJECT_TYPE(MapNode);
+TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
+.set_creator([](const std::string&) -> ObjectPtr<Object> {
+    return ::tvm::runtime::make_object<MapNode>();
+  });
+
+
+struct StrMapNodeTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static bool SEqualReduce(const StrMapNode* lhs,
+                           const StrMapNode* rhs,
+                           SEqualReducer equal) {
+    if (rhs->data.size() != lhs->data.size()) return false;
+    for (const auto& kv : lhs->data) {
+      auto it = rhs->data.find(kv.first);
+      if (it == rhs->data.end()) return false;
+      if (!equal(kv.second, it->second)) return false;
+    }
+    return true;
+  }
+};
+
+TVM_REGISTER_OBJECT_TYPE(StrMapNode);
+TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait)
+.set_creator([](const std::string&) -> ObjectPtr<Object> {
+    return ::tvm::runtime::make_object<StrMapNode>();
+  });
+
+
 TVM_REGISTER_GLOBAL("node.Map")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     CHECK_EQ(args.size() % 2, 0);
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index 183079f..824874f 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -180,7 +180,7 @@ ObjectPtr<Object>
 ReflectionVTable::CreateInitObject(const std::string& type_key,
                                    const std::string& global_key) const {
   uint32_t tindex = Object::TypeKey2Index(type_key);
-  if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
+  if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
     LOG(FATAL) << "TypeError: " << type_key
                << " is not registered via TVM_REGISTER_NODE_TYPE";
   }
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
new file mode 100644
index 0000000..23dfe15
--- /dev/null
+++ b/src/node/structural_equal.cc
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/node/structural_equal.cc
+ */
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/reflection.h>
+#include <tvm/node/functor.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/registry.h>
+
+#include <unordered_map>
+
+namespace tvm {
+
+// Define the dispatch functio here since primary user is in this file.
+bool ReflectionVTable::
+SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const {
+  uint32_t tindex = self->type_index();
+  if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) {
+    LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
+        << " is not registered via TVM_REGISTER_NODE_TYPE";
+  }
+  return fsequal_[tindex](self, other, equal);
+}
+
+/*!
+ * \brief A non recursive stack based SEqual handler that can remaps vars.
+ *
+ *  This handler pushs the Object equality cases into a stack, and
+ *  traverses the stack to expand the necessary children that need to be checked.
+ *
+ *  The order of SEqual being called is the same as the order as if we
+ *  eagerly do recursive calls in SEqualReduce.
+ */
+class RemapVarSEqualHandler :
+      public SEqualReducer::Handler {
+ public:
+  explicit RemapVarSEqualHandler(bool assert_mode)
+      : assert_mode_(assert_mode) {}
+
+  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
+    // We cannot use check lhs.same_as(rhs) to check equality.
+    // if we choose to enable var remapping.
+    //
+    // Counter example below (%x, %y) are shared vars
+    // between the two functions(possibly before/after rewriting).
+    //
+    // - function0: fn (%x, %y) { %x + %y }
+    // - function1. fn (%y, %x) { %x + %y }
+    //
+    // Because we choose to enable var remapping,
+    // %x is mapped to %y, and %y is mapped to %x,
+    // the body of the function no longer means the same thing.
+    //
+    // Take away: We can either choose only compare Var by address,
+    // in which case we can use same_as for quick checking,
+    // or we have to run deep comparison and avoid to use same_as checks.
+    auto run = [=]() {
+      if (!lhs.defined() && !rhs.defined()) return true;
+      if (!lhs.defined() && rhs.defined()) return false;
+      if (!rhs.defined() && lhs.defined()) return false;
+      if (lhs->type_index() != rhs->type_index()) return false;
+      auto it = equal_map_lhs_.find(lhs);
+      if (it != equal_map_lhs_.end()) {
+        return it->second.same_as(rhs);
+      }
+      if (equal_map_rhs_.count(rhs)) return false;
+      // need to push to pending tasks in this case
+      pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
+      return true;
+    };
+    return CheckResult(run(), lhs, rhs);
+  }
+
+  void MarkGraphNode() final {
+    // need to push to pending tasks in this case
+    CHECK(!allow_push_to_stack_ && !task_stack_.empty());
+    task_stack_.back().graph_equal = true;
+  }
+
+  ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
+    auto it = equal_map_lhs_.find(lhs);
+    if (it != equal_map_lhs_.end()) return it->second;
+    return ObjectRef(nullptr);
+  }
+
+  // Function that implements actual equality check.
+  bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+    task_stack_.clear();
+    pending_tasks_.clear();
+    equal_map_lhs_.clear();
+    equal_map_rhs_.clear();
+    if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
+    CHECK_EQ(pending_tasks_.size(), 1U);
+    CHECK(allow_push_to_stack_);
+    task_stack_.emplace_back(std::move(pending_tasks_.back()));
+    pending_tasks_.clear();
+    return RunTasks();
+  }
+
+ protected:
+  // Check the result.
+  bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
+    if (assert_mode_ && !result) {
+      LOG(FATAL)
+          << "ValueError: StructuralEqual check failed, caused by\n"
+          << "lhs = " << lhs << "\nrhs = " << rhs;
+    }
+    return result;
+  }
+  /*!
+   * \brief Run tasks until the stack reaches the stack begin
+   * \param stack_begin The expected beginning of the stack.
+   * \return The checks we encountered throughout the process.
+   */
+  bool RunTasks() {
+    while (task_stack_.size() != 0) {
+      // Caution: entry becomes invalid when the stack changes
+      auto& entry = task_stack_.back();
+
+      if (entry.children_expanded) {
+        // When all the children has expanded and visited.
+        // This means all the condition checks for
+        // the current entry has been passed
+        // We can safely mark lhs and rhs as equal to each other.
+        auto it = equal_map_lhs_.find(entry.lhs);
+        if (it != equal_map_lhs_.end()) {
+          CHECK(it->second.same_as(entry.rhs));
+        }
+        // create the map if the quality is graph equal.
+        if (entry.graph_equal) {
+          equal_map_lhs_[entry.lhs] = entry.rhs;
+          equal_map_rhs_[entry.rhs] = entry.lhs;
+        }
+        task_stack_.pop_back();
+      } else {
+        // mark before expand
+        // Important: because entry becomes invalid when stack changes.
+        entry.children_expanded = true;
+        // Expand the objects
+        // The SEqual of the object can call into this->SEqualReduce
+        // which populates the pending tasks.
+        CHECK_EQ(pending_tasks_.size(), 0U);
+        allow_push_to_stack_ = false;
+        if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false;
+        allow_push_to_stack_ = true;
+        // Push pending tasks in reverse order, so earlier tasks get to
+        // expand first in the stack
+        while (pending_tasks_.size() != 0) {
+          task_stack_.emplace_back(std::move(pending_tasks_.back()));
+          pending_tasks_.pop_back();
+        }
+      }
+    }
+    return true;
+  }
+
+  // The default equal as registered in the structural equal vtable.
+  bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+    auto compute = [=]() {
+      CHECK(lhs.defined() &&
+            rhs.defined() &&
+            lhs->type_index() == rhs->type_index());
+      // skip entries that already have equality maps.
+      auto it = equal_map_lhs_.find(lhs);
+      if (it != equal_map_lhs_.end()) {
+        return it->second.same_as(rhs);
+      }
+      if (equal_map_rhs_.count(rhs)) return false;
+      // Run reduce check for free nodes.
+      return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars));
+    };
+    return CheckResult(compute(), lhs, rhs);
+  }
+
+ private:
+  /*! \brief Pending reduce tasks. */
+  struct Task {
+    /*! \brief The lhs operand to be compared. */
+    ObjectRef lhs;
+    /*! \brief The rhs operand to be compared. */
+    ObjectRef rhs;
+    /*! \brief The map free var argument. */
+    bool map_free_vars;
+    /*! \brief Whether the children has been expanded via SEqualReduce */
+    bool children_expanded{false};
+    /*! \brief whether the task is about graph equality(need remap). */
+    bool graph_equal{false};
+
+    Task() = default;
+    Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
+        : lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
+  };
+  // list of pending tasks to be pushed to the stack.
+  std::vector<Task> pending_tasks_;
+  // Internal task stack to executed the task.
+  std::vector<Task> task_stack_;
+  // Whether we allow push to stack.
+  bool allow_push_to_stack_{true};
+  //  If in assert mode, must return true, and will throw error otherwise.
+  bool assert_mode_{false};
+  // reflection vtable
+  ReflectionVTable* vtable_ = ReflectionVTable::Global();
+  // map from lhs to rhs
+  std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_lhs_;
+  // map from rhs to lhs
+  std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_rhs_;
+};
+
+
+TVM_REGISTER_GLOBAL("node.StructuralEqual")
+.set_body_typed([](const ObjectRef& lhs,
+                   const ObjectRef& rhs,
+                   bool assert_mode,
+                   bool map_free_vars) {
+  return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
+});
+
+bool StructuralEqual::operator()(const ObjectRef& lhs,
+                                 const ObjectRef& rhs) const {
+  return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
+}
+
+}  // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 07759b3..bee0256 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var")
 TVM_REGISTER_GLOBAL("tir.SizeVar")
 .set_body_typed([](std::string s, DataType t) {
     return SizeVar(s, t);
-  });
+});
+
 
 IterVar IterVarNode::make(Range dom,
                           Var var,
@@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) {
 TVM_REGISTER_GLOBAL("tir.StringImm")
 .set_body_typed(StringImmNode::make);
 
+
 PrimExpr CastNode::make(DataType t, PrimExpr value) {
   CHECK(value.defined());
   CHECK_EQ(t.lanes(), value.dtype().lanes());
@@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) {
   return PrimExpr(node);
 }
 
+
 PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(b.defined()) << "ValueError: b is undefined";
@@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
   return PrimExpr(node);
 }
 
+
 PrimExpr NotNode::make(PrimExpr a) {
   CHECK(a.defined()) << "ValueError: a is undefined";
   CHECK(a.dtype().is_bool());
@@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) {
   return PrimExpr(node);
 }
 
+
+
 PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
   CHECK(condition.defined()) << "ValueError: condition is undefined";
   CHECK(true_value.defined()) << "ValueError: true_value is undefined";
@@ -270,11 +276,11 @@ bool CallNode::is_vectorizable() const {
 }
 
 PrimExpr CallNode::make(DataType dtype,
-                std::string name,
-                Array<PrimExpr> args,
-                CallType call_type,
-                FunctionRef func,
-                int value_index) {
+                        std::string name,
+                        Array<PrimExpr> args,
+                        CallType call_type,
+                        FunctionRef func,
+                        int value_index) {
   for (size_t i = 0; i < args.size(); ++i) {
     CHECK(args[i].defined());
   }
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 9d875c1..b1efe4a 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1114,7 +1114,7 @@ def test_read_variable_op():
                                        num_output=len(out_name))
             for i in range(len(tf_output)):
                 tvm.testing.assert_allclose(
-                    tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+                    tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
 
         sess.close()
 
diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py
index fbe5213..9e62491 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -17,8 +17,6 @@
 import tvm
 from tvm import te
 from tvm import relay
-from tvm.relay.analysis import graph_equal, assert_graph_equal
-from tvm.relay.analysis import alpha_equal, assert_alpha_equal
 import pytest
 from numpy import isclose
 from typing import Union
@@ -69,6 +67,13 @@ type List[A] {
 }
 """
 
+def assert_graph_equal(lhs, rhs):
+    tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
+
+def graph_equal(lhs, rhs):
+    return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
+
+
 def roundtrip(expr):
     x = relay.fromtext(expr.astext())
     assert_graph_equal(x, expr)
@@ -86,6 +91,12 @@ def parses_as(code, expr):
     result = graph_equal(parsed, expr)
     return result
 
+
+def assert_parses_as(code, expr):
+    parsed = parse_text(code)
+    assert_graph_equal(parsed, expr)
+
+
 def get_scalar(x):
     # type: (relay.Constant) -> (Union[float, int, bool])
     return x.data.asnumpy().item()
@@ -102,7 +113,7 @@ UNIT = relay.Tuple([])
 
 
 def test_comments():
-    assert parses_as(
+    assert_parses_as(
         """
         // This is a line comment!
         ()
@@ -110,7 +121,7 @@ def test_comments():
         UNIT
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         /* This is a block comment!
             This is still a block comment!
@@ -120,7 +131,7 @@ def test_comments():
         UNIT
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         /* This is a block comment!
            /*Block comment is recursive!*/
@@ -172,7 +183,7 @@ def test_negative():
 
 def test_bin_op():
     for bin_op in BINARY_OPS.keys():
-        assert parses_as(
+        assert_parses_as(
             "1 {} 1".format(bin_op),
             BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
         )
@@ -213,7 +224,7 @@ def test_vars():
 
 
 def test_let():
-    assert parses_as(
+    assert_parses_as(
         "let %x = 1; ()",
         relay.Let(
             X,
@@ -222,7 +233,7 @@ def test_let():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         let %x = 1;
         let %y = 2;
@@ -241,7 +252,7 @@ def test_let():
 
 
 def test_seq():
-    assert parses_as(
+    assert_parses_as(
         "();; ()",
         relay.Let(
             _,
@@ -249,7 +260,7 @@ def test_seq():
             UNIT)
     )
 
-    assert parses_as(
+    assert_parses_as(
         "let %_ = 1; ()",
         relay.Let(
             X,
@@ -261,14 +272,10 @@ def test_seq():
 
 def test_graph():
     code = "%0 = (); %1 = 1; (%0, %0, %1)"
-    assert parses_as(
+    assert_parses_as(
         code,
         relay.Tuple([UNIT, UNIT, relay.const(1)])
     )
-    assert not parses_as(
-        code,
-        relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
-    )
 
 
 @raises_parse_error
@@ -287,18 +294,18 @@ def test_let_op():
 
 
 def test_tuple():
-    assert parses_as("()", relay.Tuple([]))
+    assert_parses_as("()", relay.Tuple([]))
 
-    assert parses_as("(0,)", relay.Tuple([relay.const(0)]))
+    assert_parses_as("(0,)", relay.Tuple([relay.const(0)]))
 
-    assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
+    assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
 
-    assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
+    assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
 
 
 def test_func():
     # 0 args
-    assert parses_as(
+    assert_parses_as(
         "fn () { 0 }",
         relay.Function(
             [],
@@ -309,7 +316,7 @@ def test_func():
     )
 
     # 1 arg
-    assert parses_as(
+    assert_parses_as(
         "fn (%x) { %x }",
         relay.Function(
             [X],
@@ -320,7 +327,7 @@ def test_func():
     )
 
     # 2 args
-    assert parses_as(
+    assert_parses_as(
         "fn (%x, %y) { %x + %y }",
         relay.Function(
             [X, Y],
@@ -331,7 +338,7 @@ def test_func():
     )
 
     # annotations
-    assert parses_as(
+    assert_parses_as(
         "fn (%x: int32) -> int32 { %x }",
         relay.Function(
             [X_ANNO],
@@ -342,7 +349,7 @@ def test_func():
     )
 
     # attributes
-    assert parses_as(
+    assert_parses_as(
         "fn (n=5) { () }",
         relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs", n=relay.const(5)))
     )
@@ -370,7 +377,7 @@ def test_recursive_call():
 
 
 def test_ifelse():
-    assert parses_as(
+    assert_parses_as(
         """
         if (True) {
             0
@@ -403,7 +410,7 @@ def test_ifelse_scope():
 def test_call():
     # select right function to call: simple ident case
     id_func = relay.Var("id")
-    assert parses_as(
+    assert_parses_as(
         """
         let %id = fn (%x) { %x };
         10 * %id(10)
@@ -417,7 +424,7 @@ def test_call():
 
     # 0 args
     constant = relay.Var("constant")
-    assert parses_as(
+    assert_parses_as(
         """
         let %constant = fn () { 0 };
         %constant()
@@ -431,7 +438,7 @@ def test_call():
 
     # 1 arg
     id_var = relay.Var("id")
-    assert parses_as(
+    assert_parses_as(
         """
         let %id = fn (%x) { %x };
         %id(1)
@@ -445,7 +452,7 @@ def test_call():
 
     # 2 args
     multiply = relay.Var("multiply")
-    assert parses_as(
+    assert_parses_as(
         """
         let %multiply = fn (%x, %y) { %x * %y };
         %multiply(0, 0)
@@ -463,7 +470,7 @@ def test_call():
     )
 
     # anonymous function
-    assert parses_as(
+    assert_parses_as(
         """
         (fn (%x) { %x })(0)
         """,
@@ -483,7 +490,7 @@ def test_call():
     # TODO(@jmp): re-enable after sequence parsing improvements
     # curried function
     # curried_mult = relay.Var("curried_mult")
-    # assert parses_as(
+    # assert_parses_as(
     #     """
     #     let %curried_mult =
     #         fn (%x) {
@@ -516,7 +523,7 @@ def test_call():
     # )
 
     # op
-    assert parses_as(
+    assert_parses_as(
         "abs(1)",
         relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
     )
@@ -525,7 +532,7 @@ def test_call():
 
 
 def test_incomplete_type():
-    assert parses_as(
+    assert_parses_as(
         "let %_ : _ = (); ()",
         relay.Let(
             _,
@@ -541,7 +548,7 @@ def test_builtin_types():
 
 
 def test_tensor_type():
-    assert parses_as(
+    assert_parses_as(
         "let %_ : Tensor[(), float32] = (); ()",
         relay.Let(
             relay.Var("_", relay.TensorType((), "float32")),
@@ -550,7 +557,7 @@ def test_tensor_type():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         "let %_ : Tensor[(1), float32] = (); ()",
         relay.Let(
             relay.Var("_", relay.TensorType((1,), "float32")),
@@ -559,7 +566,7 @@ def test_tensor_type():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         "let %_ : Tensor[(1, 1), float32] = (); ()",
         relay.Let(
             relay.Var("_", relay.TensorType((1, 1), "float32")),
@@ -570,7 +577,7 @@ def test_tensor_type():
 
 
 def test_function_type():
-    assert parses_as(
+    assert_parses_as(
         """
         let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
         """,
@@ -581,7 +588,7 @@ def test_function_type():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
         """,
@@ -592,7 +599,7 @@ def test_function_type():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
         """,
@@ -605,7 +612,7 @@ def test_function_type():
 
 
 def test_tuple_type():
-    assert parses_as(
+    assert_parses_as(
         """
         let %_: () = (); ()
         """,
@@ -616,7 +623,7 @@ def test_tuple_type():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         let %_: (int32,) = (0,); ()
         """,
@@ -627,7 +634,7 @@ def test_tuple_type():
         )
     )
 
-    assert parses_as(
+    assert_parses_as(
         """
         let %_: (int32, int32) = (0, 1); ()
         """,
@@ -648,7 +655,7 @@ def test_adt_defn():
             [],
             [relay.Constructor("Nil", [], glob_typ_var)])
     mod[glob_typ_var] = prog
-    assert parses_as(
+    assert_parses_as(
         """
         type Ayy { Nil }
         """,
@@ -662,7 +669,7 @@ def test_empty_adt_defn():
     glob_typ_var = relay.GlobalTypeVar("Ayy")
     prog = relay.TypeData(glob_typ_var, [], [])
     mod[glob_typ_var] = prog
-    assert parses_as(
+    assert_parses_as(
         """
         type Ayy { }
         """,
@@ -683,7 +690,7 @@ def test_multiple_cons_defn():
                 relay.Constructor("Nil", [], list_var),
             ])
     mod[list_var] = prog
-    assert parses_as(LIST_DEFN, mod)
+    assert_parses_as(LIST_DEFN, mod)
 
 
 def test_multiple_type_param_defn():
@@ -699,7 +706,7 @@ def test_multiple_type_param_defn():
             ])
     mod = tvm.IRModule()
     mod[glob_typ_var] = prog
-    assert parses_as(
+    assert_parses_as(
         """
         type Either[A, B] {
           Left(A),
@@ -755,7 +762,7 @@ def test_match():
         )
         mod[length_var] = length_func
 
-        assert parses_as(
+        assert_parses_as(
             """
             %s
 
@@ -796,7 +803,7 @@ def test_adt_cons_expr():
     )
     mod[make_singleton_var] = make_singleton_func
 
-    assert parses_as(
+    assert_parses_as(
         """
         %s
 
@@ -861,7 +868,7 @@ def test_extern_adt_defn():
     extern_def = relay.TypeData(extern_var, [typ_var], [])
     mod[extern_var] = extern_def
 
-    assert parses_as(
+    assert_parses_as(
         """
         extern type T[A]
         """,
@@ -872,6 +879,7 @@ def test_import_grad():
     mod.import_from_std("gradient.rly")
 
 if __name__ == "__main__":
+    test_graph()
     test_comments()
     test_int_literal()
     test_float_literal()
@@ -882,7 +890,6 @@ if __name__ == "__main__":
     test_op_assoc()
     test_let()
     test_seq()
-    test_graph()
     test_tuple()
     test_func()
     test_defn()
@@ -905,4 +912,4 @@ if __name__ == "__main__":
     test_duplicate_adt_cons_defn()
     test_duplicate_global_var()
     test_extern_adt_defn()
-    test_import_grad()
\ No newline at end of file
+    test_import_grad()
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_ir_structural_equal.py
similarity index 78%
rename from tests/python/relay/test_pass_alpha_equal.py
rename to tests/python/relay/test_ir_structural_equal.py
index 411906d..5881ab9 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_ir_structural_equal.py
@@ -21,23 +21,24 @@ from tvm import relay
 from tvm.relay import analysis
 from tvm.relay.testing import run_opt_pass
 
-def alpha_equal(x, y):
+def sequal(x, y):
     """
     Wrapper around alpha equality which ensures that
     the hash function respects equality.
     """
-    return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
+    return (tvm.ir.structural_equal(x, y) and
+            analysis.structural_hash(x) == analysis.structural_hash(y))
 
-def alpha_equal_commutative(x, y):
+def sequal_commutative(x, y):
     """
     Check for commutative property of equality
     """
-    xy = analysis.alpha_equal(x, y)
-    yx = analysis.alpha_equal(y, x)
+    xy = tvm.ir.structural_equal(x, y)
+    yx = tvm.ir.structural_equal(y, x)
     assert xy == yx
     return xy
 
-def test_tensor_type_alpha_equal():
+def test_tensor_type_sequal():
     t1 = relay.TensorType((3, 4), "float32")
     t2 = relay.TensorType((3, 4), "float32")
     t3 = relay.TensorType((3, 4, 5), "float32")
@@ -49,7 +50,7 @@ def test_tensor_type_alpha_equal():
     assert t1 == t2
 
 
-def test_incomplete_type_alpha_equal():
+def test_incomplete_type_sequal():
     t1 = relay.IncompleteType(relay.TypeKind.ShapeVar)
     t2 = relay.IncompleteType(relay.TypeKind.Type)
     t3 = relay.IncompleteType(relay.TypeKind.Type)
@@ -61,7 +62,7 @@ def test_incomplete_type_alpha_equal():
     assert t2 != t3
 
 
-def test_type_param_alpha_equal():
+def test_type_param_sequal():
     t1 = relay.TypeVar("v1", relay.TypeKind.Type)
     t2 = relay.TypeVar("v2", relay.TypeKind.ShapeVar)
     t3 = relay.TypeVar("v3", relay.TypeKind.Type)
@@ -83,7 +84,7 @@ def test_type_param_alpha_equal():
     assert ft1 != ft3 # kinds still do not match
 
 
-def test_func_type_alpha_equal():
+def test_func_type_sequal():
     t1 = relay.TensorType((1, 2), "float32")
     t2 = relay.TensorType((1, 2, 3), "float32")
 
@@ -143,7 +144,7 @@ def test_func_type_alpha_equal():
     assert ft != more_rels
 
 
-def test_tuple_type_alpha_equal():
+def test_tuple_type_sequal():
     t1 = relay.TensorType((1, 2, 3), "float32")
     t2 = relay.TensorType((1, 2, 3, 4), "float32")
     tp1 = relay.TypeVar("v1", relay.TypeKind.Type)
@@ -161,7 +162,7 @@ def test_tuple_type_alpha_equal():
     assert tup1 != tup4
 
 
-def test_type_relation_alpha_equal():
+def test_type_relation_sequal():
     t1 = relay.TensorType((1, 2), "float32")
     t2 = relay.TensorType((1, 2, 3), "float32")
     t3 = relay.TensorType((1, 2, 3, 4), "float32")
@@ -197,7 +198,7 @@ def test_type_relation_alpha_equal():
 
     assert bigger != diff_num_inputs
 
-def test_type_call_alpha_equal():
+def test_type_call_sequal():
     h1 = relay.GlobalTypeVar("h1")
     h2 = relay.GlobalTypeVar("h2")
     t1 = relay.TensorType((1, 2), "float32")
@@ -221,49 +222,49 @@ def test_type_call_alpha_equal():
     assert tc != different_order_args
 
 
-def test_constant_alpha_equal():
+def test_constant_sequal():
     x = relay.const(1)
     y = relay.const(2)
-    assert alpha_equal(x, x)
-    assert not alpha_equal(x, y)
-    assert alpha_equal(x, relay.const(1))
+    assert sequal(x, x)
+    assert not sequal(x, y)
+    assert sequal(x, relay.const(1))
 
-def test_type_node_alpha_equal():
+def test_type_node_sequal():
     v1 = relay.TypeVar('v1', 6)
     v2 = relay.TypeVar('v2', 6)
-    assert not alpha_equal(v1, v2)
+    assert not sequal(v1, v2)
 
     v1 = relay.TypeVar('v1', 0)
     v2 = relay.TypeVar('v2', 6)
-    assert not alpha_equal(v1, v2)
+    assert not sequal(v1, v2)
 
-    assert alpha_equal_commutative(v1, v1)
+    assert sequal_commutative(v1, v1)
 
-def test_type_node_incompatible_alpha_equal():
+def test_type_node_incompatible_sequal():
     v1 = relay.TypeVar('v1', 6)
     v2 = relay.Var("v2")
-    assert not alpha_equal_commutative(v1, v2)
+    assert not sequal_commutative(v1, v2)
 
-def test_expr_node_incompatible_alpha_equal():
+def test_expr_node_incompatible_sequal():
     v1 = relay.Var("v1")
     v2 = relay.PatternVar(relay.Var("v2"))
-    assert not alpha_equal_commutative(v1, v2)
+    assert not sequal_commutative(v1, v2)
 
-def test_var_alpha_equal():
+def test_var_sequal():
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
     # normally only pointer equality
-    assert alpha_equal(v1, v1)
-    assert not alpha_equal(v1, v2)
+    assert sequal(v1, v1)
+    assert not sequal(v1, v2)
 
     # let node allows for setting the eq_map
     l1 = relay.Let(v1, relay.const(1), v1)
     l2 = relay.Let(v2, relay.const(1), v2)
     l3 = relay.Let(v1, relay.const(1), v2)
 
-    assert alpha_equal(l1, l2)
-    assert not alpha_equal(l1, l3)
+    assert sequal(l1, l2)
+    assert not sequal(l1, l3)
 
     # type annotations
     tt1 = relay.TensorType([], "int32")
@@ -278,34 +279,34 @@ def test_var_alpha_equal():
     l6 = relay.Let(v5, relay.const(1), v5)
 
     # same annotations
-    assert alpha_equal(l4, l5)
+    assert sequal(l4, l5)
     # different annotations
-    assert not alpha_equal(l4, l6)
+    assert not sequal(l4, l6)
     # one null annotation
-    assert not alpha_equal(l1, l4)
+    assert not sequal(l1, l4)
 
 
-def test_global_var_alpha_equal():
+def test_global_var_sequal():
     v1 = relay.GlobalVar("v1")
     v2 = relay.GlobalVar("v2")
 
     # only pointer equality suffices (smoke test)
-    assert alpha_equal(v1, v1)
-    assert not alpha_equal(v1, v2)
+    assert sequal(v1, v1)
+    assert not sequal(v1, v2)
 
 
-def test_tuple_alpha_equal():
+def test_tuple_sequal():
     v0 = relay.Var("v0")
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
     # unit value is a valid tuple
-    assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
+    assert sequal(relay.Tuple([]), relay.Tuple([]))
 
     tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
     same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
 
-    assert alpha_equal(tup, same)
+    assert sequal(tup, same)
 
     # use the eq_map
 
@@ -315,33 +316,33 @@ def test_tuple_alpha_equal():
                                             relay.Tuple([relay.const(4)])]),
                            v2)
 
-    assert alpha_equal(let_tup, let_mapped)
+    assert sequal(let_tup, let_mapped)
 
     more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
-    assert not alpha_equal(tup, more_fields)
+    assert not sequal(tup, more_fields)
 
     fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)])
-    assert not alpha_equal(tup, fewer_fields)
+    assert not sequal(tup, fewer_fields)
 
     different_end = relay.Tuple([v1, relay.const(2), relay.const(3),
                            relay.Tuple([relay.const(5)])])
-    assert not alpha_equal(tup, different_end)
+    assert not sequal(tup, different_end)
 
     different_start = relay.Tuple([v2, relay.const(2), relay.const(3),
                                  relay.Tuple([relay.const(4)])])
-    assert not alpha_equal(tup, different_start)
+    assert not sequal(tup, different_start)
 
     longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3),
                                  relay.Tuple([relay.const(4), relay.const(5)])])
-    assert not alpha_equal(tup, longer_at_end)
+    assert not sequal(tup, longer_at_end)
 
 
-def test_tuple_get_item_alpha_equal():
+def test_tuple_get_item_sequal():
     x = relay.Var('x')
     y = relay.Var('y')
-    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
-    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
-    assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
+    assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
+    assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
+    assert sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
 
 
 def test_function_attr():
@@ -364,10 +365,10 @@ def test_function_attr():
     q10 = relay.multiply(p10, w12)
     func1 = relay.Function([x1, w10, w11, w12], q10)
     func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b"))
-    assert not alpha_equal(func0, func1)
+    assert not sequal(func0, func1)
 
 
-def test_function_alpha_equal():
+def test_function_sequal():
     tt1 = relay.TensorType((1, 2, 3), "float32")
     tt2 = relay.TensorType((4, 5, 6), "int8")
     tt3 = relay.TupleType([tt1, tt2])
@@ -389,58 +390,58 @@ def test_function_alpha_equal():
     func = relay.Function([v1, v2], v1,
                           tt2, basic_tps)
     mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
-    assert alpha_equal(func, mapped)
+    assert sequal(func, mapped)
 
     fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
-    assert not alpha_equal(func, fewer_params)
+    assert not sequal(func, fewer_params)
 
     more_params = relay.Function([relay.Var("v3", tt1),
                                   relay.Var("v4", tt2),
                                   relay.Var("v2", tt2)], v4, tt2, basic_tps)
-    assert not alpha_equal(func, more_params)
+    assert not sequal(func, more_params)
 
     params_unordered = relay.Function([v2, v1], v1,
                                       tt2, basic_tps)
-    assert not alpha_equal(func, params_unordered)
+    assert not sequal(func, params_unordered)
 
     params_mismatch = relay.Function([v1, v3], v1,
                                      tt2, basic_tps)
-    assert not alpha_equal(func, params_mismatch)
+    assert not sequal(func, params_mismatch)
 
     # also would not typecheck
     ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
-    assert not alpha_equal(func, ret_type_mismatch)
+    assert not sequal(func, ret_type_mismatch)
 
     # also mis-typed
     different_body = relay.Function(basic_args, v3, tt2, basic_tps)
-    assert not alpha_equal(func, different_body)
+    assert not sequal(func, different_body)
 
     fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
-    assert not alpha_equal(func, fewer_type_params)
+    assert not sequal(func, fewer_type_params)
 
     more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
-    assert not alpha_equal(func, more_type_params)
+    assert not sequal(func, more_type_params)
 
     type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
-    assert not alpha_equal(func, type_params_unordered)
+    assert not sequal(func, type_params_unordered)
 
     different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
-    assert not alpha_equal(func, different_type_params)
+    assert not sequal(func, different_type_params)
 
     # a well-typed example that also differs in body, ret type, and type params
     tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
-    assert not alpha_equal(func, tupled_example)
+    assert not sequal(func, tupled_example)
 
     # nullable
     no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
     # both null
-    assert alpha_equal(no_ret_type, no_ret_type)
+    assert sequal(no_ret_type, no_ret_type)
     # one null
-    assert not alpha_equal(func, no_ret_type)
-    assert not alpha_equal(no_ret_type, func)
+    assert not sequal(func, no_ret_type)
+    assert not sequal(no_ret_type, func)
 
 
-def test_call_alpha_equal():
+def test_call_sequal():
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
@@ -458,43 +459,43 @@ def test_call_alpha_equal():
     call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])],
                       attr1, [tt1])
     same = relay.Call(v1, basic_args, attr1, [tt1])
-    assert alpha_equal(call, same)
+    assert sequal(call, same)
 
     different_fn = relay.Call(v2, basic_args, attr1, [tt1])
-    assert not alpha_equal(call, different_fn)
+    assert not sequal(call, different_fn)
 
     fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1])
-    assert not alpha_equal(call, fewer_args)
+    assert not sequal(call, fewer_args)
 
     reordered_args = relay.Call(v1, [relay.const(2), relay.const(1),
                                      relay.Tuple([]), v2], attr1, [tt1])
-    assert not alpha_equal(call, reordered_args)
+    assert not sequal(call, reordered_args)
 
     different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)],
                                 attr1, [tt1])
-    assert not alpha_equal(call, different_args)
+    assert not sequal(call, different_args)
 
     more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]),
                                 relay.const(3), relay.const(4)], attr1, [tt1])
-    assert not alpha_equal(call, more_args)
+    assert not sequal(call, more_args)
 
     different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
-    assert not alpha_equal(call, different_attrs)
+    assert not sequal(call, different_attrs)
 
     same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
-    assert alpha_equal(call, same_attrs)
+    assert sequal(call, same_attrs)
 
     no_type_args = relay.Call(v1, basic_args, attr1)
-    assert not alpha_equal(call, no_type_args)
+    assert not sequal(call, no_type_args)
 
     more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
-    assert not alpha_equal(call, more_type_args)
+    assert not sequal(call, more_type_args)
 
     different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
-    assert not alpha_equal(call, different_type_arg)
+    assert not sequal(call, different_type_arg)
 
 
-def test_let_alpha_equal():
+def test_let_sequal():
     tt1 = relay.TensorType((), "float32")
     tt2 = relay.TensorType((), "int8")
     v1 = relay.Var("v1")
@@ -504,57 +505,57 @@ def test_let_alpha_equal():
 
     let = relay.Let(v1, relay.const(2), v1)
     mapped = relay.Let(v2, relay.const(2), v2)
-    assert alpha_equal(let, mapped)
+    assert sequal(let, mapped)
 
     mismatched_var = relay.Let(v2, relay.const(2), v3)
-    assert not alpha_equal(let, mismatched_var)
+    assert not sequal(let, mismatched_var)
 
     different_value = relay.Let(v2, relay.const(3), v2)
-    assert not alpha_equal(let, different_value)
+    assert not sequal(let, different_value)
 
     different_body = relay.Let(v2, relay.const(3), relay.const(12))
-    assert not alpha_equal(let, different_body)
+    assert not sequal(let, different_body)
 
     # specified types must match
 
     let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
     same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
-    assert alpha_equal(let_with_type, same_type)
-    assert not alpha_equal(let, let_with_type)
+    assert sequal(let_with_type, same_type)
+    assert not sequal(let, let_with_type)
     v2 = relay.Var("v1", tt2)
     different_type = relay.Let(v2, relay.const(2), v2)
-    assert not alpha_equal(let_with_type, different_type)
+    assert not sequal(let_with_type, different_type)
 
 
-def test_if_alpha_equal():
+def test_if_sequal():
     v1 = relay.Var("v1")
     v2 = relay.Var("v2")
 
     if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
     same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
-    assert alpha_equal(if_sample, same)
+    assert sequal(if_sample, same)
 
     different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
-    assert not alpha_equal(if_sample, different_cond)
+    assert not sequal(if_sample, different_cond)
 
     different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)]))
-    assert not alpha_equal(if_sample, different_true)
+    assert not sequal(if_sample, different_true)
 
     different_false = relay.If(v1, relay.const(1), relay.Tuple([]))
-    assert not alpha_equal(if_sample, different_false)
+    assert not sequal(if_sample, different_false)
 
 
-def test_constructor_alpha_equal():
+def test_constructor_sequal():
     # smoke test: it should be pointer equality
     mod = tvm.IRModule()
     p = relay.prelude.Prelude(mod)
 
-    assert alpha_equal(p.nil, p.nil)
-    assert alpha_equal(p.cons, p.cons)
-    assert not alpha_equal(p.nil, p.cons)
+    assert sequal(p.nil, p.nil)
+    assert sequal(p.cons, p.cons)
+    assert not sequal(p.nil, p.cons)
 
 
-def test_match_alpha_equal():
+def test_match_sequal():
     mod = tvm.IRModule()
     p = relay.prelude.Prelude(mod)
 
@@ -604,27 +605,28 @@ def test_match_alpha_equal():
                      p.cons(x, p.nil()))
     ])
 
-    assert alpha_equal(match, match)
-    assert alpha_equal(match, equivalent)
-    assert not alpha_equal(match, no_cons)
-    assert not alpha_equal(match, no_nil)
-    assert not alpha_equal(match, empty)
-    assert not alpha_equal(match, different_data)
-    assert not alpha_equal(match, different_order)
-    assert not alpha_equal(match, different_nil)
-    assert not alpha_equal(match, different_cons)
-    assert not alpha_equal(match, another_case)
-    assert not alpha_equal(match, wrong_constructors)
-
-
-def test_op_alpha_equal():
+    tvm.ir.assert_structural_equal(match, match)
+    assert sequal(match, match)
+    assert sequal(match, equivalent)
+    assert not sequal(match, no_cons)
+    assert not sequal(match, no_nil)
+    assert not sequal(match, empty)
+    assert not sequal(match, different_data)
+    assert not sequal(match, different_order)
+    assert not sequal(match, different_nil)
+    assert not sequal(match, different_cons)
+    assert not sequal(match, another_case)
+    assert not sequal(match, wrong_constructors)
+
+
+def test_op_sequal():
     # only checks names
     op1 = relay.op.get("add")
     op2 = relay.op.get("add")
-    assert alpha_equal(op1, op2)
+    assert sequal(op1, op2)
 
     op3 = relay.op.get("take")
-    assert not alpha_equal(op1, op3)
+    assert not sequal(op1, op3)
 
 
 def test_graph_equal():
@@ -638,14 +640,14 @@ def test_graph_equal():
 
     z3 = relay.add(relay.add(x, x), relay.add(x, x))
 
-    assert alpha_equal(z0, z1)
-    assert alpha_equal(z0, z1)
+    assert sequal(z0, z1)
+    assert sequal(z0, z1)
 
     # z3's dataflow format is different from z0
     # z0 is computed from a common y0 node
     # Relay view them as different programs
     # Check the difference in the text format.
-    assert not alpha_equal(z0, z3)
+    assert not sequal(z0, z3)
 
 def test_hash_unequal():
     x1 = relay.var("x1", shape=(10, 10), dtype="float32")
@@ -677,7 +679,7 @@ def test_tuple_match():
     b = relay.Var("b")
     clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
     y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
-    assert analysis.alpha_equal(x, y)
+    assert sequal(x, y)
     assert analysis.structural_hash(x) == analysis.structural_hash(y)
 
 
@@ -697,34 +699,34 @@ def test_fn_attribute():
     add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test"))
     add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
 
-    assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
-    assert not relay.analysis.alpha_equal(add_fn, add_1_fn)
+    assert not sequal(add_1_fn, add_fn)
+    assert not sequal(add_fn, add_1_fn)
 
 
 if __name__ == "__main__":
-    test_tensor_type_alpha_equal()
-    test_incomplete_type_alpha_equal()
-    test_constant_alpha_equal()
-    test_type_node_alpha_equal()
-    test_type_node_incompatible_alpha_equal()
-    test_expr_node_incompatible_alpha_equal()
-    test_func_type_alpha_equal()
-    test_tuple_type_alpha_equal()
-    test_type_relation_alpha_equal()
-    test_type_call_alpha_equal()
-    test_constant_alpha_equal()
-    test_global_var_alpha_equal()
-    test_tuple_alpha_equal()
-    test_tuple_get_item_alpha_equal()
-    test_function_alpha_equal()
+    test_tensor_type_sequal()
+    test_incomplete_type_sequal()
+    test_constant_sequal()
+    test_type_node_sequal()
+    test_type_node_incompatible_sequal()
+    test_expr_node_incompatible_sequal()
+    test_func_type_sequal()
+    test_tuple_type_sequal()
+    test_type_relation_sequal()
+    test_type_call_sequal()
+    test_constant_sequal()
+    test_global_var_sequal()
+    test_tuple_sequal()
+    test_tuple_get_item_sequal()
+    test_function_sequal()
     test_function_attr()
-    test_call_alpha_equal()
-    test_let_alpha_equal()
-    test_if_alpha_equal()
-    test_constructor_alpha_equal()
-    test_match_alpha_equal()
-    test_op_alpha_equal()
-    test_var_alpha_equal()
+    test_call_sequal()
+    test_let_sequal()
+    test_if_sequal()
+    test_constructor_sequal()
+    test_match_sequal()
+    test_op_sequal()
+    test_var_sequal()
     test_graph_equal()
     test_hash_unequal()
     test_fn_attribute()
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py
index 604ec89..3a0bf1f 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass):
 def test_let():
     orig = relay.Let(e.x, e.y, e.z)
     orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
+    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
 
 
 def test_used_let():
     orig = relay.Let(e.c, e.one, e.c + e.c)
     orig = run_opt_pass(orig, transform.DeadCodeElimination())
     expected = relay.Let(e.c, e.one, e.c + e.c)
-    assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
+    assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
 
 def test_inline():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
@@ -75,7 +75,7 @@ def test_inline():
 def test_chain_unused_let():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
     orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
+    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
 
 
 def use_f(func):
@@ -111,13 +111,13 @@ def test_recursion_dead():
     x = relay.Let(e.a, e.one, e.three)
     dced_f = lambda f: x
     dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert alpha_equal(dced, e.three)
+    assert tvm.ir.structural_equal(dced, e.three)
 
 
 def test_op_let():
     dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
                         transform.DeadCodeElimination())
-    assert alpha_equal(dced, add(e.three, e.two))
+    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
 
 
 def test_tuple_get_item():
@@ -126,10 +126,10 @@ def test_tuple_get_item():
     a = relay.Var('a')
     g = relay.TupleGetItem(t, 0)
     dced = run_opt_pass(g, transform.DeadCodeElimination())
-    assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
     orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
     dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
 
 
 @pytest.mark.timeout(timeout=10, method="thread")
diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py
index f54dd6b..1299084 100644
--- a/tests/python/relay/test_pass_partial_eval.py
+++ b/tests/python/relay/test_pass_partial_eval.py
@@ -72,7 +72,7 @@ def test_tuple():
     f = Function([x], body, None, [t])
     expected = relay.Function([x], x, None, [t])
     expected = run_opt_pass(expected, transform.InferType())
-    assert alpha_equal(dcpe(f), expected)
+    assert tvm.ir.structural_equal(dcpe(f), expected)
 
 
 def test_const_inline():
@@ -80,7 +80,7 @@ def test_const_inline():
     d = Var("d", t)
     double = Function([d], d + d)
     orig = double(const(4.0))
-    assert alpha_equal(dcpe(orig), const(8.0))
+    assert tvm.ir.structural_equal(dcpe(orig), const(8.0))
 
 
 def test_ref():
@@ -93,7 +93,7 @@ def test_ref():
     body = Let(r, RefCreate(d), body)
     square = Function([d], body)
     expected = run_opt_pass(Function([d], d * d), transform.InferType())
-    assert alpha_equal(dcpe(square), expected)
+    assert tvm.ir.structural_equal(dcpe(square), expected)
 
 
 def test_empty_ad():
@@ -105,7 +105,7 @@ def test_empty_ad():
     g = dcpe(f, grad=True)
     expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
     expected = run_opt_pass(expected, transform.InferType())
-    assert alpha_equal(g, expected)
+    assert tvm.ir.structural_equal(g, expected)
 
 
 def test_ad():
@@ -180,7 +180,7 @@ def test_head_cons():
     body = hd(p.cons(x, p.nil()))
     f = Function([x], body, None, [t])
     res = dcpe(f, mod)
-    assert alpha_equal(res, Function([x], x, t, [t]))
+    assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
 
 
 def test_map():
@@ -197,7 +197,7 @@ def test_map():
     expected = mod["main"]
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res.body, expected.body)
+    assert tvm.ir.structural_equal(res.body, expected.body)
 
 
 def test_loop():
@@ -211,7 +211,7 @@ def test_loop():
     expected = mod["main"].body
     call = Function([], loop(const(1)))
     res = dcpe(call, mod=mod)
-    assert alpha_equal(res.body, expected)
+    assert tvm.ir.structural_equal(res.body, expected)
 
 
 def test_swap_loop():
@@ -226,7 +226,7 @@ def test_swap_loop():
     prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
     res = Function([], prog)
     res = dcpe(res, mod=mod)
-    assert alpha_equal(prog, res.body)
+    assert tvm.ir.structural_equal(prog, res.body)
 
 
 def test_abs_diff():
@@ -248,7 +248,7 @@ def test_abs_diff():
     orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res.body, make_nat_expr(p, 4))
+    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
 
 
 def test_match_nat_id():
@@ -265,7 +265,7 @@ def test_match_nat_id():
     orig = nat_id(make_nat_expr(p, 3))
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res.body, make_nat_expr(p, 3))
+    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
 
 
 def test_nat_id():
@@ -280,7 +280,7 @@ def test_nat_id():
     orig = nat_id(make_nat_expr(p, 3))
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res.body, make_nat_expr(p, 3))
+    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
 
 
 def test_global_match_nat_id():
@@ -294,7 +294,7 @@ def test_global_match_nat_id():
     orig = Match(make_nat_expr(p, 3), [z_case, s_case])
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res.body, make_nat_expr(p, 3))
+    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
 
 
 def test_double():
@@ -304,7 +304,7 @@ def test_double():
     orig = p.double(make_nat_expr(p, 3))
     orig = Function([], orig)
     res = dcpe(orig, mod=mod)
-    assert alpha_equal(res.body, make_nat_expr(p, 6))
+    assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
 
 
 def test_concat():
diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py
index ed05096..b164821 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -134,7 +134,7 @@ def test_qnn_legalize_qnn_conv2d():
         # Since same dtype, there should not be any transformation
         with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert alpha_equal(mod, legalized_mod)
+            assert tvm.ir.structural_equal(mod, legalized_mod)
 
         ################################################################
         # Check transformations for platforms without fast Int8 support.
@@ -157,7 +157,7 @@ def test_qnn_legalize_qnn_conv2d():
     # Check no transformation for Intel VNNI.
     with tvm.target.create('llvm -mcpu=skylake-avx512'):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert alpha_equal(mod, legalized_mod)
+        assert tvm.ir.structural_equal(mod, legalized_mod)
 
     # ARM - so check that transformation has happened.
     with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
@@ -221,7 +221,7 @@ def test_qnn_legalize_qnn_dense():
         # Since same dtype, there should not be any transformation
         with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
             legalized_mod = relay.qnn.transform.Legalize()(mod)
-            assert alpha_equal(mod, legalized_mod)
+            assert tvm.ir.structural_equal(mod, legalized_mod)
 
         ################################################################
         # Check transformations for platforms without fast Int8 support.
@@ -244,7 +244,7 @@ def test_qnn_legalize_qnn_dense():
     # Check no transformation for Intel VNNI.
     with tvm.target.create('llvm -mcpu=skylake-avx512'):
         legalized_mod = relay.qnn.transform.Legalize()(mod)
-        assert alpha_equal(mod, legalized_mod)
+        assert tvm.ir.structural_equal(mod, legalized_mod)
 
     # ARM - so check that transformation has happened.
     with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+dotprod'):
diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py
index 2a6103e..29818f8 100644
--- a/tests/python/relay/test_pass_to_a_normal_form.py
+++ b/tests/python/relay/test_pass_to_a_normal_form.py
@@ -76,7 +76,7 @@ def test_order():
     expected_output = relay.Let(b, y, expected_output)
     expected_output = relay.Let(a, x, expected_output)
     expected_output = run_opt_pass(expected_output, transform.InferType())
-    assert alpha_equal(anf, expected_output)
+    assert tvm.ir.structural_equal(anf, expected_output)
 
 
 def test_if():
@@ -93,7 +93,7 @@ def test_if():
     expected_output = relay.Let(d, expected_output, d)
     expected_output = relay.Let(c, cond, expected_output)
     expected_output = run_opt_pass(expected_output, transform.InferType())
-    assert alpha_equal(anf, expected_output)
+    assert tvm.ir.structural_equal(anf, expected_output)
 
 
 # make sure we dont infinite loop.
diff --git a/tests/python/relay/test_pass_to_cps.py b/tests/python/relay/test_pass_to_cps.py
index e2ac924..4aaa9a0 100644
--- a/tests/python/relay/test_pass_to_cps.py
+++ b/tests/python/relay/test_pass_to_cps.py
@@ -17,7 +17,7 @@
 import numpy as np
 import tvm
 from tvm import relay
-from tvm.relay.analysis import alpha_equal, detect_feature
+from tvm.relay.analysis import detect_feature
 from tvm.relay.transform import to_cps, un_cps
 from tvm.relay.analysis import Feature
 from tvm.relay.prelude import Prelude
diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py
index 74507ba..4591618 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -21,7 +21,6 @@ import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import op, transform, analysis
-from tvm.relay.analysis import assert_alpha_equal
 
 
 def run_infer_type(expr, mod=None):
@@ -360,7 +359,7 @@ def test_let_polymorphism():
     body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
     body = run_infer_type(body)
     int32 = relay.TensorType((), "int32")
-    assert_alpha_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
+    tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py
index a25ba0a..f2848ff 100644
--- a/tests/python/unittest/test_node_reflection.py
+++ b/tests/python/unittest/test_node_reflection.py
@@ -25,7 +25,7 @@ def test_const_saveload_json():
     z = z + z
     json_str = tvm.ir.save_json(z)
     zz = tvm.ir.load_json(json_str)
-    assert tvm.ir.save_json(zz) == tvm.ir.save_json(z)
+    tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
 
 
 def test_make_smap():
@@ -38,6 +38,7 @@ def test_make_smap():
     arr = tvm.ir.load_json(json_str)
     assert len(arr) == 1
     assert arr[0]["z"].a == arr[0]["x"]
+    tvm.ir.assert_structural_equal(arr, [smap], map_free_vars=True)
 
 
 def test_make_node():
@@ -90,7 +91,6 @@ def test_env_func():
 
 if __name__ == "__main__":
     test_env_func()
-    test_make_attrs()
     test_make_node()
     test_make_smap()
     test_const_saveload_json()
diff --git a/tests/python/unittest/test_tir_structural_equal.py b/tests/python/unittest/test_tir_structural_equal.py
new file mode 100644
index 0000000..26f3085
--- /dev/null
+++ b/tests/python/unittest/test_tir_structural_equal.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import pytest
+from tvm import te
+
+
+def test_exprs():
+    # save load json
+    x = tvm.tir.const(1, "int32")
+    y = tvm.tir.const(10, "int32")
+    vx = te.var("x")
+    vy = te.var("y")
+    vz = te.var("z")
+
+    # test assert trigger.
+    with pytest.raises(ValueError):
+        tvm.ir.assert_structural_equal(x, y)
+
+    assert not tvm.ir.structural_equal(vx, vy)
+    assert tvm.ir.structural_equal(vx, vy, map_free_vars=True)
+    # corner case lhs:vx == rhs:vy, but cannot map it iteslf
+    assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True)
+    # corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
+    assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True)
+    # corner case2: rolling remap.
+    assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True)
+    assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False)
+    # Defintition remap
+    assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1),
+                                   tvm.tir.Let(vy, 1, vy - 1))
+    # Default same address free var remap
+    assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz),
+                                   tvm.tir.Let(vy, 1, vy // vz))
+
+    zx = vx + vx
+    zy = vy + vy
+    assert tvm.ir.structural_equal(zx * zx, zx * zx)
+    assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True)
+    assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False)
+    assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx),
+                                   map_free_vars=False)
+
+
+def test_prim_func():
+    x = te.var('x')
+    y = te.var('y')
+    # counter example of same equality
+    func0 = tvm.tir.PrimFunc(
+        [x, y], tvm.tir.Evaluate(x + y))
+    func1 = tvm.tir.PrimFunc(
+        [x, y], tvm.tir.Evaluate(y + x))
+    assert not tvm.ir.structural_equal(func0, func1)
+
+    # new cases
+    b = tvm.tir.decl_buffer((x,), "float32")
+    stmt = tvm.tir.LetStmt(
+        x, 10, tvm.tir.Evaluate(x + 1))
+    func0 = tvm.tir.PrimFunc(
+        [x, y, b], stmt)
+    # easiest way to deep copy is via save/load
+    func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
+    tvm.ir.assert_structural_equal(func0, func1)
+
+    data0 = tvm.nd.array([1, 2, 3])
+    data1 = tvm.nd.array([1, 2, 3])
+    # attributes and ndarrays
+    func0 = func0.with_attr("data", data0)
+    func1 = func1.with_attr("data", data1)
+    # IRModules
+    mod0 = tvm.IRModule.from_expr(func0)
+    mod1 = tvm.IRModule.from_expr(func1)
+    tvm.ir.assert_structural_equal(mod0, mod1)
+
+
+def test_attrs():
+    x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
+    y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
+    z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
+    tvm.ir.assert_structural_equal(y, x)
+    assert not tvm.ir.structural_equal(y, z)
+
+
+
+if __name__ == "__main__":
+    test_exprs()
+    test_prim_func()
+    test_attrs()


Mime
View raw message