Updated design details

# Details on legalization

Since most of the HW has no native support for computation on bf16, we added a 
pass `BF16Legalization` to use fp32 computing bf16 data. It has 3 sub-passes: 
`Promotion`, `Elimilination` and `Lowering`.

## BF16Promotion

It adds `cast_to_fp32()` before each Op involing bf16 operands, and use Ops of 
fp32 to compute. Finally, it adds a 'cast_to_bf16()' after each Op that is 
altered. e.g.

`add(a,b)` => `cast16(add(cast32(a), cast32(b)))`

We call this phase as "BF16Promotion". It is a sub-pass of `BF16Legalization` 
pass.

## BF16CastElimination

We note that this will add redundant casting. e.g.

`add(a, neg(b))` => `cast16(add(cast32(a), cast32(cast16(neg(cast32(b))))) `

The pattern `cast32(cast16(some_fp32_value))` can be simplified to 
`some_fp32_value`.

Thus, we add an optimization pass after "BF16Promotion" in `BF16Legalization` 
pass, which eliminates redundant casts.

## BF16Lowering

This pass replace all dtypes of bf16 to uint16. It also lowers the cast between 
bf16 and fp32 with shifting and other TIR nodes.

After `BF16Legalization` pass, there will be no bf16 related node in the IR.





---
[Visit Topic](https://discuss.tvm.ai/t/rfc-add-bfloat16-data-type/6778/5) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/0cf51e82dc090a277131da1199678a795d4e9b67fae5131d6607e89a7cbeacf6).

Reply via email to