This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a6e2ea8ac8 [Relax][Frontend][KVCache] Add masked sequence prefill 
helper for encoder valid lengths (#19392)
a6e2ea8ac8 is described below

commit a6e2ea8ac8443e3ecebe0ef59b1ea7e410210458
Author: Xijing Wang <[email protected]>
AuthorDate: Mon Apr 13 21:13:56 2026 -0400

    [Relax][Frontend][KVCache] Add masked sequence prefill helper for encoder 
valid lengths (#19392)
    
    Adds `_attention_sequence_prefill_with_mask` in
    `python/tvm/relax/frontend/nn/llm/kv_cache.py` — a masked variant of the
    existing sequence prefill kernel that supports right-padded encoder
    batches with per-sample `valid_lens`.
    
    The existing `_attention_sequence_prefill` assumes all positions in `[0,
    seq_len)` are valid, which breaks for padded encoder inputs where each
    batch element has a different valid prefix length. This helper adds the
    masking semantics needed for correctness:
    
    - accepts a per-batch `valid_lens` input
    - ignores padded query rows and padded key/value positions
    - excludes padded `(row, col)` pairs from the online softmax update
    
    It reuses the existing prefill kernel config and schedule — no new
    tuning knobs, no target-specific changes, no performance claims.
    Correctness only.
    
    ## Motivation: encoder batch prefill for downstream consumers
    
    This is the TVM-side primitive needed to support **encoder batch
    prefill** in downstream projects like `mlc-llm`, where padded encoder
    batches with `valid_lens` need to be lowered without materializing an
    explicit broadcast attention mask on the host.
    
    The helper is generic and useful for any encoder-style sequence prefill
    consumer with per-sample valid lengths.
---
 python/tvm/relax/frontend/nn/llm/kv_cache.py       | 267 +++++++++++++++++++++
 ...test_frontend_nn_llm_sequence_prefill_masked.py | 197 +++++++++++++++
 2 files changed, 464 insertions(+)

diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index 749707cb29..12cdcbcdce 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -2305,6 +2305,273 @@ def _attention_sequence_prefill(h_kv, h_q, d, dtype, 
target: Target, causal=0, s
     return sch.mod["main"].with_attr("tirx.is_scheduled", True)
 
 
+def _attention_sequence_prefill_with_mask(h_kv, h_q, d, dtype, target: Target, 
sm_scale=1.0):  # pylint: disable=line-too-long
+    """Tiled sequence prefill kernel with a per-batch right-padding mask.
+
+    This is the counterpart of :func:`_attention_sequence_prefill` for batched
+    encoder-style inputs where each sample in the batch is padded to a common
+    ``seq_len`` but only the first ``valid_lens[b]`` tokens carry real content.
+    The kernel takes an extra ``valid_lens`` buffer of shape ``(batch_size,)``
+    and applies the padding mask inside the QKV load path and the online
+    softmax update, so no explicit mask tensor broadcast or additive bias is
+    needed on the host side.
+
+    Semantics: for batch ``b``, positions ``[0, valid_lens[b])`` are real and
+    positions ``[valid_lens[b], seq_len)`` are padding. Padding queries and
+    keys/values are zeroed at load time; padded ``(row, col)`` pairs are
+    excluded from the max/sum of the online softmax via a ``-inf`` slot.
+    """
+    (
+        _,
+        LOAD_VEC,
+        group_size,
+        bdx,
+        num_warps,
+        tile_x,
+        tile_y,
+        tile_z,
+    ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
+
+    def _valid_length_mask(valid_len, row, col, qo_len):
+        """Return True when both the query row and the key col are unpadded."""
+        return tirx.And(
+            tirx.And(row < qo_len, row < valid_len),
+            col < valid_len,
+        )
+
+    # fmt: off
+    @T.prim_func
+    def batch_sequence_prefill_kv_masked(  # pylint: disable=too-many-branches
+        var_q: T.handle, # [batch_size, qo_len, h_q, d]
+        var_k: T.handle, # [batch_size, kv_len, h_kv, d]
+        var_v: T.handle, # [batch_size, kv_len, h_kv, d]
+        var_valid_lens: T.handle, # [batch_size], int32
+        var_output: T.handle, # [batch_size, qo_len, h_q, d]
+        var_lse: T.handle # [batch_size, qo_len, h_q]
+    ):
+        batch_size = T.int32(is_size_var=True)
+        qo_len = T.int32(is_size_var=True)
+        kv_len = T.int32(is_size_var=True)
+        q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)
+        k = T.match_buffer(var_k, (batch_size, kv_len, h_kv, d), dtype)
+        v = T.match_buffer(var_v, (batch_size, kv_len, h_kv, d), dtype)
+        valid_lens = T.match_buffer(var_valid_lens, (batch_size,), "int32")
+        output = T.match_buffer(var_output, (batch_size, qo_len, h_q, d), 
dtype)
+        lse = T.match_buffer(var_lse, (batch_size, qo_len, h_q), dtype)
+
+        batch_tiles: T.int32 = T.ceildiv(qo_len * group_size, tile_x)
+
+        for lbx in T.thread_binding(T.cast(batch_size, "int32") * batch_tiles, 
thread="blockIdx.x"):
+            for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
+                for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
+                    for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
+                        with T.sblock("attn"):
+                            vbx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, 
lty, ltx])
+                            T.reads()
+                            T.writes()
+
+                            Q_smem = T.sblock_alloc_buffer((tile_x, d), dtype, 
scope="shared")
+                            K_smem = T.sblock_alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            V_smem = T.sblock_alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            S_smem = T.sblock_alloc_buffer((tile_x, tile_z), 
"float32", scope="shared")
+
+                            S_local = T.sblock_alloc_buffer((tile_x, tile_z), 
"float32", scope="local")
+                            O_local = T.sblock_alloc_buffer((tile_x, d), 
"float32", scope="local")
+
+                            m_smem = T.sblock_alloc_buffer((tile_x,), 
"float32", scope="shared")
+                            m_prev_smem = T.sblock_alloc_buffer((tile_x,), 
"float32", scope="shared")
+                            d_smem = T.sblock_alloc_buffer((tile_x,), 
"float32", scope="shared")
+
+                            m_new = T.sblock_alloc_buffer(
+                                (math.ceil(tile_x / (bdx * num_warps)),), 
"float32", scope="local"
+                            )
+                            m_prev = T.sblock_alloc_buffer(
+                                (math.ceil(tile_x / (bdx * num_warps)),), 
"float32", scope="local"
+                            )
+                            d_new = T.sblock_alloc_buffer(
+                                (math.ceil(tile_x / (bdx * num_warps)),), 
"float32", scope="local"
+                            )
+
+                            b_idx: T.int32 = vbx // batch_tiles
+                            valid_len: T.int32 = valid_lens[b_idx]
+                            tile_id: T.int32 = vbx % batch_tiles
+                            LH_start: T.int32 = tile_id * tile_x
+                            T.tvm_storage_sync("shared")
+
+                            # init states
+                            for i in T.serial(T.ceildiv(tile_x, bdx * 
num_warps)):
+                                row: T.int32 = i * bdx * num_warps + ty * bdx 
+ tx
+                                if row < tile_x:
+                                    m_smem[row] = -5e4
+                                    d_smem[row] = 1.0
+
+                            for li, lj in T.grid(tile_x, tile_y):
+                                with T.sblock("O_init"):
+                                    i, j = T.axis.remap("SS", [li, lj])
+                                    O_local[i, j] = 0.0
+                            T.tvm_storage_sync("shared")
+
+                            # Load Q; padded rows are zeroed so they 
contribute nothing downstream.
+                            for li, lj in T.grid(tile_x, tile_y):
+                                with T.sblock("Q_load"):
+                                    i, j = T.axis.remap("SS", [li, lj])
+                                    T.reads()
+                                    T.writes()
+                                    cur_L = (LH_start + i) // group_size
+                                    cur_H_qo = by * group_size + (LH_start + 
i) % group_size
+                                    if tirx.And(cur_L < qo_len, cur_L < 
valid_len):
+                                        Q_smem[i, j] = q[b_idx, cur_L, 
cur_H_qo, j]
+                                    else:
+                                        Q_smem[i, j] = 0.0
+                            T.tvm_storage_sync("shared")
+
+                            for iterator in T.serial(T.ceildiv(kv_len, 
tile_z)):
+                                L_kv_start: T.int32 = iterator * tile_z
+                                L_kv_base: T.int32 = 0
+                                for lz, ly in T.grid(tile_z, tile_y):
+                                    with T.sblock("K_load"):
+                                        i, j = T.axis.remap("SS", [lz, ly])
+                                        T.reads()
+                                        T.writes()
+                                        cur_L = L_kv_start + i
+                                        if tirx.And(cur_L < kv_len, cur_L < 
valid_len):
+                                            K_smem[i, j] = k[b_idx, L_kv_base 
+ cur_L, by, j]
+                                        else:
+                                            K_smem[i, j] = 0.0
+                                T.tvm_storage_sync("shared")
+                                for lz, ly in T.grid(tile_z, tile_y):
+                                    with T.sblock("V_load"):
+                                        i, j = T.axis.remap("SS", [lz, ly])
+                                        T.reads()
+                                        T.writes()
+                                        cur_L = L_kv_start + i
+                                        if tirx.And(cur_L < kv_len, cur_L < 
valid_len):
+                                            V_smem[i, j] = v[b_idx, L_kv_base 
+ cur_L, by, j]
+                                        else:
+                                            V_smem[i, j] = 0.0
+                                T.tvm_storage_sync("shared")
+
+                                # Compute S
+                                with T.sblock():
+                                    for li, lj, lk in T.grid(tile_x, tile_z, 
tile_y):
+                                        with T.sblock("S_gemm"):
+                                            i, j, k = T.axis.remap("SSR", [li, 
lj, lk])
+                                            with T.init():
+                                                S_local[i, j] = 0.0
+                                            S_local[i, j] += (
+                                                T.cast(Q_smem[i, k], "float32")
+                                                * T.cast(K_smem[j, k], 
"float32")
+                                                * sm_scale
+                                                * math.log2(math.exp(1))
+                                            )
+                                T.tvm_storage_sync("shared")
+                                for li, lj in T.grid(tile_x, tile_z):
+                                    with T.sblock("S_store"):
+                                        i, j = T.axis.remap("SS", [li, lj])
+                                        S_smem[i, j] = S_local[i, j]
+                                T.tvm_storage_sync("shared")
+
+                                # Update S, m, d — use padding mask instead of 
causal.
+                                for i in T.serial(T.ceildiv(tile_x, bdx * 
num_warps)):
+                                    row: T.int32 = i * bdx * num_warps + ty * 
bdx + tx
+                                    if row < tile_x:
+                                        with T.sblock("update1"):
+                                            m_prev[i] = m_smem[row]
+                                            m_new[i] = m_smem[row]
+                                            row_: T.int32 = (LH_start + row) 
// group_size
+                                            for j in T.serial(tile_z):
+                                                if _valid_length_mask(
+                                                    valid_len,
+                                                    row=row_,
+                                                    col=L_kv_start + j,
+                                                    qo_len=qo_len,
+                                                ):
+                                                    m_new[i] = T.max(
+                                                        m_new[i], S_smem[row, 
j]
+                                                    )
+                                            d_new[i] = d_smem[row] * T.exp2(
+                                                m_prev[i] - m_new[i]
+                                            )
+
+                                for i in T.serial(T.ceildiv(tile_x, bdx * 
num_warps)):
+                                    row: T.int32 = i * bdx * num_warps + ty * 
bdx + tx
+                                    with T.sblock("update"):
+                                        for j in T.serial(tile_z):
+                                            # sync is outside the branch, so 
the predicate is inside
+                                            if row < tile_x:
+                                                row_: T.int32 = (
+                                                    LH_start + row
+                                                ) // group_size
+                                                if _valid_length_mask(
+                                                    valid_len,
+                                                    row=row_,
+                                                    col=L_kv_start + j,
+                                                    qo_len=qo_len,
+                                                ):
+                                                    S_smem[row, j] = T.exp2(
+                                                        S_smem[row, j] - 
m_new[i]
+                                                    )
+                                                else:
+                                                    S_smem[row, j] = 
T.exp2(-5e4 - m_new[i])
+
+                                for i in T.serial(T.ceildiv(tile_x, bdx * 
num_warps)):
+                                    row: T.int32 = i * bdx * num_warps + ty * 
bdx + tx
+                                    if row < tile_x:
+                                        with T.sblock("update"):
+                                            for j in T.serial(tile_z):
+                                                d_new[i] += S_smem[row, j]
+                                            m_smem[row] = m_new[i]
+                                            d_smem[row] = d_new[i]
+                                            m_prev_smem[row] = m_prev[i]
+                                T.tvm_storage_sync("shared")
+
+                                # Update O
+                                with T.sblock():
+                                    for li, lj, lk in T.grid(tile_x, tile_y, 
tile_z):
+                                        with T.sblock("O_gemm"):
+                                            i, j, k = T.axis.remap("SSR", [li, 
lj, lk])
+                                            with T.init():
+                                                O_local[i, j] *= T.exp2(
+                                                    m_prev_smem[i] - m_smem[i]
+                                                )
+                                            O_local[i, j] += S_smem[i, k] * 
T.cast(
+                                                V_smem[k, j], "float32"
+                                            )
+
+                            # Store O
+                            for li, lj in T.grid(tile_x, tile_y):
+                                with T.sblock("O_store"):
+                                    i, j = T.axis.remap("SS", [li, lj])
+                                    cur_L: T.int32 = 0 + (LH_start + i) // 
group_size
+                                    cur_H_qo: T.int32 = (
+                                        by * group_size + (LH_start + i) % 
group_size
+                                    )
+                                    if cur_L < qo_len:
+                                        output[b_idx, cur_L, cur_H_qo, j] = (
+                                            O_local[i, j] / d_smem[i]
+                                        )
+
+                            # Store LSE
+                            for li in T.grid(tile_x):
+                                with T.sblock("lse_store"):
+                                    i = T.axis.remap("S", [li])
+                                    cur_L: T.int32 = 0 + (LH_start + i) // 
group_size
+                                    cur_H_qo: T.int32 = (
+                                        by * group_size + (LH_start + i) % 
group_size
+                                    )
+                                    if cur_L < qo_len:
+                                        lse[b_idx, cur_L, cur_H_qo] = 
m_smem[i] + T.log2(
+                                            d_smem[i]
+                                        )
+
+    # fmt: on
+    sch = tvm.s_tir.Schedule(batch_sequence_prefill_kv_masked)
+    sch = _schedule_prefill_kernel(
+        sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False
+    )
+    return sch.mod["main"].with_attr("tirx.is_scheduled", True)
+
+
 def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v, dtype, rope_scaling: 
dict[str, Any]):
     group_size = h_q // h_kv
 
