tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use
Date Fri, 01 May 2020 00:51:34 GMT

junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418367759



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>())
{
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_)
{}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }

Review comment:
       ```suggestion
         while (name_alloc_map_.count(
           unique_prefix =
             prefix + "_" + std::to_string(++it->second)
         ) > 0);
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message