The GitHub Actions job "Nightly Docker Update" on tvm.git/main has failed.
Run started by GitHub user areusch (triggered by areusch).

Head commit for run:
e3f5ac1c6bccebc4bf1c35c9a1d81cf4c0a1740d / Yeongjae Jang 
<[email protected]>
[Relax] Correct YaRN RoPE frequency scaling formula to align with the original 
paper (#18576)

## Summary
Fixed frequency calculations for RoPE (YaRN) scaling and correct range
finding.

## Description
Greetings:

This PR corrects the mathematical formulation of the
[YaRN](https://arxiv.org/abs/2309.00071) RoPE scaling.
I have verified that this change eliminates the discrepancy observed
when comparing against PyTorch baseline (an implementation of
`gpt-oss`).

### in `yarn_find_correction_range()`
#### `low`, `high`
Removed `tir.floor` and `tir.ceil` operations in
`yarn_find_correction_dim()`.
In YaRN paper, there is no floor or ceil function within calculations of
those values.
In `gpt-oss`, the implementation uses floating-point values for these
thresholds to ensure smooth interpolation in the ramp function.
Rounding them caused quantization errors in the ramp mask.

### in `rope_freq_yarn()`
#### `freq_inter`
Currently, the implementation calculates the inverse frequency as:
```
freq_inter = tir.const(1, "float32") / tir.power(
    scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32")
)
```

This implies `scale` is also affected by the exponent, leading to
non-uniform scaling across dimensions.

According to the YaRN method (and an implementation of `gpt-oss`), the
scaling factor should be applied linearly:
```
exponent = d * 2 % d_range / tir.const(d_range, "float32")
freq_power = tir.power(theta, exponent)
freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
```

#### `d_range`
The `yarn_find_correction_range()` function was incorrectly using the
current dimension index `d` to calculate thresholds.
This caused the ramp boundaries to shift dynamically per dimension. 
It has been corrected to use the total dimension size (`d_range`) to
ensure consistent frequency thresholds.

Before: 
```
yarn_find_correction_range(..., d, ...)
```

After: 
```
yarn_find_correction_range(..., d_range, ...)
```

Thank you very much for reading.

---------

Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>

Report URL: https://github.com/apache/tvm/actions/runs/20647757109

With regards,
GitHub Actions via GitBox


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to