Author: Helena Kotas
Date: 2025-03-31T10:05:59-07:00
New Revision: dcc2faecd8aebc64eb541aebe0005ecceffef558

URL: 
https://github.com/llvm/llvm-project/commit/dcc2faecd8aebc64eb541aebe0005ecceffef558
DIFF: 
https://github.com/llvm/llvm-project/commit/dcc2faecd8aebc64eb541aebe0005ecceffef558.diff

LOG: [HLSL] Fix codegen to support classes in `cbuffer` (#132828)

Fixes #132309

Added: 
    

Modified: 
    clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
    clang/test/CodeGenHLSL/cbuffer.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp 
b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
index e0f5b0f59ef40..b546b6dd574ff 100644
--- a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
+++ b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
@@ -52,11 +52,11 @@ static unsigned getScalarOrVectorSizeInBytes(llvm::Type 
*Ty) {
 namespace clang {
 namespace CodeGen {
 
-// Creates a layout type for given struct with HLSL constant buffer layout
-// taking into account PackOffsets, if provided.
+// Creates a layout type for given struct or class with HLSL constant buffer
+// layout taking into account PackOffsets, if provided.
 // Previously created layout types are cached by CGHLSLRuntime.
 //
-// The function iterates over all fields of the StructType (including base
+// The function iterates over all fields of the record type (including base
 // classes) and calls layoutField to converts each field to its corresponding
 // LLVM type and to calculate its HLSL constant buffer layout. Any embedded
 // structs (or arrays of structs) are converted to target layout types as well.
@@ -67,12 +67,11 @@ namespace CodeGen {
 // -1 value instead. These elements must be placed at the end of the layout
 // after all of the elements with specific offset.
 llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
-    const RecordType *StructType,
-    const llvm::SmallVector<int32_t> *PackOffsets) {
+    const RecordType *RT, const llvm::SmallVector<int32_t> *PackOffsets) {
 
   // check if we already have the layout type for this struct
   if (llvm::TargetExtType *Ty =
-          CGM.getHLSLRuntime().getHLSLBufferLayoutType(StructType))
+          CGM.getHLSLRuntime().getHLSLBufferLayoutType(RT))
     return Ty;
 
   SmallVector<unsigned> Layout;
@@ -87,7 +86,7 @@ llvm::TargetExtType 
*HLSLBufferLayoutBuilder::createLayoutType(
 
   // iterate over all fields of the record, including fields on base classes
   llvm::SmallVector<const RecordType *> RecordTypes;
-  RecordTypes.push_back(StructType);
+  RecordTypes.push_back(RT);
   while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
     CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
     assert(D->getNumBases() == 1 &&
@@ -148,7 +147,7 @@ llvm::TargetExtType 
*HLSLBufferLayoutBuilder::createLayoutType(
 
   // create the layout struct type; anonymous struct have empty name but
   // non-empty qualified name
-  const CXXRecordDecl *Decl = StructType->getAsCXXRecordDecl();
+  const CXXRecordDecl *Decl = RT->getAsCXXRecordDecl();
   std::string Name =
       Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
   llvm::StructType *StructTy =
@@ -158,7 +157,7 @@ llvm::TargetExtType 
*HLSLBufferLayoutBuilder::createLayoutType(
   llvm::TargetExtType *NewLayoutTy = llvm::TargetExtType::get(
       CGM.getLLVMContext(), LayoutTypeName, {StructTy}, Layout);
   if (NewLayoutTy)
-    CGM.getHLSLRuntime().addHLSLBufferLayoutType(StructType, NewLayoutTy);
+    CGM.getHLSLRuntime().addHLSLBufferLayoutType(RT, NewLayoutTy);
   return NewLayoutTy;
 }
 
@@ -202,9 +201,9 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl 
*FD,
     }
     // For array of structures, create a new array with a layout type
     // instead of the structure type.
-    if (Ty->isStructureType()) {
+    if (Ty->isStructureOrClassType()) {
       llvm::Type *NewTy =
-          
cast<llvm::TargetExtType>(createLayoutType(Ty->getAsStructureType()));
+          cast<llvm::TargetExtType>(createLayoutType(Ty->getAs<RecordType>()));
       if (!NewTy)
         return false;
       assert(isa<llvm::TargetExtType>(NewTy) && "expected target type");
@@ -220,9 +219,10 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl 
*FD,
     ArrayStride = llvm::alignTo(ElemSize, CBufferRowSizeInBytes);
     ElemOffset = (Packoffset != -1) ? Packoffset : NextRowOffset;
 
-  } else if (FieldTy->isStructureType()) {
+  } else if (FieldTy->isStructureOrClassType()) {
     // Create a layout type for the structure
-    ElemLayoutTy = createLayoutType(FieldTy->getAsStructureType());
+    ElemLayoutTy =
+        createLayoutType(cast<RecordType>(FieldTy->getAs<RecordType>()));
     if (!ElemLayoutTy)
       return false;
     assert(isa<llvm::TargetExtType>(ElemLayoutTy) && "expected target type");

diff  --git a/clang/test/CodeGenHLSL/cbuffer.hlsl 
b/clang/test/CodeGenHLSL/cbuffer.hlsl
index 98948ea6811e3..db06cea808b62 100644
--- a/clang/test/CodeGenHLSL/cbuffer.hlsl
+++ b/clang/test/CodeGenHLSL/cbuffer.hlsl
@@ -13,6 +13,12 @@
 // CHECK: %C = type <{ i32, target("dx.Layout", %A, 8, 0) }>
 // CHECK: %__cblayout_D = type <{ [2 x [3 x target("dx.Layout", %B, 14, 0, 
8)]] }>
 
+// CHECK: %__cblayout_CBClasses = type <{ target("dx.Layout", %K, 4, 0), 
target("dx.Layout", %L, 8, 0, 4),
+// CHECK-SAME: target("dx.Layout", %M, 68, 0), [10 x target("dx.Layout", %K, 
4, 0)] }>
+// CHECK: %K = type <{ float }>
+// CHECK: %L = type <{ float, float }>
+// CHECK: %M = type <{ [5 x target("dx.Layout", %K, 4, 0)] }>
+
 // CHECK: %__cblayout_CBMix = type <{ [2 x target("dx.Layout", %Test, 8, 0, 
4)], float, [3 x [2 x <2 x float>]], float,
 // CHECK-SAME: target("dx.Layout", %anon, 4, 0), double, target("dx.Layout", 
