This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a6e2ea8ac8 [Relax][Frontend][KVCache] Add masked sequence prefill
helper for encoder valid lengths (#19392)
a6e2ea8ac8 is described below
commit a6e2ea8ac8443e3ecebe0ef59b1ea7e410210458
Author: Xijing Wang <[email protected]>
AuthorDate: Mon Apr 13 21:13:56 2026 -0400
[Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder
valid lengths (#19392)
Adds `_attention_sequence_prefill_with_mask` in
`python/tvm/relax/frontend/nn/llm/kv_cache.py` — a masked variant of the
existing sequence prefill kernel that supports right-padded encoder
batches with per-sample `valid_lens`.
The existing `_attention_sequence_prefill` assumes all positions in `[0,
seq_len)` are valid, which breaks for padded encoder inputs where each
batch element has a different valid prefix length. This helper adds the
masking semantics needed for correctness:
- accepts a per-batch `valid_lens` input
- ignores padded query rows and padded key/value positions
- excludes padded `(row, col)` pairs from the online softmax update
It reuses the existing prefill kernel config and schedule — no new
tuning knobs, no target-specific changes, no performance claims.
Correctness only.
## Motivation: encoder batch prefill for downstream consumers
This is the TVM-side primitive needed to support **encoder batch
prefill** in downstream projects like `mlc-llm`, where padded encoder
batches with `valid_lens` need to be lowered without materializing an
explicit broadcast attention mask on the host.
The helper is generic and useful for any encoder-style sequence prefill
consumer with per-sample valid lengths.
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 267 +++++++++++++++++++++
...test_frontend_nn_llm_sequence_prefill_masked.py | 197 +++++++++++++++
2 files changed, 464 insertions(+)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 749707cb29..12cdcbcdce 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -2305,6 +2305,273 @@ def _attention_sequence_prefill(h_kv, h_q, d, dtype,
target: Target, causal=0, s
return sch.mod["main"].with_attr("tirx.is_scheduled", True)
+def _attention_sequence_prefill_with_mask(h_kv, h_q, d, dtype, target: Target,
sm_scale=1.0): # pylint: disable=line-too-long
+ """Tiled sequence prefill kernel with a per-batch right-padding mask.
+
+ This is the counterpart of :func:`_attention_sequence_prefill` for batched
+ encoder-style inputs where each sample in the batch is padded to a common
+ ``seq_len`` but only the first ``valid_lens[b]`` tokens carry real content.
+ The kernel takes an extra ``valid_lens`` buffer of shape ``(batch_size,)``
+ and applies the padding mask inside the QKV load path and the online
+ softmax update, so no explicit mask tensor broadcast or additive bias is
+ needed on the host side.
+
+ Semantics: for batch ``b``, positions ``[0, valid_lens[b])`` are real and
+ positions ``[valid_lens[b], seq_len)`` are padding. Padding queries and
+ keys/values are zeroed at load time; padded ``(row, col)`` pairs are
+ excluded from the max/sum of the online softmax via a ``-inf`` slot.
+ """
+ (
+ _,
+ LOAD_VEC,
+ group_size,
+ bdx,
+ num_warps,
+ tile_x,
+ tile_y,
+ tile_z,
+ ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
+
+ def _valid_length_mask(valid_len, row, col, qo_len):
+ """Return True when both the query row and the key col are unpadded."""
+ return tirx.And(
+ tirx.And(row < qo_len, row < valid_len),
+ col < valid_len,
+ )
+
+ # fmt: off
+ @T.prim_func
+ def batch_sequence_prefill_kv_masked( # pylint: disable=too-many-branches
+ var_q: T.handle, # [batch_size, qo_len, h_q, d]
+ var_k: T.handle, # [batch_size, kv_len, h_kv, d]
+ var_v: T.handle, # [batch_size, kv_len, h_kv, d]
+ var_valid_lens: T.handle, # [batch_size], int32
+ var_output: T.handle, # [batch_size, qo_len, h_q, d]
+ var_lse: T.handle # [batch_size, qo_len, h_q]
+ ):
+ batch_size = T.int32(is_size_var=True)
+ qo_len = T.int32(is_size_var=True)
+ kv_len = T.int32(is_size_var=True)
+ q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)
+ k = T.match_buffer(var_k, (batch_size, kv_len, h_kv, d), dtype)
+ v = T.match_buffer(var_v, (batch_size, kv_len, h_kv, d), dtype)
+ valid_lens = T.match_buffer(var_valid_lens, (batch_size,), "int32")
+ output = T.match_buffer(var_output, (batch_size, qo_len, h_q, d),
dtype)
+ lse = T.match_buffer(var_lse, (batch_size, qo_len, h_q), dtype)
+
+ batch_tiles: T.int32 = T.ceildiv(qo_len * group_size, tile_x)
+
+ for lbx in T.thread_binding(T.cast(batch_size, "int32") * batch_tiles,
thread="blockIdx.x"):
+ for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
+ for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
+ for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
+ with T.sblock("attn"):
+ vbx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby,
lty, ltx])
+ T.reads()
+ T.writes()
+
+ Q_smem = T.sblock_alloc_buffer((tile_x, d), dtype,
scope="shared")
+ K_smem = T.sblock_alloc_buffer((tile_z, d), dtype,
scope="shared")
+ V_smem = T.sblock_alloc_buffer((tile_z, d), dtype,
scope="shared")
+ S_smem = T.sblock_alloc_buffer((tile_x, tile_z),
"float32", scope="shared")
+
+ S_local = T.sblock_alloc_buffer((tile_x, tile_z),
"float32", scope="local")
+ O_local = T.sblock_alloc_buffer((tile_x, d),
"float32", scope="local")
+
+ m_smem = T.sblock_alloc_buffer((tile_x,),
"float32", scope="shared")
+ m_prev_smem = T.sblock_alloc_buffer((tile_x,),
"float32", scope="shared")
+ d_smem = T.sblock_alloc_buffer((tile_x,),
"float32", scope="shared")
+
+ m_new = T.sblock_alloc_buffer(
+ (math.ceil(tile_x / (bdx * num_warps)),),
"float32", scope="local"
+ )
+ m_prev = T.sblock_alloc_buffer(
+ (math.ceil(tile_x / (bdx * num_warps)),),
"float32", scope="local"
+ )
+ d_new = T.sblock_alloc_buffer(
+ (math.ceil(tile_x / (bdx * num_warps)),),
"float32", scope="local"
+ )
+
+ b_idx: T.int32 = vbx // batch_tiles
+ valid_len: T.int32 = valid_lens[b_idx]
+ tile_id: T.int32 = vbx % batch_tiles
+ LH_start: T.int32 = tile_id * tile_x
+ T.tvm_storage_sync("shared")
+
+ # init states
+ for i in T.serial(T.ceildiv(tile_x, bdx *
num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty * bdx
+ tx
+ if row < tile_x:
+ m_smem[row] = -5e4
+ d_smem[row] = 1.0
+
+ for li, lj in T.grid(tile_x, tile_y):
+ with T.sblock("O_init"):
+ i, j = T.axis.remap("SS", [li, lj])
+ O_local[i, j] = 0.0
+ T.tvm_storage_sync("shared")
+
+ # Load Q; padded rows are zeroed so they
contribute nothing downstream.
+ for li, lj in T.grid(tile_x, tile_y):
+ with T.sblock("Q_load"):
+ i, j = T.axis.remap("SS", [li, lj])
+ T.reads()
+ T.writes()
+ cur_L = (LH_start + i) // group_size
+ cur_H_qo = by * group_size + (LH_start +
i) % group_size
+ if tirx.And(cur_L < qo_len, cur_L <
valid_len):
+ Q_smem[i, j] = q[b_idx, cur_L,
cur_H_qo, j]
+ else:
+ Q_smem[i, j] = 0.0
+ T.tvm_storage_sync("shared")
+
+ for iterator in T.serial(T.ceildiv(kv_len,
tile_z)):
+ L_kv_start: T.int32 = iterator * tile_z
+ L_kv_base: T.int32 = 0
+ for lz, ly in T.grid(tile_z, tile_y):
+ with T.sblock("K_load"):
+ i, j = T.axis.remap("SS", [lz, ly])
+ T.reads()
+ T.writes()
+ cur_L = L_kv_start + i
+ if tirx.And(cur_L < kv_len, cur_L <
valid_len):
+ K_smem[i, j] = k[b_idx, L_kv_base
+ cur_L, by, j]
+ else:
+ K_smem[i, j] = 0.0
+ T.tvm_storage_sync("shared")
+ for lz, ly in T.grid(tile_z, tile_y):
+ with T.sblock("V_load"):
+ i, j = T.axis.remap("SS", [lz, ly])
+ T.reads()
+ T.writes()
+ cur_L = L_kv_start + i
+ if tirx.And(cur_L < kv_len, cur_L <
valid_len):
+ V_smem[i, j] = v[b_idx, L_kv_base
+ cur_L, by, j]
+ else:
+ V_smem[i, j] = 0.0
+ T.tvm_storage_sync("shared")
+
+ # Compute S
+ with T.sblock():
+ for li, lj, lk in T.grid(tile_x, tile_z,
tile_y):
+ with T.sblock("S_gemm"):
+ i, j, k = T.axis.remap("SSR", [li,
lj, lk])
+ with T.init():
+ S_local[i, j] = 0.0
+ S_local[i, j] += (
+ T.cast(Q_smem[i, k], "float32")
+ * T.cast(K_smem[j, k],
"float32")
+ * sm_scale
+ * math.log2(math.exp(1))
+ )
+ T.tvm_storage_sync("shared")
+ for li, lj in T.grid(tile_x, tile_z):
+ with T.sblock("S_store"):
+ i, j = T.axis.remap("SS", [li, lj])
+ S_smem[i, j] = S_local[i, j]
+ T.tvm_storage_sync("shared")
+
+ # Update S, m, d — use padding mask instead of
causal.
+ for i in T.serial(T.ceildiv(tile_x, bdx *
num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty *
bdx + tx
+ if row < tile_x:
+ with T.sblock("update1"):
+ m_prev[i] = m_smem[row]
+ m_new[i] = m_smem[row]
+ row_: T.int32 = (LH_start + row)
// group_size
+ for j in T.serial(tile_z):
+ if _valid_length_mask(
+ valid_len,
+ row=row_,
+ col=L_kv_start + j,
+ qo_len=qo_len,
+ ):
+ m_new[i] = T.max(
+ m_new[i], S_smem[row,
j]
+ )
+ d_new[i] = d_smem[row] * T.exp2(
+ m_prev[i] - m_new[i]
+ )
+
+ for i in T.serial(T.ceildiv(tile_x, bdx *
num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty *
bdx + tx
+ with T.sblock("update"):
+ for j in T.serial(tile_z):
+ # sync is outside the branch, so
the predicate is inside
+ if row < tile_x:
+ row_: T.int32 = (
+ LH_start + row
+ ) // group_size
+ if _valid_length_mask(
+ valid_len,
+ row=row_,
+ col=L_kv_start + j,
+ qo_len=qo_len,
+ ):
+ S_smem[row, j] = T.exp2(
+ S_smem[row, j] -
m_new[i]
+ )
+ else:
+ S_smem[row, j] =
T.exp2(-5e4 - m_new[i])
+
+ for i in T.serial(T.ceildiv(tile_x, bdx *
num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty *
bdx + tx
+ if row < tile_x:
+ with T.sblock("update"):
+ for j in T.serial(tile_z):
+ d_new[i] += S_smem[row, j]
+ m_smem[row] = m_new[i]
+ d_smem[row] = d_new[i]
+ m_prev_smem[row] = m_prev[i]
+ T.tvm_storage_sync("shared")
+
+ # Update O
+ with T.sblock():
+ for li, lj, lk in T.grid(tile_x, tile_y,
tile_z):
+ with T.sblock("O_gemm"):
+ i, j, k = T.axis.remap("SSR", [li,
lj, lk])
+ with T.init():
+ O_local[i, j] *= T.exp2(
+ m_prev_smem[i] - m_smem[i]
+ )
+ O_local[i, j] += S_smem[i, k] *
T.cast(
+ V_smem[k, j], "float32"
+ )
+
+ # Store O
+ for li, lj in T.grid(tile_x, tile_y):
+ with T.sblock("O_store"):
+ i, j = T.axis.remap("SS", [li, lj])
+ cur_L: T.int32 = 0 + (LH_start + i) //
group_size
+ cur_H_qo: T.int32 = (
+ by * group_size + (LH_start + i) %
group_size
+ )
+ if cur_L < qo_len:
+ output[b_idx, cur_L, cur_H_qo, j] = (
+ O_local[i, j] / d_smem[i]
+ )
+
+ # Store LSE
+ for li in T.grid(tile_x):
+ with T.sblock("lse_store"):
+ i = T.axis.remap("S", [li])
+ cur_L: T.int32 = 0 + (LH_start + i) //
group_size
+ cur_H_qo: T.int32 = (
+ by * group_size + (LH_start + i) %
group_size
+ )
+ if cur_L < qo_len:
+ lse[b_idx, cur_L, cur_H_qo] =
m_smem[i] + T.log2(
+ d_smem[i]
+ )
+
+ # fmt: on
+ sch = tvm.s_tir.Schedule(batch_sequence_prefill_kv_masked)
+ sch = _schedule_prefill_kernel(
+ sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False
+ )
+ return sch.mod["main"].with_attr("tirx.is_scheduled", True)
+
+
def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v, dtype, rope_scaling:
dict[str, Any]):
group_size = h_q // h_kv
diff --git a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
new file mode 100644
index 0000000000..b64ef459d8
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Focused correctness tests for ``_attention_sequence_prefill_with_mask``.
+
+The masked variant is the encoder-style counterpart of
+``_attention_sequence_prefill``: each sample in a padded batch carries its
+own ``valid_len`` and the kernel applies the padding mask inside the QKV
+load path and the online softmax update. These tests cover the four shape
+/ mask regimes that can break the kernel independently of any scheduler
+tuning:
+
+* ``valid_len == 0`` — entire batch row is padding
+* ``valid_len == seq_len`` — full-length row, must match the unmasked kernel
+* mixed ``valid_lens`` — typical encoder batch
+* grouped-query attention — ``h_q > h_kv`` with ``group_size > 1``
+
+The reference is a float32 NumPy implementation of masked softmax attention
+restricted to the valid prefix, so the kernel is only compared on the
+unpadded positions (padded positions are intentionally free to contain
+arbitrary garbage).
+"""
+# ruff: noqa: E501
+import math
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm.relax.frontend.nn.llm.kv_cache import
_attention_sequence_prefill_with_mask
+
+
+def _reference_masked_attention(q, k, v, valid_lens, sm_scale):
+ """NumPy fp32 reference. Only the first ``valid_lens[b]`` rows are
written."""
+ batch, seq_q, h_q, d = q.shape
+ _, seq_kv, h_kv, _ = k.shape
+ group_size = h_q // h_kv
+ out = np.zeros_like(q, dtype=np.float32)
+ q32 = q.astype(np.float32)
+ k32 = k.astype(np.float32)
+ v32 = v.astype(np.float32)
+ for b in range(batch):
+ L = int(valid_lens[b])
+ if L == 0:
+ continue
+ for h in range(h_q):
+ hk = h // group_size
+ qh = q32[b, :L, h, :] # [L, d]
+ kh = k32[b, :L, hk, :] # [L, d]
+ vh = v32[b, :L, hk, :] # [L, d]
+ s = (qh @ kh.T) * sm_scale # [L, L]
+ m = s.max(axis=-1, keepdims=True)
+ e = np.exp(s - m)
+ p = e / e.sum(axis=-1, keepdims=True)
+ out[b, :L, h, :] = p @ vh
+ return out
+
+
+def _build_masked_prefill(h_kv, h_q, d, dtype, target):
+ sm_scale = 1.0 / math.sqrt(d)
+ tir_func = _attention_sequence_prefill_with_mask(
+ h_kv=h_kv,
+ h_q=h_q,
+ d=d,
+ dtype=dtype,
+ target=target,
+ sm_scale=sm_scale,
+ )
+ mod = tvm.IRModule({"main": tir_func})
+ return tvm.tirx.build(mod["main"], target=target), sm_scale
+
+
+def _run_case(
+ *,
+ target,
+ dev,
+ h_kv,
+ h_q,
+ d,
+ batch,
+ seq,
+ valid_lens,
+ dtype="float16",
+ seed=0,
+):
+ target = tvm.target.Target(target)
+ built, sm_scale = _build_masked_prefill(h_kv, h_q, d, dtype, target)
+
+ np_dtype = {"float16": np.float16, "float32": np.float32}[dtype]
+ rng = np.random.default_rng(seed)
+ q_np = (rng.standard_normal((batch, seq, h_q, d)) * 0.1).astype(np_dtype)
+ k_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
+ v_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
+ valid_np = np.asarray(valid_lens, dtype=np.int32)
+ out_np = np.zeros((batch, seq, h_q, d), dtype=np_dtype)
+ lse_np = np.zeros((batch, seq, h_q), dtype=np_dtype)
+
+ q_nd = tvm.runtime.tensor(q_np, device=dev)
+ k_nd = tvm.runtime.tensor(k_np, device=dev)
+ v_nd = tvm.runtime.tensor(v_np, device=dev)
+ valid_nd = tvm.runtime.tensor(valid_np, device=dev)
+ out_nd = tvm.runtime.tensor(out_np, device=dev)
+ lse_nd = tvm.runtime.tensor(lse_np, device=dev)
+
+ built.main(q_nd, k_nd, v_nd, valid_nd, out_nd, lse_nd)
+
+ got = out_nd.numpy().astype(np.float32)
+ ref = _reference_masked_attention(q_np, k_np, v_np, valid_np, sm_scale)
+
+ # Only compare valid rows. Padding rows are undefined by design.
+ rtol, atol = (2e-2, 2e-2) if dtype == "float16" else (1e-4, 1e-4)
+ for b in range(batch):
+ L = int(valid_np[b])
+ if L == 0:
+ continue
+ np.testing.assert_allclose(got[b, :L], ref[b, :L], rtol=rtol,
atol=atol)
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_zero(target, dev):
+ """All samples are fully padded: kernel must not crash and must stay
bounded."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=4,
+ h_q=4,
+ d=64,
+ batch=2,
+ seq=16,
+ valid_lens=[0, 0],
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_full(target, dev):
+ """All samples are fully valid: must match a plain unmasked attention."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=4,
+ h_q=4,
+ d=64,
+ batch=2,
+ seq=32,
+ valid_lens=[32, 32],
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_mixed(target, dev):
+ """Typical encoder batch with different valid lengths per sample."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=4,
+ h_q=4,
+ d=64,
+ batch=4,
+ seq=64,
+ valid_lens=[10, 64, 5, 33],
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_mixed_gqa(target, dev):
+ """Grouped-query attention: ``group_size = h_q / h_kv > 1``."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=2,
+ h_q=4,
+ d=64,
+ batch=3,
+ seq=32,
+ valid_lens=[8, 32, 17],
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()