tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] liangfu commented on a change in pull request #5601: [TVM Core] Add bfloat16
Date Tue, 19 May 2020 09:31:58 GMT

liangfu commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r427150166



##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const
{
       default:
         LOG(FATAL) << "do not support " << dtype;
     }
+  } else if (dtype.is_bfloat()) {
+    CHECK_EQ(dtype.bits(), 16);

Review comment:
       Since bfloat is assumed to be 16bit, can we keep the terminology more consistent? Since
the data type is termed as `bf`, `bf16`, `bfloat16`, `bfloat` in the proposed change. Or are
we going to support more data types like bfloat18 and bfloat20 in the future?

##########
File path: tests/python/unittest/test_tir_transform_bf16_legalize.py
##########
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import topi
+from tvm import te
+from tvm.tir import const
+
+
+def lower_stmt(sche, params, passfunc):
+    func = tvm.driver.build_module.form_irmodule(sche, params, "main", None)["main"]
+    func = passfunc()(
+        tvm.IRModule.from_expr(func))["main"]
+    stmt = func.body
+    return stmt
+
+def to32(v):
+    return topi.cast(v, 'float')
+def to16(v):
+    return topi.cast(v, 'bf16')
+
+def test_promote():
+    def runpass(op, passfunc):
+        a = te.placeholder((100,), dtype='bf16')
+        b = te.placeholder((100,), dtype='bf16')
+        c = te.compute((100,), lambda i: op(a[i], b[i]))
+        s = te.create_schedule(c.op)
+        return lower_stmt(s, [a, b, c], passfunc)
+    
+    def get_promoted(op):
+        a = te.placeholder((100,), dtype='bf16')
+        b = te.placeholder((100,), dtype='bf16')
+        c = te.compute((100,), lambda i:
+                topi.cast(op(topi.cast(a[i],'float'),
+                    topi.cast(b[i],'float')), 'bf16')
+                )
+        s = te.create_schedule(c.op)
+        func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"]
+        return func.body
+
+    def test_promoted(op):
+        stmt = runpass(op, tvm.tir.transform.BF16Promote)
+        tvm.ir.assert_structural_equal(stmt, get_promoted(op))
+    test_promoted(topi.add)
+    test_promoted(topi.subtract)
+    test_promoted(topi.multiply)
+    test_promoted(topi.divide)
+
+def test_eliminate():
+    def get_eliminated():
+        a = te.placeholder((100,), dtype='bf16')
+        b = te.placeholder((100,), dtype='bf16')
+        c = te.compute((100,), lambda i: to16(
+            topi.add(
+                to32(
+                    to16(
+                        topi.add(
+                            to32(a[i]),
+                            to32(b[i]),
+                        )
+                    )
+                ),
+                to32(
+                    to16(
+                        topi.add(
+                            to32(a[i]),
+                            to32(b[i]),
+                        )
+                    )
+                )
+            )
+        ))
+        s = te.create_schedule(c.op)
+        stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
+        return stmt
+
+    def get_target():
+        a = te.placeholder((100,), dtype='bf16')
+        b = te.placeholder((100,), dtype='bf16')
+        c = te.compute((100,), lambda i: to16(
+            topi.add(topi.add(
+                        to32(a[i]),
+                        to32(b[i]),
+                    ),
+                    topi.add(
+                        to32(a[i]),
+                        to32(b[i]),
+                    )
+                )
+        ))
+        s = te.create_schedule(c.op)
+        func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"]
+        return func.body
+
+    tvm.ir.assert_structural_equal(get_eliminated(), get_target())
+
+def test_legalize():
+    def check(fcompute_before, fcompute_after):
+        a = te.placeholder((100,), dtype='bf16')
+        b = te.placeholder((100,), dtype='bf16')
+        c = te.compute((100,), fcompute_before(a,b))
+        s = te.create_schedule(c.op)
+        stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize)
+
+        a = te.placeholder((100,), dtype='bf16')
+        b = te.placeholder((100,), dtype='bf16')
+        c = te.compute((100,), fcompute_after(a,b))
+        s = te.create_schedule(c.op)
+        func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", None)["main"]
+        tvm.ir.assert_structural_equal(stmt, func.body)
+
+    def orig1(a,b):
+        return lambda i: a[i]+b[i]+a[99-i]+b[99-i]
+    def after1(a,b):
+        return lambda i: to16(to32(a[i])+to32(b[i])+to32(a[99-i])+to32(b[99-i]))
+    def orig1(a,b):
+        return lambda i: a[i]*b[i]+a[99-i]*b[99-i]+a[i]
+    def after1(a,b):
+        return lambda i: to16(to32(a[i])*to32(b[i])+to32(a[99-i])*to32(b[99-i])+to32(a[i]))
+
+    check(orig1, after1)
+
+if __name__ == "__main__":
+    test_promote()
+    test_eliminate()
+    test_legalize()

Review comment:
       Please leave a new line at EOF, even this is test script :)

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -906,7 +954,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
   llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
     if (t.is_int()) {                                                                \
       return builder_->CreateICmpS##Op(a, b);                                        \
-    } else if (t.is_uint()) {                                                        \
+    } else if (t.is_uint() || t.is_bfloat()) {                                       \

Review comment:
       Isn't comparing bfloat16 this way risky?

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -555,12 +558,48 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
llvm::Va
   builder_->SetInsertPoint(for_end);
 }
 
+static llvm::Value* GetInt32VectorOrScalar(
+    llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>* builder,
uint32_t v,
+    int lanes) {
+  if (lanes == 1) {
+    return builder->getInt32(v);
+  } else {
+    std::vector<llvm::Constant*> consts;
+    for (int i = 0; i < lanes; i++) {
+      consts.emplace_back(builder->getInt32(v));
+    }
+    return llvm::ConstantVector::get(consts);
+  }
+}
+
 // cast operatpr
 llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
   llvm::Type* target = DTypeToLLVMType(to);
   if (value->getType() == target) return value;
   if (to.is_handle()) {
     return builder_->CreateBitCast(value, target);
+  } else if (to.is_float() && from.is_bfloat()) {
+    CHECK_EQ(from.bits(), 16);
+    CHECK_EQ(to.bits(), 32);
+    llvm::Type* extended_type = (from.lanes() == 1)
+                                    ? static_cast<llvm::Type*>(builder_->getInt32Ty())
+                                    : llvm::VectorType::get(builder_->getInt32Ty(), from.lanes());
+    auto v = builder_->CreateZExt(value, extended_type);
+    v = builder_->CreateShl(v, 16);

Review comment:
       Potential endianness problem here?

##########
File path: tests/python/unittest/test_target_codegen_llvm.py
##########
@@ -710,6 +710,52 @@ def _transform(f, *_):
         module(a_, b_, c_)
         tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
 
+def np_float2np_bf16(arr):
+    ''' Convert a numpy array of float to a numpy array 
+    of bf16 in uint16'''
+    orig = arr.view('<u4')
+    bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
+    return np.right_shift(orig + bias, 16).astype('uint16')
+
+def np_float2tvm_bf16(arr):
+    ''' Convert a numpy array of float to a TVM array 
+    of bf16'''
+    nparr = np_float2np_bf16(arr)
+    return tvm.nd.empty(nparr.shape, 'bf16').copyfrom(nparr)
+
+def np_bf162np_float(arr):
+    ''' Convert a numpy array of bf16 (uint16) to a numpy array 
+    of float'''
+    u32 = np.left_shift(arr.astype('uint32'), 16)

Review comment:
       Are we going to produce a potential endianness problem here?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message