%anon.0, 8, 0), float, <1 x double>, i16 }>
 
@@ -133,6 +139,33 @@ cbuffer CBStructs {
   uint16_t3 f;
 };
 
+
+class K {
+  float i;
+};
+
+class L : K {
+  float j;
+};
+
+class M {
+  K array[5];
+};
+
+cbuffer CBClasses {
+  K k;
+  L l;
+  M m;
+  K ka[10];
+};
+
+// CHECK: @CBClasses.cb = global target("dx.CBuffer", target("dx.Layout", 
%__cblayout_CBClasses,
+// CHECK-SAME: 260, 0, 16, 32, 112))
+// CHECK: @k = external addrspace(2) global target("dx.Layout", %K, 4, 0), 
align 4
+// CHECK: @l = external addrspace(2) global target("dx.Layout", %L, 8, 0, 4), 
align 4
+// CHECK: @m = external addrspace(2) global target("dx.Layout", %M, 68, 0), 
align 4
+// CHECK: @ka = external addrspace(2) global [10 x target("dx.Layout", %K, 4, 
0)], align 4
+
 struct Test {
     float a, b;
 };
@@ -237,7 +270,7 @@ RWBuffer<float> Buf;
 
 [numthreads(4,1,1)]
 void main() {
-  Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + B10.z + D1.B2;
+  Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + ka[2].i + B10.z + D1.B2;
 }
 
 // CHECK: define internal void @_GLOBAL__sub_I_cbuffer.hlsl()
@@ -245,8 +278,8 @@ void main() {
 // CHECK-NEXT: call void @_init_resource_CBScalars.cb()
 // CHECK-NEXT: call void @_init_resource_CBArrays.cb()
 
-// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], 
![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBMIX:[0-9]+]],
-// CHECK-SAME: ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}
+// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], 
![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBCLASSES:[0-9]+]],
+// CHECK-SAME: ![[CBMIX:[0-9]+]], ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], 
![[CB_C:[0-9]+]]}
 
 // CHECK: ![[CBSCALARS]] = !{ptr @CBScalars.cb, ptr addrspace(2) @a1, ptr 
addrspace(2) @a2, ptr addrspace(2) @a3, ptr addrspace(2) @a4,
 // CHECK-SAME: ptr addrspace(2) @a5, ptr addrspace(2) @a6, ptr addrspace(2) 
@a7, ptr addrspace(2) @a8}
@@ -260,6 +293,8 @@ void main() {
 // CHECK: ![[CBSTRUCTS]] = !{ptr @CBStructs.cb, ptr addrspace(2) @a, ptr 
addrspace(2) @b, ptr addrspace(2) @c, ptr addrspace(2) @array_of_A,
 // CHECK-SAME: ptr addrspace(2) @d, ptr addrspace(2) @e, ptr addrspace(2) @f}
 
+// CHECK: ![[CBCLASSES]] = !{ptr @CBClasses.cb, ptr addrspace(2) @k, ptr 
addrspace(2) @l, ptr addrspace(2) @m, ptr addrspace(2) @ka}
+
 // CHECK: ![[CBMIX]] = !{ptr @CBMix.cb, ptr addrspace(2) @test, ptr 
addrspace(2) @f1, ptr addrspace(2) @f2, ptr addrspace(2) @f3,
 // CHECK-SAME: ptr addrspace(2) @f4, ptr addrspace(2) @f5, ptr addrspace(2) 
@f6, ptr addrspace(2) @f7, ptr addrspace(2) @f8, ptr addrspace(2) @f9}
 


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to