See the custom datatypes pull request at #2900. This RFC proposes support for custom datatypes. Specifically, we only propose supporting software datatypes at the moment. By "software datatype", we mean a datatype (e.g. a [posit](https://en.wikipedia.org/wiki/Unum_(number_format)#Type_III_Unum_-_Posit)) which is implemented by a library (e.g. [SoftPosit](https://gitlab.com/cerlane/SoftPosit)). That is, this RFC does not make mention of supporting custom datatype hardware.
Research into custom datatypes for machine learning has been picking up as of late; see, for example, Facebook's [Rethinking Floating Point for Deep Learning](https://arxiv.org/abs/1811.01721) paper, and the [Deep Positron](https://arxiv.org/abs/1812.01762) paper. Often, much of the initial exploration into new datatypes is done by emulating datatype hardware in software, which is slower, but allows the researcher to evaluate the numerical behavior of their datatype. TVM is perfectly primed to support this kind of exploration, and supporting software custom datatypes will take very little modification. With support for software custom datatypes, TVM can become an indispensable tool for datatype researchers seeking to test their datatypes on real deep learning models. # Proposed Design This design assumes the programmer has a datatype library which they would like to use. In this example (which is included as a unit test in the linked PR) we will implement `bfloat16`. ## Frontend We will first allow the programmer to register custom datatypes in Python: ```python tvm.custom_datatypes.register("bfloat", 24) ``` This will reserve the type code 24 for a custom datatype called `bfloat`. The programmer will now be free to use `bfloat` wherever they might have used built-in datatypes in the past: ```python X = tvm.placeholder((3, ), name="X") Y = tvm.placeholder((3, ), name="Y") Z = topi.cast( topi.cast(X, dtype="custom[bfloat]16") + topi.cast(Y, dtype="custom[bfloat]16"), dtype="float") ``` Note the `dtype` string formatting used to signify custom datatypes. When this code is compiled, TVM will need to know how to lower operations involving the custom datatypes. In this example, we see two `float`-to-`bfloat` casts, a `bfloat` add, and a `bfloat`-to-`float` cast. We will allow the programmer to register a lowering function for each type of operation on their custom datatype: ```python tvm.custom_datatypes.register_op( tvm.custom_datatypes.create_lower_func("FloatToBFloat16_wrapper"), "Cast", "llvm", "bfloat", "float") tvm.custom_datatypes.register_op( tvm.custom_datatypes.create_lower_func("BFloat16ToFloat_wrapper"), "Cast", "llvm", "float", "bfloat") tvm.custom_datatypes.register_op( tvm.custom_datatypes.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", "bfloat") ``` `register_op` takes a lowering function, the name of the operation, the compilation target, and the datatype. Here, we use a convenience function, `create_lower_func`. This function creates a lowering function which will lower matching operations to a call to an external function, whose name is passed to `create_lower_func`. So, for example, in the first call to `register_op` we create a lowering function which will lower casts from `float`s to `bfloat`s with calls to an external library function called `FloatToBFloat16_wrapper`. These library functions can be made available by loading them: ```python CDLL("libmybfloat16.so", RTLD_GLOBAL) ``` Finally, we can build our program: ```python s = tvm.create_schedule([Z.op]) flist = tvm.lower(s, [X, Y, Z]) flist = [flist] flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist] built_cast = tvm.build(flist[0], target=tgt) ``` Note that we manually run the datatype lowering pass. Once this pass runs, all custom datatypes will be lowered to implementations using built-in datatypes. ## Backend In this section we describe additions to the backend which support the proposed frontend design. ### Custom Datatype Registry `src/codegen/custom_datatypes/registry.{h,cc}` implements the *datatype registry*. This registry allows programmers to register custom datatypes, choosing a name and a code. The registry is global at TVM compile-time and runtime, and is used to get information about custom datatypes at many different points. ### Storage Types When the programmer specifies their custom datatype as a `dtype` parameter, they format it as `custom[<type name>]<bits>`. The `<bits>` field specifies the width of the datatype, as it does for built-in datatypes. This information is especially important for custom datatypes, however, as it specifies the underlying *storage type* of the custom datatype. When the custom datatype is lowered, it will be lowered to an opaque unsigned integer of the length specified in `<bits>`; this is the storage type of the custom datatype. ### Lowering Function Registration When the programmer uses `register_op` to register a lowering function, on the backend we register the lowering function as a TVM global under the namespace `tvm.custom_datatypes.lower`. For casts, this looks like `tvm.custom_datatypes.lower.Cast.<target>.<type>.<src_type>`. For other types of ops, this looks like `tvm.custom_datatypes.lower.<op>.<target>.<type>`. This makes it possible to easily locate the lowering functions later on. ### Datatype Lowering Pass Finally, to lower the datatypes, we implement the datatype lowering pass in `src/pass/lower_custom_datatypes.cc`. After the pass runs, all uses of custom datatypes will have been lowered to their appropriate storage types. Each time the pass finds an IR node of a custom datatype, it looks up the appropriate lowering function using the name format described above. The pass transforms the node using the registered lowering function. In our example above, a function which lowers `bfloat` adds to the `BFloat16Add_wrapper` function gets registered as `tvm.custom_datatypes.lower.Add.llvm.bfloat`. During datatype lowering, the pass looks up this function and uses it to transform the `bfloat` add node into a call node, calling the `BFloatAdd_wrapper` function. # Roadmap - [x] Add enough custom datatype infrastructure to support simple examples (e.g. a simple `bfloat` program). - [ ] Identify real datatype libraries to begin testing with. (1 week) - [ ] Work out bugs involved in getting infrastructure working with a real library. (2 weeks) - [ ] Test case: inference with a commonly-used deep learning model. (3 weeks) - [ ] Test case: training with a commonly-used deep learning model. (3 weeks) -- You are receiving this because you are subscribed to this thread. Reply to this email directly or view it on GitHub: https://github.com/dmlc/tvm/issues/3060
