This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d58e9f7bbe [WebGPU] Add gating logic for subgroup shuffle primitives 
(#18823)
d58e9f7bbe is described below

commit d58e9f7bbef42c2256df693d8fdec8bd8cce0b69
Author: Sai Gopal Reddy Kovvuri <[email protected]>
AuthorDate: Mon Apr 6 14:56:15 2026 -0400

    [WebGPU] Add gating logic for subgroup shuffle primitives (#18823)
    
    ## Summary
    This adds gating logic on top of #17699 to support optional subgroup
    shuffle
    primitives based on a compile-time flag.
    
    ## Problem
    The PR #17699 always generates subgroup shuffle ops when targeting
    WebGPU.
    However, not all WebGPU devices support subgroups. We need a way to:
    - Default to shared memory reductions (universally compatible)
    - Optionally enable subgroup shuffles for devices that support them
    
    ## Solution
    Implement gating via TVM target parameter:
    - Default `thread_warp_size=1` disables warp reductions (uses shared
    memory + barriers)
    - Add target parser `UpdateWebGPUAttrs()` that sets
    `thread_warp_size=32` when `supports_subgroups=true`
    - Add `--enable-subgroups` CLI flag in mlc-llm to surface the option to
    users
    
    The gating happens at the reduction path selection level
    (`IsWarpReduction()` in
    `lower_thread_allreduce.cc`), ensuring subgroup ops are never generated
    unless explicitly enabled.
    
    ## Testing
    
    Tested with Llama-3.2-1B-q4f16_1. Baseline (no flag) uses shared memory
    reductions;
    with flag, generates subgroupShuffle* ops.
    Both the generated WGSLs here:
    https://gist.github.com/ksgr5566/301664a5dda3e46f44092be4d09b2d4f
    Benchmarking:
    https://gist.github.com/ksgr5566/c9bd5bc5aadba999ec2f2c38eb0c49b3
---
 src/s_tir/transform/lower_thread_allreduce.cc      |   4 +-
 src/target/source/codegen_webgpu.cc                |   7 +-
 src/target/source/codegen_webgpu.h                 |   2 +
 src/target/source/intrin_rule_webgpu.cc            |  59 ++++++++++++
 src/target/target_kind.cc                          |  35 +++++++
 ...test_s_tir_transform_lower_thread_all_reduce.py | 101 +++++++++++++++++++++
 tests/python/target/test_target_target.py          |  20 ++++
 web/src/webgpu.ts                                  |   3 +
 8 files changed, 228 insertions(+), 3 deletions(-)

diff --git a/src/s_tir/transform/lower_thread_allreduce.cc 
b/src/s_tir/transform/lower_thread_allreduce.cc
index f1e5a3cfaf..f7253a7689 100644
--- a/src/s_tir/transform/lower_thread_allreduce.cc
+++ b/src/s_tir/transform/lower_thread_allreduce.cc
@@ -742,11 +742,11 @@ class ThreadAllreduceBuilder final : public 
StmtExprMutator {
   bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, 
int reduce_extent,
                        int contiguous_reduce_extent) {
     if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
-        (target_->kind->name != "metal")) {
+        (target_->kind->name != "metal") && (target_->kind->name != "webgpu")) 
{
       return false;
     }
 
-    need_warp_shuffle_mask_ = target_->kind->name != "metal";
+    need_warp_shuffle_mask_ = target_->kind->name != "metal" && 
target_->kind->name != "webgpu";
 
     // rocm only supports 32 bit operands for shuffling at the moment
     if ((target_->kind->name == "rocm") &&
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index ae4c93e80c..ca7b3878ae 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() {
   if (enable_fp16_) {
     header_stream << "enable f16;\n\n";
   }
+  if (enable_subgroups_) {
+    header_stream << "enable subgroups;\n\n";
+  }
   return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() 
+ stream.str();
 }
 
@@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
   }
 }
 
-CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
+CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
+  enable_subgroups_ = 
target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
+}
 
 runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool 
