The GitHub Actions job "Teams" on tvm.git/main has succeeded. Run started by GitHub user yanyanyanggg (triggered by yanyanyanggg).
Head commit for run: bddc091bffc31a3cc9dde16c169222774784e0dc / Park Woorak <[email protected]> [TIR][Schedule] Fix bug on bfloat16 conversion (#18556) ## Description This PR fixes a conversion bug that occurs when performing operations on `bfloat16` tensors. In conclusion, when applying the `BF16ComputeLegalize` compile pass and visiting a `BufferStoreNode`, if the stored value's dtype is different from the buffer's, `DTypeConversion()` should be used instead of a simple `cast` to apply the appropriate conversion logic. ## Test I added a test for this situation based on the existing tests. With the fix, `B[i] = A[i]` turns into `B[i] = bf16tof32(A[i])` properly, so the test passes. I'm not really sure whether the structure or name of this added test is appropriate. So let me gladly modify it if there is any comment on this. ## Process ### Problem observed This bug was identified when applying `nn.Linear()` to a `bfloat16` tensor resulted in excessively large numbers. While it appears to exist in other operations as well, it's particularly noticeable when the inner dimension of `MatMul` is a multiple of `8`(`16` for CUDA and ROCm). #### Example of problematic code ```python from ml_dtypes import bfloat16 import numpy as np from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op from tvm.target import Target n = 10 INNER_DIM = 8 * n # if INNER_DIM is a multiple of 8 class TestModule(nn.Module): def __init__(self): self.weight = nn.Parameter((32, INNER_DIM), dtype=dtype) def run(self, x: Tensor): t = op.matmul(self.weight, x, out_dtype=dtype) return t def get_default_spec(self): mod_spec = { "run": { "x": nn.spec.Tensor([INNER_DIM, 100], dtype), "$": { "param_mode": "packed", "effect_mode": "none", }, }, } return nn.spec.ModuleSpec.from_raw(mod_spec, self) def compile_module(...): ... def main(): target = "metal" # or "cuda", "vulkan", ... model = TestModule() ex, _ = compile_module(model, target) device = tvm.device(target, 0) vm = create_vm(ex, device=device) frun = vm["run"] params = [] param = tvm.runtime.empty( (32, INNER_DIM), dtype="bfloat16", device=device, ) param.copyfrom(np.ones((32, INNER_DIM), dtype=bfloat16)) params.append(param) inputs = np.ones((INNER_DIM, 100), dtype=bfloat16) arr = frun(inputs, params) print(f"{arr=}") # arr has weird values! ``` In cases where the inner dimension is not a multiple of `8`(or `16`), the issue was avoided by applying `T.if_then_else()` through `PadEinsum`. `PadEinsum` itself wasn't a troublemaker, and rather helped identify the issue. ### Problem Identified I could see the problems were avoided by wrapping an expression with `T.if_then_else()` or `T.cast()` before applying `BF16ComputeLegalize` compile pass. #### Statement with problem ```python weight_reindex_shared[v0, v1, v2] = weight[v1, v2] ``` #### Statements without problem ```python # 1) wrapped with T.if_then_else() weight_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 511, weight[v1, v2], T.bfloat16(0.0)) # 2) wrapped with T.Cast() weight_reindex_pad_shared[v0, v1, v2] = T.Cast("float32", weight[v1, v2]) # ... ``` In the `BF16ComputeLegalize` compile pass, if a specific `Expr`(here, `weight[...]`) is processed through `PromoteToTarget()`(eventually, `DTypeConversion()`), the syntax changes to the syntax below(TO-BE), which applies the conversion logic. While the problematic statement simply applies `T.Cast()`(AS-IS). #### AS-IS ```python T.Cast("float32", weight[...]) ``` #### TO-BE ```python T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", weight[...])), T.uint32(16))) ``` ### Fixing the problem This situation is caused by L332 in the code below. Changing this part to apply `DTypeConversion()` instead of `cast()` will resolve the issue. (In the cases that the `Expr` is wrapped with `T.if_then_else()` or something else, the `Expr` is processed properly in other visit functions through L312 or L313. So the problems were avoided.) #### L332 ```diff - value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); + value = DTypeConversion(value, new_buf->dtype.with_lanes(value.dtype().lanes())); ``` https://github.com/apache/tvm/blob/26b107fa12672c3b958da222fc87755a69d64c42/src/tir/transforms/unsupported_dtype_legalize.cc#L311-L338 Report URL: https://github.com/apache/tvm/actions/runs/20051632486 With regards, GitHub Actions via GitBox --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
