https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/145577
>From 44cb96a30dc3b2b28449661a52ac5a73c63e2139 Mon Sep 17 00:00:00 2001 From: Steven Perron <stevenper...@google.com> Date: Tue, 24 Jun 2025 15:44:10 -0400 Subject: [PATCH] [HLSL][SPIRV] Handle `uint` type for spec constant The testing only tried `unsigned int` and not `uint`. We want to correctly handle these surgared types as specialization constants. --- clang/lib/Sema/SemaHLSL.cpp | 8 ++-- .../test/AST/HLSL/vk.spec-constant.usage.hlsl | 11 +++++ .../vk-features/vk.spec-constant.hlsl | 45 +++++++++++++------ 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index d003967a522a1..e36db2ee10c20 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -120,7 +120,7 @@ static ResourceClass getResourceClass(RegisterType RT) { llvm_unreachable("unexpected RegisterType value"); } -static Builtin::ID getSpecConstBuiltinId(QualType Type) { +static Builtin::ID getSpecConstBuiltinId(const Type *Type) { const auto *BT = dyn_cast<BuiltinType>(Type); if (!BT) { if (!Type->isEnumeralType()) @@ -654,7 +654,8 @@ SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, auto *VD = cast<VarDecl>(D); - if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) { + if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) == + Builtin::NotBuiltin) { Diag(VD->getLocation(), diag::err_specialization_const); return nullptr; } @@ -3920,7 +3921,8 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) { return false; } - Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType()); + Builtin::ID BID = + getSpecConstBuiltinId(VDecl->getType()->getUnqualifiedDesugaredType()); // Argument 1: The ID from the attribute int ConstantID = ConstIdAttr->getId(); diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl index c0955c1ea7b43..733c4e2ee5a36 100644 --- a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl +++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl @@ -64,6 +64,17 @@ const unsigned short ushort_const = 10; [[vk::constant_id(6)]] const unsigned int uint_const = 12; +// CHECK: VarDecl {{.*}} uint_const_2 'const hlsl_private uint':'const hlsl_private unsigned int' static cinit +// CHECK-NEXT: CallExpr {{.*}} 'unsigned int' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int (*)(unsigned int, unsigned int) noexcept' <FunctionToPointerDecay> +// CHECK-NEXT: DeclRefExpr {{.*}} 'unsigned int (unsigned int, unsigned int) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_uint' 'unsigned int (unsigned int, unsigned int) noexcept' +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast> +// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 6 +// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast> +// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 12 +[[vk::constant_id(6)]] +const uint uint_const_2 = 12; + // CHECK: VarDecl {{.*}} ulong_const 'const hlsl_private unsigned long long' static cinit // CHECK-NEXT: CallExpr {{.*}} 'unsigned long long' diff --git a/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl b/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl index cbc1fa61eae2b..15c54beb03d38 100644 --- a/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl +++ b/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl @@ -21,6 +21,9 @@ const unsigned short ushort_const = 10; [[vk::constant_id(6)]] const unsigned int uint_const = 12; +[[vk::constant_id(6)]] +const uint uint_const_2 = 12; + [[vk::constant_id(7)]] const unsigned long long ulong_const = 25; @@ -50,6 +53,7 @@ void main() { long long l = long_const; unsigned short us = ushort_const; unsigned int ui = uint_const; + uint ui2 = uint_const_2; unsigned long long ul = ulong_const; half h = half_const; float f = float_const; @@ -63,6 +67,7 @@ void main() { // CHECK: @_ZL10long_const = internal addrspace(10) global i64 0, align 8 // CHECK: @_ZL12ushort_const = internal addrspace(10) global i16 0, align 2 // CHECK: @_ZL10uint_const = internal addrspace(10) global i32 0, align 4 +// CHECK: @_ZL12uint_const_2 = internal addrspace(10) global i32 0, align 4 // CHECK: @_ZL11ulong_const = internal addrspace(10) global i64 0, align 8 // CHECK: @_ZL10half_const = internal addrspace(10) global float 0.000000e+00, align 4 // CHECK: @_ZL11float_const = internal addrspace(10) global float 0.000000e+00, align 4 @@ -79,6 +84,7 @@ void main() { // CHECK-NEXT: [[L:%.*]] = alloca i64, align 8 // CHECK-NEXT: [[US:%.*]] = alloca i16, align 2 // CHECK-NEXT: [[UI:%.*]] = alloca i32, align 4 +// CHECK-NEXT: [[UI2:%.*]] = alloca i32, align 4 // CHECK-NEXT: [[UL:%.*]] = alloca i64, align 8 // CHECK-NEXT: [[H:%.*]] = alloca float, align 4 // CHECK-NEXT: [[F:%.*]] = alloca float, align 4 @@ -98,16 +104,18 @@ void main() { // CHECK-NEXT: store i16 [[TMP5]], ptr [[US]], align 2 // CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(10) @_ZL10uint_const, align 4 // CHECK-NEXT: store i32 [[TMP6]], ptr [[UI]], align 4 -// CHECK-NEXT: [[TMP7:%.*]] = load i64, ptr addrspace(10) @_ZL11ulong_const, align 8 -// CHECK-NEXT: store i64 [[TMP7]], ptr [[UL]], align 8 -// CHECK-NEXT: [[TMP8:%.*]] = load float, ptr addrspace(10) @_ZL10half_const, align 4 -// CHECK-NEXT: store float [[TMP8]], ptr [[H]], align 4 -// CHECK-NEXT: [[TMP9:%.*]] = load float, ptr addrspace(10) @_ZL11float_const, align 4 -// CHECK-NEXT: store float [[TMP9]], ptr [[F]], align 4 -// CHECK-NEXT: [[TMP10:%.*]] = load double, ptr addrspace(10) @_ZL12double_const, align 8 -// CHECK-NEXT: store double [[TMP10]], ptr [[D]], align 8 -// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr addrspace(10) @_ZL10enum_const, align 4 -// CHECK-NEXT: store i32 [[TMP11]], ptr [[E]], align 4 +// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr addrspace(10) @_ZL12uint_const_2, align 4 +// CHECK-NEXT: store i32 [[TMP7]], ptr [[UI2]], align 4 +// CHECK-NEXT: [[TMP8:%.*]] = load i64, ptr addrspace(10) @_ZL11ulong_const, align 8 +// CHECK-NEXT: store i64 [[TMP8]], ptr [[UL]], align 8 +// CHECK-NEXT: [[TMP9:%.*]] = load float, ptr addrspace(10) @_ZL10half_const, align 4 +// CHECK-NEXT: store float [[TMP9]], ptr [[H]], align 4 +// CHECK-NEXT: [[TMP10:%.*]] = load float, ptr addrspace(10) @_ZL11float_const, align 4 +// CHECK-NEXT: store float [[TMP10]], ptr [[F]], align 4 +// CHECK-NEXT: [[TMP11:%.*]] = load double, ptr addrspace(10) @_ZL12double_const, align 8 +// CHECK-NEXT: store double [[TMP11]], ptr [[D]], align 8 +// CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(10) @_ZL10enum_const, align 4 +// CHECK-NEXT: store i32 [[TMP12]], ptr [[E]], align 4 // CHECK-NEXT: ret void // // CHECK-LABEL: define internal spir_func void @__cxx_global_var_init( @@ -169,12 +177,21 @@ void main() { // CHECK-SAME: ) #[[ATTR3]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry() +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @_Z20__spirv_SpecConstantij(i32 6, i32 12) +// CHECK-NEXT: store i32 [[TMP1]], ptr addrspace(10) @_ZL12uint_const_2, align 4 +// CHECK-NEXT: ret void +// +// +// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.7( +// CHECK-SAME: ) #[[ATTR3]] { +// CHECK-NEXT: [[ENTRY:.*:]] +// CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry() // CHECK-NEXT: [[TMP1:%.*]] = call i64 @_Z20__spirv_SpecConstantiy(i32 7, i64 25) // CHECK-NEXT: store i64 [[TMP1]], ptr addrspace(10) @_ZL11ulong_const, align 8 // CHECK-NEXT: ret void // // -// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.7( +// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.8( // CHECK-SAME: ) #[[ATTR3]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry() @@ -183,7 +200,7 @@ void main() { // CHECK-NEXT: ret void // // -// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.8( +// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.9( // CHECK-SAME: ) #[[ATTR3]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry() @@ -192,7 +209,7 @@ void main() { // CHECK-NEXT: ret void // // -// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.9( +// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.10( // CHECK-SAME: ) #[[ATTR3]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry() @@ -201,7 +218,7 @@ void main() { // CHECK-NEXT: ret void // // -// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.10( +// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.11( // CHECK-SAME: ) #[[ATTR3]] { // CHECK-NEXT: [[ENTRY:.*:]] // CHECK-NEXT: [[TMP0:%.*]] = call token @llvm.experimental.convergence.entry() _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits