tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: remove AttrsEqual and AttrsHash related code (#5169)
Date Mon, 30 Mar 2020 01:57:37 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 6536b35  remove AttrsEqual and AttrsHash related code (#5169)
6536b35 is described below

commit 6536b356fcd74ee3a2544900c36dfe2340290720
Author: Zhi <5145158+zhiics@users.noreply.github.com>
AuthorDate: Sun Mar 29 18:57:27 2020 -0700

    remove AttrsEqual and AttrsHash related code (#5169)
---
 include/tvm/ir/attrs.h                             | 171 +------------
 src/ir/attr_functor.h                              |  89 -------
 src/ir/attrs.cc                                    | 278 ---------------------
 src/node/structural_equal.cc                       |   1 +
 src/relay/transforms/combine_parallel_conv2d.cc    |   4 +-
 src/relay/transforms/combine_parallel_dense.cc     |   2 +-
 src/relay/transforms/combine_parallel_op.cc        |   3 +-
 src/relay/transforms/combine_parallel_op_batch.cc  |   4 +-
 src/relay/transforms/eliminate_common_subexpr.cc   |   2 +-
 src/relay/transforms/fold_scale_axis.cc            |   4 +-
 src/relay/transforms/fuse_ops.cc                   |   2 +-
 src/relay/transforms/pattern_util.h                |   2 +-
 src/tir/pass/ffi_api.cc                            |  12 -
 tests/python/relay/test_ir_nodes.py                |   1 -
 tests/python/unittest/test_ir_attrs.py             |   9 +-
 .../unittest/test_tir_pass_attrs_hash_equal.py     |  18 +-
 16 files changed, 29 insertions(+), 573 deletions(-)

diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index fbca3bb..0fc832e 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -46,6 +46,8 @@
 
 #include <dmlc/common.h>
 #include <tvm/ir/expr.h>
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
 #include <tvm/runtime/packed_func.h>
 
 #include <unordered_map>
@@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
 };
 
-class AttrsHashHandler;
-class AttrsEqualHandler;
-/*!
- * \brief Content-aware Equality comparator for attrs.
- *
- * This comparator will recursively deep compare the following Attributes.
- *
- * - IntImm, UIntImm, FloatImm, StringImm
- * - Any subclass of BaseAttrsNode
- * - Array of Attributes.
- * - Map from string to Attributes.
- */
-class AttrsEqual {
- public:
-  bool operator()(const double& lhs, const double& rhs) const {
-    // fuzzy float pt comparison
-    constexpr double atol = 1e-9;
-    if (lhs == rhs) return true;
-    double diff = lhs - rhs;
-    return diff > -atol && diff < atol;
-  }
-
-  bool operator()(const int64_t& lhs, const int64_t& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const int& lhs, const int& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const bool& lhs, const bool& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const std::string& lhs, const std::string& rhs) const {
-    return lhs == rhs;
-  }
-  bool operator()(const DataType& lhs, const DataType& rhs) const {
-    return lhs == rhs;
-  }
-  // node comparator
-  TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
-
- protected:
-  friend class AttrsEqualHandler;
-  /*! \brief internal handle. */
-  AttrsEqualHandler* handler_{nullptr};
-};
-
-/*!
- * \brief Content-aware hash function.
- *
- * This hash functor will recursively hash the content of the Attributes.
- * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
- */
-class AttrsHash {
- public:
-  size_t operator()(const double& value) const {
-    return std::hash<double>()(value);
-  }
-  size_t operator()(const int64_t& value) const {
-    return std::hash<int64_t>()(value);
-  }
-  size_t operator()(const uint64_t& value) const {
-    return std::hash<uint64_t>()(value);
-  }
-  size_t operator()(const int& value) const {
-    return std::hash<int>()(value);
-  }
-  size_t operator()(const bool& value) const {
-    return std::hash<bool>()(value);
-  }
-  size_t operator()(const std::string& value) const {
-    return std::hash<std::string>()(value);
-  }
-  size_t operator()(const DataType& value) const {
-    return std::hash<int>()(
-        static_cast<int>(value.code()) |
-        (static_cast<int>(value.bits()) << 8) |
-        (static_cast<int>(value.lanes()) << 16));
-  }
-  TVM_DLL size_t operator()(const ObjectRef& value) const;
-
- private:
-  friend class AttrsHashHandler;
-  /*! \brief internal handle. */
-  AttrsHashHandler* handler_{nullptr};
-};
-
 /*!
  * \brief Base class of all attribute class
  * \note Do not subclass AttrBaseNode directly,
@@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
    * \note This function throws when the required field is not present.
    */
   TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false)
= 0;
-  /*!
-   * \brief Whether this attribute's content equals to another node.
-   * \param other The pointer to another node.
-   * \param equal The equal comparator
-   * \return The comparison result.
-   */
-  TVM_DLL virtual bool ContentEqual(
-      const Object* other, AttrsEqual equal) const = 0;
-  /*!
-   * \brief Content aware hash.
-   * \param hasher The hasher to run the hash.
-   * \return the hash result.
-   */
-  TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
 
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
@@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
   void VisitNonDefaultAttrs(AttrVisitor* v) final;
   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
   Array<AttrFieldInfo> ListFieldInfo() const final;
-  bool ContentEqual(const Object* other, AttrsEqual equal) const final;
-  size_t ContentHash(AttrsHash hasher) const final;
   // type info
   static constexpr const char* _type_key = "DictAttrs";
   TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
@@ -386,34 +283,6 @@ class AttrNormalVisitor {
   AttrVisitor* visitor_;
 };
 
-// Wrapper for normal visitor.
-class AttrsEqualVisitor {
- public:
-  bool result_{true};
-  // constructor
-  AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual& equal)
-      : lhs_(lhs), rhs_(rhs), equal_(equal) {
-  }
-  template<typename T>
-  AttrNopEntry operator()(const char* key, T* lhs_value) {
-    if (!result_) return AttrNopEntry();
-    const T* rhs_value =
-        reinterpret_cast<const T*>(
-            reinterpret_cast<const char*>(rhs_) +
-            (reinterpret_cast<const char*>(lhs_value) -
-             reinterpret_cast<const char*>(lhs_)));
-    if (!equal_(*lhs_value, *rhs_value)) {
-      result_ = false;
-    }
-    return AttrNopEntry();
-  }
-
- private:
-  const Object* lhs_;
-  const Object* rhs_;
-  const AttrsEqual& equal_;
-};
-
 class AttrsSEqualVisitor {
  public:
   bool result_{true};
@@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
   const SEqualReducer& equal_;
 };
 
