This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5dda0b3a3a [DLight] Add CPU Reduction schedule rule for softmax-like
operators (#19374)
5dda0b3a3a is described below
commit 5dda0b3a3a5378514e556d063887be7372ed5d06
Author: Soowon Jeong <[email protected]>
AuthorDate: Sat Apr 11 03:44:17 2026 +0900
[DLight] Add CPU Reduction schedule rule for softmax-like operators (#19374)
This PR resolves part of #18569
## Description
This PR adds a DLight CPU schedule rule targeting reduction patterns
(softmax, layer norm, RMS norm) that previously had no CPU-specific
schedule.
Without this rule, LLVM auto-vectorization produces suboptimal code for
RVV targets — #18569 reports RVV softmax is **1.34x slower than
scalar**.
The root cause is:
- LLVM fully unrolls the 185-element reduction loop
- Generates harmful fixed-width vectors (`<8 x float>` = 256-bit) on a
128-bit VLEN target
- No parallelization of the batch axis
### Schedule strategy
1. **Parallelize** leading spatial axes (batch dimension)
2. **Compute-at** all blocks under the spatial loop for data locality
3. **Vectorize** injective blocks (exp, delta, norm) on inner axis
4. **Split + unroll** reduction inner axis to VLEN-sized chunks
### Results (shape=(14, 185), float32, `fast_softmax`)
| Config | ASM Instructions | Vector Insns | LLVM IR Lines |
|--------|:---:|:---:|:---:|
| RV scalar (baseline) | 1,463 | 0 | 1,792 |
| RVV unscheduled (**the bug**) | 3,282 | 960 | 2,345 |
| RVV + this schedule | **1,111** | **105** | **1,338** |
- **66% fewer instructions** vs unscheduled RVV
- **24% fewer instructions** vs scalar baseline
- `fast_softmax` polynomial exp fully vectorizes into RVV
`vfsub`/`vfmul`/`vfmax` instructions with zero scalar `exp()` calls
### Limitations (follow-up work)
- Reduction blocks (max, sum) use split+unroll rather than vectorized
partial reduction via `rfactor`, because TVM's `rfactor` primitive
requires the reduction block to be the first child of its enclosing
loop — incompatible with `compute_at` when multiple blocks share one
spatial loop. A follow-up PR will register RVV reduction intrinsics
(`vfredmax`/`vfredusum`) and use `tensorize` to vectorize reductions.
## Testing
```bash
pytest tests/python/s_tir/dlight/test_cpu_reduction.py -v
```
- 14 shape x operator applicability tests
- 2 TIR structure verification tests
- 4 RVV codegen quality tests (code size, exp vectorization, instruction
count)
- Existing `test_cpu_gemv.py` unaffected (10/10 pass)
---
python/tvm/s_tir/dlight/cpu/__init__.py | 1 +
python/tvm/s_tir/dlight/cpu/reduction.py | 153 ++++++++++++++
tests/python/s_tir/dlight/test_cpu_reduction.py | 270 ++++++++++++++++++++++++
3 files changed, 424 insertions(+)
diff --git a/python/tvm/s_tir/dlight/cpu/__init__.py
b/python/tvm/s_tir/dlight/cpu/__init__.py
index 8743c616bb..20e1e9a3b8 100644
--- a/python/tvm/s_tir/dlight/cpu/__init__.py
+++ b/python/tvm/s_tir/dlight/cpu/__init__.py
@@ -20,3 +20,4 @@ CPU-generic schedule rules.
"""
from .gemv import GEMV
+from .reduction import Reduction
diff --git a/python/tvm/s_tir/dlight/cpu/reduction.py
b/python/tvm/s_tir/dlight/cpu/reduction.py
new file mode 100644
index 0000000000..2e804f9537
--- /dev/null
+++ b/python/tvm/s_tir/dlight/cpu/reduction.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""CPU reduction rule for operators including softmax, layer norm, RMS norm,
etc."""
+
+from tvm import DataType, s_tir, tirx
+from tvm.target import Target
+from tvm.target.codegen import llvm_get_vector_width
+
+from ..analysis import normalize_prim_func
+from ..base import get_extent
+from .base import CPUScheduleRule
+
+
+def _get_num_leading_s(dom_kind: str) -> int:
+ """Count leading spatial ('S') axes in a dom_kind string."""
+ return len(dom_kind) - len(dom_kind.lstrip("S"))
+
+
+class Reduction(CPUScheduleRule):
+ """CPU reduction rule for softmax, layer norm, RMS norm, and similar
operators.
+
+ Targets patterns with a mix of reduction (SR) and injective (SS) blocks,
+ where all blocks share the same leading spatial axes.
+ Example: softmax = maxelem(SR) -> exp(SS) -> expsum(SR) -> norm(SS).
+
+ Schedule strategy:
+ 1. Parallelize leading spatial axes (batch dimension).
+ 2. Move all blocks under the spatial loop via compute_at.
+ 3. Vectorize injective blocks (exp, delta, norm) on their inner axis.
+ 4. Split reduction inner axis to VLEN-sized chunks and annotate for
+ LLVM unrolling, preventing harmful full-unroll by the backend.
+
+ Note: vectorized reduction via rfactor is not used here because TVM's
+ rfactor primitive requires the reduction block to be the first child of
+ its enclosing loop, which is incompatible with compute_at when multiple
+ blocks share the same spatial loop. A follow-up using RVV reduction
+ intrinsics (vfredmax/vfredusum) via tensorize can address this.
+ """
+
+ def apply( # pylint:
disable=too-many-locals,too-many-return-statements,too-many-branches
+ self,
+ func: tirx.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> None | s_tir.Schedule | list[s_tir.Schedule]:
+ if not isinstance(func, tirx.PrimFunc) or not
self.is_target_available(target):
+ return None
+
+ sch = s_tir.Schedule(func)
+ block_infos = normalize_prim_func(sch)
+ if block_infos is None or len(block_infos) < 2:
+ return None
+
+ # Must have at least one reduction block and last block must be
injective.
+ if not any(not bi.is_injective() for bi in block_infos):
+ return None
+ if not block_infos[-1].is_injective():
+ return None
+
+ # Every block must start with at least one spatial axis, and all blocks
+ # must agree on the minimum number of leading spatial axes.
+ num_leading_s = None
+ for bi in block_infos:
+ dk = bi.dom_kind()
+ if not dk or dk[0] != "S":
+ return None
+ n = _get_num_leading_s(dk)
+ num_leading_s = n if num_leading_s is None else min(num_leading_s,
n)
+ if not num_leading_s:
+ return None
+
+ # Infer dtype from the last block's write buffer.
+ last_block_stmt = sch.get(block_infos[-1].block_rv)
+ dtype_bits = (
+ DataType(last_block_stmt.writes[0].buffer.dtype).bits if
last_block_stmt.writes else 32
+ )
+
+ # Determine vector lanes from target VLEN.
+ vlen_bits = llvm_get_vector_width(target)
+ if vlen_bits <= 0:
+ vlen_bits = 128
+ vec_lanes = max(vlen_bits // dtype_bits, 2)
+
+ # --- Phase 1: Parallelize spatial on the last block ---
+ last_block = block_infos[-1]
+ loops = sch.get_loops(last_block.block_rv)
+ if num_leading_s > 1:
+ spatial = sch.fuse(*loops[:num_leading_s])
+ else:
+ spatial = loops[0]
+ sch.parallel(spatial)
+
+ # --- Phase 2: Vectorize the last (injective) block ---
+ self._vectorize_inner(sch, last_block.block_rv, vec_lanes)
+
+ # --- Phase 3: compute_at all preceding blocks under spatial ---
+ for block_info in reversed(block_infos[:-1]):
+ sch.compute_at(block_info.block_rv, spatial,
preserve_unit_loops=True)
+
+ # --- Phase 4: Vectorize injective, split+unroll reduction blocks ---
+ for block_info in block_infos[:-1]:
+ if block_info.is_injective():
+ self._vectorize_inner(sch, block_info.block_rv, vec_lanes)
+ else:
+ self._unroll_reduction_inner(sch, block_info.block_rv,
vec_lanes)
+
+ return sch
+
+ @staticmethod
+ def _vectorize_inner(sch, block_rv, vec_lanes):
+ """Split the innermost loop to vec_lanes and vectorize."""
+ block_loops = sch.get_loops(block_rv)
+ if len(block_loops) <= 1:
+ return
+ inner = block_loops[-1]
+ extent = get_extent(sch, inner)
+ if isinstance(extent, int):
+ if extent > vec_lanes:
+ _, vec_loop = sch.split(inner, factors=[None, vec_lanes])
+ sch.vectorize(vec_loop)
+ elif extent >= 2:
+ sch.vectorize(inner)
+ else:
+ _, vec_loop = sch.split(inner, factors=[None, vec_lanes])
+ sch.vectorize(vec_loop)
+
+ @staticmethod
+ def _unroll_reduction_inner(sch, block_rv, vec_lanes):
+ """Split the reduction inner loop and annotate for unrolling."""
+ block_loops = sch.get_loops(block_rv)
+ if len(block_loops) <= 1:
+ return
+ inner = block_loops[-1]
+ extent = get_extent(sch, inner)
+ if isinstance(extent, int) and extent <= vec_lanes:
+ return
+ _, inner_loop = sch.split(inner, factors=[None, vec_lanes])
+ sch.annotate(inner_loop, ann_key="pragma_auto_unroll_max_step",
ann_val=vec_lanes)
+ sch.annotate(inner_loop, ann_key="pragma_unroll_explicit", ann_val=1)
diff --git a/tests/python/s_tir/dlight/test_cpu_reduction.py
b/tests/python/s_tir/dlight/test_cpu_reduction.py
new file mode 100644
index 0000000000..db8280a61a
--- /dev/null
+++ b/tests/python/s_tir/dlight/test_cpu_reduction.py
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""Tests for CPU DLight Reduction schedule rule."""
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te, tirx, topi
+from tvm.s_tir import dlight as dl
+from tvm.s_tir.dlight.cpu import Reduction
+from tvm.target import Target
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _llvm_target():
+ return Target({"kind": "llvm"})
+
+
+def _rvv_target():
+ return Target(
+ {
+ "kind": "llvm",
+ "mtriple": "riscv64-linux-gnu",
+ "mcpu": "generic-rv64",
+ "mabi": "lp64d",
+ "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c", "+v"],
+ }
+ )
+
+
+def _build_softmax(batch, features, fast=False):
+ A = te.placeholder((batch, features), dtype="float32", name="A")
+ B = topi.nn.fast_softmax(A, axis=1) if fast else topi.nn.softmax(A, axis=1)
+ func = te.create_prim_func([A, B])
+ return tvm.IRModule({"main": func})
+
+
+def _apply_and_check(mod, target):
+ """Apply Reduction rule and verify it was applied."""
+ rule = Reduction()
+ result = rule.apply(mod["main"], target, False)
+ assert result is not None, "Reduction rule should apply"
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Test: schedule applicability
+# ---------------------------------------------------------------------------
+
+
[email protected]("fast", [False, True], ids=["softmax",
"fast_softmax"])
[email protected](
+ "batch,features",
+ [
+ (1, 10),
+ (1, 128),
+ (14, 185),
+ (32, 256),
+ (64, 512),
+ (128, 1024),
+ (1, 30522),
+ ],
+)
+def test_reduction_applies(batch, features, fast):
+ """Reduction rule should apply to softmax/fast_softmax of various
shapes."""
+ mod = _build_softmax(batch, features, fast=fast)
+ target = _llvm_target()
+ _apply_and_check(mod, target)
+
+
+# ---------------------------------------------------------------------------
+# Test: scheduled TIR structure
+# ---------------------------------------------------------------------------
+
+
+def test_softmax_schedule_structure():
+ """Verify the scheduled TIR has expected structure:
+ - parallel on batch axis
+ - vectorized innermost loops for injective blocks
+ - split+unroll for reduction blocks
+ """
+ mod = _build_softmax(14, 185, fast=False)
+ target = _llvm_target()
+ sch = _apply_and_check(mod, target)
+ scheduled_mod = sch.mod
+
+ # Check that tirx.is_scheduled is NOT set (only set by
ApplyDefaultSchedule)
+ # but the schedule should be valid
+ assert scheduled_mod is not None
+
+ # Verify via ApplyDefaultSchedule path
+ with target:
+ scheduled = dl.ApplyDefaultSchedule(Reduction())(mod)
+ func = scheduled["main"]
+
+ # Check tirx.is_scheduled is set
+ assert func.attrs and func.attrs.get("tirx.is_scheduled", False)
+
+
+def test_fast_softmax_schedule_structure():
+ """fast_softmax should keep T_fast_exp as a separate vectorizable block."""
+ mod = _build_softmax(14, 185, fast=True)
+ target = _llvm_target()
+ sch = _apply_and_check(mod, target)
+ script = str(sch.mod)
+
+ # fast_exp block should exist (not inlined)
+ assert "T_fast_exp" in script or "T_softmax_delta" in script
+ # Should have T.parallel
+ assert "T.parallel" in script
+ # Should have T.vectorized
+ assert "T.vectorized" in script
+
+
+# ---------------------------------------------------------------------------
+# Test: LLVM IR quality (cross-compile to RISC-V RVV)
+# ---------------------------------------------------------------------------
+
+
+def _codegen_llvm_ir(mod, target):
+ """Lower and codegen to LLVM IR (no linking)."""
+ bound = tirx.transform.BindTarget(target.with_host(target))(mod)
+ pipeline = tirx.get_tir_pipeline("default")
+ lowered = pipeline(bound)
+ from tvm.tirx.build import split_host_device_mods
+
+ host_mod, _ = split_host_device_mods(lowered)
+ host_mod = tirx.pipeline.finalize_host_passes()(host_mod)
+ built = tvm.target.codegen.build_module(host_mod, target)
+ return built.inspect_source("ll")
+
+
+def _codegen_asm(mod, target):
+ """Lower and codegen to assembly (no linking)."""
+ bound = tirx.transform.BindTarget(target.with_host(target))(mod)
+ pipeline = tirx.get_tir_pipeline("default")
+ lowered = pipeline(bound)
+ from tvm.tirx.build import split_host_device_mods
+
+ host_mod, _ = split_host_device_mods(lowered)
+ host_mod = tirx.pipeline.finalize_host_passes()(host_mod)
+ built = tvm.target.codegen.build_module(host_mod, target)
+ return built.inspect_source("s")
+
+
[email protected]("fast", [False, True], ids=["softmax",
"fast_softmax"])
+def test_rvv_code_size_reduction(fast):
+ """Scheduled RVV code should be smaller than unscheduled.
+
+ The original issue (apache/tvm#18569) shows RVV softmax is 1.34x slower
+ than scalar, partly due to LLVM generating bloated code with excessive
+ unrolling. The schedule should reduce code size significantly.
+ """
+ target = _rvv_target()
+ mod = _build_softmax(14, 185, fast=fast)
+
+ # Unscheduled
+ ir_unsched = _codegen_llvm_ir(mod, target)
+ n_unsched = len(ir_unsched.splitlines())
+
+ # Scheduled
+ with target:
+ mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod)
+ ir_sched = _codegen_llvm_ir(mod_sched, target)
+ n_sched = len(ir_sched.splitlines())
+
+ # Scheduled should be meaningfully smaller (at least 30% reduction)
+ ratio = n_sched / n_unsched
+ assert ratio < 0.75, (
+ f"Expected >=25% code reduction, got {(1 - ratio) * 100:.1f}% "
+ f"({n_unsched} -> {n_sched} lines)"
+ )
+
+
+def test_rvv_fast_softmax_vectorizes_exp():
+ """fast_softmax + schedule should produce RVV vector instructions
+ for the polynomial exp approximation (no scalar exp calls)."""
+ target = _rvv_target()
+ mod = _build_softmax(14, 185, fast=True)
+ with target:
+ mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod)
+ ir = _codegen_llvm_ir(mod_sched, target)
+
+ # Should have zero scalar exp calls (fast_exp uses polynomial)
+ scalar_exp = sum(1 for line in ir.splitlines() if "llvm.exp.f32" in line)
+ assert scalar_exp == 0, f"Expected 0 scalar exp calls, got {scalar_exp}"
+
+ # Should have scalable vector operations
+ n_svec = ir.count("<vscale x")
+ assert n_svec > 0, "Expected scalable vector operations in LLVM IR"
+
+
+def test_rvv_asm_instruction_reduction():
+ """Scheduled RVV assembly should have fewer total instructions
+ than both unscheduled RVV and scalar RV."""
+ rvv = _rvv_target()
+ rv = Target(
+ {
+ "kind": "llvm",
+ "mtriple": "riscv64-linux-gnu",
+ "mcpu": "generic-rv64",
+ "mabi": "lp64d",
+ "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c"],
+ }
+ )
+
+ mod = _build_softmax(14, 185, fast=True)
+
+ # Scalar baseline
+ asm_rv = _codegen_asm(mod, rv)
+ n_rv = len(
+ [
+ line
+ for line in asm_rv.splitlines()
+ if line.strip() and not line.strip().startswith((".", "#", "/"))
+ ]
+ )
+
+ # RVV unscheduled
+ asm_rvv = _codegen_asm(mod, rvv)
+ n_rvv = len(
+ [
+ line
+ for line in asm_rvv.splitlines()
+ if line.strip() and not line.strip().startswith((".", "#", "/"))
+ ]
+ )
+
+ # RVV scheduled
+ with rvv:
+ mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod)
+ asm_sched = _codegen_asm(mod_sched, rvv)
+ n_sched = len(
+ [
+ line
+ for line in asm_sched.splitlines()
+ if line.strip() and not line.strip().startswith((".", "#", "/"))
+ ]
+ )
+
+ # Scheduled should be smaller than both unscheduled RVV and scalar
+ assert n_sched < n_rvv, (
+ f"Scheduled ({n_sched}) should have fewer instructions than
unscheduled RVV ({n_rvv})"
+ )
+ assert n_sched <= n_rv * 1.1, (
+ f"Scheduled ({n_sched}) should not be much larger than scalar RV
({n_rv})"
+ )
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])