tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [LLVM] Fix generation of LLVM intrinsics (#5282)
Date Sat, 11 Apr 2020 04:19:12 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 403929f  [LLVM] Fix generation of LLVM intrinsics (#5282)
403929f is described below

commit 403929f9019e0c44732a64999b8fbaee803cd7e5
Author: Krzysztof Parzyszek <kparzysz@quicinc.com>
AuthorDate: Fri Apr 10 23:19:03 2020 -0500

    [LLVM] Fix generation of LLVM intrinsics (#5282)
    
    * [LLVM] Fix generation of LLVM intrinsics
    
    The type list in the call to llvm::Intrinsic::getDeclaration is not
    the intrinsic's signature, it's the list of overloaded types. Without
    this fix, the updated unit test would cause the following error:
    
    TVMError: LLVM module verification failed with the following errors:
    Intrinsic name not mangled correctly for type arguments! Should be:
    llvm.ctlz.i32
    i32 (i32, i1)* @llvm.ctlz.i32.i1
    
    Special handling for llvm.prefetch, sig matching for overloaded ints only
    
    The prefetch intrinsic returns void in LLVM, while it returns i32 in TVM.
    This case needs to be handled specially, because rule-based intrinsic
    translation would cause invalid LLVM type to be created.
    
    Do the signature matching only for overloaded intrinsics. It's not needed
    for non-overloaded ones, so this can save a bit of compile-time.
    
    * Include intrinsic name in the error message
    
    * Fix number of arguments for llvm.fmuladd and llvm.pow
---
 src/target/llvm/codegen_llvm.cc                   | 92 +++++++++++++++++++++--
 src/target/llvm/codegen_llvm.h                    | 15 ++++
 src/target/llvm/intrin_rule_llvm.cc               |  6 +-
 tests/python/unittest/test_target_codegen_llvm.py | 29 ++++++-
 topi/python/topi/arm_cpu/bitserial_conv2d.py      |  9 +--
 5 files changed, 133 insertions(+), 18 deletions(-)

diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bb0b7e4..7112691 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -684,6 +684,74 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) {
   return call;
 }
 
+llvm::Function* CodeGenLLVM::GetIntrinsicDecl(
+    llvm::Intrinsic::ID id, llvm::Type* ret_type,
+    llvm::ArrayRef<llvm::Type*> arg_types) {
+  llvm::Module* module = module_.get();
+
+  if (!llvm::Intrinsic::isOverloaded(id)) {
+    return llvm::Intrinsic::getDeclaration(module, id, {});
+  }
+
+  llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
+  llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
+  llvm::SmallVector<llvm::Type*, 4> overload_types;
+
+#if TVM_LLVM_VERSION >= 90
+  auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
+    overload_types.clear();
+    llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
+    auto match =
+        llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
+    if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
+      bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
+      if (error) {
+        return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
+      }
+    }
+    return match;
+  };
+
+  // First, try matching the signature assuming non-vararg case.
+  auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
+  switch (try_match(fn_ty, false)) {
+    case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
+      // The return type doesn't match, there is nothing else to do.
+      return nullptr;
+    case llvm::Intrinsic::MatchIntrinsicTypes_Match:
+      return llvm::Intrinsic::getDeclaration(module, id, overload_types);
+    case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
+      break;
+  }
+
+  // Keep adding one type at a time (starting from empty list), and
+  // try matching the vararg signature.
+  llvm::SmallVector<llvm::Type*, 4> var_types;
+  for (int i = 0, e = arg_types.size(); i <= e; ++i) {
+    if (i > 0) var_types.push_back(arg_types[i - 1]);
+    auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
+    if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
+      return llvm::Intrinsic::getDeclaration(module, id, overload_types);
+    }
+  }
+  // Failed to identify the type.
+  return nullptr;
+
+#else  // TVM_LLVM_VERSION
+  llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
+  // matchIntrinsicType returns true on error.
+  if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
+    return nullptr;
+  }
+  for (llvm::Type* t : arg_types) {
+    if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) {
+      return nullptr;
+    }
+  }
+  return llvm::Intrinsic::getDeclaration(module, id, overload_types);
+#endif  // TVM_LLVM_VERSION
+}
+
 llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   if (op->is_intrinsic("llvm_intrin")) {
     CHECK_GE(op->args.size(), 2U);
@@ -691,19 +759,27 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
         Downcast<IntImm>(op->args[0])->value);
     int64_t num_signature  = Downcast<IntImm>(op->args[1])->value;
     std::vector<llvm::Value*> arg_value;
-    std::vector<llvm::Type*> sig_type;
+    std::vector<llvm::Type*> arg_type;
     for (size_t i = 2; i < op->args.size(); ++i) {
       arg_value.push_back(MakeValue(op->args[i]));
       if (i - 2 < static_cast<size_t>(num_signature)) {
-        sig_type.push_back(arg_value.back()->getType());
+        arg_type.push_back(arg_value.back()->getType());
       }
     }
