This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a7bfc857b9 [FIX] Inline ceil_log2 in gpu_2d_continuous_cumsum to fix
MakePackedAPI error (#18957)
a7bfc857b9 is described below
commit a7bfc857b901177a5e82a86d6f9cc2ffed763e09
Author: Gabe Guralnick <[email protected]>
AuthorDate: Wed Apr 1 21:00:49 2026 -0700
[FIX] Inline ceil_log2 in gpu_2d_continuous_cumsum to fix MakePackedAPI
error (#18957)
- The intermediate variable `ceil_log2` in `gpu_2d_continuous_cumsum`
created a `LetStmt`-bound `Var` in the TIR function
- When `MakePackedAPI` processed the function, it reported `ceil_log2`
as an undefined variable not passed as an API argument
- Inline the expression directly into `total_rounds` to avoid the
intermediate `Var` — the computation is identical
## Test plan
- Compile a model that uses GPU sampling (e.g. any LLM with top-p
sampling on Metal) and verify compilation succeeds
- The error this fixes: `Check failed: undefined.size() == 0: In
PrimFunc gpu_2d_continuous_cumsum variables [ceil_log2] are used, but
are not passed in as API arguments`
Co-authored-by: Akaash Parthasarathy
<[email protected]>
---
python/tvm/relax/backend/gpu_generic/cumsum.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/python/tvm/relax/backend/gpu_generic/cumsum.py
b/python/tvm/relax/backend/gpu_generic/cumsum.py
index bd2cec3bcd..a2054fdf41 100644
--- a/python/tvm/relax/backend/gpu_generic/cumsum.py
+++ b/python/tvm/relax/backend/gpu_generic/cumsum.py
@@ -159,8 +159,7 @@ def gpu_2d_continuous_cumsum(
A = T.match_buffer(var_a, [m, n], dtype=in_dtype)
Out = T.match_buffer(var_out, [m, n], dtype=out_dtype)
Tmp = T.alloc_buffer([m, n], dtype=out_dtype)
- ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
- total_rounds = ceil_log2 // LOG_BLOCK_N
+ total_rounds = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
// LOG_BLOCK_N
block_inclusive_inside_block(
m, n, A, Out, Tmp, src_offset=T.int64(0), tmp_offset=T.int64(0)