skip_readonly_decl) {
   // clear previous generated state.
diff --git a/src/target/source/codegen_webgpu.h 
b/src/target/source/codegen_webgpu.h
index f53d090e58..d0a541677a 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -92,6 +92,8 @@ class CodeGenWebGPU final : public CodeGenC {
 
   // whether enable fp16
   bool enable_fp16_{false};
+  // whether enable subgroups
+  bool enable_subgroups_{false};
 
   /*! \brief the header stream for function label and enable directive if any, 
goes before any other
    * declaration */
diff --git a/src/target/source/intrin_rule_webgpu.cc 
b/src/target/source/intrin_rule_webgpu.cc
index 968df9a579..86658d8e28 100644
--- a/src/target/source/intrin_rule_webgpu.cc
+++ b/src/target/source/intrin_rule_webgpu.cc
@@ -32,6 +32,30 @@ namespace intrin {
 
 using tirx::FLowerIntrinsic;
 
+// warp-level primitives. Follows implementation in intrin_rule_metal.cc
+struct WebGPUWarpIntrinsic {
+  const Op operator()(DataType t, const Op& orig_op) const {
+    if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
+      return Op::Get("tirx.webgpu.subgroup_shuffle");
+    } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
+      return Op::Get("tirx.webgpu.subgroup_shuffle_up");
+    } else {
+      TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
+      return Op::Get("tirx.webgpu.subgroup_shuffle_down");
+    }
+  }
+};
+
+template <typename T>
+static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
+  const CallNode* call = e.as<CallNode>();
+  TVM_FFI_ICHECK(call != nullptr);
+  TVM_FFI_ICHECK_EQ(call->args.size(), 5);  // mask, value, warp_id, width, 
warp_size
+  PrimExpr lane_or_delta = Cast(DataType::UInt(32, 
call->args[2].dtype().lanes()), call->args[2]);
+  ffi::Array<PrimExpr> webgpu_args{{call->args[1], lane_or_delta}};
+  return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), 
webgpu_args);
+}
+
 // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions
 
 struct ReturnAbs {
@@ -113,6 +137,41 @@ TVM_REGISTER_OP("tirx.trunc")
 // extra dispatch
 
TVM_REGISTER_OP("tirx.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", 
DispatchFastErf);
 
+// warp-level primitives. Follows implementation in intrin_rule_metal.cc
+TVM_REGISTER_OP("tirx.tvm_warp_shuffle")
+    .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
+                               DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);
+
+TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up")
+    .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
+                               DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);
+
+TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down")
+    .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
+                               DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);
+
+// Register low-level builtin ops.
+TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle")
+    .set_num_inputs(2)
+    .add_argument("var", "Expr", "The variable to sync.")
+    .add_argument("lane", "Expr", "The source thread id.")
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffle")
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up")
+    .set_num_inputs(2)
+    .add_argument("var", "Expr", "The variable to sync.")
+    .add_argument("delta", "Expr", "The source lane id offset to be added.")
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleUp")
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down")
+    .set_num_inputs(2)
+    .add_argument("var", "Expr", "The variable to sync.")
+    .add_argument("delta", "Expr", "The source lane id offset to be 
subtracted.")
+    .set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleDown")
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 }  // namespace intrin
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 8c328dd9cf..2fb5e17d5f 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -256,6 +256,36 @@ ffi::Map<ffi::String, ffi::Any> 
UpdateROCmAttrs(ffi::Map<ffi::String, ffi::Any>
   return target;
 }
 
