This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5dda0b3a3a [DLight] Add CPU Reduction schedule rule for softmax-like 
operators (#19374)
5dda0b3a3a is described below

commit 5dda0b3a3a5378514e556d063887be7372ed5d06
Author: Soowon Jeong <[email protected]>
AuthorDate: Sat Apr 11 03:44:17 2026 +0900

    [DLight] Add CPU Reduction schedule rule for softmax-like operators (#19374)
    
    This PR resolves part of #18569
    
    ## Description
    
    This PR adds a DLight CPU schedule rule targeting reduction patterns
    (softmax, layer norm, RMS norm) that previously had no CPU-specific
    schedule.
    
    Without this rule, LLVM auto-vectorization produces suboptimal code for
    RVV targets — #18569 reports RVV softmax is **1.34x slower than
    scalar**.
    The root cause is:
    - LLVM fully unrolls the 185-element reduction loop
    - Generates harmful fixed-width vectors (`<8 x float>` = 256-bit) on a
      128-bit VLEN target
    - No parallelization of the batch axis
    
    ### Schedule strategy
    
    1. **Parallelize** leading spatial axes (batch dimension)
    2. **Compute-at** all blocks under the spatial loop for data locality
    3. **Vectorize** injective blocks (exp, delta, norm) on inner axis
    4. **Split + unroll** reduction inner axis to VLEN-sized chunks
    
    ### Results (shape=(14, 185), float32, `fast_softmax`)
    
    | Config | ASM Instructions | Vector Insns | LLVM IR Lines |
    |--------|:---:|:---:|:---:|
    | RV scalar (baseline) | 1,463 | 0 | 1,792 |
    | RVV unscheduled (**the bug**) | 3,282 | 960 | 2,345 |
    | RVV + this schedule | **1,111** | **105** | **1,338** |
    
    - **66% fewer instructions** vs unscheduled RVV
    - **24% fewer instructions** vs scalar baseline
    - `fast_softmax` polynomial exp fully vectorizes into RVV
    `vfsub`/`vfmul`/`vfmax` instructions with zero scalar `exp()` calls
    
    ### Limitations (follow-up work)
    
    - Reduction blocks (max, sum) use split+unroll rather than vectorized
      partial reduction via `rfactor`, because TVM's `rfactor` primitive
      requires the reduction block to be the first child of its enclosing
      loop — incompatible with `compute_at` when multiple blocks share one
      spatial loop. A follow-up PR will register RVV reduction intrinsics
      (`vfredmax`/`vfredusum`) and use `tensorize` to vectorize reductions.
    
    ## Testing
    
    ```bash
    pytest tests/python/s_tir/dlight/test_cpu_reduction.py -v
    ```
    
    - 14 shape x operator applicability tests
    - 2 TIR structure verification tests
    - 4 RVV codegen quality tests (code size, exp vectorization, instruction
    count)
    - Existing `test_cpu_gemv.py` unaffected (10/10 pass)
---
 python/tvm/s_tir/dlight/cpu/__init__.py         |   1 +
 python/tvm/s_tir/dlight/cpu/reduction.py        | 153 ++++++++++++++
 tests/python/s_tir/dlight/test_cpu_reduction.py | 270 ++++++++++++++++++++++++
 3 files changed, 424 insertions(+)

diff --git a/python/tvm/s_tir/dlight/cpu/__init__.py 
b/python/tvm/s_tir/dlight/cpu/__init__.py
index 8743c616bb..20e1e9a3b8 100644
--- a/python/tvm/s_tir/dlight/cpu/__init__.py
+++ b/python/tvm/s_tir/dlight/cpu/__init__.py
@@ -20,3 +20,4 @@ CPU-generic schedule rules.
 """
 
 from .gemv import GEMV
+from .reduction import Reduction
diff --git a/python/tvm/s_tir/dlight/cpu/reduction.py 
b/python/tvm/s_tir/dlight/cpu/reduction.py
new file mode 100644
index 0000000000..2e804f9537
--- /dev/null
+++ b/python/tvm/s_tir/dlight/cpu/reduction.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""CPU reduction rule for operators including softmax, layer norm, RMS norm, 
etc."""
+
+from tvm import DataType, s_tir, tirx
+from tvm.target import Target
+from tvm.target.codegen import llvm_get_vector_width
+
+from ..analysis import normalize_prim_func
+from ..base import get_extent
+from .base import CPUScheduleRule
+
+
+def _get_num_leading_s(dom_kind: str) -> int:
+    """Count leading spatial ('S') axes in a dom_kind string."""
+    return len(dom_kind) - len(dom_kind.lstrip("S"))
+
+
+class Reduction(CPUScheduleRule):
+    """CPU reduction rule for softmax, layer norm, RMS norm, and similar 
operators.
+
+    Targets patterns with a mix of reduction (SR) and injective (SS) blocks,
+    where all blocks share the same leading spatial axes.
+    Example: softmax = maxelem(SR) -> exp(SS) -> expsum(SR) -> norm(SS).
+
+    Schedule strategy:
+      1. Parallelize leading spatial axes (batch dimension).
+      2. Move all blocks under the spatial loop via compute_at.
+      3. Vectorize injective blocks (exp, delta, norm) on their inner axis.
+      4. Split reduction inner axis to VLEN-sized chunks and annotate for
+         LLVM unrolling, preventing harmful full-unroll by the backend.
+
+    Note: vectorized reduction via rfactor is not used here because TVM's
+    rfactor primitive requires the reduction block to be the first child of
+    its enclosing loop, which is incompatible with compute_at when multiple
+    blocks share the same spatial loop. A follow-up using RVV reduction
+    intrinsics (vfredmax/vfredusum) via tensorize can address this.
+    """
+
+    def apply(  # pylint: 
disable=too-many-locals,too-many-return-statements,too-many-branches
+        self,
+        func: tirx.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> None | s_tir.Schedule | list[s_tir.Schedule]:
+        if not isinstance(func, tirx.PrimFunc) or not 
self.is_target_available(target):
+            return None
+
+        sch = s_tir.Schedule(func)
+        block_infos = normalize_prim_func(sch)
+        if block_infos is None or len(block_infos) < 2:
+            return None
+
+        # Must have at least one reduction block and last block must be 
injective.
+        if not any(not bi.is_injective() for bi in block_infos):
+            return None
+        if not block_infos[-1].is_injective():
+            return None
+
+        # Every block must start with at least one spatial axis, and all blocks
+        # must agree on the minimum number of leading spatial axes.
+        num_leading_s = None
+        for bi in block_infos:
+            dk = bi.dom_kind()
+            if not dk or dk[0] != "S":
+                return None
+            n = _get_num_leading_s(dk)
+            num_leading_s = n if num_leading_s is None else min(num_leading_s, 
n)
+        if not num_leading_s:
+            return None
+
+        # Infer dtype from the last block's write buffer.
+        last_block_stmt = sch.get(block_infos[-1].block_rv)
+        dtype_bits = (
+            DataType(last_block_stmt.writes[0].buffer.dtype).bits if 
last_block_stmt.writes else 32
+        )
+
+        # Determine vector lanes from target VLEN.
+        vlen_bits = llvm_get_vector_width(target)
+        if vlen_bits <= 0:
+            vlen_bits = 128
+        vec_lanes = max(vlen_bits // dtype_bits, 2)
+
+        # --- Phase 1: Parallelize spatial on the last block ---
+        last_block = block_infos[-1]
+        loops = sch.get_loops(last_block.block_rv)
+        if num_leading_s > 1:
+            spatial = sch.fuse(*loops[:num_leading_s])
+        else:
+            spatial = loops[0]
+        sch.parallel(spatial)
+
+        # --- Phase 2: Vectorize the last (injective) block ---
+        self._vectorize_inner(sch, last_block.block_rv, vec_lanes)
+
+        # --- Phase 3: compute_at all preceding blocks under spatial ---
+        for block_info in reversed(block_infos[:-1]):
+            sch.compute_at(block_info.block_rv, spatial, 
preserve_unit_loops=True)
+
+        # --- Phase 4: Vectorize injective, split+unroll reduction blocks ---
+        for block_info in block_infos[:-1]:
+            if block_info.is_injective():
+                self._vectorize_inner(sch, block_info.block_rv, vec_lanes)
+            else:
+                self._unroll_reduction_inner(sch, block_info.block_rv, 
vec_lanes)
+
+        return sch
+
+    @staticmethod
+    def _vectorize_inner(sch, block_rv, vec_lanes):
+        """Split the innermost loop to vec_lanes and vectorize."""
+        block_loops = sch.get_loops(block_rv)
+        if len(block_loops) <= 1:
+            return
+        inner = block_loops[-1]
+        extent = get_extent(sch, inner)
+        if isinstance(extent, int):
+            if extent > vec_lanes:
+                _, vec_loop = sch.split(inner, factors=[None, vec_lanes])
+                sch.vectorize(vec_loop)
+            elif extent >= 2:
+                sch.vectorize(inner)
+        else:
+            _, vec_loop = sch.split(inner, factors=[None, vec_lanes])
+            sch.vectorize(vec_loop)
+
+    @staticmethod
+    def _unroll_reduction_inner(sch, block_rv, vec_lanes):
+        """Split the reduction inner loop and annotate for unrolling."""
+        block_loops = sch.get_loops(block_rv)
+        if len(block_loops) <= 1:
+            return
+        inner = block_loops[-1]
+        extent = get_extent(sch, inner)
+        if isinstance(extent, int) and extent <= vec_lanes:
+            return
+        _, inner_loop = sch.split(inner, factors=[None, vec_lanes])
+        sch.annotate(inner_loop, ann_key="pragma_auto_unroll_max_step", 
ann_val=vec_lanes)
+        sch.annotate(inner_loop, ann_key="pragma_unroll_explicit", ann_val=1)
diff --git a/tests/python/s_tir/dlight/test_cpu_reduction.py 
b/tests/python/s_tir/dlight/test_cpu_reduction.py
new file mode 100644
index 0000000000..db8280a61a
--- /dev/null
+++ b/tests/python/s_tir/dlight/test_cpu_reduction.py
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""Tests for CPU DLight Reduction schedule rule."""
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te, tirx, topi
+from tvm.s_tir import dlight as dl
+from tvm.s_tir.dlight.cpu import Reduction
+from tvm.target import Target
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _llvm_target():
+    return Target({"kind": "llvm"})
+
+
+def _rvv_target():
+    return Target(
+        {
+            "kind": "llvm",
+            "mtriple": "riscv64-linux-gnu",
+            "mcpu": "generic-rv64",
+            "mabi": "lp64d",
+            "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c", "+v"],
+        }
+    )
+
+
+def _build_softmax(batch, features, fast=False):
+    A = te.placeholder((batch, features), dtype="float32", name="A")
+    B = topi.nn.fast_softmax(A, axis=1) if fast else topi.nn.softmax(A, axis=1)
+    func = te.create_prim_func([A, B])
+    return tvm.IRModule({"main": func})
+
+
+def _apply_and_check(mod, target):
+    """Apply Reduction rule and verify it was applied."""
+    rule = Reduction()
+    result = rule.apply(mod["main"], target, False)
+    assert result is not None, "Reduction rule should apply"
+    return result
+
+
+# ---------------------------------------------------------------------------
+# Test: schedule applicability
+# ---------------------------------------------------------------------------
+
+
[email protected]("fast", [False, True], ids=["softmax", 
"fast_softmax"])
[email protected](
+    "batch,features",
+    [
+        (1, 10),
+        (1, 128),
+        (14, 185),
+        (32, 256),
+        (64, 512),
+        (128, 1024),
+        (1, 30522),
+    ],
+)
+def test_reduction_applies(batch, features, fast):
+    """Reduction rule should apply to softmax/fast_softmax of various 
shapes."""
+    mod = _build_softmax(batch, features, fast=fast)
+    target = _llvm_target()
+    _apply_and_check(mod, target)
+
+
+# ---------------------------------------------------------------------------
+# Test: scheduled TIR structure
+# ---------------------------------------------------------------------------
+
+
+def test_softmax_schedule_structure():
+    """Verify the scheduled TIR has expected structure:
+    - parallel on batch axis
+    - vectorized innermost loops for injective blocks
+    - split+unroll for reduction blocks
+    """
+    mod = _build_softmax(14, 185, fast=False)
+    target = _llvm_target()
+    sch = _apply_and_check(mod, target)
+    scheduled_mod = sch.mod
+
+    # Check that tirx.is_scheduled is NOT set (only set by 
ApplyDefaultSchedule)
+    # but the schedule should be valid
+    assert scheduled_mod is not None
+
+    # Verify via ApplyDefaultSchedule path
+    with target:
+        scheduled = dl.ApplyDefaultSchedule(Reduction())(mod)
+    func = scheduled["main"]
+
+    # Check tirx.is_scheduled is set
+    assert func.attrs and func.attrs.get("tirx.is_scheduled", False)
+
+
+def test_fast_softmax_schedule_structure():
+    """fast_softmax should keep T_fast_exp as a separate vectorizable block."""
+    mod = _build_softmax(14, 185, fast=True)
+    target = _llvm_target()
+    sch = _apply_and_check(mod, target)
+    script = str(sch.mod)
+
+    # fast_exp block should exist (not inlined)
+    assert "T_fast_exp" in script or "T_softmax_delta" in script
+    # Should have T.parallel
+    assert "T.parallel" in script
+    # Should have T.vectorized
+    assert "T.vectorized" in script
+
+
+# ---------------------------------------------------------------------------
+# Test: LLVM IR quality (cross-compile to RISC-V RVV)
+# ---------------------------------------------------------------------------
+
+
+def _codegen_llvm_ir(mod, target):
+    """Lower and codegen to LLVM IR (no linking)."""
+    bound = tirx.transform.BindTarget(target.with_host(target))(mod)
+    pipeline = tirx.get_tir_pipeline("default")
+    lowered = pipeline(bound)
+    from tvm.tirx.build import split_host_device_mods
+
+    host_mod, _ = split_host_device_mods(lowered)
+    host_mod = tirx.pipeline.finalize_host_passes()(host_mod)
+    built = tvm.target.codegen.build_module(host_mod, target)
+    return built.inspect_source("ll")
+
+
+def _codegen_asm(mod, target):
+    """Lower and codegen to assembly (no linking)."""
+    bound = tirx.transform.BindTarget(target.with_host(target))(mod)
+    pipeline = tirx.get_tir_pipeline("default")
+    lowered = pipeline(bound)
+    from tvm.tirx.build import split_host_device_mods
+
+    host_mod, _ = split_host_device_mods(lowered)
+    host_mod = tirx.pipeline.finalize_host_passes()(host_mod)
+    built = tvm.target.codegen.build_module(host_mod, target)
+    return built.inspect_source("s")
+
+
[email protected]("fast", [False, True], ids=["softmax", 
"fast_softmax"])
+def test_rvv_code_size_reduction(fast):
+    """Scheduled RVV code should be smaller than unscheduled.
+
+    The original issue (apache/tvm#18569) shows RVV softmax is 1.34x slower
+    than scalar, partly due to LLVM generating bloated code with excessive
+    unrolling. The schedule should reduce code size significantly.
+    """
+    target = _rvv_target()
+    mod = _build_softmax(14, 185, fast=fast)
+
+    # Unscheduled
+    ir_unsched = _codegen_llvm_ir(mod, target)
+    n_unsched = len(ir_unsched.splitlines())
+
+    # Scheduled
+    with target:
+        mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod)
+    ir_sched = _codegen_llvm_ir(mod_sched, target)
+    n_sched = len(ir_sched.splitlines())
+
+    # Scheduled should be meaningfully smaller (at least 30% reduction)
+    ratio = n_sched / n_unsched
+    assert ratio < 0.75, (
+        f"Expected >=25% code reduction, got {(1 - ratio) * 100:.1f}% "
+        f"({n_unsched} -> {n_sched} lines)"
+    )
+
+
+def test_rvv_fast_softmax_vectorizes_exp():
+    """fast_softmax + schedule should produce RVV vector instructions
+    for the polynomial exp approximation (no scalar exp calls)."""
+    target = _rvv_target()
+    mod = _build_softmax(14, 185, fast=True)
+    with target:
+        mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod)
+    ir = _codegen_llvm_ir(mod_sched, target)
+
+    # Should have zero scalar exp calls (fast_exp uses polynomial)
+    scalar_exp = sum(1 for line in ir.splitlines() if "llvm.exp.f32" in line)
+    assert scalar_exp == 0, f"Expected 0 scalar exp calls, got {scalar_exp}"
+
+    # Should have scalable vector operations
+    n_svec = ir.count("<vscale x")
+    assert n_svec > 0, "Expected scalable vector operations in LLVM IR"
+
+
+def test_rvv_asm_instruction_reduction():
+    """Scheduled RVV assembly should have fewer total instructions
+    than both unscheduled RVV and scalar RV."""
+    rvv = _rvv_target()
+    rv = Target(
+        {
+            "kind": "llvm",
+            "mtriple": "riscv64-linux-gnu",
+            "mcpu": "generic-rv64",
+            "mabi": "lp64d",
+            "mattr": ["+64bit", "+m", "+a", "+f", "+d", "+c"],
+        }
+    )
+
+    mod = _build_softmax(14, 185, fast=True)
+
+    # Scalar baseline
+    asm_rv = _codegen_asm(mod, rv)
+    n_rv = len(
+        [
+            line
+            for line in asm_rv.splitlines()
+            if line.strip() and not line.strip().startswith((".", "#", "/"))
+        ]
+    )
+
+    # RVV unscheduled
+    asm_rvv = _codegen_asm(mod, rvv)
+    n_rvv = len(
+        [
+            line
+            for line in asm_rvv.splitlines()
+            if line.strip() and not line.strip().startswith((".", "#", "/"))
+        ]
+    )
+
+    # RVV scheduled
+    with rvv:
+        mod_sched = dl.ApplyDefaultSchedule(Reduction())(mod)
+    asm_sched = _codegen_asm(mod_sched, rvv)
+    n_sched = len(
+        [
+            line
+            for line in asm_sched.splitlines()
+            if line.strip() and not line.strip().startswith((".", "#", "/"))
+        ]
+    )
+
+    # Scheduled should be smaller than both unscheduled RVV and scalar
+    assert n_sched < n_rvv, (
+        f"Scheduled ({n_sched}) should have fewer instructions than 
unscheduled RVV ({n_rvv})"
+    )
+    assert n_sched <= n_rv * 1.1, (
+        f"Scheduled ({n_sched}) should not be much larger than scalar RV 
({n_rv})"
+    )
+
+
+if __name__ == "__main__":
+    pytest.main([__file__, "-v"])

Reply via email to