tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #5160: [NODE][IR] Introduce StructuralHash for the Unified IR.
Date Sat, 28 Mar 2020 21:48:29 GMT
MarisaKirisame commented on a change in pull request #5160: [NODE][IR] Introduce StructuralHash
for the Unified IR.
URL: https://github.com/apache/incubator-tvm/pull/5160#discussion_r399713065
 
 

 ##########
 File path: src/node/structural_hash.cc
 ##########
 @@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/node/structural_hash.cc
+ */
+#include <tvm/node/structural_hash.h>
+#include <tvm/node/reflection.h>
+#include <tvm/node/functor.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/registry.h>
+
+#include <unordered_map>
+#include <algorithm>
+
+
+namespace tvm {
+
+// Define the dispatch functio here since primary user is in this file.
+void ReflectionVTable::
+SHashReduce(const Object* self, SHashReducer reducer) const {
+  uint32_t tindex = self->type_index();
+  if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) {
+    LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey()
+        << " is not registered via TVM_REGISTER_NODE_TYPE";
+  }
+  fshash_reduce_[tindex](self, reducer);
+}
+
+// Hash handler that handles free vars
+// by assigning an unique counter in the order of their ocurrence.
+//
+// This algorithm depends on the determinism of the traversal of SHash function.
+// In particular, when we traverse unordered_map, we should first sort
+// the entries by keys(or hash of keys) before traversing.
+
+class VarCountingSHashHandler :
+      public SHashReducer::Handler {
+ public:
+  /*! \brief Pending reduce tasks. */
+  struct Task {
+    /*!
+     * \brief The object operand to be hashed.
+     *  If the object is nullptr, then the reduced hash is already set
+     *  the correct value.
+     */
+    ObjectRef object;
+    /*! \biref The partially reduce hash value.*/
+    size_t reduced_hash;
+    /*! \brief The expected location in the result stack. */
+    size_t result_stack_index = std::numeric_limits<size_t>::max();
+    /*! \brief Whether the children has been expanded via SEqualReduce */
+    bool children_expanded{false};
+    /*! \brief Whether the node is graph node. */
+    bool graph_node_hash{false};
+    /*! \brief whether to map the free variables. */
+    bool map_free_vars;
+
+    Task() = default;
+    explicit Task(ObjectRef object, size_t reduced_hash, bool map_free_vars)
+        : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {}
+  };
+
+
+  VarCountingSHashHandler() {}
+
+  void MarkGraphNode() final {
+    // need to push to pending tasks in this case
+    CHECK(!allow_push_to_stack_ && !task_stack_.empty());
+    task_stack_.back().graph_node_hash = true;
+  }
+
+  bool LookupHashedValue(const ObjectRef& key, size_t* hash_value) final {
+    auto it = hash_memo_.find(key);
+    if (it != hash_memo_.end()) {
+      hash_value[0] = it->second;
+      return true;
+    }
+    return false;
+  }
+
+  void SHashReduceHashedValue(size_t hashed_value) final {
+    pending_tasks_.emplace_back(
+        Task(ObjectRef(nullptr), hashed_value, false));
+  }
+
+  void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) final {
+    CHECK(!hash_memo_.count(GetRef<ObjectRef>(var)));
+    if (map_free_vars) {
+      // use counter value.
+      size_t value = std::hash<size_t>()(free_var_counter_++);
+      pending_tasks_.emplace_back(
+          Task(ObjectRef(nullptr), value, false));
+    } else {
+      // use pointer hash
+      size_t value = std::hash<const runtime::Object*>()(var);
+      pending_tasks_.emplace_back(
+          Task(ObjectRef(nullptr), value, false));
+    }
+  }
+
+  void SHashReduce(const ObjectRef& object, bool map_free_vars) final {
+    // Directly push the result
+    // Note: it is still important to push the result to pendng tasks
+    // so that the reduction order of hash values stays the same.
+    if (!object.defined()) {
+      pending_tasks_.emplace_back(Task(ObjectRef(nullptr), 0, false));
+      return;
+    }
+    auto it = hash_memo_.find(object);
+    if (it != hash_memo_.end()) {
+      pending_tasks_.emplace_back(
+          Task(ObjectRef(nullptr), it->second, false));
+    } else {
+      // Push a pending task with initial value.
+      pending_tasks_.emplace_back(
+          Task(object, object->GetTypeKeyHash(), map_free_vars));
+    }
+  }
+
+  size_t Hash(const ObjectRef& object, bool map_free_vars) {
+    CHECK_EQ(task_stack_.size(), 0U);
+    CHECK_EQ(pending_tasks_.size(), 0U);
+    CHECK_EQ(result_stack_.size(), 0U);
+
+    this->SHashReduce(object, map_free_vars);
+    CHECK_EQ(pending_tasks_.size(), 1U);
+    CHECK(allow_push_to_stack_);
+    task_stack_.emplace_back(std::move(pending_tasks_.back()));
+    pending_tasks_.clear();
+
+    this->RunTasks();
+
+    CHECK_EQ(result_stack_.size(), 1U);
+    size_t ret = result_stack_.back();
+    result_stack_.pop_back();
+    return ret;
+  }
+
+ protected:
+  /*!
+   * \brief Pop the top entry of the task stack and push the hash into the result stack.
+   */
+  void PopTaskStack() {
+    const auto& entry = task_stack_.back();
+    result_stack_.push_back(entry.reduced_hash);
+    task_stack_.pop_back();
+  }
+  /*!
+   * \brief Compute the reduced hash value for the task.
+   * \param task The indicated task.
+   */
+  size_t ReduceHash(const Task& task) {
+    size_t stack_begin = task.result_stack_index;
+    CHECK_LE(stack_begin, result_stack_.size());
+
+    // combine in the reverse order of the stack.
+    size_t reduced_hash = task.reduced_hash;
+    for (size_t i = result_stack_.size(); i != stack_begin; --i) {
+      reduced_hash = HashCombine(reduced_hash, result_stack_[i - 1]);
+    }
+    result_stack_.resize(stack_begin);
+    return reduced_hash;
+  }
+  // run the tasks.
+  void RunTasks() {
+    while (task_stack_.size() != 0) {
+      // Caution: entry becomes invalid when the stack changes
+      auto& entry = task_stack_.back();
+      if (entry.children_expanded) {
+        // reduce hash
+        entry.reduced_hash = ReduceHash(entry);
+        // When all the children has expanded and visited.
+        // entry.reduced_hash contains the reduced hash result.
+        auto it = hash_memo_.find(entry.object);
+        if (it != hash_memo_.end()) {
+          // use the pre-computed hash for the object.
+          entry.reduced_hash = it->second;
+        } else {
+          // Append the graph node counter to the hash
+          // so that we can distinguish DAG from trees.
+          if (entry.graph_node_hash) {
+            entry.reduced_hash = HashCombine(
+                entry.reduced_hash,
+                std::hash<size_t>()(graph_node_counter_++));
+          }
+          hash_memo_[entry.object] = entry.reduced_hash;
+        }
+        // send value to parent.
+        this->PopTaskStack();
+      } else if (!entry.object.defined()) {
+        // Direct send value to parent
 
 Review comment:
   ```suggestion
           // Directly send value to parent
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message