-    llvm::Type *return_type = GetLLVMType(GetRef<PrimExpr>(op));
-    if (sig_type.size() > 0 && return_type != sig_type[0]) {
-      sig_type.insert(sig_type.begin(), return_type);
-    }
-    llvm::Function* f = llvm::Intrinsic::getDeclaration(
-        module_.get(), id, sig_type);
+    // LLVM's prefetch intrinsic returns "void", while TVM's prefetch
+    // returns int32. This causes problems because prefetch is one of
+    // those intrinsics that is generated automatically via the
+    // tvm.intrin.rule mechanism. Any other intrinsic with a type
+    // mismatch will have to be treated specially here.
+    // TODO(kparzysz-quic): fix this once TVM prefetch uses the same
+    // type as LLVM.
+    llvm::Type *return_type = (id != llvm::Intrinsic::prefetch)
+        ? GetLLVMType(GetRef<PrimExpr>(op))
+        : llvm::Type::getVoidTy(*ctx_);
+
+    llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
+    CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
+             << llvm::Intrinsic::getName(id, {});
     return builder_->CreateCall(f, arg_value);
   } else if (op->is_intrinsic(CallNode::bitwise_and)) {
     return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 6249aa4..e785f3e 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -232,6 +232,21 @@ class CodeGenLLVM :
    * \param type The corresponding TVM Type.
    */
   llvm::Type* GetLLVMType(const PrimExpr& expr) const;
+  /*!
+   * \brief Get the declaration of the LLVM intrinsic based on the intrinsic
+   *        id, and the type of the actual call.
+   *
+   * \param id The intrinsic id.
+   * \param ret_type The call return type.
+   * \param arg_types The types of the call arguments.
+   *
+   * \return Return the llvm::Function pointer, or nullptr if the declaration
+   *         could not be generated (e.g. if the argument/return types do not
+   *         match).
+   */
+  llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id,
+                                   llvm::Type* ret_type,
+                                   llvm::ArrayRef<llvm::Type*> arg_types);
   // initialize the function state.
   void InitFuncState();
   // Get alignment given index.
diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc
index 880a0fe..58bfb37 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -30,7 +30,7 @@ namespace codegen {
 namespace llvm {
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
-.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);
+.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
@@ -53,7 +53,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
 });
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
+.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
@@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
 });
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
-.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>);
+.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 34135c6..3de1d16 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -40,6 +40,30 @@ def test_llvm_intrin():
     fcode = tvm.build(func, None, "llvm")
 
 
+def test_llvm_overloaded_intrin():
+    # Name lookup for overloaded intrinsics in LLVM 4- requires a name
+    # that includes the overloaded types.
+    if tvm.target.codegen.llvm_version_major() < 5:
+        return
+
+    def use_llvm_intrinsic(A, C):
+        ib = tvm.tir.ir_builder.create()
+        L = A.vload((0,0))
+        I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz',
+            tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1'))
+        S = C.vstore((0,0), I)
+        ib.emit(S)
+        return ib.get()
+
+    A = tvm.te.placeholder((1,1), dtype = 'int32', name = 'A')
+    C = tvm.te.extern((1,1), [A],
+        lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]),
+        name = 'C' , dtype = 'int32')
+
+    s = tvm.te.create_schedule(C.op)
+    f = tvm.build(s, [A, C], target = 'llvm')
+
+
 def test_llvm_import():
     # extern "C" is necessary to get the correct signature
     cc_code = """
@@ -82,9 +106,9 @@ def test_llvm_import():
 
 def test_llvm_lookup_intrin():
     ib = tvm.tir.ir_builder.create()
-    m = te.size_var("m")
     A = ib.pointer("uint8x8", name="A")
-    x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'),
A)
+    z = tvm.tir.const(0, 'int32')
+    x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'),
A[z])
     ib.emit(x)
     body = ib.get()
     func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
@@ -680,6 +704,7 @@ if __name__ == "__main__":
     test_llvm_vadd_pipeline()
     test_llvm_add_pipeline()
     test_llvm_intrin()
+    test_llvm_overloaded_intrin()
     test_llvm_flip_pipeline()
     test_llvm_madd_pipeline()
     test_llvm_temp_space()
diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py
index bdda496..b7da66f 100644
--- a/topi/python/topi/arm_cpu/bitserial_conv2d.py
+++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py
@@ -197,7 +197,6 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
         ww, xx = ins
         zz = outs[0]
 
-        args_1 = tvm.tir.const(1, 'uint32')
         args_2 = tvm.tir.const(2, 'uint32')
 
         if unipolar:
@@ -237,10 +236,10 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
                             cnts8[i] = upper_half + lower_half
                         for i in range(m//2):
                             cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_1, cnts8[i*2], cnts8[i*2+1])
+                                                                args_2, cnts8[i*2], cnts8[i*2+1])
                         for i in range(m//4):
                             cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_1, cnts4[i*2], cnts4[i*2+1])
+                                                                args_2, cnts4[i*2], cnts4[i*2+1])
                         cnts = tvm.tir.call_pure_intrin(
                             full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                         shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
@@ -257,10 +256,10 @@ def _intrin_popcount(m, k_i, w_b, x_b, unipolar):
                                 cnts8[i] = tvm.tir.popcount(w_ & x_)
                         for i in range(m//2):
                             cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_1, cnts8[i*2], cnts8[i*2+1])
+                                                                args_2, cnts8[i*2], cnts8[i*2+1])
                         for i in range(m//4):
                             cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
-                                                                args_1, cnts4[i*2], cnts4[i*2+1])
+                                                                args_2, cnts4[i*2], cnts4[i*2+1])
                         cnts = tvm.tir.call_pure_intrin(
                             full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
                         shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)


Mime
View raw message