-class AttrsHashVisitor {
- public:
-  explicit AttrsHashVisitor(const AttrsHash& hasher)
-      : hasher_(hasher) {}
-
-  size_t result_{0};
-
-  template<typename T>
-  AttrNopEntry operator()(const char* key, T* value) {
-    result_ = dmlc::HashCombine(result_, hasher_(*value));
-    return AttrNopEntry();
-  }
-
- private:
-  const AttrsHash& hasher_;
-};
-
 class AttrsSHashVisitor {
  public:
   explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
@@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
     return *this;
   }
   TSelf& set_default(const T& value) {
-    if (AttrsEqual()(value, *data_)) {
+    if (tvm::StructuralEqual()(value, *data_)) {
       trigger_ = false;
     }
     return *this;
@@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
     return visitor.fields_;
   }
 
-  bool ContentEqual(const Object* other, AttrsEqual equal) const final {
-    DerivedType* pself = self();
-    if (pself == other) return true;
-    if (other == nullptr) return false;
-    if (pself->type_index() != other->type_index()) return false;
-    ::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
-    self()->__VisitAttrs__(visitor);
-    return visitor.result_;
-  }
-
-  size_t ContentHash(AttrsHash hasher) const final {
-    ::tvm::detail::AttrsHashVisitor visitor(hasher);
-    visitor.result_ = this->GetTypeKeyHash();
-    self()->__VisitAttrs__(visitor);
-    return visitor.result_;
-  }
-
  private:
   DerivedType* self() const {
     return const_cast<DerivedType*>(
diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h
index 9acc465..dbd5a4f 100644
--- a/src/ir/attr_functor.h
+++ b/src/ir/attr_functor.h
@@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
   }
 };
 
-class AttrsEqualHandler :
-      protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> {
- public:
-  /*!
-   * \brief Check if lhs equals rhs
-   * \param lhs The left operand.
-   * \param rhs The right operand.
-   */
-  bool Equal(const ObjectRef& lhs, const ObjectRef& rhs);
-
- protected:
-  bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final;
-  bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final;
-};
-
-class AttrsHashHandler :
-      protected AttrFunctor<size_t(const ObjectRef&)> {
- public:
-  /*!
-   * \brief Get hash value of node
-   * \param node The node to be hashed.
-   */
-  size_t Hash(const ObjectRef& node) {
-    if (!node.defined()) return 0;
-    return this->VisitAttr(node);
-  }
-
- protected:
-  size_t VisitAttrDefault_(const Object* lhs) final;
-  size_t VisitAttr_(const tir::IntImmNode* lhs) final;
-  size_t VisitAttr_(const tir::FloatImmNode* lhs) final;
-  size_t VisitAttr_(const tir::StringImmNode* lhs) final;
-  size_t VisitAttr_(const ArrayNode* lhs) final;
-  size_t VisitAttr_(const StrMapNode* lhs) final;
-  size_t VisitAttr_(const tir::AddNode* op) final;
-  size_t VisitAttr_(const tir::SubNode* op) final;
-  size_t VisitAttr_(const tir::MulNode* op) final;
-  size_t VisitAttr_(const tir::DivNode* op) final;
-  size_t VisitAttr_(const tir::ModNode* op) final;
-  size_t VisitAttr_(const tir::FloorDivNode* op) final;
-  size_t VisitAttr_(const tir::FloorModNode* op) final;
-  size_t VisitAttr_(const tir::MinNode* op) final;
-  size_t VisitAttr_(const tir::MaxNode* op) final;
-  size_t VisitAttr_(const tir::GENode* op) final;
-  size_t VisitAttr_(const tir::GTNode* op) final;
-  size_t VisitAttr_(const tir::LENode* op) final;
-  size_t VisitAttr_(const tir::LTNode* op) final;
-  size_t VisitAttr_(const tir::EQNode* op) final;
-  size_t VisitAttr_(const tir::NENode* op) final;
-  size_t VisitAttr_(const tir::AndNode* op) final;
-  size_t VisitAttr_(const tir::OrNode* op) final;
-  size_t VisitAttr_(const tir::NotNode* op) final;
-  size_t VisitAttr_(const tir::CastNode* op) final;
-  size_t VisitAttr_(const tir::CallNode* op) final;
-  size_t VisitAttr_(const tir::SelectNode* op) final;
-  /*!
-   * \brief alias of dmlc::HashCombine
-   * \param lhs The first hash value.
-   * \param rhs The second hash value.
-   */
-  static size_t Combine(size_t lhs, size_t rhs) {
-    return dmlc::HashCombine(lhs, rhs);
-  }
-};
 }  // namespace tvm
 #endif  // TVM_IR_ATTR_FUNCTOR_H_
diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc
index 868fec6..066b8f9 100644
--- a/src/ir/attrs.cc
+++ b/src/ir/attrs.cc
@@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
   return attrs->dict;
 });
 
-
-using namespace tir;
-// Equal handler.
-bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
-  if (lhs.same_as(rhs)) return true;
-  if (!lhs.defined() && rhs.defined()) return false;
-  if (!rhs.defined() && lhs.defined()) return false;
-  return this->VisitAttr(lhs, rhs);
-}
-
-bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other)
{
-  if (lhs->IsInstance<BaseAttrsNode>()) {
-    AttrsEqual equal;
-    equal.handler_ = this;
-    return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
-        other.get(), equal);
-  }
-  return lhs == other.get();
-}
-
-bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<IntImmNode>()) {
-    return lhs->value == rhs->value;
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<FloatImmNode>()) {
-    return lhs->value == rhs->value;
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other)
{
-  if (const auto* rhs = other.as<StringImmNode>()) {
-    return lhs->value == rhs->value;
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<ArrayNode>()) {
-    if (rhs->data.size() != lhs->data.size()) return false;
-    for (size_t i = 0; i < lhs->data.size(); ++i) {
-      if (!Equal(lhs->data[i], rhs->data[i])) return false;
-    }
-    return true;
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<StrMapNode>()) {
-    if (rhs->data.size() != lhs->data.size()) return false;
-    for (const auto& kv : lhs->data) {
-      auto it = rhs->data.find(kv.first);
-      if (it == rhs->data.end()) return false;
-      if (!Equal(kv.second, it->second)) return false;
-    }
-    return true;
-  } else {
-    return false;
-  }
-}
-
-#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
-  bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \
-    if (const auto* rhs = other.as<NodeName>()) {                       \
-      if (!Equal(lhs->a, rhs->a)) return false;                         \
-      if (!Equal(lhs->b, rhs->b)) return false;                         \
-      return true;                                                      \
-    } else {                                                            \
-      return false;                                                     \
-    }                                                                   \
-  }                                                                     \
-
-TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
-
-bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<NotNode>()) {
-    return Equal(lhs->a, rhs->a);
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<CastNode>()) {
-    if (lhs->dtype != rhs->dtype) return false;
-    return Equal(lhs->value, rhs->value);
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<CallNode>()) {
-    return
-        lhs->name == rhs->name &&
-        lhs->dtype == rhs->dtype &&
-        lhs->call_type == rhs->call_type &&
-        Equal(lhs->args, rhs->args);
-  } else {
-    return false;
-  }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
-  if (const auto* rhs = other.as<SelectNode>()) {
-    return
-        Equal(lhs->condition, rhs->condition) &&
-        Equal(lhs->true_value, rhs->true_value) &&
-        Equal(lhs->false_value, rhs->false_value);
-  } else {
-    return false;
-  }
-}
-
-// Hash Handler.
-size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
-  if (value->IsInstance<BaseAttrsNode>()) {
-    AttrsHash hasher;
-    hasher.handler_ = this;
-    return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
-  } else {
-    return ObjectHash()(GetRef<ObjectRef>(value));
-  }
-}
-
-size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
-  return std::hash<int64_t>()(op->value);
-}
-
-size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
-  return std::hash<double>()(op->value);
-}
-
-size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
-  return std::hash<std::string>()(op->value);
-}
-
-size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
-  size_t result = op->data.size();
-  for (size_t  i = 0; i < op->data.size(); ++i) {
-    result = Combine(result, this->Hash(op->data[i]));
-  }
-  return result;
-}
-
-size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
-    using Entry = std::pair<std::string, ObjectRef>;
-    std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
-    std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
-        return a.first < b.first;
-      });
-    size_t result = 0;
-    for (const Entry& kv : data) {
-      result = Combine(result, std::hash<std::string>()(kv.first));
-      result = Combine(result, this->Hash(kv.second));
-    }
-    return result;
-}
-
-
-#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName)                           \
-  size_t AttrsHashHandler::VisitAttr_(const NodeName* op) {             \
-    static size_t key = std::hash<std::string>()(NodeName::_type_key);  \
-    return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
-  }                                                                     \
-
-TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
-TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
-TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
-TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
-
-size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
-  static size_t key = std::hash<std::string>()(NotNode::_type_key);
-  return Combine(key, Hash(op->a));
-}
-
-size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
-  static size_t key = std::hash<std::string>()(CastNode::_type_key);
-  AttrsHash hasher;
-  size_t res = key;
-  res = Combine(res, hasher(op->dtype));
-  res = Combine(res, Hash(op->value));
-  return res;
-}
-
-size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
-  static size_t key = std::hash<std::string>()(CallNode::_type_key);
-  AttrsHash hasher;
-  size_t res = key;
-  res = Combine(res, hasher(op->name));
-  res = Combine(res, hasher(op->dtype));
-  res = Combine(res, Hash(op->args));
-  return res;
-}
-
-size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
-  static size_t key = std::hash<std::string>()(SelectNode::_type_key);
-  size_t res = key;
-  res = Combine(res, Hash(op->condition));
-  res = Combine(res, Hash(op->true_value));
-  res = Combine(res, Hash(op->false_value));
-  return res;
-}
-
-
-// Default case
-bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
-  if (lhs.same_as(rhs)) return true;
-  if (handler_ == nullptr) {
-    return AttrsEqualHandler().Equal(lhs, rhs);
-  } else {
-    return handler_->Equal(lhs, rhs);
-  }
-}
-
-size_t AttrsHash::operator()(const ObjectRef& node) const {
-  if (!node.defined()) return 0;
-  if (handler_ == nullptr) {
-    return AttrsHashHandler().Hash(node);
-  } else {
-    return handler_->Hash(node);
-  }
-}
-
-size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
-  return hasher(this->dict);
-}
-
-bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
-  if (this == other) return true;
-  if (other == nullptr) return false;
-  if (this->type_index() != other->type_index()) return false;
-  return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
-}
-
 TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
 .set_body_typed([](Attrs attrs) {
   return attrs->ListFieldInfo();
 });
 
