The GitHub Actions job "Teams" on tvm.git/main has succeeded.
Run started by GitHub user yanyanyanggg (triggered by yanyanyanggg).

Head commit for run:
bddc091bffc31a3cc9dde16c169222774784e0dc / Park Woorak <[email protected]>
[TIR][Schedule] Fix bug on bfloat16 conversion (#18556)

## Description
This PR fixes a conversion bug that occurs when performing operations on
`bfloat16` tensors.
In conclusion, when applying the `BF16ComputeLegalize` compile pass and
visiting a `BufferStoreNode`, if the stored value's dtype is different
from the buffer's, `DTypeConversion()` should be used instead of a
simple `cast` to apply the appropriate conversion logic.


## Test
I added a test for this situation based on the existing tests.
With the fix, `B[i] = A[i]` turns into `B[i] = bf16tof32(A[i])`
properly, so the test passes.

I'm not really sure whether the structure or name of this added test is
appropriate.
So let me gladly modify it if there is any comment on this.

## Process
### Problem observed
This bug was identified when applying `nn.Linear()` to a `bfloat16`
tensor resulted in excessively large numbers.
While it appears to exist in other operations as well, it's particularly
noticeable when the inner dimension of `MatMul` is a multiple of
`8`(`16` for CUDA and ROCm).
#### Example of problematic code
```python
from ml_dtypes import bfloat16
import numpy as np
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.target import Target

n = 10
INNER_DIM = 8 * n   # if INNER_DIM is a multiple of 8

class TestModule(nn.Module):
  def __init__(self):
    self.weight = nn.Parameter((32, INNER_DIM), dtype=dtype)

  def run(self, x: Tensor):
    t = op.matmul(self.weight, x, out_dtype=dtype)
    return t

  def get_default_spec(self):
    mod_spec = {
      "run": {
        "x": nn.spec.Tensor([INNER_DIM, 100], dtype),
        "$": {
          "param_mode": "packed",
          "effect_mode": "none",
        },
      },
    }
    return nn.spec.ModuleSpec.from_raw(mod_spec, self)

def compile_module(...):
  ...

def main():
  target = "metal" # or "cuda", "vulkan", ...
  model = TestModule()
  ex, _ = compile_module(model, target)
  device = tvm.device(target, 0)
  vm = create_vm(ex, device=device)
  frun = vm["run"]
  
  params = []
  param = tvm.runtime.empty(
    (32, INNER_DIM),
    dtype="bfloat16",
    device=device,
  )
  param.copyfrom(np.ones((32, INNER_DIM), dtype=bfloat16))
  params.append(param)
  
  inputs = np.ones((INNER_DIM, 100), dtype=bfloat16)
  arr = frun(inputs, params)
  print(f"{arr=}")  # arr has weird values!
```

In cases where the inner dimension is not a multiple of `8`(or `16`),
the issue was avoided by applying `T.if_then_else()` through
`PadEinsum`. `PadEinsum` itself wasn't a troublemaker, and rather helped
identify the issue.

### Problem Identified
I could see the problems were avoided by wrapping an expression with
`T.if_then_else()` or `T.cast()` before applying `BF16ComputeLegalize`
compile pass.

#### Statement with problem
```python
weight_reindex_shared[v0, v1, v2] = weight[v1, v2]
```
#### Statements without problem
```python
# 1) wrapped with T.if_then_else()
weight_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 511, weight[v1, 
v2], T.bfloat16(0.0))
# 2) wrapped with T.Cast()
weight_reindex_pad_shared[v0, v1, v2] = T.Cast("float32", weight[v1, v2])
# ...
```

In the `BF16ComputeLegalize` compile pass, if a specific `Expr`(here,
`weight[...]`) is processed through `PromoteToTarget()`(eventually,
`DTypeConversion()`), the syntax changes to the syntax below(TO-BE),
which applies the conversion logic. While the problematic statement
simply applies `T.Cast()`(AS-IS).

#### AS-IS
```python
T.Cast("float32", weight[...])
```
#### TO-BE
```python
T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", 
weight[...])), T.uint32(16)))
```

### Fixing the problem
This situation is caused by L332 in the code below. Changing this part
to apply `DTypeConversion()` instead of `cast()` will resolve the issue.
(In the cases that the `Expr` is wrapped with `T.if_then_else()` or
something else, the `Expr` is processed properly in other visit
functions through L312 or L313. So the problems were avoided.)

#### L332
```diff
- value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value);
+ value = DTypeConversion(value, 
new_buf->dtype.with_lanes(value.dtype().lanes()));
```


https://github.com/apache/tvm/blob/26b107fa12672c3b958da222fc87755a69d64c42/src/tir/transforms/unsupported_dtype_legalize.cc#L311-L338

Report URL: https://github.com/apache/tvm/actions/runs/20051632486

With regards,
GitHub Actions via GitBox


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to