This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d58e9f7bbe [WebGPU] Add gating logic for subgroup shuffle primitives
(#18823)
d58e9f7bbe is described below
commit d58e9f7bbef42c2256df693d8fdec8bd8cce0b69
Author: Sai Gopal Reddy Kovvuri <[email protected]>
AuthorDate: Mon Apr 6 14:56:15 2026 -0400
[WebGPU] Add gating logic for subgroup shuffle primitives (#18823)
## Summary
This adds gating logic on top of #17699 to support optional subgroup
shuffle
primitives based on a compile-time flag.
## Problem
The PR #17699 always generates subgroup shuffle ops when targeting
WebGPU.
However, not all WebGPU devices support subgroups. We need a way to:
- Default to shared memory reductions (universally compatible)
- Optionally enable subgroup shuffles for devices that support them
## Solution
Implement gating via TVM target parameter:
- Default `thread_warp_size=1` disables warp reductions (uses shared
memory + barriers)
- Add target parser `UpdateWebGPUAttrs()` that sets
`thread_warp_size=32` when `supports_subgroups=true`
- Add `--enable-subgroups` CLI flag in mlc-llm to surface the option to
users
The gating happens at the reduction path selection level
(`IsWarpReduction()` in
`lower_thread_allreduce.cc`), ensuring subgroup ops are never generated
unless explicitly enabled.
## Testing
Tested with Llama-3.2-1B-q4f16_1. Baseline (no flag) uses shared memory
reductions;
with flag, generates subgroupShuffle* ops.
Both the generated WGSLs here:
https://gist.github.com/ksgr5566/301664a5dda3e46f44092be4d09b2d4f
Benchmarking:
https://gist.github.com/ksgr5566/c9bd5bc5aadba999ec2f2c38eb0c49b3
---
src/s_tir/transform/lower_thread_allreduce.cc | 4 +-
src/target/source/codegen_webgpu.cc | 7 +-
src/target/source/codegen_webgpu.h | 2 +
src/target/source/intrin_rule_webgpu.cc | 59 ++++++++++++
src/target/target_kind.cc | 35 +++++++
...test_s_tir_transform_lower_thread_all_reduce.py | 101 +++++++++++++++++++++
tests/python/target/test_target_target.py | 20 ++++
web/src/webgpu.ts | 3 +
8 files changed, 228 insertions(+), 3 deletions(-)
diff --git a/src/s_tir/transform/lower_thread_allreduce.cc
b/src/s_tir/transform/lower_thread_allreduce.cc
index f1e5a3cfaf..f7253a7689 100644
--- a/src/s_tir/transform/lower_thread_allreduce.cc
+++ b/src/s_tir/transform/lower_thread_allreduce.cc
@@ -742,11 +742,11 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent,
int reduce_extent,
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
- (target_->kind->name != "metal")) {
+ (target_->kind->name != "metal") && (target_->kind->name != "webgpu"))
{
return false;
}
- need_warp_shuffle_mask_ = target_->kind->name != "metal";
+ need_warp_shuffle_mask_ = target_->kind->name != "metal" &&
target_->kind->name != "webgpu";
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index ae4c93e80c..ca7b3878ae 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() {
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
+ if (enable_subgroups_) {
+ header_stream << "enable subgroups;\n\n";
+ }
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str()
+ stream.str();
}
@@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
}
}
-CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
+CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
+ enable_subgroups_ =
target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
+}
runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool
skip_readonly_decl) {
// clear previous generated state.
diff --git a/src/target/source/codegen_webgpu.h
b/src/target/source/codegen_webgpu.h
index f53d090e58..d0a541677a 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -92,6 +92,8 @@ class CodeGenWebGPU final : public CodeGenC {
// whether enable fp16
bool enable_fp16_{false};
+ // whether enable subgroups
+ bool enable_subgroups_{false};
/*! \brief the header stream for function label and enable directive if any,
goes before any other
* declaration */
diff --git a/src/target/source/intrin_rule_webgpu.cc
b/src/target/source/intrin_rule_webgpu.cc
index 968df9a579..86658d8e28 100644
--- a/src/target/source/intrin_rule_webgpu.cc
+++ b/src/target/source/intrin_rule_webgpu.cc
@@ -32,6 +32,30 @@ namespace intrin {
using tirx::FLowerIntrinsic;
+// warp-level primitives. Follows implementation in intrin_rule_metal.cc
+struct WebGPUWarpIntrinsic {
+ const Op operator()(DataType t, const Op& orig_op) const {
+ if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
+ return Op::Get("tirx.webgpu.subgroup_shuffle");
+ } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
+ return Op::Get("tirx.webgpu.subgroup_shuffle_up");
+ } else {
+ TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
+ return Op::Get("tirx.webgpu.subgroup_shuffle_down");
+ }
+ }
+};
+
+template <typename T>
+static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
+ const CallNode* call = e.as<CallNode>();
+ TVM_FFI_ICHECK(call != nullptr);
+ TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width,
warp_size
+ PrimExpr lane_or_delta = Cast(DataType::UInt(32,
call->args[2].dtype().lanes()), call->args[2]);
+ ffi::Array<PrimExpr> webgpu_args{{call->args[1], lane_or_delta}};
+ return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)),
webgpu_args);
+}
+
// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions
struct ReturnAbs {
@@ -113,6 +137,41 @@ TVM_REGISTER_OP("tirx.trunc")
// extra dispatch
TVM_REGISTER_OP("tirx.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
DispatchFastErf);
+// warp-level primitives. Follows implementation in intrin_rule_metal.cc
+TVM_REGISTER_OP("tirx.tvm_warp_shuffle")
+ .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
+ DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);
+
+TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up")
+ .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
+ DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);
+
+TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down")
+ .set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic",
+ DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);
+
+// Register low-level builtin ops.
+TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle")
+ .set_num_inputs(2)
+ .add_argument("var", "Expr", "The variable to sync.")
+ .add_argument("lane", "Expr", "The source thread id.")
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffle")
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_up")
+ .set_num_inputs(2)
+ .add_argument("var", "Expr", "The variable to sync.")
+ .add_argument("delta", "Expr", "The source lane id offset to be added.")
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleUp")
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down")
+ .set_num_inputs(2)
+ .add_argument("var", "Expr", "The variable to sync.")
+ .add_argument("delta", "Expr", "The source lane id offset to be
subtracted.")
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleDown")
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
} // namespace intrin
} // namespace codegen
} // namespace tvm
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 8c328dd9cf..2fb5e17d5f 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -256,6 +256,36 @@ ffi::Map<ffi::String, ffi::Any>
UpdateROCmAttrs(ffi::Map<ffi::String, ffi::Any>
return target;
}
+/*!
+ * \brief Update WebGPU target attributes for subgroup-enabled lowering.
+ * Runtime routing on the WebLLM side guarantees subgroup size == 32.
+ * Runtime routing on the WebLLM side guarantees
+ * maxComputeInvocationsPerWorkgroup >= 1024.
+ * This is intentionally constrained for the subgroup-enabled WASM variant.
+ * When supports_subgroups is true, canonicalize thread_warp_size to 32 so
+ * TIR lowering can emit subgroup shuffle reductions.
+ * \param target The Target to update
+ * \return The updated attributes
+ */
+ffi::Map<ffi::String, ffi::Any> UpdateWebGPUAttrs(ffi::Map<ffi::String,
ffi::Any> target) {
+ bool subgroups = false;
+ if (target.count("supports_subgroups")) {
+ subgroups = Downcast<Bool>(target.at("supports_subgroups"));
+ }
+
+ if (target.count("thread_warp_size")) {
+ int64_t thread_warp_size =
Downcast<Integer>(target.at("thread_warp_size"))->value;
+ TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1)
+ << "WebGPU target with thread_warp_size=" << thread_warp_size
+ << " requires supports_subgroups=true";
+ }
+
+ if (subgroups) {
+ target.Set("thread_warp_size", int64_t(32));
+ }
+ return target;
+}
+
/*!
* \brief Test Target Parser
* \param target The Target to update
@@ -429,6 +459,11 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
+ .add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
+ // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction
returns false, so no
+ // subgroup ops are emitted.
+ .add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
+ .set_target_canonicalizer(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});
TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
index aff3376052..558d67cade 100644
---
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
@@ -406,5 +406,106 @@ def test_metal_no_mask():
assert "tvm_storage_sync" in After_script
+def test_webgpu_warp_reduce():
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def main(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128,
"float32")):
+ T.func_attr(
+ {
+ "target": T.target(
+ {
+ "kind": "webgpu",
+ "supports_subgroups": True,
+ "host": "llvm",
+ }
+ ),
+ }
+ )
+ A_flat = T.decl_buffer(4096, data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce_data = T.alloc_buffer((1,), "float32", scope="local")
+ reduce = T.decl_buffer(1, data=reduce_data.data, scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+ After = transform(Before)
+ assert After is not None
+ After_script = After.script()
+ assert "tvm_warp_shuffle_down" in After_script
+ assert "tvm_warp_shuffle(" in After_script
+ assert "tvm_storage_sync" not in After_script
+ assert "T.uint32(" not in After_script
+
+
+def test_webgpu_multi_warp_reduce():
+ transform = tvm.s_tir.transform.LowerThreadAllreduce()
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1,
2), "float32")):
+ T.func_attr(
+ {
+ "target": T.target(
+ {
+ "kind": "webgpu",
+ "max_num_threads": 1024,
+ "supports_subgroups": True,
+ "host": "llvm",
+ }
+ ),
+ }
+ )
+ blockIdx_x = T.launch_thread("blockIdx.x", 1)
+ cross_thread_B = T.alloc_buffer((1,), "float32", scope="local")
+ threadIdx_z = T.launch_thread("threadIdx.z", 1)
+ threadIdx_y = T.launch_thread("threadIdx.y", 2)
+ threadIdx_x = T.launch_thread("threadIdx.x", 128)
+ cross_thread_B_1 = T.decl_buffer((1,), data=cross_thread_B.data,
scope="local")
+ with T.attr(
+ T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ A_1 = T.decl_buffer((256,), data=A.data)
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_1[threadIdx_y * 128 + threadIdx_x],
+ T.bool(True),
+ cross_thread_B_1[0],
+ threadIdx_x,
+ )
+ if threadIdx_x == 0:
+ B_1 = T.decl_buffer((2,), data=B.data)
+ B_1[threadIdx_y] = cross_thread_B_1[0]
+
+ After = transform(Before)
+ assert After is not None
+ After_script = After.script()
+ assert "tvm_warp_shuffle_down" in After_script
+ assert "tvm_storage_sync" in After_script
+ assert "\"tirx.volatile\": T.bool(True)" in After_script
+ assert "T.uint32(" not in After_script
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index 05c79abea7..43d55a27fc 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -426,5 +426,25 @@ def test_cli_string_rejected():
Target("llvm -mcpu=cortex-a53")
+def test_webgpu_target_subgroup_attrs():
+ """Test WebGPU target defaults and supports_subgroups canonicalization."""
+ # Default: thread_warp_size=1, supports_subgroups=False
+ tgt_default = Target({"kind": "webgpu"})
+ assert tgt_default.attrs["thread_warp_size"] == 1
+ assert tgt_default.attrs["supports_subgroups"] == 0
+
+ # With supports_subgroups=True: thread_warp_size is set to 32
+ tgt_subgroups = Target({"kind": "webgpu", "supports_subgroups": True})
+ assert tgt_subgroups.attrs["thread_warp_size"] == 32
+ assert tgt_subgroups.attrs["supports_subgroups"] == 1
+
+ for config in [
+ {"kind": "webgpu", "thread_warp_size": 32},
+ {"kind": "webgpu", "thread_warp_size": 32, "supports_subgroups":
False},
+ ]:
+ with pytest.raises(tvm.TVMError, match="requires
supports_subgroups=true"):
+ Target(config)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index 55d188516d..199fa14235 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -127,6 +127,9 @@ export async function detectGPUDevice(powerPreference:
"low-power" | "high-perfo
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
}
+ if (adapter.features.has("subgroups")) {
+ requiredFeatures.push("subgroups");
+ }
// requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
// issue when building. However, it is still needed for older browsers,
hence `as any`.
const adapterInfo = adapter.info || await (adapter as
any).requestAdapterInfo();