See the custom datatypes pull request at #2900. 

This RFC proposes support for custom datatypes. Specifically, we only propose 
supporting software datatypes at the moment. By "software datatype", we mean a 
datatype (e.g. a 
[posit](https://en.wikipedia.org/wiki/Unum_(number_format)#Type_III_Unum_-_Posit))
 which is implemented by a library (e.g. 
[SoftPosit](https://gitlab.com/cerlane/SoftPosit)). That is, this RFC does not 
make mention of supporting custom datatype hardware.

Research into custom datatypes for machine learning has been picking up as of 
late; see, for example, Facebook's [Rethinking Floating Point for Deep 
Learning](https://arxiv.org/abs/1811.01721) paper, and the [Deep 
Positron](https://arxiv.org/abs/1812.01762) paper. Often, much of the initial 
exploration into new datatypes is done by emulating datatype hardware in 
software, which is slower, but allows the researcher to evaluate the numerical 
behavior of their datatype. TVM is perfectly primed to support this kind of 
exploration, and supporting software custom datatypes will take very little 
modification. With support for software custom datatypes, TVM can become an 
indispensable tool for datatype researchers seeking to test their datatypes on 
real deep learning models.

# Proposed Design

This design assumes the programmer has a datatype library which they would like 
to use. In this example (which is included as a unit test in the linked PR) we 
will implement `bfloat16`.

## Frontend

We will first allow the programmer to register custom datatypes in Python:
```python
tvm.custom_datatypes.register("bfloat", 24)
```
This will reserve the type code 24 for a custom datatype called `bfloat`. The 
programmer will now be free to use `bfloat` wherever they might have used 
built-in datatypes in the past:
```python
  X = tvm.placeholder((3, ), name="X")
  Y = tvm.placeholder((3, ), name="Y")
  Z = topi.cast(
      topi.cast(X, dtype="custom[bfloat]16") +
      topi.cast(Y, dtype="custom[bfloat]16"),
      dtype="float")
```
Note the `dtype` string formatting used to signify custom datatypes. When this 
code is compiled, TVM will need to know how to lower operations involving the 
custom datatypes. In this example, we see two `float`-to-`bfloat` casts, a 
`bfloat` add, and a `bfloat`-to-`float` cast. We will allow the programmer to 
register a lowering function for each type of operation on their custom 
datatype:
```python
tvm.custom_datatypes.register_op(
    tvm.custom_datatypes.create_lower_func("FloatToBFloat16_wrapper"),
    "Cast", "llvm", "bfloat", "float")
tvm.custom_datatypes.register_op(
    tvm.custom_datatypes.create_lower_func("BFloat16ToFloat_wrapper"),
    "Cast", "llvm", "float", "bfloat")
tvm.custom_datatypes.register_op(
    tvm.custom_datatypes.create_lower_func("BFloat16Add_wrapper"), "Add",
    "llvm", "bfloat")
```
`register_op` takes a lowering function, the name of the operation, the 
compilation target, and the datatype. Here, we use a convenience function, 
`create_lower_func`. This function creates a lowering function which will lower 
matching operations to a call to an external function, whose name is passed to 
`create_lower_func`. So, for example, in the first call to `register_op` we 
create a lowering function which will lower casts from `float`s to `bfloat`s 
with calls to an external library function called `FloatToBFloat16_wrapper`. 
These library functions can be made available by loading them:
```python
CDLL("libmybfloat16.so", RTLD_GLOBAL)
```
Finally, we can build our program:
```python
s = tvm.create_schedule([Z.op])
flist = tvm.lower(s, [X, Y, Z])
flist = [flist]
flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist]
built_cast = tvm.build(flist[0], target=tgt)
```
Note that we manually run the datatype lowering pass. Once this pass runs, all 
custom datatypes will be lowered to implementations using built-in datatypes.

## Backend

In this section we describe additions to the backend which support the proposed 
frontend design.

### Custom Datatype Registry

`src/codegen/custom_datatypes/registry.{h,cc}` implements the *datatype 
registry*. This registry allows programmers to register custom datatypes, 
choosing a name and a code. The registry is global at TVM compile-time and 
runtime, and is used to get information about custom datatypes at many 
different points.

### Storage Types

When the programmer specifies their custom datatype as a `dtype` parameter, 
they format it as `custom[<type name>]<bits>`. The `<bits>` field specifies the 
width of the datatype, as it does for built-in datatypes. This information is 
especially important for custom datatypes, however, as it specifies the 
underlying *storage type* of the custom datatype. When the custom datatype is 
lowered, it will be lowered to an opaque unsigned integer of the length 
specified in `<bits>`; this is the storage type of the custom datatype.

### Lowering Function Registration

When the programmer uses `register_op` to register a lowering function, on the 
backend we register the lowering function as a TVM global under the namespace 
`tvm.custom_datatypes.lower`. For casts, this looks like 
`tvm.custom_datatypes.lower.Cast.<target>.<type>.<src_type>`. For other types 
of ops, this looks like `tvm.custom_datatypes.lower.<op>.<target>.<type>`. This 
makes it possible to easily locate the lowering functions later on.

### Datatype Lowering Pass

Finally, to lower the datatypes, we implement the datatype lowering pass in 
`src/pass/lower_custom_datatypes.cc`. After the pass runs, all uses of custom 
datatypes will have been lowered to their appropriate storage types. Each time 
the pass finds an IR node of a custom datatype, it looks up the appropriate 
lowering function using the name format described above. The pass transforms 
the node using the registered lowering function.
In our example above, a function which lowers `bfloat` adds to the 
`BFloat16Add_wrapper` function gets registered as 
`tvm.custom_datatypes.lower.Add.llvm.bfloat`. During datatype lowering, the 
pass looks up this function and uses it to transform the `bfloat` add node into 
a call node, calling the `BFloatAdd_wrapper` function.

# Roadmap
- [x] Add enough custom datatype infrastructure to support simple examples 
(e.g. a simple `bfloat` program).
- [ ] Identify real datatype libraries to begin testing with. (1 week)
- [ ] Work out bugs involved in getting infrastructure working with a real 
library. (2 weeks)
- [ ] Test case: inference with a commonly-used deep learning model. (3 weeks)
- [ ] Test case: training with a commonly-used deep learning model. (3 weeks)


-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/3060

Reply via email to