tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] zhiics commented on a change in pull request #5030: [RELAY] Added a AnnotatedRegion utility class
Date Mon, 23 Mar 2020 18:16:39 GMT
zhiics commented on a change in pull request #5030: [RELAY] Added a AnnotatedRegion utility
class
URL: https://github.com/apache/incubator-tvm/pull/5030#discussion_r396644680
 
 

 ##########
 File path: src/relay/analysis/annotated_region_set.h
 ##########
 @@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relay/pass/annotated_region_set.h
+ * \brief Define data structures to extract and manipulate regions from
+ * a relay function. Regions are denoted by region_begin and region_end
+ * annotations that exist on all the input and output edges of the region.
+ */
+
+#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/ir/error.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+#include <list>
+
+namespace tvm {
+namespace relay {
+
+class AnnotatedRegion;
+class AnnotatedRegionSet;
+
+class AnnotatedRegionNode : public Object {
+ public:
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("id", &id);
+    Array<Expr> nodes_array(nodes.begin(), nodes.end());
+    v->Visit("nodes", &nodes_array);
+    Array<Expr> args_array(ins.begin(), ins.end());
+    v->Visit("args", &args_array);
+    Array<Expr> rets_array(outs.begin(), outs.end());
+    v->Visit("rets", &rets_array);
+  }
+
+  /*! \brief Get the region ID. */
+  int GetID() const {
+    return id;
+  }
+
+  /*! \brief Get the region's inputs. */
+  std::list<Expr> GetInputs() const {
+    return ins;
+  }
+
+  /*! \brief Get the region's outputs. */
+  std::list<Expr> GetOutputs() const {
+    return outs;
+  }
+
+  /*! \brief Get the region's nodes. */
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const {
+    return nodes;
+  }
+
+  static constexpr const char* _type_key = "relay.AnnotatedRegion";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
+
+ protected:
+  /*! \brief The region ID. */
+  int id{-1};
+  /*! \brief The inputs to this region. */
+  std::list<Expr> ins;
+  /*! \brief The outputs of this region */
+  std::list<Expr> outs;
+  /*! \brief Nodes in this region. */
+  std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes;
+
+  friend class AnnotatedRegionSet;
+  friend class AnnotatedRegionSetNode;
+};
+
+/*!
+ * \brief An object to hold the properties of a region as used by the
+ * AnnotatedRegionSet class. This should be considered read-only.
+*/
+class AnnotatedRegion : public ObjectRef {
+ public:
+  AnnotatedRegion() {
+    auto n = make_object<AnnotatedRegionNode>();
+    data_ = std::move(n);
+  }
+
+  /*!
+ * \brief Construct from an object pointer.
+ * \param n The object pointer.
+ */
+  explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {}
+
+  /*! \return Mutable pointers to the node. */
+  AnnotatedRegionNode* operator->() const {
+    auto* ptr = get_mutable();
+    CHECK(ptr != nullptr);
+    return static_cast<AnnotatedRegionNode*>(ptr);
+  }
+};
+
+class AnnotatedRegionSetNode : public Object {
+  using UnorderedRegionSet =
+  std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>;
+  // Create iterator alias for a RegionSet object.
+  using iterator = UnorderedRegionSet::iterator;
+  using const_iterator = UnorderedRegionSet::const_iterator;
+
+ public:
+  /*! \brief Default constructor. */
+  AnnotatedRegionSetNode() = default;
+
+  /*! \return The begin iterator */
+  iterator begin() {
+    return regions_.begin();
+  }
+  /*! \return The end iterator */
+  iterator end() {
+    return regions_.end();
+  }
+  /*! \return The const begin iterator */
+  const_iterator begin() const {
+    return regions_.begin();
+  }
+  /*! \return The const end iterator */
+  const_iterator end() const {
+    return regions_.end();
+  }
+
+  /*!
+   * \brief Get the region that an expression belongs to.
+   *
+   * \param expr Which expr to get the region for.
+   *
+   * \return A pointer to the region, nullptr if the expression
+   * doesn't belong to a region.
+   */
+  AnnotatedRegion GetRegion(const Expr& expr) const;
 
 Review comment:
   As a utility, we might want to overload `operator[](const Expr& expr)`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message