This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d33630c2a2 [Vulkan] Avoid explicit layout decoration on non-interface
allocations (#18914)
d33630c2a2 is described below
commit d33630c2a28c6b6d6b86f58ec952d1f4e14df29f
Author: Yu Chengye <[email protected]>
AuthorDate: Sat Mar 21 21:43:24 2026 +0800
[Vulkan] Avoid explicit layout decoration on non-interface allocations
(#18914)
SPIR-V codegen currently emits `ArrayStride` and `Offset` decorations
for non-interface allocations in `GetStructArrayType()`. That is correct
for descriptor-backed interface blocks, but not for static workgroup
allocations.
I hit this while bringing up tilelang vulkan shared memory allocation
path: the vulkan validation rejected shaders that used shared memory
lowered through this path:
```
tvm.error.InternalError: Check failed: res == SPV_SUCCESS (-10 vs. 0) :
index=44 error:[VUID-StandaloneSpirv-None-10684] Invalid explicit layout
decorations on type for operand '25[%_ptr_Workgroup__struct_24]'
%A_shared = OpVariable %_ptr_Workgroup__struct_24 Workgroup
```
FIX: This PR keeps layout decoration for interface blocks, and skips for
non-interface allocations such as static shared/workgroup memory. A new
compile-only test is added for this.
One possible concern is that there's already a pre-existing test using
`fetch_to_shared`.
---
src/target/spirv/ir_builder.cc | 19 ++++++++++---------
tests/python/codegen/test_target_codegen_vulkan.py | 18 ++++++++++++++++++
2 files changed, 28 insertions(+), 9 deletions(-)
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index 135888c23d..f912e48276 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -178,23 +178,24 @@ SType IRBuilder::GetStructArrayType(const SType&
value_type, uint32_t num_elems,
} else {
ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type,
value_type).Commit(&global_);
}
- int nbits = value_type.type.bits() * value_type.type.lanes();
- TVM_FFI_ICHECK_EQ(nbits % 8, 0);
- uint32_t nbytes = static_cast<uint32_t>(nbits) / 8;
- // decorate the array type.
- this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride,
nbytes);
+ if (interface_block) {
+ int nbits = value_type.type.bits() * value_type.type.lanes();
+ TVM_FFI_ICHECK_EQ(nbits % 8, 0);
+ uint32_t nbytes = static_cast<uint32_t>(nbits) / 8;
+ // Explicit layout is required for descriptor-backed interface blocks.
+ this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride,
nbytes);
+ }
// declare struct of array
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = DataType::Handle();
struct_type.element_type_id = value_type.id;
ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_);
- // decorate the array type.
- ib_.Begin(spv::OpMemberDecorate)
- .AddSeq(struct_type, 0, spv::DecorationOffset, 0)
- .Commit(&decorate_);
if (interface_block) {
+ ib_.Begin(spv::OpMemberDecorate)
+ .AddSeq(struct_type, 0, spv::DecorationOffset, 0)
+ .Commit(&decorate_);
// Runtime array are always decorated as Block or BufferBlock
// (shader storage buffer)
if (spirv_support_.supports_storage_buffer_storage_class) {
diff --git a/tests/python/codegen/test_target_codegen_vulkan.py
b/tests/python/codegen/test_target_codegen_vulkan.py
index c975073922..e6a24061c8 100644
--- a/tests/python/codegen/test_target_codegen_vulkan.py
+++ b/tests/python/codegen/test_target_codegen_vulkan.py
@@ -515,6 +515,24 @@ def test_codegen_decl_buffer():
vulkan_codegen(Module, target)
[email protected]_vulkan(support_required="compile-only")
+def test_codegen_static_shared_memory():
+ """The codegen should accept static shared/workgroup allocations."""
+
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,),
"float32")):
+ A_shared = T.alloc_buffer((128,), dtype="float32", scope="shared")
+
+ for bx in T.thread_binding(1, thread="blockIdx.x"):
+ for tx in T.thread_binding(128, thread="threadIdx.x"):
+ A_shared[tx] = A[tx]
+ B[tx] = A_shared[tx]
+
+ tvm.compile(Module, target="vulkan")
+
+
@tvm.testing.requires_gpu
@tvm.testing.requires_vulkan
def test_unary():