Updated design details
# Details on legalization Since most of the HW has no native support for computation on bf16, we added a pass `BF16Legalization` to use fp32 computing bf16 data. It has 3 sub-passes: `Promotion`, `Elimilination` and `Lowering`. ## BF16Promotion It adds `cast_to_fp32()` before each Op involing bf16 operands, and use Ops of fp32 to compute. Finally, it adds a 'cast_to_bf16()' after each Op that is altered. e.g. `add(a,b)` => `cast16(add(cast32(a), cast32(b)))` We call this phase as "BF16Promotion". It is a sub-pass of `BF16Legalization` pass. ## BF16CastElimination We note that this will add redundant casting. e.g. `add(a, neg(b))` => `cast16(add(cast32(a), cast32(cast16(neg(cast32(b))))) ` The pattern `cast32(cast16(some_fp32_value))` can be simplified to `some_fp32_value`. Thus, we add an optimization pass after "BF16Promotion" in `BF16Legalization` pass, which eliminates redundant casts. ## BF16Lowering This pass replace all dtypes of bf16 to uint16. It also lowers the cast between bf16 and fp32 with shifting and other TIR nodes. After `BF16Legalization` pass, there will be no bf16 related node in the IR. --- [Visit Topic](https://discuss.tvm.ai/t/rfc-add-bfloat16-data-type/6778/5) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/0cf51e82dc090a277131da1199678a795d4e9b67fae5131d6607e89a7cbeacf6).