https://github.com/s-perron updated 
https://github.com/llvm/llvm-project/pull/145577

>From 44cb96a30dc3b2b28449661a52ac5a73c63e2139 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenper...@google.com>
Date: Tue, 24 Jun 2025 15:44:10 -0400
Subject: [PATCH] [HLSL][SPIRV] Handle `uint` type for spec constant

The testing only tried `unsigned int` and not `uint`. We want to
correctly handle these surgared types as specialization constants.
---
 clang/lib/Sema/SemaHLSL.cpp                   |  8 ++--
 .../test/AST/HLSL/vk.spec-constant.usage.hlsl | 11 +++++
 .../vk-features/vk.spec-constant.hlsl         | 45 +++++++++++++------
 3 files changed, 47 insertions(+), 17 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d003967a522a1..e36db2ee10c20 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -120,7 +120,7 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
-static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+static Builtin::ID getSpecConstBuiltinId(const Type *Type) {
   const auto *BT = dyn_cast<BuiltinType>(Type);
   if (!BT) {
     if (!Type->isEnumeralType())
@@ -654,7 +654,8 @@ SemaHLSL::mergeVkConstantIdAttr(Decl *D, const 
AttributeCommonInfo &AL,
 
   auto *VD = cast<VarDecl>(D);
 
-  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+  if (getSpecConstBuiltinId(VD->getType()->getUnqualifiedDesugaredType()) ==
+      Builtin::NotBuiltin) {
     Diag(VD->getLocation(), diag::err_specialization_const);
     return nullptr;
   }
@@ -3920,7 +3921,8 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr 
*&Init) {
     return false;
   }
 
-  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+  Builtin::ID BID =
+      getSpecConstBuiltinId(VDecl->getType()->getUnqualifiedDesugaredType());
 
   // Argument 1: The ID from the attribute
   int ConstantID = ConstIdAttr->getId();
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl 
b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
index c0955c1ea7b43..733c4e2ee5a36 100644
--- a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -64,6 +64,17 @@ const unsigned short ushort_const = 10;
 [[vk::constant_id(6)]]
 const unsigned int uint_const = 12;
 
+// CHECK: VarDecl {{.*}} uint_const_2 'const hlsl_private uint':'const 
hlsl_private unsigned int' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'unsigned int'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int (*)(unsigned int, 
unsigned int) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'unsigned int (unsigned int, unsigned int) 
noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_uint' 
'unsigned int (unsigned int, unsigned int) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 6
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 12
+[[vk::constant_id(6)]]
+const uint uint_const_2 = 12;
+
 
 // CHECK: VarDecl {{.*}} ulong_const 'const hlsl_private unsigned long long' 
static cinit
 // CHECK-NEXT: CallExpr {{.*}} 'unsigned long long'
diff --git a/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl 
b/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl
index cbc1fa61eae2b..15c54beb03d38 100644
--- a/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl
+++ b/clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl
@@ -21,6 +21,9 @@ const unsigned short ushort_const = 10;
 [[vk::constant_id(6)]]
 const unsigned int uint_const = 12;
 
+[[vk::constant_id(6)]]
+const uint uint_const_2 = 12;
+
 [[vk::constant_id(7)]]
 const unsigned long long ulong_const = 25;
 
@@ -50,6 +53,7 @@ void main() {
     long long l = long_const;
     unsigned short us = ushort_const;
     unsigned int ui = uint_const;
+    uint ui2 = uint_const_2;
     unsigned long long ul = ulong_const;
     half h = half_const;
     float f = float_const;
@@ -63,6 +67,7 @@ void main() {
 // CHECK: @_ZL10long_const = internal addrspace(10) global i64 0, align 8
 // CHECK: @_ZL12ushort_const = internal addrspace(10) global i16 0, align 2
 // CHECK: @_ZL10uint_const = internal addrspace(10) global i32 0, align 4
+// CHECK: @_ZL12uint_const_2 = internal addrspace(10) global i32 0, align 4
 // CHECK: @_ZL11ulong_const = internal addrspace(10) global i64 0, align 8
 // CHECK: @_ZL10half_const = internal addrspace(10) global float 0.000000e+00, 
align 4
 // CHECK: @_ZL11float_const = internal addrspace(10) global float 
0.000000e+00, align 4
@@ -79,6 +84,7 @@ void main() {
 // CHECK-NEXT:    [[L:%.*]] = alloca i64, align 8
 // CHECK-NEXT:    [[US:%.*]] = alloca i16, align 2
 // CHECK-NEXT:    [[UI:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[UI2:%.*]] = alloca i32, align 4
 // CHECK-NEXT:    [[UL:%.*]] = alloca i64, align 8
 // CHECK-NEXT:    [[H:%.*]] = alloca float, align 4
 // CHECK-NEXT:    [[F:%.*]] = alloca float, align 4
@@ -98,16 +104,18 @@ void main() {
 // CHECK-NEXT:    store i16 [[TMP5]], ptr [[US]], align 2
 // CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(10) @_ZL10uint_const, 
align 4
 // CHECK-NEXT:    store i32 [[TMP6]], ptr [[UI]], align 4
-// CHECK-NEXT:    [[TMP7:%.*]] = load i64, ptr addrspace(10) 
@_ZL11ulong_const, align 8
-// CHECK-NEXT:    store i64 [[TMP7]], ptr [[UL]], align 8
-// CHECK-NEXT:    [[TMP8:%.*]] = load float, ptr addrspace(10) 
@_ZL10half_const, align 4
-// CHECK-NEXT:    store float [[TMP8]], ptr [[H]], align 4
-// CHECK-NEXT:    [[TMP9:%.*]] = load float, ptr addrspace(10) 
@_ZL11float_const, align 4
-// CHECK-NEXT:    store float [[TMP9]], ptr [[F]], align 4
-// CHECK-NEXT:    [[TMP10:%.*]] = load double, ptr addrspace(10) 
@_ZL12double_const, align 8
-// CHECK-NEXT:    store double [[TMP10]], ptr [[D]], align 8
-// CHECK-NEXT:    [[TMP11:%.*]] = load i32, ptr addrspace(10) 
@_ZL10enum_const, align 4
-// CHECK-NEXT:    store i32 [[TMP11]], ptr [[E]], align 4
+// CHECK-NEXT:    [[TMP7:%.*]] = load i32, ptr addrspace(10) 
@_ZL12uint_const_2, align 4
+// CHECK-NEXT:    store i32 [[TMP7]], ptr [[UI2]], align 4
+// CHECK-NEXT:    [[TMP8:%.*]] = load i64, ptr addrspace(10) 
@_ZL11ulong_const, align 8
+// CHECK-NEXT:    store i64 [[TMP8]], ptr [[UL]], align 8
+// CHECK-NEXT:    [[TMP9:%.*]] = load float, ptr addrspace(10) 
@_ZL10half_const, align 4
+// CHECK-NEXT:    store float [[TMP9]], ptr [[H]], align 4
+// CHECK-NEXT:    [[TMP10:%.*]] = load float, ptr addrspace(10) 
@_ZL11float_const, align 4
+// CHECK-NEXT:    store float [[TMP10]], ptr [[F]], align 4
+// CHECK-NEXT:    [[TMP11:%.*]] = load double, ptr addrspace(10) 
@_ZL12double_const, align 8
+// CHECK-NEXT:    store double [[TMP11]], ptr [[D]], align 8
+// CHECK-NEXT:    [[TMP12:%.*]] = load i32, ptr addrspace(10) 
@_ZL10enum_const, align 4
+// CHECK-NEXT:    store i32 [[TMP12]], ptr [[E]], align 4
 // CHECK-NEXT:    ret void
 //
 // CHECK-LABEL: define internal spir_func void @__cxx_global_var_init(
@@ -169,12 +177,21 @@ void main() {
 // CHECK-SAME: ) #[[ATTR3]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[TMP0:%.*]] = call token 
@llvm.experimental.convergence.entry()
+// CHECK-NEXT:    [[TMP1:%.*]] = call i32 @_Z20__spirv_SpecConstantij(i32 6, 
i32 12)
+// CHECK-NEXT:    store i32 [[TMP1]], ptr addrspace(10) @_ZL12uint_const_2, 
align 4
+// CHECK-NEXT:    ret void
+//
+//
+// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.7(
+// CHECK-SAME: ) #[[ATTR3]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[TMP0:%.*]] = call token 
@llvm.experimental.convergence.entry()
 // CHECK-NEXT:    [[TMP1:%.*]] = call i64 @_Z20__spirv_SpecConstantiy(i32 7, 
i64 25)
 // CHECK-NEXT:    store i64 [[TMP1]], ptr addrspace(10) @_ZL11ulong_const, 
align 8
 // CHECK-NEXT:    ret void
 //
 //
-// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.7(
+// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.8(
 // CHECK-SAME: ) #[[ATTR3]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[TMP0:%.*]] = call token 
@llvm.experimental.convergence.entry()
@@ -183,7 +200,7 @@ void main() {
 // CHECK-NEXT:    ret void
 //
 //
-// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.8(
+// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.9(
 // CHECK-SAME: ) #[[ATTR3]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[TMP0:%.*]] = call token 
@llvm.experimental.convergence.entry()
@@ -192,7 +209,7 @@ void main() {
 // CHECK-NEXT:    ret void
 //
 //
-// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.9(
+// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.10(
 // CHECK-SAME: ) #[[ATTR3]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[TMP0:%.*]] = call token 
@llvm.experimental.convergence.entry()
@@ -201,7 +218,7 @@ void main() {
 // CHECK-NEXT:    ret void
 //
 //
-// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.10(
+// CHECK-LABEL: define internal spir_func void @__cxx_global_var_init.11(
 // CHECK-SAME: ) #[[ATTR3]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[TMP0:%.*]] = call token 
@llvm.experimental.convergence.entry()

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to