From commits-return-11040-archive-asf-public=cust-asf.ponee.io@tvm.apache.org Fri Apr 10 14:46:34 2020 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [207.244.88.153]) by mx-eu-01.ponee.io (Postfix) with SMTP id 171ED18057A for ; Fri, 10 Apr 2020 16:46:32 +0200 (CEST) Received: (qmail 39392 invoked by uid 500); 10 Apr 2020 14:46:32 -0000 Mailing-List: contact commits-help@tvm.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@tvm.apache.org Delivered-To: mailing list commits@tvm.apache.org Received: (qmail 39383 invoked by uid 99); 10 Apr 2020 14:46:32 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 10 Apr 2020 14:46:32 +0000 Received: by gitbox.apache.org (ASF Mail Server at gitbox.apache.org, from userid 33) id 5B5C88B6A3; Fri, 10 Apr 2020 14:46:32 +0000 (UTC) Date: Fri, 10 Apr 2020 14:46:32 +0000 To: "commits@tvm.apache.org" Subject: [incubator-tvm] branch master updated: [REFACTOR][IR] Move to runtime::String (#5276) MIME-Version: 1.0 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit Message-ID: <158652999197.21153.17461397398758910468@gitbox.apache.org> From: tqchen@apache.org X-Git-Host: gitbox.apache.org X-Git-Repo: incubator-tvm X-Git-Refname: refs/heads/master X-Git-Reftype: branch X-Git-Oldrev: 48082358879b5402d273715418f8d629d689af54 X-Git-Newrev: 5da361d3adf87033b90ab5ff6f3117e8af1bee43 X-Git-Rev: 5da361d3adf87033b90ab5ff6f3117e8af1bee43 X-Git-NotificationType: ref_changed_plus_diff X-Git-Multimail-Version: 1.5.dev Auto-Submitted: auto-generated This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git The following commit(s) were added to refs/heads/master by this push: new 5da361d [REFACTOR][IR] Move to runtime::String (#5276) 5da361d is described below commit 5da361d3adf87033b90ab5ff6f3117e8af1bee43 Author: Zhi <5145158+zhiics@users.noreply.github.com> AuthorDate: Fri Apr 10 07:46:23 2020 -0700 [REFACTOR][IR] Move to runtime::String (#5276) * Use runtime::String * move string to tvm namespace * add const char* constructor * implicit cast from std::string --- include/tvm/ir/expr.h | 7 +-- include/tvm/ir/transform.h | 11 ++-- include/tvm/node/container.h | 2 + include/tvm/node/node.h | 2 + include/tvm/relay/transform.h | 5 +- include/tvm/runtime/container.h | 10 +++- include/tvm/target/target.h | 6 +-- include/tvm/tir/stmt_functor.h | 4 +- include/tvm/tir/transform.h | 4 +- python/tvm/autotvm/task/task.py | 3 ++ python/tvm/relay/backend/graph_runtime_codegen.py | 3 +- python/tvm/runtime/container.py | 63 ++++++++++++++++++---- python/tvm/runtime/object_generic.py | 4 +- python/tvm/target/target.py | 10 ++-- src/autotvm/touch_extractor.cc | 8 +-- src/ir/attrs.cc | 2 +- src/ir/expr.cc | 7 ++- src/ir/op.cc | 8 +-- src/ir/transform.cc | 26 ++++----- src/node/container.cc | 1 - src/relay/backend/build_module.cc | 17 +++--- src/relay/backend/compile_engine.cc | 17 +++--- src/relay/backend/contrib/codegen_c/codegen_c.h | 2 +- src/relay/backend/graph_runtime_codegen.cc | 9 ++-- src/relay/backend/vm/compiler.cc | 6 +-- src/relay/backend/vm/inline_primitives.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/backend/vm/removed_unused_funcs.cc | 7 ++- src/relay/ir/transform.cc | 4 +- src/relay/op/tensor/transform.cc | 1 - src/relay/transforms/alter_op_layout.cc | 3 +- src/relay/transforms/annotate_target.cc | 11 ++-- src/relay/transforms/canonicalize_cast.cc | 3 +- src/relay/transforms/canonicalize_ops.cc | 3 +- src/relay/transforms/combine_parallel_conv2d.cc | 3 +- src/relay/transforms/combine_parallel_dense.cc | 3 +- src/relay/transforms/combine_parallel_op_batch.cc | 3 +- src/relay/transforms/convert_layout.cc | 4 +- src/relay/transforms/device_annotation.cc | 3 +- src/relay/transforms/eliminate_common_subexpr.cc | 3 +- src/relay/transforms/fast_math.cc | 3 +- src/relay/transforms/fold_scale_axis.cc | 6 +-- src/relay/transforms/fuse_ops.cc | 3 +- src/relay/transforms/inline.cc | 2 +- src/relay/transforms/legalize.cc | 2 +- src/relay/transforms/merge_composite.cc | 21 ++++---- src/relay/transforms/partition_graph.cc | 2 +- src/relay/transforms/simplify_inference.cc | 3 +- src/relay/transforms/to_a_normal_form.cc | 2 +- src/runtime/container.cc | 32 ++++++++--- src/target/build_common.h | 2 +- src/target/generic_func.cc | 5 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/llvm_module.cc | 2 +- src/target/source/codegen_c.cc | 2 +- src/target/source/codegen_metal.cc | 2 +- src/target/source/codegen_opengl.cc | 2 +- src/target/source/codegen_vhls.cc | 7 ++- src/target/spirv/build_vulkan.cc | 2 +- src/target/spirv/codegen_spirv.cc | 2 +- src/target/stackvm/codegen_stackvm.cc | 2 +- src/target/target.cc | 40 +++++++------- src/tir/ir/expr.cc | 25 +++++---- src/tir/ir/stmt_functor.cc | 6 +-- src/tir/ir/transform.cc | 2 +- src/tir/pass/arg_binder.cc | 18 ++++--- src/tir/pass/hoist_if_then_else.cc | 11 ++-- src/tir/pass/tensor_core.cc | 2 +- src/tir/transforms/bind_device_type.cc | 3 +- src/tir/transforms/make_packed_api.cc | 13 +++-- src/tir/transforms/remap_thread_axis.cc | 8 ++- src/tir/transforms/split_host_device.cc | 2 +- tests/cpp/container_test.cc | 2 +- tests/python/relay/test_annotate_target.py | 4 +- tests/python/relay/test_call_graph.py | 2 +- tests/python/relay/test_external_codegen.py | 5 +- tests/python/relay/test_ir_nodes.py | 4 +- .../python/relay/test_ir_structural_equal_hash.py | 6 +-- tests/python/relay/test_pass_inline.py | 28 +++++----- tests/python/relay/test_pass_merge_composite.py | 32 +++++------ tests/python/relay/test_pass_partition_graph.py | 49 +++++++---------- tests/python/unittest/test_ir_attrs.py | 2 +- topi/include/topi/contrib/cublas.h | 4 +- topi/include/topi/contrib/rocblas.h | 2 +- 85 files changed, 364 insertions(+), 306 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index a683fd6..6822159 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr { * \param value The value to be constructed. */ TVM_DLL PrimExpr(float value); // NOLINT(*) + /*! - * \brief construct from string. - * \param str The value to be constructed. + * \brief construct from runtime String. + * \param value The value to be constructed. */ - TVM_DLL PrimExpr(std::string str); // NOLINT(*) + TVM_DLL PrimExpr(runtime::String value); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index ecd234a..3a9913f 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -57,6 +57,7 @@ #define TVM_IR_TRANSFORM_H_ #include +#include #include #include #include @@ -95,9 +96,9 @@ class PassContextNode : public Object { int fallback_device{static_cast(kDLCPU)}; /*! \brief The list of required passes. */ - Array required_pass; + Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; + Array disabled_pass; TraceFunc trace_func; @@ -197,7 +198,7 @@ class PassInfoNode : public Object { std::string name; /*! \brief The passes that are required to perform the current pass. */ - Array required; + Array required; PassInfoNode() = default; @@ -226,7 +227,7 @@ class PassInfo : public ObjectRef { */ TVM_DLL PassInfo(int opt_level, std::string name, - Array required); + Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -346,7 +347,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const Array& required); + const Array& required); } // namespace transform } // namespace tvm diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 461fa11..cf2ac26 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -36,6 +36,8 @@ namespace tvm { +using runtime::String; +using runtime::StringObj; using runtime::Object; using runtime::ObjectPtr; using runtime::ObjectRef; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 04f477b..b39e3b4 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -35,6 +35,7 @@ #define TVM_NODE_NODE_H_ #include +#include #include #include #include @@ -62,6 +63,7 @@ using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using runtime::String; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index deb084c..2dcf7f3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ +#include #include #include #include @@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)>& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const tvm::Array& required); /*! \brief Remove expressions which does not effect the program result. * @@ -355,7 +356,7 @@ TVM_DLL Pass Inline(); * * \return The pass. */ -TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); +TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); } // namespace transform diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 50b406b..083f87f 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -360,7 +360,15 @@ class String : public ObjectRef { * \note If user passes const reference, it will trigger copy. If it's rvalue, * it will be moved into other. */ - explicit String(std::string other); + String(std::string other); // NOLINT(*) + + /*! + * \brief Construct a new String object + * + * \param other a char array. + */ + String(const char* other) // NOLINT(*) + : String(std::string(other)) {} /*! * \brief Change the value the reference object points to. diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index f6fd3c4..59aa955 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -52,11 +52,11 @@ class TargetNode : public Object { /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ int thread_warp_size = 1; /*! \brief Keys for this target */ - Array keys_array; + Array keys_array; /*! \brief Options for this target */ - Array options_array; + Array options_array; /*! \brief Collection of imported libs */ - Array libs_array; + Array libs_array; /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 6824022..ad5c5cd 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -326,7 +326,7 @@ class StmtExprMutator : * won't do further recursion. * \param postorder The function called after recursive mutation. * The recursive mutation result is passed to postorder for further mutation. - * \param only_enable List of StringImm. + * \param only_enable List of runtime::String. * If it is empty, all IRNode will call preorder/postorder * If it is not empty, preorder/postorder will only be called * when the IRNode's type key is in the list. @@ -334,7 +334,7 @@ class StmtExprMutator : TVM_DLL Stmt IRTransform(Stmt node, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, - const Array& only_enable = {}); + const Array& only_enable = {}); /*! * \brief recursively visit the ir in post DFS order node, apply fvisit diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 860014d..5ad40a3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const tvm::Array& required); /*! * \brief Transform the high-level PrimFunc to a low-level version @@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); * * \return The pass. */ -TVM_DLL Pass RemapThreadAxis(Map axis_map); +TVM_DLL Pass RemapThreadAxis(Map axis_map); /*! diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index ddee149..00b6676 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -24,6 +24,7 @@ registers the standard task. import numpy as np from tvm import target as _target +from tvm import runtime from tvm.ir import container from tvm.tir import expr from tvm.te import tensor, placeholder @@ -55,6 +56,8 @@ def serialize_args(args): return x if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): return x.value + if isinstance(x, runtime.container.String): + return str(x) if x is None: return None raise RuntimeError('Do not support type "%s" in argument. Consider to use' diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 3e5f015..8210f27 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -84,8 +84,7 @@ class GraphRuntimeCodegen(object): lowered_func = self._get_irmodule() param_names = self._list_params_name() params = {} - for name in param_names: - key = name.value + for key in param_names: arr = self._get_param_by_name(key) param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx) arr.copyto(param) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index dd59011..a719dcd 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,8 +16,9 @@ # under the License. """Runtime container structures.""" import tvm._ffi - +from tvm._ffi.base import string_types from tvm.runtime import Object, ObjectTypes +from tvm.runtime import _ffi_api def getitem_helper(obj, elem_getter, length, idx): """Helper function to implement a pythonic getitem function. @@ -75,18 +76,19 @@ class ADT(Object): for f in fields: assert isinstance(f, ObjectTypes), "Expect object or " \ "tvm NDArray type, but received : {0}".format(type(f)) - self.__init_handle_by_constructor__(_ADT, tag, *fields) + self.__init_handle_by_constructor__(_ffi_api.ADT, tag, + *fields) @property def tag(self): - return _GetADTTag(self) + return _ffi_api.GetADTTag(self) def __getitem__(self, idx): return getitem_helper( - self, _GetADTFields, len(self), idx) + self, _ffi_api.GetADTFields, len(self), idx) def __len__(self): - return _GetADTSize(self) + return _ffi_api.GetADTSize(self) def tuple_object(fields=None): @@ -106,7 +108,7 @@ def tuple_object(fields=None): for f in fields: assert isinstance(f, ObjectTypes), "Expect object or tvm " \ "NDArray type, but received : {0}".format(type(f)) - return _Tuple(*fields) + return _ffi_api.Tuple(*fields) @tvm._ffi.register_object("runtime.String") @@ -115,7 +117,7 @@ class String(Object): Parameters ---------- - string : Str + string : str The string used to construct a runtime String object Returns @@ -124,7 +126,50 @@ class String(Object): The created object. """ def __init__(self, string): - self.__init_handle_by_constructor__(_String, string) + self.__init_handle_by_constructor__(_ffi_api.String, string) + + def __str__(self): + return _ffi_api.GetStdString(self) + + def __len__(self): + return _ffi_api.GetStringSize(self) + + def __hash__(self): + return _ffi_api.StringHash(self) + + def __eq__(self, other): + if isinstance(other, string_types): + return self.__str__() == other + + if not isinstance(other, String): + return False + + return _ffi_api.CompareString(self, other) == 0 + + def __ne__(self, other): + return not self.__eq__(other) + + def __gt__(self, other): + return _ffi_api.CompareString(self, other) > 0 + + def __lt__(self, other): + return _ffi_api.CompareString(self, other) < 0 + + def __getitem__(self, key): + return self.__str__()[key] + + def startswith(self, string): + """Check if the runtime string starts with a given string + Parameters + ---------- + string : str + The provided string -tvm._ffi._init_api("tvm.runtime.container") + Returns + ------- + ret : boolean + Return true if the runtime string starts with the given string, + otherwise, false. + """ + return self.__str__().startswith(string) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 22354db..a7716df 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -19,7 +19,7 @@ from numbers import Number, Integral from tvm._ffi.base import string_types -from . import _ffi_node_api +from . import _ffi_node_api, _ffi_api from .object import ObjectBase, _set_class_object_generic from .ndarray import NDArrayBase from .packed_func import PackedFuncBase, convert_to_tvm_func @@ -56,7 +56,7 @@ def convert_to_object(value): if isinstance(value, Number): return const(value) if isinstance(value, string_types): - return _ffi_node_api.String(value) + return _ffi_api.String(value) if isinstance(value, (list, tuple)): value = [convert_to_object(x) for x in value] return _ffi_node_api.Array(*value) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index a83ea0c..fd15ff9 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -48,26 +48,26 @@ class Target(Object): @property def keys(self): if not self._keys: - self._keys = [k.value for k in self.keys_array] + self._keys = [str(k) for k in self.keys_array] return self._keys @property def options(self): if not self._options: - self._options = [o.value for o in self.options_array] + self._options = [str(o) for o in self.options_array] return self._options @property def libs(self): if not self._libs: - self._libs = [l.value for l in self.libs_array] + self._libs = [str(l) for l in self.libs_array] return self._libs @property def model(self): for opt in self.options_array: - if opt.value.startswith('-model='): - return opt.value[7:] + if opt.startswith('-model='): + return opt[7:] return 'unknown' @property diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index b5bf2ed..fbd0829 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto var : vars) { Array > feature_row; ItervarFeature &fea = touch_analyzer.itervar_map[var]; - feature_row.push_back(Array{std::string("_itervar_"), var}); + feature_row.push_back(Array{tvm::tir::StringImmNode::make("_itervar_"), var}); - Array attr{std::string("_attr_"), + Array attr{tvm::tir::StringImmNode::make("_attr_"), FloatImm(DataType::Float(32), trans(fea.length)), IntImm(DataType::Int(32), fea.nest_level), FloatImm(DataType::Float(32), trans(fea.topdown_product)), @@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > feature_row.push_back(attr); // arithmetic - feature_row.push_back(Array{std::string("_arith_"), + feature_row.push_back(Array{tvm::tir::StringImmNode::make("_arith_"), FloatImm(DataType::Float(32), trans(fea.add_ct)), FloatImm(DataType::Float(32), trans(fea.mul_ct)), FloatImm(DataType::Float(32), trans(fea.div_ct)), @@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto k : bufs) { TouchPattern &v = fea.touch_feature[k]; feature_row.push_back( - Array{k, + Array{tvm::tir::StringImmNode::make(k), FloatImm(DataType::Float(32), trans(v.stride)), FloatImm(DataType::Float(32), trans(v.mod)), FloatImm(DataType::Float(32), trans(v.count)), diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 066b8f9..bee103d 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -42,7 +42,7 @@ void DictAttrsNode::InitByPackedArgs( if (val.IsObjectRef()) { dict.Set(key, val.operator ObjectRef()); } else if (val.type_code() == kTVMStr) { - dict.Set(key, PrimExpr(val.operator std::string())); + dict.Set(key, val.operator String()); } else { dict.Set(key, val.operator PrimExpr()); } diff --git a/src/ir/expr.cc b/src/ir/expr.cc index b07f04a..1f0337e 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -40,8 +40,8 @@ PrimExpr::PrimExpr(int32_t value) PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr::PrimExpr(std::string str) - : PrimExpr(tir::StringImmNode::make(str)) {} +PrimExpr::PrimExpr(runtime::String value) + : PrimExpr(tir::StringImmNode::make(value)) {} PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { using runtime::ObjectTypeChecker; @@ -51,6 +51,9 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { if (ptr->IsInstance()) { return te::Tensor(ptr)(); } + if (ptr->IsInstance()) { + return tir::StringImmNode::make(runtime::String(ptr)); + } CHECK(ObjectTypeChecker::Check(ptr.get())) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); diff --git a/src/ir/op.cc b/src/ir/op.cc index 6a50240..b024165 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -140,10 +141,9 @@ void OpRegistry::UpdateAttr(const std::string& key, // Frontend APIs TVM_REGISTER_GLOBAL("relay.op._ListOpNames") .set_body_typed([]() { - Array ret; - for (const std::string& name : - dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::PrimExpr(name)); + Array ret; + for (const std::string& name : dmlc::Registry::ListAllNames()) { + ret.push_back(name); } return ret; }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 61c1fc2..6e38aac 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -212,7 +213,7 @@ class SequentialNode : public PassNode { PassInfo::PassInfo(int opt_level, std::string name, - tvm::Array required) { + tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -274,12 +275,10 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { } // linearly scan the pass array to match pass_name -inline bool PassArrayContains(const Array& pass_array, +inline bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { - auto* str_name = x.as(); - CHECK(str_name) << "pass name must be str"; - if (str_name->value == pass_name) return true; + if (x == pass_name) return true; } return false; } @@ -324,9 +323,7 @@ IRModule SequentialNode::operator()(const IRModule& module, if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { - const auto* name = it.as(); - CHECK(name); - mod = GetPass(name->value)(mod, pass_ctx); + mod = GetPass(it)(mod, pass_ctx); } mod = pass(mod, pass_ctx); } @@ -337,7 +334,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } @@ -345,7 +342,7 @@ Pass CreateModulePass( TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") -.set_body_typed([](int opt_level, std::string name, tvm::Array required) { +.set_body_typed([](int opt_level, std::string name, tvm::Array required) { return PassInfo(opt_level, name, required); }); @@ -363,8 +360,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "opt_level: " << node->opt_level; p->stream << "required passes: [" << "\n"; for (const auto& it : node->required) { - const auto* str = it.as(); - p->stream << str->value << ", "; + p->stream << it << ", "; } p->stream << "]\n"; }); @@ -401,7 +397,7 @@ TVM_REGISTER_GLOBAL("transform.Sequential") tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; - tvm::Array required = args[3]; + tvm::Array required = args[3]; PassInfo pass_info = PassInfo(opt_level, name, required); *ret = Sequential(passes, pass_info); }); @@ -427,8 +423,8 @@ TVM_REGISTER_GLOBAL("transform.PassContext") auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; - tvm::Array required = args[2]; - tvm::Array disabled = args[3]; + tvm::Array required = args[2]; + tvm::Array disabled = args[3]; TraceFunc trace_func = args[4]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; diff --git a/src/node/container.cc b/src/node/container.cc index e7e4979..bce2eee 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -63,7 +63,6 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) static_cast(n)).operator std::string(); }); - struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index eaf78bc..e2d5e93 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -86,9 +86,10 @@ struct GraphCodegen { std::unordered_map GetParams() { std::unordered_map ret; - auto names = CallFunc>("list_params_name", nullptr); - for (auto expr : names) { - auto key = expr.as()->value; + auto names = CallFunc>("list_params_name", nullptr); + for (const auto& expr : names) { + // Implicit cast from runtime::String to std::string + std::string key = expr; ret[key] = CallFunc("get_param_by_name", key); } return ret; @@ -191,12 +192,12 @@ class RelayBuildModule : public runtime::ModuleNode { /*! * \brief List all paramter names * - * \return Array names of params + * \return Array names of params */ - Array ListParamNames() { - Array ret; + Array ListParamNames() { + Array ret; for (const auto& kv : params_) { - ret.push_back(tir::StringImmNode::make(kv.first)); + ret.push_back(kv.first); } return ret; } @@ -272,7 +273,7 @@ class RelayBuildModule : public runtime::ModuleNode { } Array pass_seqs; - Array entry_functions{tvm::PrimExpr{"main"}}; + Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index f75da07..9cb6b2e 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -617,17 +617,18 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; CHECK(src_func.defined()); - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; - if (ext_mods.find(code_gen->value) == ext_mods.end()) { - ext_mods[code_gen->value] = IRModule({}, {}); + std::string code_gen_name = code_gen; + if (ext_mods.find(code_gen_name) == ext_mods.end()) { + ext_mods[code_gen_name] = IRModule({}, {}); } - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); auto gv = GlobalVar(std::string(symbol_name)); - ext_mods[code_gen->value]->Add(gv, src_func); + ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } } @@ -691,10 +692,10 @@ class CompileEngineImpl : public CompileEngineNode { } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { + if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); const auto name_node = - key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node); diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 79d4d3f..1db3f20 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -70,7 +70,7 @@ class CSourceModuleCodegenBase { */ std::string GetExtSymbol(const Function& func) const { const auto name_node = - func->GetAttr(tvm::attr::kGlobalSymbol); + func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index c7f1be8..4279db0 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -419,7 +419,7 @@ class GraphRuntimeCodegen auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); CCacheKey key = (*pf0)(func, target); CachedFunc ext_func = (*pf1)(compile_engine_, key); @@ -482,7 +482,7 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - CHECK(op->GetAttr(attr::kCompiler).defined()) + CHECK(op->GetAttr(attr::kCompiler).defined()) << "Only functions supported by custom codegen"; return {}; } @@ -633,10 +633,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Array ret; + Array ret; for (const auto &kv : this->output_.params) { - tvm::PrimExpr name = tir::StringImmNode::make(kv.first); - ret.push_back(name); + ret.push_back(kv.first); } *rv = ret; }); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index d68bff6..e2b0fff 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); } else { // Next generate the invoke instruction. @@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor { auto cfunc = engine_->Lower(key); auto op_index = -1; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); } else { @@ -873,7 +873,7 @@ void VMCompiler::Lower(IRModule mod, IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { Array pass_seqs; - Array entry_functions{tvm::PrimExpr{"main"}}; + Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 74b2a47..12113b0 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 80745e1..59c549c 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); func = Function(func->params, VisitExpr(func->body), diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index dd11fce..c2fe37f 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -87,11 +87,10 @@ struct CallTracer : ExprVisitor { * \return The module with dead functions removed. */ IRModule RemoveUnusedFunctions(const IRModule& module, - Array entry_funcs) { + Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { - auto* str_name = entry.as(); - auto funcs = CallTracer(module).Trace(str_name->value); + auto funcs = CallTracer(module).Trace(entry); called_funcs.insert(funcs.cbegin(), funcs.cend()); } auto existing_functions = module->functions; @@ -108,7 +107,7 @@ IRModule RemoveUnusedFunctions(const IRModule& module, namespace transform { -Pass RemoveUnusedFunctions(Array entry_functions) { +Pass RemoveUnusedFunctions(Array entry_functions) { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a4bab36..fa709eb 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,14 +145,14 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + (func->GetAttr(attr::kCompiler).defined()); } Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return FunctionPass(pass_func, pass_info); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 87b4602..7aa8bf1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1177,7 +1177,6 @@ Array ArangeCompute(const Attrs& attrs, te::Tensor start = inputs[0]; te::Tensor stop = inputs[1]; te::Tensor step = inputs[2]; - Array empty = {0}; return { DynamicArange(start, stop, step, param->dtype) }; } diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 63c1cb9..aab0b3a 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -125,8 +125,7 @@ Pass AlterOpLayout() { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; - return CreateFunctionPass(pass_func, 3, "AlterOpLayout", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index c3d34cb..44ef35a 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -59,11 +59,12 @@ class AnnotateTargetWrapper : public ExprMutator { // handle composite functions Function func = Downcast(call->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); + auto comp_name = func->GetAttr(attr::kComposite); if (comp_name.defined()) { - size_t i = comp_name->value.find('.'); + std::string comp_name_str = comp_name; + size_t i = comp_name_str.find('.'); if (i != std::string::npos) { - std::string target = comp_name->value.substr(0, i); + std::string target = comp_name_str.substr(0, i); if (target == target_) return true; } } @@ -147,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator { Function func; Expr new_body; // don't step into composite functions - if (fn->GetAttr(attr::kComposite).defined()) { + if (fn->GetAttr(attr::kComposite).defined()) { func = GetRef(fn); new_body = func->body; } else { @@ -225,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) { return Downcast(relay::annotate_target::AnnotateTarget(f, target)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {tir::StringImmNode::make("InferType")}); + {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index 759a4ae..ebcbd57 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -133,8 +133,7 @@ Pass CanonicalizeCast() { [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeCast(f)); }; - return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 97a128d..1d3111b 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -74,8 +74,7 @@ Pass CanonicalizeOps() { [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeOps(f)); }; - return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 3884dac..af6b135 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -220,8 +220,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 612dae5..1278020 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -80,8 +80,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) { [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelDense(f, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelDense", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 55ca3f6..361565e 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -193,8 +193,7 @@ Pass CombineParallelOpBatch(const std::string& op_name, batch_op_name, min_num_branches)); }; - return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index 871969d..dbb2c38 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -133,9 +133,7 @@ Pass ConvertLayout(const std::string& desired_layout) { return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); }; return CreateFunctionPass( - pass_func, 3, "ConvertLayout", - {tir::StringImmNode::make("InferType"), - tir::StringImmNode::make("CanonicalizeOps")}); + pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index b4d61f1..908ba87 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -573,8 +573,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); }; - return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index f905ba5..68c59f5 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -91,8 +91,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { [=](Function f, IRModule m, PassContext pc) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; - return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index cf00a89..8234dea 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -70,8 +70,7 @@ Pass FastMath() { [=](Function f, IRModule m, PassContext pc) { return Downcast(FastMath(f)); }; - return CreateFunctionPass(pass_func, 4, "FastMath", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FastMath") diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 49f6e3f..cfe74bf 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -960,8 +960,7 @@ Pass ForwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") @@ -973,8 +972,7 @@ Pass BackwardFoldScaleAxis() { return Downcast( relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; - return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 9168898..f646042 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -980,8 +980,7 @@ Pass FuseOps(int fuse_opt_level) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; - return CreateFunctionPass(pass_func, 1, "FuseOps", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps") diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index ef3c51f..ba0f568 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -131,7 +131,7 @@ class Inliner : ExprMutator { fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined()) { + if (!func->GetAttr(attr::kCompiler).defined()) { CHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 01411a6..0b5c671 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -101,7 +101,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; - return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 35b93dc..75d95f0 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -159,9 +159,9 @@ class MergeCompositeWrapper : public ExprMutator { if (call->op->IsInstance()) { Function func = Downcast(call->op); CHECK(func.defined()); - const auto name_node = func->GetAttr(attr::kComposite); + auto name_node = func->GetAttr(attr::kComposite); // don't step into existing composite functions - if (name_node.defined() && name_node->value != "") { + if (name_node.defined() && name_node != "") { tvm::Array new_args; for (const auto& arg : call->args) { auto new_e = this->Mutate(arg); @@ -185,7 +185,7 @@ class MergeCompositeWrapper : public ExprMutator { auto free_vars = FreeVars(extract); // make the composite function auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); - f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_)); + f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_)); // find the expressions associated with the free vars using the args_map // this tells us which expressions should be given as inputs to the composite function Array args; @@ -207,16 +207,14 @@ class MergeCompositeWrapper : public ExprMutator { PackedFunc check_; }; -Expr MergeComposite(const Expr& expr, const Array& pattern_names, +Expr MergeComposite(const Expr& expr, const Array& pattern_names, const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); Expr merged_expr = expr; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { - std::string pattern_name = pattern_names[i]->value; - Expr pattern = patterns[i]; - PackedFunc check = checks[i]; - merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr); + merged_expr = + MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr); } return merged_expr; } @@ -225,7 +223,7 @@ Expr MergeComposite(const Expr& expr, const Array& pattern_names namespace transform { -Pass MergeComposite(const tvm::Array& pattern_names, +Pass MergeComposite(const tvm::Array& pattern_names, const tvm::Array& patterns, const std::vector& checks) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { @@ -236,8 +234,9 @@ Pass MergeComposite(const tvm::Array& pattern_names, return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { - tvm::Array pattern_names = args[0]; +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") +.set_body([](TVMArgs args, TVMRetValue* rv) { + tvm::Array pattern_names = args[0]; tvm::Array patterns = args[1]; std::vector checks; for (int i = 2; i < args.size(); i++) { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index a4e3863..8eeac17 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -245,7 +245,7 @@ class Partitioner : public ExprMutator { global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::tir::StringImmNode::make(target)); + tvm::runtime::String(target)); global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index bc7c15e..d349fdd 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -204,8 +204,7 @@ Pass SimplifyInference() { [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyInference(f)); }; - return CreateFunctionPass(pass_func, 0, "SimplifyInference", - {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 6e35dfb..21c5162 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) { for (const auto& it : funcs) { CHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 400f646..81dfd3d 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -32,14 +32,14 @@ namespace runtime { using namespace vm; -TVM_REGISTER_GLOBAL("runtime.container._GetADTTag") +TVM_REGISTER_GLOBAL("runtime.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("runtime.container._GetADTSize") +TVM_REGISTER_GLOBAL("runtime.GetADTSize") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); @@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTSize") }); -TVM_REGISTER_GLOBAL("runtime.container._GetADTFields") +TVM_REGISTER_GLOBAL("runtime.GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("runtime.container._Tuple") +TVM_REGISTER_GLOBAL("runtime.Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { @@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("runtime.container._Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("runtime.container._ADT") +TVM_REGISTER_GLOBAL("runtime.ADT") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); @@ -76,11 +76,31 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_GLOBAL("runtime.container._String") +TVM_REGISTER_GLOBAL("runtime.String") .set_body_typed([](std::string str) { return String(std::move(str)); }); +TVM_REGISTER_GLOBAL("runtime.GetStringSize") +.set_body_typed([](String str) { + return static_cast(str.size()); +}); + +TVM_REGISTER_GLOBAL("runtime.GetStdString") +.set_body_typed([](String str) { + return std::string(str); +}); + +TVM_REGISTER_GLOBAL("runtime.CompareString") +.set_body_typed([](String lhs, String rhs) { + return lhs.compare(rhs); +}); + +TVM_REGISTER_GLOBAL("runtime.StringHash") +.set_body_typed([](String str) { + return static_cast(std::hash()(str)); +}); + TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); diff --git a/src/target/build_common.h b/src/target/build_common.h index fc45cef..5ba51da 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -57,7 +57,7 @@ ExtractFuncInfo(const IRModule& mod) { info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol)] = info; } return fmap; diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 8eef4b7..44d017f 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -150,12 +151,12 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); - Array tags = args[2]; + Array tags = args[2]; bool allow_override = args[3]; std::vector tags_vector; for (auto& tag : tags) { - tags_vector.push_back(tag.as()->value); + tags_vector.push_back(tag); } generic_func diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f0b0a4b..a863056 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -126,7 +126,7 @@ void CodeGenCPU::Init(const std::string& module_name, void CodeGenCPU::AddFunction(const PrimFunc& f) { CodeGenLLVM::AddFunction(f); if (f_tvm_register_system_symbol_ != nullptr) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 28f4efd..bb0b7e4 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { llvm::FunctionType* ftype = llvm::FunctionType::get( ret_void ? t_void_ : t_int_, param_types, false); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 9ea77ac..52dccba 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -214,7 +214,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); entry_func = global_symbol; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 0cb4742..a0e18a6 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -78,7 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { // reserve keywords ReserveKeywordsAsUnique(); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 2f31a3e..715c0ae 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -56,7 +56,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { GetUniqueName("_"); // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 4748599..13d87d2 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -156,7 +156,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { arg_kinds.push_back(kind); } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 6c1c3b9..7486164 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -147,7 +147,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { std::string whole_code = cg.Finish(); // Generate source code for compilation. - Array > kernel_info; + Array > kernel_info; for (auto kv : mod->functions) { CHECK(kv.second->IsInstance()) @@ -161,11 +161,10 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { code = (*f)(code).operator std::string(); } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - std::string func_name = global_symbol; - kernel_info.push_back(Array({func_name, code})); + kernel_info.push_back({global_symbol, code}); } std::string xclbin; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index b6f9b86..5872141 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -90,7 +90,7 @@ runtime::Module BuildSPIRV(IRModule mod) { CHECK(calling_conv.defined() && calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 0241e22..db2a2f3 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -78,7 +78,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpFunctionEnd); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index af8b341..da75a70 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -536,7 +536,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { CHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol; diff --git a/src/target/target.cc b/src/target/target.cc index 8fb9cb6..306fba4 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -62,39 +62,39 @@ Target CreateTarget(const std::string& target_name, std::string device_flag = "-device="; std::string keys_flag = "-keys="; for (auto& item : options) { - t->options_array.push_back(tir::StringImmNode::make(item)); + t->options_array.push_back(item); if (item.find(libs_flag) == 0) { std::stringstream ss(item.substr(libs_flag.length())); std::string lib_item; while (std::getline(ss, lib_item, ',')) { - t->libs_array.push_back(tir::StringImmNode::make(lib_item)); + t->libs_array.push_back(lib_item); } } else if (item.find(device_flag) == 0) { t->device_name = item.substr(device_flag.length()); - t->keys_array.push_back(tir::StringImmNode::make(t->device_name)); + t->keys_array.push_back(t->device_name); } else if (item.find(keys_flag) == 0) { std::stringstream ss(item.substr(keys_flag.length())); std::string key_item; while (std::getline(ss, key_item, ',')) { - t->keys_array.push_back(tir::StringImmNode::make(key_item)); + t->keys_array.push_back(key_item); } } } if (t->device_name.length() > 0) { - t->keys_array.push_back(tir::StringImmNode::make(t->device_name)); + t->keys_array.push_back(t->device_name); } t->device_type = kDLCPU; t->thread_warp_size = 1; if (target_name == "c" && t->device_name == "micro_dev") { t->device_type = kDLMicroDev; } else if (target_name == "c" || target_name == "llvm") { - t->keys_array.push_back(tir::StringImmNode::make("cpu")); + t->keys_array.push_back("cpu"); } else if (target_name == "cuda" || target_name == "nvptx") { t->device_type = kDLGPU; - t->keys_array.push_back(tir::StringImmNode::make("cuda")); - t->keys_array.push_back(tir::StringImmNode::make("gpu")); + t->keys_array.push_back("cuda"); + t->keys_array.push_back("gpu"); t->max_num_threads = 1024; t->thread_warp_size = 32; } else if (target_name == "rocm" || target_name == "opencl") { @@ -104,8 +104,8 @@ Target CreateTarget(const std::string& target_name, } else { t->device_type = kDLROCM; } - t->keys_array.push_back(tir::StringImmNode::make(target_name)); - t->keys_array.push_back(tir::StringImmNode::make("gpu")); + t->keys_array.push_back(target_name); + t->keys_array.push_back("gpu"); t->max_num_threads = 256; if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; @@ -116,20 +116,20 @@ Target CreateTarget(const std::string& target_name, } else { t->device_type = kDLVulkan; } - t->keys_array.push_back(tir::StringImmNode::make(target_name)); - t->keys_array.push_back(tir::StringImmNode::make("gpu")); + t->keys_array.push_back(target_name); + t->keys_array.push_back("gpu"); t->max_num_threads = 256; } else if (target_name == "sdaccel") { t->device_type = kDLOpenCL; - t->keys_array.push_back(tir::StringImmNode::make("sdaccel")); - t->keys_array.push_back(tir::StringImmNode::make("hls")); + t->keys_array.push_back("sdaccel"); + t->keys_array.push_back("hls"); } else if (target_name == "aocl" || target_name == "aocl_sw_emu") { t->device_type = kDLAOCL; - t->keys_array.push_back(tir::StringImmNode::make("aocl")); - t->keys_array.push_back(tir::StringImmNode::make("hls")); + t->keys_array.push_back("aocl"); + t->keys_array.push_back("hls"); } else if (target_name == "opengl") { t->device_type = kOpenGL; - t->keys_array.push_back(tir::StringImmNode::make("opengl")); + t->keys_array.push_back("opengl"); } else if (target_name == "stackvm") { t->device_type = kDLCPU; } else if (target_name == "ext_dev") { @@ -168,7 +168,7 @@ TVM_REGISTER_GLOBAL("target.TargetFromString") std::vector TargetNode::keys() const { std::vector result; for (auto& expr : keys_array) { - result.push_back(expr.as()->value); + result.push_back(expr); } return result; } @@ -176,7 +176,7 @@ std::vector TargetNode::keys() const { std::vector TargetNode::options() const { std::vector result; for (auto& expr : options_array) { - result.push_back(expr.as()->value); + result.push_back(expr); } return result; } @@ -184,7 +184,7 @@ std::vector TargetNode::options() const { std::unordered_set TargetNode::libs() const { std::unordered_set result; for (auto& expr : libs_array) { - result.insert(expr.as()->value); + result.insert(expr); } return result; } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 891d137..0efa33a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -47,7 +47,6 @@ Var::Var(std::string name_hint, Type type_annotation) { data_ = std::move(n); } - Var Var::copy_with_suffix(const std::string& suffix) const { const VarNode* node = get(); ObjectPtr new_ptr; @@ -826,20 +825,28 @@ TVM_REGISTER_GLOBAL("tir.Load") } }); - - TVM_REGISTER_GLOBAL("tir.Call") .set_body_typed([]( DataType type, std::string name, - Array args, int call_type, + Array args, int call_type, FunctionRef func, int value_index ) { + Array prim_expr_args; + for (const auto& it : args) { + CHECK(it->IsInstance() || + it->IsInstance()); + if (const auto* str = it.as()) { + prim_expr_args.push_back(StringImmNode::make(str->data)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } return CallNode::make(type, - name, - args, - static_cast(call_type), - func, - value_index); + name, + prim_expr_args, + static_cast(call_type), + func, + value_index); }); } // namespace tir diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ea19982..96fc435 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -120,10 +120,10 @@ class IRTransformer final : Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, - const Array& only_enable) { + const Array& only_enable) { std::unordered_set only_type_index; - for (PrimExpr s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); + for (auto s : only_enable) { + only_type_index.insert(Object::TypeKey2Index(s.c_str())); } IRTransformer transform(f_preorder, f_postorder, only_type_index); return transform(std::move(ir_node)); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 773c67d..001c7cf 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -124,7 +124,7 @@ Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return PrimFuncPass(pass_func, pass_info); } diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc index 30542ea..c684b9e 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/pass/arg_binder.cc @@ -42,7 +42,8 @@ void BinderAddAssert(PrimExpr cond, if (!is_one(scond)) { std::ostringstream os; os << "Argument " << arg_name << " has an unsatisfied constraint"; - asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0))); + asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()), + EvaluateNode::make(0))); } } @@ -173,7 +174,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); - asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop)); + auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str()); + asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); // type checks DataType dtype = buffer->dtype; std::ostringstream type_err_msg; @@ -187,7 +189,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { - asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); + auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str()); + asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); + asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); } // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -245,9 +249,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, stride_err_msg << arg_name << ".strides:" << " expected to be compact array"; if (conds.size() != 0) { + auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); Stmt check = AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), - stride_err_msg.str(), EvaluateNode::make(0)); + stride_msg, EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } @@ -269,9 +274,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; - asserts_.emplace_back( - AssertStmtNode::make( - NotNode::make(is_null), stride_null_err_msg.str(), nop)); + asserts_.emplace_back(AssertStmtNode::make( + NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop)); for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 1fd43ff..8bc4620 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -159,8 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - return IRTransform(parent_for_stmt, nullptr, replace_target_for, - {PrimExpr("For")}); + return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"}); } // Remove IfThenElse node from a For node. @@ -186,11 +185,9 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { } }); - then_for = IRTransform(for_stmt, nullptr, replace_then_case, - {PrimExpr("IfThenElse")}); + then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"}); if (if_stmt.as()->else_case.defined()) { - else_for = IRTransform(for_stmt, nullptr, replace_else_case, - {PrimExpr("IfThenElse")}); + else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"}); } return std::make_pair(then_for, else_for); @@ -411,7 +408,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { *ret = new_for; } }); - return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")}); + return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")}); } Stmt HoistIfThenElse(Stmt stmt) { diff --git a/src/tir/pass/tensor_core.cc b/src/tir/pass/tensor_core.cc index 88f7496..dc2df98 100644 --- a/src/tir/pass/tensor_core.cc +++ b/src/tir/pass/tensor_core.cc @@ -860,7 +860,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto it = matrix_abc_.find(simplify_name(node->name)); CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = "wmma." + it->second; + auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second); Stmt body = this->VisitStmt(op->body); return AttrStmtNode::make(op->node, op->attr_key, diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc index 486f21c..952d663 100644 --- a/src/tir/transforms/bind_device_type.cc +++ b/src/tir/transforms/bind_device_type.cc @@ -47,7 +47,8 @@ class DeviceTypeBinder: public StmtExprMutator { var_ = nullptr; std::ostringstream os; os << "device_type need to be " << device_type_; - return AssertStmtNode::make(op->value == value, os.str(), body); + return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()), + body); } } return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index c49b044..b1dd235 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,12 +41,13 @@ namespace tvm { namespace tir { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { - return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); + return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg), + EvaluateNode::make(0)); } PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; std::string name_hint = global_symbol; @@ -140,17 +141,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func, AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || - tcode == kTVMNullptr, msg.str(), nop)); + tcode == kTVMNullptr, + tvm::tir::StringImmNode::make(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop)); + seq_check.emplace_back( + AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop)); + AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop)); } } else { args.push_back(v_arg); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index f695b3c..f366353 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -76,12 +76,10 @@ class ThreadAxisRewriter : private StmtExprMutator { }; -PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { - const StringImmNode* str = kv.first.as(); - CHECK(str != nullptr); - tmap[str->value] = kv.second; + tmap[kv.first] = kv.second; } auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); @@ -101,7 +99,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { namespace transform { -Pass RemapThreadAxis(Map thread_map) { +Pass RemapThreadAxis(Map thread_map) { auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { return RemapThreadAxis(std::move(f), thread_map); }; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index ae32bdc..5149d28 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -272,7 +272,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index f1198e7..063247d 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -261,7 +261,7 @@ TEST(String, empty) { using namespace std; String s{"hello"}; CHECK_EQ(s.empty(), false); - s = ""; + s = std::string(""); CHECK_EQ(s.empty(), true); } diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 7301ef7..dd00d7e 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -231,7 +231,7 @@ def test_composite_function(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + add_relu = add_relu.with_attr("Composite", "test.add_relu") # merged function r = relay.Call(add_relu, [a, b]) @@ -249,7 +249,7 @@ def test_composite_function(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + add_relu = add_relu.with_attr("Composite", "test.add_relu") # merged function cb_1 = relay.annotation.compiler_begin(a, "test") diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 0af55d2..bae077c 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -134,7 +134,7 @@ def test_recursive_func(): func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) - func = func.with_attr("Compiler", tvm.tir.StringImm("a")) + func = func.with_attr("Compiler", "a") mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 724e81d..b4496bb 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -79,9 +79,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", def set_external_func_attr(func, compiler, ext_symbol): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm(compiler)) - func = func.with_attr("global_symbol", - runtime.container.String(ext_symbol)) + func = func.with_attr("Compiler", compiler) + func = func.with_attr("global_symbol", ext_symbol) return func diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index dbd5934..5a71023 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -96,12 +96,14 @@ def test_function(): body = relay.Tuple(tvm.runtime.convert([])) type_params = tvm.runtime.convert([]) fn = relay.Function(params, body, ret_type, type_params) - fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value")) + fn = fn.with_attr("test_attribute", "value") + fn = fn.with_attr("test_attribute1", "value1") assert fn.params == params assert fn.body == body assert fn.type_params == type_params assert fn.span == None assert fn.attrs["test_attribute"] == "value" + assert fn.attrs["test_attribute1"] == "value1" str(fn) check_json_roundtrip(fn) diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index 271960e..e1a0a01 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -356,7 +356,7 @@ def test_function_attr(): p00 = relay.subtract(z00, w01) q00 = relay.multiply(p00, w02) func0 = relay.Function([x0, w00, w01, w02], q00) - func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a")) + func0 = func0.with_attr("FuncName", "a") x1 = relay.var('x1', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10)) @@ -366,7 +366,7 @@ def test_function_attr(): p10 = relay.subtract(z10, w11) q10 = relay.multiply(p10, w12) func1 = relay.Function([x1, w10, w11, w12], q10) - func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b")) + func1 = func1.with_attr("FuncName", "b") assert not consistent_equal(func0, func1) @@ -698,7 +698,7 @@ def test_fn_attribute(): d = relay.var('d', shape=(10, 10)) add_1 = relay.add(c, d) add_1_fn = relay.Function([c, d], add_1) - add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test")) + add_1_fn = add_1_fn.with_attr("TestAttribute", "test") add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) assert not consistent_equal(add_1_fn, add_fn) diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py index 0f6d539..3b41f07 100644 --- a/tests/python/relay/test_pass_inline.py +++ b/tests/python/relay/test_pass_inline.py @@ -209,7 +209,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler(): g11 = relay.GlobalVar("g11") fn11 = relay.Function([x11], x11) fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a")) + fn11 = fn11.with_attr("Compiler", "a") mod[g11] = fn11 x1 = relay.var("x1", shape=(3, 5)) @@ -244,7 +244,7 @@ def test_call_chain_inline_multiple_levels_extern_compiler(): x11 = relay.var("x11", shape=(3, 5)) fn11 = relay.Function([x11], x11) fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a")) + fn11 = fn11.with_attr("Compiler", "a") x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) @@ -367,7 +367,7 @@ def test_recursive_not_called_extern_compiler(): x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") g1 = relay.GlobalVar("g1") mod[g1] = fn1 mod["main"] = relay.Function([x, y], x + y + g1(x)) @@ -380,7 +380,7 @@ def test_recursive_not_called_extern_compiler(): x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") mod["main"] = relay.Function([x, y], x + y + fn1(x)) return mod @@ -446,7 +446,7 @@ def test_globalvar_as_call_arg_extern_compiler(): sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") g1 = relay.GlobalVar("g1") mod[g1] = fn1 @@ -456,7 +456,7 @@ def test_globalvar_as_call_arg_extern_compiler(): sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -478,7 +478,7 @@ def test_globalvar_as_call_arg_extern_compiler(): sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) @@ -486,7 +486,7 @@ def test_globalvar_as_call_arg_extern_compiler(): sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") p0 = relay.var("p0", shape=(3, 5)) p1 = relay.var("p1", shape=(3, 5)) @@ -539,10 +539,10 @@ def test_inline_globalvar_without_args_extern_compiler(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") fn2 = relay.Function([], relay.const(2)) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") g1 = relay.GlobalVar('g1') g2 = relay.GlobalVar('g2') mod[g1] = fn1 @@ -555,10 +555,10 @@ def test_inline_globalvar_without_args_extern_compiler(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") fn2 = relay.Function([], relay.const(2)) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") p = relay.var('p', 'bool') mod['main'] = relay.Function([p], relay.Call( relay.If(p, fn1, fn2), [])) @@ -787,7 +787,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler(): y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa")) + fn0 = fn0.with_attr("Compiler", "aa") g0 = relay.GlobalVar("g0") mod[g0] = fn0 @@ -811,7 +811,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler(): y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa")) + fn0 = fn0.with_attr("Compiler", "aa") x1 = relay.var("x1", shape=(3, 5)) y1 = relay.var("y1", shape=(3, 5)) diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 110d855..e3c8991 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -184,7 +184,7 @@ def test_simple_merge(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu")) + add_relu = add_relu.with_attr("Composite", "add_relu") # merged function r = relay.Call(add_relu, [a, b]) @@ -249,8 +249,7 @@ def test_branch_merge(): sub_node = relay.subtract(in_1, in_2) mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) - add_sub_mul = add_sub_mul.with_attr("Composite", - tir.StringImm("add_sub_mul")) + add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul") # add_sub_mul1 function in_3 = relay.var('in_3', shape=(10, 10)) @@ -259,8 +258,7 @@ def test_branch_merge(): sub_node_1 = relay.subtract(in_3, in_4) mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) - add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", - tir.StringImm("add_sub_mul")) + add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul") # merged function m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) @@ -319,8 +317,7 @@ def test_reuse_call_merge(): add_node_1 = relay.add(in_1, add_node) add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) - add_add_add = add_add_add.with_attr("Composite", - tir.StringImm("add_add_add")) + add_add_add = add_add_add.with_attr("Composite", "add_add_add") # merged function sub_node = relay.subtract(a, b) @@ -404,7 +401,7 @@ def test_multiple_patterns(): r = relay.nn.relu(bias_node) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", - tir.StringImm("conv2d_bias_relu")) + "conv2d_bias_relu") # add_relu function in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) @@ -412,7 +409,7 @@ def test_multiple_patterns(): add_node = relay.add(in_4, in_5) r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) - add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu")) + add_relu = add_relu.with_attr("Composite", "add_relu") # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) @@ -481,8 +478,7 @@ def test_merge_order(): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.with_attr('Composite', - tir.StringImm(composite_name)) + merged_func = merged_func.with_attr('Composite', composite_name) ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret) @@ -547,13 +543,13 @@ def test_parallel_merge(): y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) - func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul")) + func_1 = func_1.with_attr('Composite', "add_sub_mul") call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) - func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul")) + func_2 = func_2.with_attr('Composite', "add_sub_mul") call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) @@ -632,14 +628,14 @@ def test_multiple_input_subgraphs(): add_relu_1 = relay.add(x, y) add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) - add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu')) + add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu') add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) - add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu')) + add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu') add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') @@ -647,7 +643,7 @@ def test_multiple_input_subgraphs(): sub = relay.subtract(x2, y2) add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) - add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul')) + add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul') add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -660,7 +656,7 @@ def test_multiple_input_subgraphs(): add_relu = relay.add(x, y) add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) - add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu')) + add_relu = add_relu.with_attr('Composite', 'add_relu') add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call) @@ -720,7 +716,7 @@ def test_tuple_get_item_merge(): tuple_get_item_node = bn_node[0] relu_node = relay.nn.relu(tuple_get_item_node) bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node) - bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu")) + bn_relu = bn_relu.with_attr("Composite", "bn_relu") # merged function r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var]) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 3959613..1968f34 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -24,7 +24,6 @@ import tvm import tvm.relay.testing from tvm import relay from tvm import runtime -from tvm.runtime import container from tvm.relay import transform from tvm.contrib import util from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -307,8 +306,8 @@ def test_extern_ccompiler_default_ops(): func = relay.Function([x0, y0], add) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) - func = func.with_attr("global_symbol", container.String("ccompiler_0")) + func = func.with_attr("Compiler", "ccompiler") + func = func.with_attr("global_symbol", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) @@ -392,8 +391,8 @@ def test_extern_dnnl(): func = relay.Function([data0, input0, input1], out) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl")) - func = func.with_attr("global_symbol", container.String("dnnl_0")) + func = func.with_attr("Compiler", "dnnl") + func = func.with_attr("global_symbol", "dnnl_0") glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func @@ -518,10 +517,8 @@ def test_function_lifting(): bn.astuple()) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_compiler")) - func0 = func0.with_attr("global_symbol", - container.String("test_compiler_0")) + func0 = func0.with_attr("Compiler", "test_compiler") + func0 = func0.with_attr("global_symbol", "test_compiler_0") gv0 = relay.GlobalVar("test_compiler_0") mod[gv0] = func0 @@ -537,10 +534,8 @@ def test_function_lifting(): func1 = relay.Function([data1, weight1], conv) func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Compiler", - tvm.tir.StringImm("test_compiler")) - func1 = func1.with_attr("global_symbol", - container.String("test_compiler_1")) + func1 = func1.with_attr("Compiler", "test_compiler") + func1 = func1.with_attr("global_symbol", "test_compiler_1") gv1 = relay.GlobalVar("test_compiler_1") mod[gv1] = func1 @@ -611,10 +606,8 @@ def test_function_lifting_inline(): bn.astuple()) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_compiler")) - func0 = func0.with_attr("global_symbol", - container.String("test_compiler_0")) + func0 = func0.with_attr("Compiler", "test_compiler") + func0 = func0.with_attr("global_symbol", "test_compiler_0") # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -648,8 +641,8 @@ def test_constant_propagation(): func = relay.Function([y0], add) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) - func = func.with_attr("global_symbol", container.String("ccompiler_0")) + func = func.with_attr("Compiler", "ccompiler") + func = func.with_attr("global_symbol", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [y]) @@ -748,10 +741,8 @@ def test_multiple_outputs(): bn_mean, bn_var], tuple_o) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_target")) - func0 = func0.with_attr("global_symbol", - container.String("test_target_2")) + func0 = func0.with_attr("Compiler", "test_target") + func0 = func0.with_attr("global_symbol", "test_target_2") gv0 = relay.GlobalVar("test_target_2") mod[gv0] = func0 @@ -816,10 +807,8 @@ def test_mixed_single_multiple_outputs(): func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Compiler", - tvm.tir.StringImm("test_target")) - func1 = func1.with_attr("global_symbol", - container.String("test_target_1")) + func1 = func1.with_attr("Compiler", "test_target") + func1 = func1.with_attr("global_symbol", "test_target_1") gv1 = relay.GlobalVar("test_target_1") mod[gv1] = func1 @@ -831,10 +820,8 @@ def test_mixed_single_multiple_outputs(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_target")) - func0 = func0.with_attr("global_symbol", - container.String("test_target_0")) + func0 = func0.with_attr("Compiler", "test_target") + func0 = func0.with_attr("global_symbol", "test_target_0") gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0 diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py index 8f2e9bb..48495f4 100644 --- a/tests/python/unittest/test_ir_attrs.py +++ b/tests/python/unittest/test_ir_attrs.py @@ -41,7 +41,7 @@ def test_dict_attrs(): dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert dattr.x.value == 1 datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) - assert dattr.name.value == "xyz" + assert dattr.name == "xyz" assert isinstance(dattr, tvm.ir.DictAttrs) assert "name" in dattr assert dattr["x"].value == 1 diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index 66b8a10..ee18dea 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - PrimExpr("tvm.contrib.cublas.matmul"), + runtime::String("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), @@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - PrimExpr("tvm.contrib.cublas.batch_matmul"), + runtime::String("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index 2fcafc7..9fe1825 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - PrimExpr("tvm.contrib.rocblas.matmul"), + runtime::String("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]),