diff --git a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py 
b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
new file mode 100644
index 0000000000..b64ef459d8
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Focused correctness tests for ``_attention_sequence_prefill_with_mask``.
+
+The masked variant is the encoder-style counterpart of
+``_attention_sequence_prefill``: each sample in a padded batch carries its
+own ``valid_len`` and the kernel applies the padding mask inside the QKV
+load path and the online softmax update. These tests cover the four shape
+/ mask regimes that can break the kernel independently of any scheduler
+tuning:
+
+* ``valid_len == 0``       — entire batch row is padding
+* ``valid_len == seq_len`` — full-length row, must match the unmasked kernel
+* mixed ``valid_lens``     — typical encoder batch
+* grouped-query attention  — ``h_q > h_kv`` with ``group_size > 1``
+
+The reference is a float32 NumPy implementation of masked softmax attention
+restricted to the valid prefix, so the kernel is only compared on the
+unpadded positions (padded positions are intentionally free to contain
+arbitrary garbage).
+"""
+# ruff: noqa: E501
+import math
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm.relax.frontend.nn.llm.kv_cache import 
_attention_sequence_prefill_with_mask
+
+
+def _reference_masked_attention(q, k, v, valid_lens, sm_scale):
+    """NumPy fp32 reference. Only the first ``valid_lens[b]`` rows are 
written."""
+    batch, seq_q, h_q, d = q.shape
+    _, seq_kv, h_kv, _ = k.shape
+    group_size = h_q // h_kv
+    out = np.zeros_like(q, dtype=np.float32)
+    q32 = q.astype(np.float32)
+    k32 = k.astype(np.float32)
+    v32 = v.astype(np.float32)
+    for b in range(batch):
+        L = int(valid_lens[b])
+        if L == 0:
+            continue
+        for h in range(h_q):
+            hk = h // group_size
+            qh = q32[b, :L, h, :]  # [L, d]
+            kh = k32[b, :L, hk, :]  # [L, d]
+            vh = v32[b, :L, hk, :]  # [L, d]
+            s = (qh @ kh.T) * sm_scale  # [L, L]
+            m = s.max(axis=-1, keepdims=True)
+            e = np.exp(s - m)
+            p = e / e.sum(axis=-1, keepdims=True)
+            out[b, :L, h, :] = p @ vh
+    return out
+
+
+def _build_masked_prefill(h_kv, h_q, d, dtype, target):
+    sm_scale = 1.0 / math.sqrt(d)
+    tir_func = _attention_sequence_prefill_with_mask(
+        h_kv=h_kv,
+        h_q=h_q,
+        d=d,
+        dtype=dtype,
+        target=target,
+        sm_scale=sm_scale,
+    )
+    mod = tvm.IRModule({"main": tir_func})
+    return tvm.tirx.build(mod["main"], target=target), sm_scale
+
+
+def _run_case(
+    *,
+    target,
+    dev,
+    h_kv,
+    h_q,
+    d,
+    batch,
+    seq,
+    valid_lens,
+    dtype="float16",
+    seed=0,
+):
+    target = tvm.target.Target(target)
+    built, sm_scale = _build_masked_prefill(h_kv, h_q, d, dtype, target)
+
+    np_dtype = {"float16": np.float16, "float32": np.float32}[dtype]
+    rng = np.random.default_rng(seed)
+    q_np = (rng.standard_normal((batch, seq, h_q, d)) * 0.1).astype(np_dtype)
+    k_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
+    v_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
+    valid_np = np.asarray(valid_lens, dtype=np.int32)
+    out_np = np.zeros((batch, seq, h_q, d), dtype=np_dtype)
+    lse_np = np.zeros((batch, seq, h_q), dtype=np_dtype)
+
+    q_nd = tvm.runtime.tensor(q_np, device=dev)
+    k_nd = tvm.runtime.tensor(k_np, device=dev)
+    v_nd = tvm.runtime.tensor(v_np, device=dev)
+    valid_nd = tvm.runtime.tensor(valid_np, device=dev)
+    out_nd = tvm.runtime.tensor(out_np, device=dev)
+    lse_nd = tvm.runtime.tensor(lse_np, device=dev)
+
+    built.main(q_nd, k_nd, v_nd, valid_nd, out_nd, lse_nd)
+
+    got = out_nd.numpy().astype(np.float32)
+    ref = _reference_masked_attention(q_np, k_np, v_np, valid_np, sm_scale)
+
+    # Only compare valid rows. Padding rows are undefined by design.
+    rtol, atol = (2e-2, 2e-2) if dtype == "float16" else (1e-4, 1e-4)
+    for b in range(batch):
+        L = int(valid_np[b])
+        if L == 0:
+            continue
+        np.testing.assert_allclose(got[b, :L], ref[b, :L], rtol=rtol, 
atol=atol)
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_zero(target, dev):
+    """All samples are fully padded: kernel must not crash and must stay 
bounded."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=4,
+        h_q=4,
+        d=64,
+        batch=2,
+        seq=16,
+        valid_lens=[0, 0],
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_full(target, dev):
+    """All samples are fully valid: must match a plain unmasked attention."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=4,
+        h_q=4,
+        d=64,
+        batch=2,
+        seq=32,
+        valid_lens=[32, 32],
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_mixed(target, dev):
+    """Typical encoder batch with different valid lengths per sample."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=4,
+        h_q=4,
+        d=64,
+        batch=4,
+        seq=64,
+        valid_lens=[10, 64, 5, 33],
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_valid_len_mixed_gqa(target, dev):
+    """Grouped-query attention: ``group_size = h_q / h_kv > 1``."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=2,
+        h_q=4,
+        d=64,
+        batch=3,
+        seq=32,
+        valid_lens=[8, 32, 17],
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to