tvm-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Gus Smith <notificati...@github.com>
Subject [dmlc/tvm] [RFC][Datatypes] Bring Your Own Datatypes: supporting custom (software) datatypes (#3060)
Date Sun, 21 Apr 2019 01:51:34 GMT
See the custom datatypes pull request at #2900. 

This RFC proposes support for custom datatypes. Specifically, we only propose supporting software
datatypes at the moment. By "software datatype", we mean a datatype (e.g. a [posit](https://en.wikipedia.org/wiki/Unum_(number_format)#Type_III_Unum_-_Posit))
which is implemented by a library (e.g. [SoftPosit](https://gitlab.com/cerlane/SoftPosit)).
That is, this RFC does not make mention of supporting custom datatype hardware.

Research into custom datatypes for machine learning has been picking up as of late; see, for
example, Facebook's [Rethinking Floating Point for Deep Learning](https://arxiv.org/abs/1811.01721)
paper, and the [Deep Positron](https://arxiv.org/abs/1812.01762) paper. Often, much of the
initial exploration into new datatypes is done by emulating datatype hardware in software,
which is slower, but allows the researcher to evaluate the numerical behavior of their datatype.
TVM is perfectly primed to support this kind of exploration, and supporting software custom
datatypes will take very little modification. With support for software custom datatypes,
TVM can become an indispensable tool for datatype researchers seeking to test their datatypes
on real deep learning models.

# Proposed Design

This design assumes the programmer has a datatype library which they would like to use. In
this example (which is included as a unit test in the linked PR) we will implement `bfloat16`.

## Frontend

We will first allow the programmer to register custom datatypes in Python:
```python
tvm.custom_datatypes.register("bfloat", 24)
```
This will reserve the type code 24 for a custom datatype called `bfloat`. The programmer will
now be free to use `bfloat` wherever they might have used built-in datatypes in the past:
```python
  X = tvm.placeholder((3, ), name="X")
  Y = tvm.placeholder((3, ), name="Y")
  Z = topi.cast(
      topi.cast(X, dtype="custom[bfloat]16") +
      topi.cast(Y, dtype="custom[bfloat]16"),
      dtype="float")
```
Note the `dtype` string formatting used to signify custom datatypes. When this code is compiled,
TVM will need to know how to lower operations involving the custom datatypes. In this example,
we see two `float`-to-`bfloat` casts, a `bfloat` add, and a `bfloat`-to-`float` cast. We will
allow the programmer to register a lowering function for each type of operation on their custom
datatype:
```python
tvm.custom_datatypes.register_op(
    tvm.custom_datatypes.create_lower_func("FloatToBFloat16_wrapper"),
    "Cast", "llvm", "bfloat", "float")
tvm.custom_datatypes.register_op(
    tvm.custom_datatypes.create_lower_func("BFloat16ToFloat_wrapper"),
    "Cast", "llvm", "float", "bfloat")
tvm.custom_datatypes.register_op(
    tvm.custom_datatypes.create_lower_func("BFloat16Add_wrapper"), "Add",
    "llvm", "bfloat")
```
`register_op` takes a lowering function, the name of the operation, the compilation target,
and the datatype. Here, we use a convenience function, `create_lower_func`. This function
creates a lowering function which will lower matching operations to a call to an external
function, whose name is passed to `create_lower_func`. So, for example, in the first call
to `register_op` we create a lowering function which will lower casts from `float`s to `bfloat`s
with calls to an external library function called `FloatToBFloat16_wrapper`. These library
functions can be made available by loading them:
```python
CDLL("libmybfloat16.so", RTLD_GLOBAL)
```
Finally, we can build our program:
```python
s = tvm.create_schedule([Z.op])
flist = tvm.lower(s, [X, Y, Z])
flist = [flist]
flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist]
built_cast = tvm.build(flist[0], target=tgt)
```
Note that we manually run the datatype lowering pass. Once this pass runs, all custom datatypes
will be lowered to implementations using built-in datatypes.

## Backend

In this section we describe additions to the backend which support the proposed frontend design.

### Custom Datatype Registry

`src/codegen/custom_datatypes/registry.{h,cc}` implements the *datatype registry*. This registry
allows programmers to register custom datatypes, choosing a name and a code. The registry
is global at TVM compile-time and runtime, and is used to get information about custom datatypes
at many different points.

### Storage Types

When the programmer specifies their custom datatype as a `dtype` parameter, they format it
as `custom[<type name>]<bits>`. The `<bits>` field specifies the width of
the datatype, as it does for built-in datatypes. This information is especially important
for custom datatypes, however, as it specifies the underlying *storage type* of the custom
datatype. When the custom datatype is lowered, it will be lowered to an opaque unsigned integer
of the length specified in `<bits>`; this is the storage type of the custom datatype.

### Lowering Function Registration

When the programmer uses `register_op` to register a lowering function, on the backend we
register the lowering function as a TVM global under the namespace `tvm.custom_datatypes.lower`.
For casts, this looks like `tvm.custom_datatypes.lower.Cast.<target>.<type>.<src_type>`.
For other types of ops, this looks like `tvm.custom_datatypes.lower.<op>.<target>.<type>`.
This makes it possible to easily locate the lowering functions later on.

### Datatype Lowering Pass

Finally, to lower the datatypes, we implement the datatype lowering pass in `src/pass/lower_custom_datatypes.cc`.
After the pass runs, all uses of custom datatypes will have been lowered to their appropriate
storage types. Each time the pass finds an IR node of a custom datatype, it looks up the appropriate
lowering function using the name format described above. The pass transforms the node using
the registered lowering function.
In our example above, a function which lowers `bfloat` adds to the `BFloat16Add_wrapper` function
gets registered as `tvm.custom_datatypes.lower.Add.llvm.bfloat`. During datatype lowering,
the pass looks up this function and uses it to transform the `bfloat` add node into a call
node, calling the `BFloatAdd_wrapper` function.

# Roadmap
- [x] Add enough custom datatype infrastructure to support simple examples (e.g. a simple
`bfloat` program).
- [ ] Identify real datatype libraries to begin testing with. (1 week)
- [ ] Work out bugs involved in getting infrastructure working with a real library. (2 weeks)
- [ ] Test case: inference with a commonly-used deep learning model. (3 weeks)
- [ ] Test case: training with a commonly-used deep learning model. (3 weeks)


-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/3060
Mime
  • Unnamed multipart/alternative (inline, 7-Bit, 0 bytes)
View raw message