+/*!
+ * \brief Update WebGPU target attributes for subgroup-enabled lowering.
+ * Runtime routing on the WebLLM side guarantees subgroup size == 32.
+ * Runtime routing on the WebLLM side guarantees
+ * maxComputeInvocationsPerWorkgroup >= 1024.
+ * This is intentionally constrained for the subgroup-enabled WASM variant.
+ * When supports_subgroups is true, canonicalize thread_warp_size to 32 so
+ * TIR lowering can emit subgroup shuffle reductions.
+ * \param target The Target to update
+ * \return The updated attributes
+ */
+ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String, 
ffi::Any> target) {
+  bool subgroups = false;
+  if (target.count("supports_subgroups")) {
+    subgroups = Downcast<Bool>(target.at("supports_subgroups"));
+  }
+
+  if (target.count("thread_warp_size")) {
+    int64_t thread_warp_size = 
Downcast<Integer>(target.at("thread_warp_size"))->value;
+    TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1)
+        << "WebGPU target with thread_warp_size=" << thread_warp_size
+        << " requires supports_subgroups=true";
+  }
+
+  if (subgroups) {
+    target.Set("thread_warp_size", int64_t(32));
+  }
+  return target;
+}
+
 /*!
  * \brief Test Target Parser
  * \param target The Target to update
@@ -429,6 +459,11 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
 
 TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
     .add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
+    .add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
+    // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction 
returns false, so no
+    // subgroup ops are emitted.
+    .add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
+    .set_target_canonicalizer(UpdateWebGPUAttrs)
     .set_default_keys({"webgpu", "gpu"});
 
 TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
index aff3376052..558d67cade 100644
--- 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
@@ -406,5 +406,106 @@ def test_metal_no_mask():
     assert "tvm_storage_sync" in After_script
 
 
+def test_webgpu_warp_reduce():
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def main(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, 
"float32")):
+            T.func_attr(
+                {
+                    "target": T.target(
+                        {
+                            "kind": "webgpu",
+                            "supports_subgroups": True,
+                            "host": "llvm",
+                        }
+                    ),
+                }
+            )
+            A_flat = T.decl_buffer(4096, data=A.data)
+
+            for i in range(128):
+                threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                reduce_data = T.alloc_buffer((1,), "float32", scope="local")
+                reduce = T.decl_buffer(1, data=reduce_data.data, scope="local")
+
+                with T.attr(
+                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                    "reduce_scope",
+                    T.reinterpret("handle", T.uint64(0)),
+                ):
+                    T.tvm_thread_allreduce(
+                        T.uint32(1),
+                        A_flat[0],
+                        T.bool(True),
+                        reduce[0],
+                        threadIdx_x,
+                    )
+                if threadIdx_x == 0:
+                    B[i] = reduce[0]
+
+    After = transform(Before)
+    assert After is not None
+    After_script = After.script()
+    assert "tvm_warp_shuffle_down" in After_script
+    assert "tvm_warp_shuffle(" in After_script
+    assert "tvm_storage_sync" not in After_script
+    assert "T.uint32(" not in After_script
+
+
+def test_webgpu_multi_warp_reduce():
+    transform = tvm.s_tir.transform.LowerThreadAllreduce()
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 
2), "float32")):
+            T.func_attr(
+                {
+                    "target": T.target(
+                        {
+                            "kind": "webgpu",
+                            "max_num_threads": 1024,
+                            "supports_subgroups": True,
+                            "host": "llvm",
+                        }
+                    ),
+                }
+            )
+            blockIdx_x = T.launch_thread("blockIdx.x", 1)
+            cross_thread_B = T.alloc_buffer((1,), "float32", scope="local")
+            threadIdx_z = T.launch_thread("threadIdx.z", 1)
+            threadIdx_y = T.launch_thread("threadIdx.y", 2)
+            threadIdx_x = T.launch_thread("threadIdx.x", 128)
+            cross_thread_B_1 = T.decl_buffer((1,), data=cross_thread_B.data, 
scope="local")
+            with T.attr(
+                T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                A_1 = T.decl_buffer((256,), data=A.data)
+                T.tvm_thread_allreduce(
+                    T.uint32(1),
+                    A_1[threadIdx_y * 128 + threadIdx_x],
+                    T.bool(True),
+                    cross_thread_B_1[0],
+                    threadIdx_x,
+                )
+            if threadIdx_x == 0:
+                B_1 = T.decl_buffer((2,), data=B.data)
+                B_1[threadIdx_y] = cross_thread_B_1[0]
+
+    After = transform(Before)
+    assert After is not None
+    After_script = After.script()
+    assert "tvm_warp_shuffle_down" in After_script
+    assert "tvm_storage_sync" in After_script
+    assert "\"tirx.volatile\": T.bool(True)" in After_script
+    assert "T.uint32(" not in After_script
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/target/test_target_target.py 
b/tests/python/target/test_target_target.py
index 05c79abea7..43d55a27fc 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -426,5 +426,25 @@ def test_cli_string_rejected():
         Target("llvm -mcpu=cortex-a53")
 
 
+def test_webgpu_target_subgroup_attrs():
+    """Test WebGPU target defaults and supports_subgroups canonicalization."""
+    # Default: thread_warp_size=1, supports_subgroups=False
+    tgt_default = Target({"kind": "webgpu"})
+    assert tgt_default.attrs["thread_warp_size"] == 1
+    assert tgt_default.attrs["supports_subgroups"] == 0
+
+    # With supports_subgroups=True: thread_warp_size is set to 32
+    tgt_subgroups = Target({"kind": "webgpu", "supports_subgroups": True})
+    assert tgt_subgroups.attrs["thread_warp_size"] == 32
+    assert tgt_subgroups.attrs["supports_subgroups"] == 1
+
+    for config in [
+        {"kind": "webgpu", "thread_warp_size": 32},
+        {"kind": "webgpu", "thread_warp_size": 32, "supports_subgroups": 
False},
+    ]:
+        with pytest.raises(tvm.TVMError, match="requires 
supports_subgroups=true"):
+            Target(config)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 55d188516d..199fa14235 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -127,6 +127,9 @@ export async function detectGPUDevice(powerPreference: 
"low-power" | "high-perfo
     if (adapter.features.has("shader-f16")) {
       requiredFeatures.push("shader-f16");
     }
+    if (adapter.features.has("subgroups")) {
+      requiredFeatures.push("subgroups");
+    }
     // requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
     // issue when building. However, it is still needed for older browsers, 
hence `as any`.
     const adapterInfo = adapter.info || await (adapter as 
any).requestAdapterInfo();

Reply via email to