-TVM_REGISTER_GLOBAL("ir.AttrsEqual")
-.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
-  return AttrsEqual()(lhs, rhs);
-});
-
 }  // namespace tvm
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index df7b8ff..b2191c1 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -103,6 +103,7 @@ class RemapVarSEqualHandler :
 
   // Function that implements actual equality check.
   bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+    if (!lhs.defined() && !rhs.defined()) return true;
     task_stack_.clear();
     pending_tasks_.clear();
     equal_map_lhs_.clear();
diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc
index 0dbce9b..3884dac 100644
--- a/src/relay/transforms/combine_parallel_conv2d.cc
+++ b/src/relay/transforms/combine_parallel_conv2d.cc
@@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
   }
 
   bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
-    AttrsEqual eq;
+    StructuralEqual eq;
     const Layout kOIHW("OIHW");
     const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
     const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
@@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
   }
 
   bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
-    AttrsEqual eq;
+    StructuralEqual eq;
     auto ta = a->args[index]->type_as<TensorTypeNode>();
     auto tb = b->args[index]->type_as<TensorTypeNode>();
     auto toutput_a = a->type_as<TensorTypeNode>();
diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc
index cd234bb..612dae5 100644
--- a/src/relay/transforms/combine_parallel_dense.cc
+++ b/src/relay/transforms/combine_parallel_dense.cc
@@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
 
  protected:
   virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
