tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5740: [Object][Runtime] Introduce runtime::Map
Date Wed, 10 Jun 2020 03:52:33 GMT

junrushao1994 commented on a change in pull request #5740:
URL: https://github.com/apache/incubator-tvm/pull/5740#discussion_r437844727



##########
File path: include/tvm/runtime/container.h
##########
@@ -1554,6 +1593,954 @@ struct PackedFuncValueConverter<Optional<T>> {
   }
 };
 
+/*! \brief map node content */
+class MapNode : public Object {
+  /*! \brief The number of elements in a memory block */
+  static constexpr int kBlockCap = 16;
+  /*! \brief Maximum load factor of the hash map */
+  static constexpr double kMaxLoadFactor = 0.99;
+  /*! \brief Binary representation of the metadata of an empty slot */
+  static constexpr uint8_t kEmptySlot = uint8_t(0b11111111);
+  /*! \brief Binary representation of the metadata of a protected slot */
+  static constexpr uint8_t kProtectedSlot = uint8_t(0b11111110);
+  /*! \brief Number of probing choices available */
+  static constexpr int kNumJumpDists = 126;
+  /* clang-format off */
+  /*! \brief Candidates of probing distance */
+  TVM_DLL static constexpr uint64_t kJumpDists[kNumJumpDists] {
+    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+    // Quadratic probing with triangle numbers. See also:
+    // 1) https://en.wikipedia.org/wiki/Quadratic_probing
+    // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/
+    // 3) https://github.com/skarupke/flat_hash_map
+    21, 28, 36, 45, 55, 66, 78, 91, 105, 120,
+    136, 153, 171, 190, 210, 231, 253, 276, 300, 325,
+    351, 378, 406, 435, 465, 496, 528, 561, 595, 630,
+    666, 703, 741, 780, 820, 861, 903, 946, 990, 1035,
+    1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540,
+    1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145,
+    2211, 2278, 2346, 2415, 2485, 2556, 2628,
+    // larger triangle numbers
+    8515, 19110, 42778, 96141, 216153,
+    486591, 1092981, 2458653, 5532801, 12442566,
+    27993903, 62983476, 141717030, 318844378, 717352503,
+    1614057336, 3631522476, 8170957530, 18384510628, 41364789378,
+    93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695,
+    5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000,
+    309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701,
+    17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626,
+    1029107982097042876, 2315492959180353330, 5209859154120846435,
+  };
+  /* clang-format on */
+
+ public:
+  /*! \brief Type of the keys in the hash map */
+  using key_type = ObjectRef;
+  /*! \brief Type of the values in the hash map */
+  using mapped_type = ObjectRef;
+
+  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeMap;
+  static constexpr const char* _type_key = "Map";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
+
+ private:
+  struct KVType;
+  struct Block;
+  struct ListNode;
+
+ public:
+  class iterator;
+
+  /*!
+   * \brief Destroy the MapNode
+   */
+  ~MapNode() { this->Reset(); }
+
+  /*!
+   * \brief Number of elements in the MapNode
+   * \return The result
+   */
+  size_t size() const { return size_; }
+
+  /*!
+   * \brief Index value associated with a key, create new entry if the key does not exist
+   * \param key The indexing key
+   * \return The mutable reference to the value
+   */
+  mapped_type& operator[](const key_type& key) { return Emplace(key, mapped_type()).Val();
}
+
+  /*!
+   * \brief Count the number of times a key exists in the MapNode
+   * \param key The indexing key
+   * \return The result, 0 or 1
+   */
+  size_t count(const key_type& key) const { return !Search(key).IsNone(); }
+
+  /*!
+   * \brief Index value associated with a key, throw exception if the key does not exist
+   * \param key The indexing key
+   * \return The const reference to the value
+   */
+  const mapped_type& at(const key_type& key) const { return At(key); }
+
+  /*!
+   * \brief Index value associated with a key, throw exception if the key does not exist
+   * \param key The indexing key
+   * \return The mutable reference to the value
+   */
+  mapped_type& at(const key_type& key) { return At(key); }
+
+  /*! \return begin iterator */
+  iterator begin() const { return size_ == 0 ? iterator() : iterator(0, this); }
+
+  /*! \return end iterator */
+  iterator end() const { return size_ == 0 ? iterator() : iterator(slots_ + 1, this); }
+
+  /*!
+   * \brief Index value associated with a key
+   * \param key The indexing key
+   * \return The iterator of the entry associated with the key, end iterator if not exists
+   */
+  iterator find(const key_type& key) const {
+    ListNode n = Search(key);
+    return n.IsNone() ? end() : iterator(n.i, this);
+  }
+
+  /*!
+   * \brief Insert and construct in-place with the given args, do nothing if key already
exists
+   * \tparam Args Type of the args forwarded to the constructor
+   */
+  template <typename... Args>
+  void emplace(Args&&... args) {
+    Emplace(std::forward<Args>(args)...);
+  }
+
+  /*!
+   * \brief Erase the entry associated with the key, do nothing if not exists
+   * \param key The indexing key
+   */
+  void erase(const key_type& key) { Erase(key); }
+
+  /*!
+   * \brief Erase the entry associated with the iterator
+   * \param position The iterator
+   */
+  void erase(const iterator& position) {
+    uint64_t i = position.i;
+    if (position.self != nullptr && i <= this->slots_) {
+      Erase(ListNode(i, this));
+    }
+  }
+
+ private:
+  /*!
+   * \brief reset the array to content from iterator.
+   * \param first begin of iterator
+   * \param last end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template <typename IterType>
+  void Assign(IterType first, IterType last) {
+    int64_t cap = std::distance(first, last);
+    this->ReleaseItems();
+    this->Reserve(cap);
+    for (; first != last; ++first) {
+      this->Emplace(*first);
+    }
+  }
+
+  /*!
+   * \brief Search for the given key
+   * \param key The key
+   * \return ListNode that associated with the key
+   */
+  ListNode Search(const key_type& key) const {
+    if (this->size_ == 0) {
+      return ListNode();
+    }
+    for (ListNode n = ListNode::GetHead(ObjectHash()(key), this); !n.IsNone(); n.MoveToNext(this))
{
+      if (ObjectEqual()(key, n.Key())) {
+        return n;
+      }
+    }
+    return ListNode();
+  }
+
+  /*!
+   * \brief Search for the given key, throw exception if not exists
+   * \param key The key
+   * \return ListNode that associated with the key
+   */
+  mapped_type& At(const key_type& key) const {
+    ListNode n = Search(key);
+    CHECK(!n.IsNone()) << "IndexError: key is not in Map";
+    return n.Val();
+  }
+
+  /*!
+   * \brief In-place construct an entry, or do nothing if already exists
+   * \tparam Item Type of arguments forwarded to the constructor
+   * \param arg Arguments fed to the constructor
+   * \return ListNode that associated with the key, no matter whether it already exists
+   */
+  template <typename Item>
+  ListNode Emplace(Item&& arg) {
+    KVType item(std::forward<Item>(arg));
+    return Emplace(std::move(item.k), std::move(item.v));
+  }
+
+  /*!
+   * \brief In-place construct an entry, or do nothing if already exists
+   * \tparam Key Type of the key
+   * \tparam Args Type of the rest of the arguments fed to the constructor
+   * \param key The indexing key
+   * \param args Other arguments
+   * \return ListNode that associated with the key, no matter whether it already exists
+   */
+  template <typename Key, typename... Args>
+  ListNode Emplace(Key&& key, Args&&... args) {
+    ReHashIfNone();
+    // required that `m` to be the head of a linked list through which we can iterator
+    ListNode m = ListNode::FromHash(ObjectHash()(key), this);
+    // `m` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list
+    // Case 1: empty
+    if (m.IsEmpty()) {
+      KVType v(std::forward<Key>(key), std::forward<Args>(args)...);
+      m.NewHead(std::move(v));
+      this->size_ += 1;
+      return m;
+    }
+    // Case 2: body of an irrelevant list
+    if (!m.IsHead()) {
+      // we move the elements around and construct the single-elements linked list
+      return SpareListHead(std::move(m), std::forward<Key>(key), std::forward<Args>(args)...);
+    }
+    // Case 3: head of the relevant list
+    // we iterate through the linked list until the end
+    ListNode n = m;
+    do {
+      // find equal item, do not insert
+      if (ObjectEqual()(key, n.Key())) {
+        return n;
+      }
+      // make sure `m` is the previous element of `n`
+      m = n;
+    } while (n.MoveToNext(this));
+    // `m` is the tail of the linked list
+    // always check capacity before insertion
+    if (ReHashIfFull()) {
+      return Emplace(std::forward<Key>(key), std::forward<Args>(args)...);
+    }
+    uint8_t jump;
+    ListNode empty;
+    // rehash if there is no empty space after `m`
+    if (ReHashIfNoNextEmpty(m, &empty, &jump)) {
+      return Emplace(std::forward<Key>(key), std::forward<Args>(args)...);
+    }
+    KVType v(std::forward<Key>(key), std::forward<Args>(args)...);
+    empty.NewTail(std::move(v));
+    // link `n` to `empty`, and move forward
+    m.SetJump(jump);
+    this->size_ += 1;
+    return empty;
+  }
+
+  /*!
+   * \brief Spare an entry to be the head of a linked list
+   * \tparam Args Type of the arguments fed to the constructor
+   * \param n The given entry to be spared
+   * \param args The arguments
+   * \return ListNode that associated with the head
+   */
+  template <typename... Args>
+  ListNode SpareListHead(ListNode n, Args&&... args) {
+    // `n` is not the head of the linked list
+    // move the original item of `n` (if any)
+    // and construct new item on the position `n`
+    if (ReHashIfFull()) {
+      // always check capacity before insertion
+      return Emplace(std::forward<Args>(args)...);
+    }
+    // To make `n` empty, we
+    // 1) find `w` the previous element of `n` in the linked list
+    // 2) copy the linked list starting from `r = n`
+    // 3) paste them after `w`
+    // read from the linked list after `r`
+    ListNode r = n;
+    // write to the tail of `w`
+    ListNode w = n.GetPrev(this);
+    // after `n` is moved, we disallow writing to the slot
+    bool is_first = true;
+    uint8_t r_meta, jump;
+    ListNode empty;
+    do {
+      // `jump` describes how `w` is jumped to `empty`
+      // rehash if there is no empty space after `w`
+      if (ReHashIfNoNextEmpty(w, &empty, &jump)) {
+        return Emplace(std::forward<Args>(args)...);
+      }
+      // move `r` to `empty`
+      empty.NewTail(std::move(r.Data()));
+      // clear the metadata of `r`
+      r_meta = r.Meta();
+      if (is_first) {
+        is_first = false;
+        r.SetProtected();
+      } else {
+        r.SetEmpty();
+      }
+      // link `w` to `empty`, and move forward
+      w.SetJump(jump);
+      w = empty;
+      // move `r` forward as well
+    } while (r.MoveToNext(this, r_meta));
+    // finally we have done moving the linked list
+    // fill data_ into `n`
+    KVType v(std::forward<Args>(args)...);
+    n.NewHead(std::move(v));
+    this->size_ += 1;
+    return n;
+  }
+
+  /*!
+   * \brief Remove a ListNode
+   * \param n The node to be removed
+   */
+  void Erase(const ListNode& n) {
+    this->size_ -= 1;
+    if (!n.HasNext()) {
+      // `n` is the last
+      if (!n.IsHead()) {
+        // cut the link if there is any
+        n.GetPrev(this).SetJump(0);
+      }
+      n.Data().KVType::~KVType();
+      n.SetEmpty();
+    } else {
+      ListNode last = n, prev = n;
+      for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) {
+      }
+      n.Data() = std::move(last.Data());
+      last.SetEmpty();
+      prev.SetJump(0);
+    }
+  }
+
+  /*!
+   * \brief Remove an entry associated with the given key
+   * \param key The node to be removed
+   */
+  void Erase(const key_type& key) {
+    ListNode n = Search(key);
+    if (!n.IsNone()) {
+      Erase(n);
+    }
+  }
+
+  /*!
+   * \brief Reserve some space
+   * \param count The space to be reserved
+   */
+  void Reserve(uint64_t count) {
+    if (slots_ < count * 2) {
+      ReHash(count * 2);
+    }
+  }
+
+  /*! \brief Clear the container to empty, release all memory acquired */
+  void Reset() {
+    this->ReleaseItems();
+    delete[] data_;
+    data_ = nullptr;
+    slots_ = 0;
+    size_ = 0;
+    fib_ = 63;
+  }
+
+  /*! \brief Clear the container to empty, release all entries */
+  void ReleaseItems() {
+    uint64_t n_blocks = CalcNumBlocks(this->slots_);
+    MapNode* m = this;
+    for (uint64_t bi = 0; bi < n_blocks; ++bi) {
+      uint8_t* m_m = m->data_[bi].b;
+      KVType* m_d = reinterpret_cast<KVType*>(m->data_[bi].b + kBlockCap);
+      for (int j = 0; j < kBlockCap; ++j, ++m_m, ++m_d) {
+        uint8_t& meta = *m_m;
+        if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
+          meta = uint8_t(kEmptySlot);
+          m_d->KVType::~KVType();
+        }
+      }
+    }
+    this->size_ = 0;
+  }
+
+  /*!
+   * \brief Re-hashing with the given required capacity
+   * \param required The lower bound of capacity required
+   */
+  void ReHash(uint64_t required) {
+    constexpr uint64_t one = 1;
+    uint64_t new_n_slots = static_cast<uint64_t>(required / kMaxLoadFactor) + 1;
+    new_n_slots = std::max(new_n_slots, required);
+    if (new_n_slots <= 0) {
+      return;
+    }
+    uint8_t new_fib = __builtin_clzll(new_n_slots);
+    new_n_slots = one << (64 - new_fib);
+    if (new_n_slots <= slots_ + 1) {
+      return;
+    }
+    ObjectPtr<MapNode> p = MoveFrom(new_fib, new_n_slots, this);
+    std::swap(p->data_, this->data_);
+    std::swap(p->slots_, this->slots_);
+    std::swap(p->size_, this->size_);
+    std::swap(p->fib_, this->fib_);
+  }
+
+  /*!
+   * \brief Re-hashing if the container is empty
+   * \return If re-hashing happens
+   */
+  bool ReHashIfNone() {
+    constexpr uint64_t min_size = 7;
+    if (slots_ == 0) {
+      ReHash(min_size);
+      return true;
+    }
+    return false;
+  }
+
+  /*!
+   * \brief Re-hashing if the container achieves max load factor
+   * \return If re-hashing happens
+   */
+  bool ReHashIfFull() {
+    constexpr uint64_t min_size = 7;
+    if (slots_ == 0 || size_ + 1 > (slots_ + 1) * kMaxLoadFactor) {
+      ReHash(std::max(min_size, size_ + 1));
+      return true;
+    }
+    return false;
+  }
+
+  /*!
+   * \brief Re-hashing if cannot find space for the linked-list
+   * \param n Find the next empty element of this node
+   * \param empty The resulting empty element
+   * \param jump Jump required to the empty element
+   * \return If re-hashing happens
+   */
+  bool ReHashIfNoNextEmpty(const ListNode& n, ListNode* empty, uint8_t* jump) {
+    constexpr uint64_t min_size = 7;
+    *empty = n.GetNextEmpty(this, jump);
+    if (empty->IsNone()) {
+      ReHash(std::max(min_size, slots_ * 2 + 1));
+      return true;
+    }
+    return false;
+  }
+
+ public:
+  /*!
+   * \brief Create an empty container
+   * \return The object created
+   */
+  static ObjectPtr<MapNode> Empty() {
+    ObjectPtr<MapNode> p = make_object<MapNode>();
+    p->data_ = nullptr;
+    p->slots_ = 0;
+    p->size_ = 0;
+    p->fib_ = 63;
+    return p;
+  }
+
+ private:
+  /*!
+   * \brief Create an empty container
+   * \param fib The fib shift provided
+   * \param n_slots Number of slots required
+   * \return The object created
+   */
+  static ObjectPtr<MapNode> Empty(uint8_t fib, uint64_t n_slots) {
+    if (n_slots == 0) {
+      return Empty();
+    }
+    ObjectPtr<MapNode> p = make_object<MapNode>();
+    uint64_t n_blocks = CalcNumBlocks(n_slots - 1);
+    Block* block = p->data_ = new Block[n_blocks];
+    p->slots_ = n_slots - 1;
+    p->size_ = 0;
+    p->fib_ = fib;
+    for (uint64_t i = 0; i < n_blocks; ++i, ++block) {
+      std::fill(block->b, block->b + kBlockCap, uint8_t(kEmptySlot));
+    }
+    return p;
+  }
+
+  /*!
+   * \brief Create an empty container with elements moving from another MapNode
+   * \param fib The fib shift provided
+   * \param n_slots Number of slots required
+   * \param m The source container
+   * \return The object created
+   */
+  static ObjectPtr<MapNode> MoveFrom(uint8_t fib, uint64_t n_slots, MapNode* m) {
+    ObjectPtr<MapNode> p = MapNode::Empty(fib, n_slots);
+    uint64_t n_blocks = CalcNumBlocks(m->slots_);
+    for (uint64_t bi = 0; bi < n_blocks; ++bi) {
+      uint8_t* m_m = m->data_[bi].b;
+      KVType* m_d = reinterpret_cast<KVType*>(m->data_[bi].b + kBlockCap);
+      for (int j = 0; j < kBlockCap; ++j, ++m_m, ++m_d) {
+        uint8_t& meta = *m_m;
+        if (meta != uint8_t(kProtectedSlot) && meta != uint8_t(kEmptySlot)) {
+          meta = uint8_t(kEmptySlot);
+          p->Emplace(std::move(m_d->k), std::move(m_d->v));
+        }
+      }
+    }
+    delete[] m->data_;
+    m->data_ = nullptr;
+    m->slots_ = 0;
+    m->size_ = 0;
+    m->fib_ = 0;
+    return p;
+  }
+
+  /*!
+   * \brief Create an empty container with elements copying from another MapNode
+   * \param m The source container
+   * \return The object created
+   */
+  static ObjectPtr<MapNode> CopyFrom(MapNode* m) {
+    ObjectPtr<MapNode> p = make_object<MapNode>();
+    uint64_t n_blocks = CalcNumBlocks(m->slots_);
+    p->data_ = new Block[n_blocks];
+    p->slots_ = m->slots_;
+    p->size_ = m->size_;
+    p->fib_ = m->fib_;
+    for (uint64_t bi = 0; bi < n_blocks; ++bi) {
+      uint8_t* m_m = m->data_[bi].b;
+      uint8_t* p_m = p->data_[bi].b;
+      KVType* m_d = reinterpret_cast<KVType*>(m->data_[bi].b + kBlockCap);
+      KVType* p_d = reinterpret_cast<KVType*>(p->data_[bi].b + kBlockCap);
+      for (int j = 0; j < kBlockCap; ++j, ++m_m, ++m_d, ++p_m, ++p_d) {
+        uint8_t& meta = *p_m = *m_m;
+        CHECK(meta != kProtectedSlot);
+        if (meta != uint8_t(kEmptySlot)) {
+          new (p_d) KVType(*m_d);
+        }
+      }
+    }
+    return p;
+  }
+
+  static uint64_t CalcNumBlocks(uint64_t n_slots_m1) {
+    uint64_t n_slots = n_slots_m1 > 0 ? n_slots_m1 + 1 : 0;
+    return (n_slots + kBlockCap - 1) / kBlockCap;
+  }
+
+  /*! \brief Alternative to std::pair with standard layout */
+  struct KVType {
+    template <class K, class V>
+    KVType(const K& k, const V& v) : k(k), v(v) {}
+    template <class K, class V>
+    KVType(K&& k, V&& v) : k(std::forward<K>(k)), v(std::forward<V>(v))
{}
+    template <class K, class V>
+    KVType(const K& k, V&& v) : k(k), v(std::forward<V>(v)) {}
+    template <class K, class V>
+    KVType(K&& k, const V& v) : k(std::forward<K>(k)), v(v) {}
+    /*! \brief The STL type */
+    using TStl = std::pair<key_type, mapped_type>;
+    /*! \brief Converting from STL type */
+    KVType(const TStl& kv) : k(kv.first), v(kv.second) {}  // NOLINT(*)
+    /*! \brief Converting to STL type */
+    operator TStl() const { return std::make_pair(k, v); }
+    /*! \brief The key, or std::pair::first */
+    key_type k;
+    /*! \brief The value, or std::pair::second */
+    mapped_type v;
+  };
+
+  /*! \brief POD type of a chunk of memory used to */
+  struct Block {
+    uint8_t b[kBlockCap + kBlockCap * sizeof(KVType)];
+  };
+
+  /*! \brief The implicit in-place linked list used to index a chain */
+  struct ListNode {

Review comment:
       I will write some comprehensive comments on the high-level description of this algorithm
to make sure people could understand this




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message