-    AttrsEqual eq;
+    StructuralEqual eq;
     const auto* attrs_a = a->attrs.as<DenseAttrs>();
     const auto* attrs_b = b->attrs.as<DenseAttrs>();
     CHECK(attrs_a);
diff --git a/src/relay/transforms/combine_parallel_op.cc b/src/relay/transforms/combine_parallel_op.cc
index 6b9926c..a7f7af2 100644
--- a/src/relay/transforms/combine_parallel_op.cc
+++ b/src/relay/transforms/combine_parallel_op.cc
@@ -23,6 +23,7 @@
  * \brief Abstract class to combine parallel ops and their successive element-wise ops.
  */
 
+#include <tvm/node/structural_hash.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/attrs/nn.h>
@@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) {
 
 bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index)
{
     const CallNode* call = branches[0][depth];
-    AttrsEqual attrs_equal;
+    tvm::StructuralEqual attrs_equal;
     // check if all branches in current depth can be combined
     for (auto it = branches.begin() + 1; it != branches.end(); it++) {
       const Branch& branch = *it;
diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc
index fa63573..55ca3f6 100644
--- a/src/relay/transforms/combine_parallel_op_batch.cc
+++ b/src/relay/transforms/combine_parallel_op_batch.cc
@@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const
CallNode
     return false;
   }
 
-  AttrsEqual eq;
+  StructuralEqual eq;
   for (size_t i = 0; i < a->args.size(); i++) {
     auto ta = a->args[i]->type_as<TensorTypeNode>();
     auto tb = b->args[i]->type_as<TensorTypeNode>();
@@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches)
{
 }
 
 bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t
index) {
-  AttrsEqual eq;
+  StructuralEqual eq;
   auto ta = a->args[index]->type_as<TensorTypeNode>();
   auto tb = b->args[index]->type_as<TensorTypeNode>();
 
diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc
index bb31d32..f905ba5 100644
--- a/src/relay/transforms/eliminate_common_subexpr.cc
+++ b/src/relay/transforms/eliminate_common_subexpr.cc
@@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator {
     const CallNode* new_call = new_expr.as<CallNode>();
     CHECK(new_call);
     const OpNode* op = new_call->op.as<OpNode>();
-    AttrsEqual attrs_equal;
+    StructuralEqual attrs_equal;
 
     if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op),
false)) {
       return new_expr;
diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc
index c3114c7..49f6e3f 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
 Message AddSubBackwardPrep(const Call& call, const Array<Message>& in_messages)
{
   const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
   const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
-  AttrsEqual equal;
+  StructuralEqual equal;
   if (in_messages[0].defined() &&
       MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
     return in_messages[0];
@@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call,
   }
   Message lhs_message = transformer->GetMessage(call->args[0]);
   Message rhs_message = transformer->GetMessage(call->args[1]);
-  AttrsEqual equal;
+  StructuralEqual equal;
 
   if (lhs_message.defined() && rhs_message.defined()) {
     CHECK(equal(lhs_message->axes, rhs_message->axes));
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 6e95441..9168898 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
   // The output.
   IndexedForwardGraph graph_;
   // attribute equal comparator
-  AttrsEqual attr_equal_;
+  StructuralEqual attr_equal_;
   // Update the message stored at the node.
   void Update(const Expr& node,
               IndexedForwardGraph::Node* parent,
diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h
index 8ce42a2..350d9e1 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
                                      const Array<Integer>& lhs_axes,
                                      Expr* rhs_value = nullptr) {
   if (tlhs->shape.size() < trhs->shape.size()) return false;
-  AttrsEqual equal;
+  StructuralEqual equal;
   size_t base = tlhs->shape.size() - trhs->shape.size();
   size_t j = 0;
 
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
index 233bfa5..46d0f67 100644
--- a/src/tir/pass/ffi_api.cc
+++ b/src/tir/pass/ffi_api.cc
@@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
       return RewriteForTensorCore(stmt, schedule, extern_buffer);
   });
 
-TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
-.set_body_typed(
-  [](const ObjectRef& lhs, const ObjectRef& rhs) {
-    return AttrsEqual()(lhs, rhs);
-  });
-
-TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
-.set_body_typed([](const ObjectRef &node) -> int64_t {
-    return AttrsHash()(node);
-});
-
-
 TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
     *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py
index 6d4a685..dbd5934 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -106,7 +106,6 @@ def test_function():
     check_json_roundtrip(fn)
 
 
-@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
 def test_function_attrs():
     param_names = ['a', 'b', 'c', 'd']
     params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in param_names])
diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py
index f4148ca..8f2e9bb 100644
--- a/tests/python/unittest/test_ir_attrs.py
+++ b/tests/python/unittest/test_ir_attrs.py
@@ -51,14 +51,13 @@ def test_dict_attrs():
 
 
 def test_attrs_equal():
-    attr_equal = tvm.ir._ffi_api.AttrsEqual
     dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
     dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
     dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
-    assert attr_equal(dattr0, dattr1)
-    assert not attr_equal(dattr0, dattr2)
-    assert not attr_equal({"x": 1}, tvm.runtime.convert(1))
-    assert not attr_equal([1, 2], tvm.runtime.convert(1))
+    assert tvm.ir.structural_equal(dattr0, dattr1)
+    assert not tvm.ir.structural_equal(dattr0, dattr2)
+    assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1))
+    assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))
 
 
 
diff --git a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
index b3587cd..9a115be 100644
--- a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
+++ b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
@@ -21,28 +21,28 @@ def test_attrs_equal():
     x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
     y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
     z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
-    assert tvm.tir.ir_pass.AttrsEqual(x, y)
-    assert not tvm.tir.ir_pass.AttrsEqual(x, z)
+    assert tvm.ir.structural_equal(x, y)
+    assert not tvm.ir.structural_equal(x, z)
 
     dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
-    assert not tvm.tir.ir_pass.AttrsEqual(dattr, x)
+    assert not tvm.ir.structural_equal(dattr, x)
     dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
-    assert tvm.tir.ir_pass.AttrsEqual(dattr, dattr2)
+    assert tvm.ir.structural_equal(dattr, dattr2)
 
-    assert tvm.tir.ir_pass.AttrsEqual({"x": x}, {"x": y})
+    assert tvm.ir.structural_equal({"x": x}, {"x": y})
     # array related checks
-    assert tvm.tir.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
-    assert not tvm.tir.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
+    assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]})
+    assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]})
 
     n = te.var("n")
-    assert tvm.tir.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
+    assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1})
 
 
 
 
 
 def test_attrs_hash():
-    fhash = tvm.tir.ir_pass.AttrsHash
+    fhash = tvm.ir.structural_hash
     x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
     y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
     assert fhash({"x": x}) == fhash({"x": y})


Mime
View raw message