https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/111203
>From f545a14e11556c91d10b14617e3588fe5eae6d42 Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Fri, 4 Oct 2024 12:21:51 -0700 Subject: [PATCH 1/7] [HLSL] Collect explicit resource binding information (part 1) - Do not create resource binding attribute if it is not valid - Store basic resource binding information on HLSLResourceBindingAttr - Move UDT type checking to to ActOnVariableDeclarator Part 1 of #110719 --- clang/include/clang/Basic/Attr.td | 29 +++ clang/include/clang/Sema/SemaHLSL.h | 2 + clang/lib/Sema/SemaDecl.cpp | 3 + clang/lib/Sema/SemaHLSL.cpp | 227 ++++++++++++------ .../resource_binding_attr_error_udt.hlsl | 8 +- 5 files changed, 188 insertions(+), 81 deletions(-) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index fbcbf0ed416416..668c599da81390 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr { let LangOpts = [HLSL]; let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>]; let Documentation = [HLSLResourceBindingDocs]; + let AdditionalMembers = [{ + enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; + + const FieldDecl *ResourceField = nullptr; + RegisterType RegType; + unsigned SlotNumber; + unsigned SpaceNumber; + + void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) { + RegType = RT; + SlotNumber = SlotNum; + SpaceNumber = SpaceNum; + } + void setResourceField(const FieldDecl *FD) { + ResourceField = FD; + } + const FieldDecl *getResourceField() { + return ResourceField; + } + RegisterType getRegisterType() { + return RegType; + } + unsigned getSlotNumber() { + return SlotNumber; + } + unsigned getSpaceNumber() { + return SpaceNumber; + } + }]; } def HLSLPackOffset: HLSLAnnotationAttr { diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index fa957abc9791af..018e7ea5901a2b 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -28,6 +28,7 @@ class AttributeCommonInfo; class IdentifierInfo; class ParsedAttr; class Scope; +class VarDecl; // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no // longer need to create builtin buffer types in HLSLExternalSemaSource. @@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase { const Attr *A, llvm::Triple::EnvironmentType Stage, std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages); void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); + void ProcessResourceBindingOnDecl(VarDecl *D); QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 2bf610746bc317..8e27a5e068e702 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator( // Handle attributes prior to checking for duplicates in MergeVarDecl ProcessDeclAttributes(S, NewVD, D); + if (getLangOpts().HLSL) + HLSL().ProcessResourceBindingOnDecl(NewVD); + // FIXME: This is probably the wrong location to be doing this and we should // probably be doing this for more attributes (especially for function // pointer attributes such as format, warn_unused_result, etc.). Ideally diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index fbcba201a351a6..568a8de30c1fc5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -41,9 +41,7 @@ using namespace clang; using llvm::dxil::ResourceClass; - -enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; - +using RegisterType = HLSLResourceBindingAttr::RegisterType; static RegisterType getRegisterType(ResourceClass RC) { switch (RC) { case ResourceClass::SRV: @@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -// get the record decl from a var decl that we expect -// represents a resource -static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { - const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); - assert(Ty && "Resource must have an element type."); - - if (Ty->isBuiltinType()) - return nullptr; - - CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); - assert(TheRecordDecl && "Resource should have a resource type declaration."); - return TheRecordDecl; -} - +// Returns handle type of a resource, if the VarDecl is a resource +// or an array of resources static const HLSLAttributedResourceType * -findAttributedResourceTypeOnField(VarDecl *VD) { +FindHandleTypeOnResource(const VarDecl *VD) { + // If VarDecl is a resource class, the first field must + // be the resource handle of type HLSLAttributedResourceType assert(VD != nullptr && "expected VarDecl"); - if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) { - for (auto *FD : RD->fields()) { - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr())) - return AttrResType; + const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); + if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { + if (!RD->fields().empty()) { + const auto &FirstFD = RD->fields().begin(); + return dyn_cast<HLSLAttributedResourceType>( + FirstFD->getType().getTypePtr()); } } return nullptr; } -// Iterate over RecordType fields and return true if any of them matched the -// register type -static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, - RegisterType RegType) { +// Walks though the user defined record type, finds resource class +// that matches the RegisterBinding.Type and assigns it to +// RegisterBinding::Decl. +static bool +ProcessResourceBindingOnUserRecordDecl(const RecordType *RT, + HLSLResourceBindingAttr *RBA) { + llvm::SmallVector<const Type *> TypesToScan; TypesToScan.emplace_back(RT); + RegisterType RegType = RBA->getRegisterType(); while (!TypesToScan.empty()) { const Type *T = TypesToScan.pop_back_val(); - while (T->isArrayType()) + + while (T->isArrayType()) { + // FIXME: calculate the binding size from the array dimensions (or + // unbounded for unsized array) size *= (size_of_array); T = T->getArrayElementTypeNoTypeQual(); + } if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { if (RegType == RegisterType::C) return true; @@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, if (const HLSLAttributedResourceType *AttrResType = dyn_cast<HLSLAttributedResourceType>(FieldTy)) { ResourceClass RC = AttrResType->getAttrs().ResourceClass; - if (getRegisterType(RC) == RegType) + if (getRegisterType(RC) == RegType) { + assert(RBA->getResourceField() == nullptr && + "multiple register bindings of the same type are not allowed"); + RBA->setResourceField(FD); return true; + } } else { TypesToScan.emplace_back(FD->getType().getTypePtr()); } @@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, return false; } -static void CheckContainsResourceForRegisterType(Sema &S, - SourceLocation &ArgLoc, - Decl *D, RegisterType RegType, - bool SpecifiedSpace) { +// return false if the register binding is not valid +static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { int RegTypeNum = static_cast<int>(RegType); // check if the decl type is groupshared if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - return; + return false; } // Cbuffers and Tbuffers are HLSLBufferDecl types if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer : ResourceClass::SRV; - if (RegType != getRegisterType(RC)) - S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - return; + if (RegType == getRegisterType(RC)) + return true; + + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return false; } // Samplers, UAVs, and SRVs are VarDecl types @@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S, // Resource if (const HLSLAttributedResourceType *AttrResType = - findAttributedResourceTypeOnField(VD)) { - if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass)) - S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - return; + FindHandleTypeOnResource(VD)) { + if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass)) + return true; + + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return false; } const clang::Type *Ty = VD->getType().getTypePtr(); @@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S, // Basic types if (Ty->isArithmeticType()) { + bool IsValid = true; bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); - if (SpecifiedSpace && !DeclaredInCOrTBuffer) + if (SpecifiedSpace && !DeclaredInCOrTBuffer) { S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); + IsValid = false; + } if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { // Default Globals if (RegType == RegisterType::CBuffer) S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); - else if (RegType != RegisterType::C) + else if (RegType != RegisterType::C) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + IsValid = false; + } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); - else + else { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + IsValid = false; + } } - } else if (Ty->isRecordType()) { - // Class/struct types - walk the declaration and check each field and - // subclass - if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType)) - S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member) - << RegTypeNum; - } else { - // Anything else is an error - S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return IsValid; } + if (Ty->isRecordType()) + // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl + // that is called from ActOnVariableDeclarator + return true; + + // Anything else is an error + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return false; } -static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, +static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType) { // make sure that there are no two register annotations // applied to the decl with the same register type @@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType otherRegType = getRegisterType(attr->getSlot()); if (RegisterTypesDetected[static_cast<int>(otherRegType)]) { - if (PreviousConflicts[TheDecl].count(otherRegType)) - continue; int otherRegTypeNum = static_cast<int>(otherRegType); S.Diag(TheDecl->getLocation(), diag::err_hlsl_duplicate_register_annotation) << otherRegTypeNum; - PreviousConflicts[TheDecl].insert(otherRegType); - } else { - RegisterTypesDetected[static_cast<int>(otherRegType)] = true; + return false; } + RegisterTypesDetected[static_cast<int>(otherRegType)] = true; } } + return true; } -static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, +static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace) { @@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, "expecting VarDecl or HLSLBufferDecl"); // check if the declaration contains resource matching the register type - CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace); + if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace)) + return false; // next, if multiple register annotations exist, check that none conflict. - ValidateMultipleRegisterAnnotations(S, D, RegType); + return ValidateMultipleRegisterAnnotations(S, D, RegType); } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { @@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { Slot = Str; } - RegisterType regType; + RegisterType RegType; + unsigned SlotNum = 0; + unsigned SpaceNum = 0; // Validate. if (!Slot.empty()) { - regType = getRegisterType(Slot); - if (regType == RegisterType::I) { + RegType = getRegisterType(Slot); + if (RegType == RegisterType::I) { Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); return; } - if (regType == RegisterType::Invalid) { + if (RegType == RegisterType::Invalid) { Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); return; } - StringRef SlotNum = Slot.substr(1); - unsigned Num = 0; - if (SlotNum.getAsInteger(10, Num)) { + StringRef SlotNumStr = Slot.substr(1); + if (SlotNumStr.getAsInteger(10, SlotNum)) { Diag(ArgLoc, diag::err_hlsl_unsupported_register_number); return; } @@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; return; } - StringRef SpaceNum = Space.substr(5); - unsigned Num = 0; - if (SpaceNum.getAsInteger(10, Num)) { + StringRef SpaceNumStr = Space.substr(5); + if (SpaceNumStr.getAsInteger(10, SpaceNum)) { Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; return; } - DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType, - SpecifiedSpace); + if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType, + SpecifiedSpace)) + return; HLSLResourceBindingAttr *NewAttr = HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL); - if (NewAttr) + if (NewAttr) { + NewAttr->setBinding(RegType, SlotNum, SpaceNum); TheDecl->addAttr(NewAttr); + } } void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) { @@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { Ty.addRestrict(); return Ty; } + +// Walks though existing explicit bindings, finds the actual resource class +// decl the binding applies to and sets it to attr->ResourceField. +// Additional processing of resource binding can be added here later on, +// such as preparation for overapping resource detection or implicit binding. +void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { + if (!D->hasGlobalStorage()) + return; + + for (Attr *A : D->attrs()) { + HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); + if (!RBA) + continue; + + // // Cbuffers and Tbuffers are HLSLBufferDecl types + if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { + assert(RBA->getRegisterType() == + getRegisterType(CBufferOrTBuffer->isCBuffer() + ? ResourceClass::CBuffer + : ResourceClass::SRV) && + "this should have been handled in DiagnoseLocalRegisterBinding"); + // should we handle HLSLBufferDecl here? + continue; + } + + // Samplers, UAVs, and SRVs are VarDecl types + assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); + const VarDecl *VD = cast<VarDecl>(D); + + // Register binding directly on global resource class variable + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(VD)) { + // FIXME: if array, calculate the binding size from the array dimensions + // (or unbounded for unsized array) + assert(RBA->getResourceField() == nullptr); + continue; + } + + // Global array + const clang::Type *Ty = VD->getType().getTypePtr(); + while (Ty->isArrayType()) { + Ty = Ty->getArrayElementTypeNoTypeQual(); + } + + // Basic types + if (Ty->isArithmeticType()) { + continue; + } + + if (Ty->isRecordType()) { + if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(), + RBA)) { + SemaRef.Diag(D->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RBA->getRegisterType()); + } + } + } +} diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl index ea2d576e4cca55..40517f393e1284 100644 --- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl +++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl @@ -106,7 +106,6 @@ struct Eg12{ MySRV s1; MySRV s2; }; -// expected-warning@+3{{binding type 'u' only applies to types containing UAV resources}} // expected-warning@+2{{binding type 'u' only applies to types containing UAV resources}} // expected-error@+1{{binding type 'u' cannot be applied more than once}} Eg12 e12 : register(u9) : register(u10); @@ -115,12 +114,14 @@ struct Eg13{ MySRV s1; MySRV s2; }; -// expected-warning@+4{{binding type 'u' only applies to types containing UAV resources}} // expected-warning@+3{{binding type 'u' only applies to types containing UAV resources}} -// expected-warning@+2{{binding type 'u' only applies to types containing UAV resources}} +// expected-error@+2{{binding type 'u' cannot be applied more than once}} // expected-error@+1{{binding type 'u' cannot be applied more than once}} Eg13 e13 : register(u9) : register(u10) : register(u11); +// expected-error@+1{{binding type 't' cannot be applied more than once}} +Eg13 e13_2 : register(t11) : register(t12); + struct Eg14{ MyTemplatedUAV<int> r1; }; @@ -132,4 +133,3 @@ struct Eg15 { }; // expected no error Eg15 e15 : register(c0); - >From a6c06943ce5df79e6765e12874c96c907b20d030 Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Fri, 4 Oct 2024 13:52:47 -0700 Subject: [PATCH 2/7] clang-format --- clang/lib/Sema/SemaHLSL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 568a8de30c1fc5..5c27a74a853bba 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2250,7 +2250,7 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { if (!D->hasGlobalStorage()) return; - + for (Attr *A : D->attrs()) { HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); if (!RBA) >From a6a52327bef4325a00a2b8a1715b8b5b1315994f Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Wed, 9 Oct 2024 16:34:06 -0700 Subject: [PATCH 3/7] Collect all resource binding requirements and analyze explicit bindings based on that Also adds bindings size calculation and removed ResourceDecl field from HLSLResourceBindingAttr. --- clang/include/clang/Basic/Attr.td | 25 ++- clang/include/clang/Sema/SemaHLSL.h | 59 +++++- clang/lib/Sema/SemaDecl.cpp | 2 +- clang/lib/Sema/SemaHLSL.cpp | 276 ++++++++++++++++++---------- 4 files changed, 256 insertions(+), 106 deletions(-) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 668c599da81390..3997ffe78fbf96 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4591,22 +4591,20 @@ def HLSLResourceBinding: InheritableAttr { let AdditionalMembers = [{ enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; - const FieldDecl *ResourceField = nullptr; RegisterType RegType; unsigned SlotNumber; unsigned SpaceNumber; + + // Size of the binding + // 0 == not set + //-1 == unbounded + int Size; - void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) { + void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) { RegType = RT; SlotNumber = SlotNum; SpaceNumber = SpaceNum; } - void setResourceField(const FieldDecl *FD) { - ResourceField = FD; - } - const FieldDecl *getResourceField() { - return ResourceField; - } RegisterType getRegisterType() { return RegType; } @@ -4616,6 +4614,17 @@ def HLSLResourceBinding: InheritableAttr { unsigned getSpaceNumber() { return SpaceNumber; } + unsigned getSize() { + assert(Size == -1 || Size > 0 && "size not set"); + return Size; + } + void setSize(int N) { + assert(N == -1 || N > 0 && "unexpected size value"); + Size = N; + } + bool isSizeUnbounded() { + return Size == -1; + } }]; } diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 018e7ea5901a2b..ce262fd41dff37 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -30,12 +30,60 @@ class ParsedAttr; class Scope; class VarDecl; +using llvm::dxil::ResourceClass; + // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no // longer need to create builtin buffer types in HLSLExternalSemaSource. bool CreateHLSLAttributedResourceType( Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo = nullptr); +enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit }; + +// DeclBindingInfo struct stores information about required/assigned resource +// binding onon a declaration for specific resource class. +struct DeclBindingInfo { + const VarDecl *Decl; + ResourceClass ResClass; + int Size; // -1 == unbounded array + const HLSLResourceBindingAttr *Attr; + BindingType BindType; + + DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0, + BindingType BindType = BindingType::NotAssigned, + const HLSLResourceBindingAttr *Attr = nullptr) + : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr), + BindType(BindType) {} + + void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) { + assert(Attr == nullptr && BindType == BindingType::NotAssigned && + "binding attribute already assigned"); + Attr = A; + BindType = BT; + } +}; + +// ResourceBindings class stores information about all resource bindings +// in a shader. It is used for binding diagnostics and implicit binding +// assigments. +class ResourceBindings { +public: + DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass, + int Size); + DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass); + bool hasBindingInfoForDecl(const VarDecl *VD); + +private: + // List of all resource bindings required by the shader. + // A global declaration can have multiple bindings for different + // resource classes. They are all stored sequentially in this list. + // The DeclToBindingListIndex hashtable maps a declaration to the + // index of the first binding info in the list. + llvm::SmallVector<DeclBindingInfo> BindingsList; + llvm::DenseMap<const VarDecl *, unsigned> DeclToBindingListIndex; +}; + class SemaHLSL : public SemaBase { public: SemaHLSL(Sema &S); @@ -56,6 +104,7 @@ class SemaHLSL : public SemaBase { mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling); void ActOnTopLevelFunction(FunctionDecl *FD); + void ActOnVariableDeclarator(VarDecl *VD); void CheckEntryPoint(FunctionDecl *FD); void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param, const HLSLAnnotationAttr *AnnotationAttr); @@ -63,7 +112,6 @@ class SemaHLSL : public SemaBase { const Attr *A, llvm::Triple::EnvironmentType Stage, std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages); void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); - void ProcessResourceBindingOnDecl(VarDecl *D); QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, @@ -104,6 +152,15 @@ class SemaHLSL : public SemaBase { llvm::DenseMap<const HLSLAttributedResourceType *, HLSLAttributedResourceLocInfo> LocsForHLSLAttributedResources; + + // List of all resource bindings + ResourceBindings Bindings; + +private: + void FindResourcesOnVarDecl(VarDecl *D); + void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT, + int Size); + void ProcessExplicitBindingsOnDecl(VarDecl *D); }; } // namespace clang diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 8e27a5e068e702..770d00710a6816 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -7877,7 +7877,7 @@ NamedDecl *Sema::ActOnVariableDeclarator( ProcessDeclAttributes(S, NewVD, D); if (getLangOpts().HLSL) - HLSL().ProcessResourceBindingOnDecl(NewVD); + HLSL().ActOnVariableDeclarator(NewVD); // FIXME: This is probably the wrong location to be doing this and we should // probably be doing this for more attributes (especially for function diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 5c27a74a853bba..197ee63c07deeb 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -40,8 +40,8 @@ #include <utility> using namespace clang; -using llvm::dxil::ResourceClass; using RegisterType = HLSLResourceBindingAttr::RegisterType; + static RegisterType getRegisterType(ResourceClass RC) { switch (RC) { case ResourceClass::SRV: @@ -81,6 +81,49 @@ static RegisterType getRegisterType(StringRef Slot) { } } +static ResourceClass getResourceClass(RegisterType RT) { + switch (RT) { + case RegisterType::SRV: + return ResourceClass::SRV; + case RegisterType::UAV: + return ResourceClass::UAV; + case RegisterType::CBuffer: + return ResourceClass::CBuffer; + case RegisterType::Sampler: + return ResourceClass::Sampler; + default: + llvm_unreachable("unexpected RegisterType value"); + } +} + +DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass, + int Size) { + assert(getDeclBindingInfo(VD, ResClass) == nullptr && + "DeclBindingInfo already added"); + if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end()) + DeclToBindingListIndex[VD] = BindingsList.size(); + return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size)); +} + +DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass) { + auto Entry = DeclToBindingListIndex.find(VD); + if (Entry != DeclToBindingListIndex.end()) { + unsigned Index = Entry->getSecond(); + while (Index < BindingsList.size() && BindingsList[Index].Decl == VD) { + if (BindingsList[Index].ResClass == ResClass) + return &BindingsList[Index]; + Index++; + } + } + return nullptr; +} + +bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) { + return DeclToBindingListIndex.contains(VD); +} + SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, @@ -983,14 +1026,11 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -// Returns handle type of a resource, if the VarDecl is a resource -// or an array of resources +// Returns handle type of a resource, if the type is a resource static const HLSLAttributedResourceType * -FindHandleTypeOnResource(const VarDecl *VD) { - // If VarDecl is a resource class, the first field must +FindHandleTypeOnResource(const Type *Ty) { + // If Ty is a resource class, the first field must // be the resource handle of type HLSLAttributedResourceType - assert(VD != nullptr && "expected VarDecl"); - const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { if (!RD->fields().empty()) { const auto &FirstFD = RD->fields().begin(); @@ -1001,51 +1041,53 @@ FindHandleTypeOnResource(const VarDecl *VD) { return nullptr; } -// Walks though the user defined record type, finds resource class -// that matches the RegisterBinding.Type and assigns it to -// RegisterBinding::Decl. -static bool -ProcessResourceBindingOnUserRecordDecl(const RecordType *RT, - HLSLResourceBindingAttr *RBA) { - - llvm::SmallVector<const Type *> TypesToScan; - TypesToScan.emplace_back(RT); - RegisterType RegType = RBA->getRegisterType(); - - while (!TypesToScan.empty()) { - const Type *T = TypesToScan.pop_back_val(); - - while (T->isArrayType()) { - // FIXME: calculate the binding size from the array dimensions (or - // unbounded for unsized array) size *= (size_of_array); - T = T->getArrayElementTypeNoTypeQual(); - } - if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { - if (RegType == RegisterType::C) - return true; +// Returns handle type of a resource, if the VarDecl is a resource +static const HLSLAttributedResourceType * +FindHandleTypeOnResource(const VarDecl *VD) { + assert(VD != nullptr && "expected VarDecl"); + return FindHandleTypeOnResource(VD->getType().getTypePtr()); +} + +// Walks though the global variable declaration, collects all resource binding +// requirements and adds them to Bindings +void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD, + const RecordType *RT, int Size) { + const RecordDecl *RD = RT->getDecl(); + for (FieldDecl *FD : RD->fields()) { + const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); + + // Calculate array size and unwrap + int ArraySize = 1; + assert(!Ty->isIncompleteArrayType() && + "incomplete arrays inside user defined types are not supported"); + while (Ty->isConstantArrayType()) { + const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); + ArraySize *= CAT->getSize().getSExtValue(); + Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); } - const RecordType *RT = T->getAs<RecordType>(); - if (!RT) + + if (!Ty->isRecordType()) continue; - const RecordDecl *RD = RT->getDecl(); - for (FieldDecl *FD : RD->fields()) { - const Type *FieldTy = FD->getType().getTypePtr(); - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FieldTy)) { - ResourceClass RC = AttrResType->getAttrs().ResourceClass; - if (getRegisterType(RC) == RegType) { - assert(RBA->getResourceField() == nullptr && - "multiple register bindings of the same type are not allowed"); - RBA->setResourceField(FD); - return true; - } - } else { - TypesToScan.emplace_back(FD->getType().getTypePtr()); - } + // Field is a resource or array of resources + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(Ty)) { + ResourceClass RC = AttrResType->getAttrs().ResourceClass; + + // Add a new DeclBindingInfo to Bindings. Update the binding size if + // a binding info already exists (there are multiple resources of same + // resource class in this user decl) + if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC)) + DBI->Size += Size * ArraySize; + else + Bindings.addDeclBindingInfo(VD, RC, Size); + } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) { + // Recursively scan embedded struct or class; it would be nice to do this + // without recursion, but tricky to corrently calculate the size. + // Hopefully nesting of structs in structs too many levels is unlikely. + FindResourcesOnUserRecordDecl(VD, RT, Size); } } - return false; } // return false if the register binding is not valid @@ -1093,11 +1135,9 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, // Basic types if (Ty->isArithmeticType()) { - bool IsValid = true; bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); if (SpecifiedSpace && !DeclaredInCOrTBuffer) { S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); - IsValid = false; } if (!DeclaredInCOrTBuffer && @@ -1107,17 +1147,15 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); else if (RegType != RegisterType::C) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - IsValid = false; } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); else { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - IsValid = false; } } - return IsValid; + return false; } if (Ty->isRecordType()) // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl @@ -2057,6 +2095,7 @@ bool SemaHLSL::IsIntangibleType(clang::QualType QT) { CXXRecordDecl *RD = RT->getAsCXXRecordDecl(); assert(RD != nullptr && "all HLSL struct and classes should be CXXRecordDecl"); + assert(RD->isCompleteDefinition() && "expecting complete type"); return RD->isHLSLIntangible(); } @@ -2243,61 +2282,106 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { return Ty; } -// Walks though existing explicit bindings, finds the actual resource class -// decl the binding applies to and sets it to attr->ResourceField. -// Additional processing of resource binding can be added here later on, -// such as preparation for overapping resource detection or implicit binding. -void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { - if (!D->hasGlobalStorage()) +void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { + if (VD->hasGlobalStorage()) { + // make sure the declaration has a complete type + if (SemaRef.RequireCompleteType( + VD->getLocation(), + SemaRef.getASTContext().getBaseElementType(VD->getType()), + diag::err_typecheck_decl_incomplete_type)) { + VD->setInvalidDecl(); + return; + } + + // find all resources on decl + if (IsIntangibleType(VD->getType())) + FindResourcesOnVarDecl(VD); + + // process explicit bindings + ProcessExplicitBindingsOnDecl(VD); + } +} + +// Walks though the global variable declaration, collects all resource binding +// requirements and adds them to Bindings +void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) { + assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) && + "expected global variable that contains HLSL resource"); + + // Cbuffers and Tbuffers are HLSLBufferDecl types + if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) { + Bindings.addDeclBindingInfo(VD, + CBufferOrTBuffer->isCBuffer() + ? ResourceClass::CBuffer + : ResourceClass::SRV, + 1); return; + } - for (Attr *A : D->attrs()) { - HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); - if (!RBA) - continue; + // Calculate size of array and unwrap + int Size = 1; + const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); + if (Ty->isIncompleteArrayType()) + Size = -1; + while (Ty->isConstantArrayType()) { + const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); + Size *= CAT->getSize().getSExtValue(); + Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); + } - // // Cbuffers and Tbuffers are HLSLBufferDecl types - if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { - assert(RBA->getRegisterType() == - getRegisterType(CBufferOrTBuffer->isCBuffer() - ? ResourceClass::CBuffer - : ResourceClass::SRV) && - "this should have been handled in DiagnoseLocalRegisterBinding"); - // should we handle HLSLBufferDecl here? - continue; - } + // Resource (or array of resources) + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(Ty)) { + Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass, + Size); + return; + } - // Samplers, UAVs, and SRVs are VarDecl types - assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); - const VarDecl *VD = cast<VarDecl>(D); + assert(Size != -1 && + "unbounded arrays of user defined types are not supported"); - // Register binding directly on global resource class variable - if (const HLSLAttributedResourceType *AttrResType = - FindHandleTypeOnResource(VD)) { - // FIXME: if array, calculate the binding size from the array dimensions - // (or unbounded for unsized array) - assert(RBA->getResourceField() == nullptr); + // User defined record type + if (const RecordType *RT = dyn_cast<RecordType>(Ty)) + FindResourcesOnUserRecordDecl(VD, RT, Size); +} + +// Walks though the explicit resource binding attributes on the declaration, +// and makes sure there is a resource that matched the binding and updates +// DeclBindingInfoLists +void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) { + assert(VD->hasGlobalStorage() && "expected global variable"); + + for (Attr *A : VD->attrs()) { + HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); + if (!RBA) continue; - } - // Global array - const clang::Type *Ty = VD->getType().getTypePtr(); - while (Ty->isArrayType()) { - Ty = Ty->getArrayElementTypeNoTypeQual(); - } + RegisterType RT = RBA->getRegisterType(); + assert(RT != RegisterType::I && RT != RegisterType::Invalid && + "invalid or obsolete register type should never have an attribute " + "created"); - // Basic types - if (Ty->isArithmeticType()) { + // These were already diagnosed earlier + if (RT == RegisterType::C) { + if (Bindings.hasBindingInfoForDecl(VD)) + SemaRef.Diag(VD->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RT); continue; } - if (Ty->isRecordType()) { - if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(), - RBA)) { - SemaRef.Diag(D->getLocation(), - diag::warn_hlsl_user_defined_type_missing_member) - << static_cast<int>(RBA->getRegisterType()); - } + // Find DeclBindingInfo for this binding and update it, or report error + // if it does not exist (user type does to contain resources with the + // expected resource class). + ResourceClass RC = getResourceClass(RT); + if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) { + // update binding info + RBA->setSize(BI->Size); + BI->setBindingAttribute(RBA, BindingType::Explicit); + } else { + SemaRef.Diag(VD->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RT); } } } >From aa6247f414b2bd3d39f349646f3a97ec72d5d517 Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Wed, 9 Oct 2024 17:08:25 -0700 Subject: [PATCH 4/7] removed unused variable, cleanup --- clang/lib/Sema/SemaHLSL.cpp | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 197ee63c07deeb..0423340ee5fc4f 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1136,24 +1136,21 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, // Basic types if (Ty->isArithmeticType()) { bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); - if (SpecifiedSpace && !DeclaredInCOrTBuffer) { + if (SpecifiedSpace && !DeclaredInCOrTBuffer) S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); - } if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { // Default Globals if (RegType == RegisterType::CBuffer) S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); - else if (RegType != RegisterType::C) { + else if (RegType != RegisterType::C) S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); - else { + else S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - } } return false; } @@ -1172,13 +1169,8 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, // make sure that there are no two register annotations // applied to the decl with the same register type bool RegisterTypesDetected[5] = {false}; - RegisterTypesDetected[static_cast<int>(regType)] = true; - // we need a static map to keep track of previous conflicts - // so that we don't emit the same error multiple times - static std::map<Decl *, std::set<RegisterType>> PreviousConflicts; - for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) { if (HLSLResourceBindingAttr *attr = dyn_cast<HLSLResourceBindingAttr>(*it)) { >From a6edabe43eefc2957932498ee35b71e800af9fdd Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Tue, 15 Oct 2024 21:14:19 -0700 Subject: [PATCH 5/7] Code review feedback - remove size calculation and storage - it is currently not used or tested - remove invalid register type - set fields on HLSLResourceBindingAttr as private and accessors public, add const - update function names - update comments - use more effective SmallVector and DenseMap methods --- clang/include/clang/Basic/Attr.td | 31 +++---- clang/include/clang/Sema/SemaHLSL.h | 16 ++-- clang/lib/Sema/SemaHLSL.cpp | 121 ++++++++++++++-------------- 3 files changed, 75 insertions(+), 93 deletions(-) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 916757ccbe2d47..0259b6e40ca962 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4594,42 +4594,29 @@ def HLSLResourceBinding: InheritableAttr { let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>]; let Documentation = [HLSLResourceBindingDocs]; let AdditionalMembers = [{ - enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; - + public: + enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I }; + + private: RegisterType RegType; unsigned SlotNumber; unsigned SpaceNumber; - - // Size of the binding - // 0 == not set - //-1 == unbounded - int Size; - void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) { + public: + void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) { RegType = RT; SlotNumber = SlotNum; SpaceNumber = SpaceNum; } - RegisterType getRegisterType() { + RegisterType getRegisterType() const { return RegType; } - unsigned getSlotNumber() { + unsigned getSlotNumber() const { return SlotNumber; } - unsigned getSpaceNumber() { + unsigned getSpaceNumber() const { return SpaceNumber; } - unsigned getSize() { - assert(Size == -1 || Size > 0 && "size not set"); - return Size; - } - void setSize(int N) { - assert(N == -1 || N > 0 && "unexpected size value"); - Size = N; - } - bool isSizeUnbounded() { - return Size == -1; - } }]; } diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index ce262fd41dff37..5eda4d544a5ae5 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -45,15 +45,13 @@ enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit }; struct DeclBindingInfo { const VarDecl *Decl; ResourceClass ResClass; - int Size; // -1 == unbounded array const HLSLResourceBindingAttr *Attr; BindingType BindType; DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0, BindingType BindType = BindingType::NotAssigned, const HLSLResourceBindingAttr *Attr = nullptr) - : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr), - BindType(BindType) {} + : Decl(Decl), ResClass(ResClass), Attr(Attr), BindType(BindType) {} void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) { assert(Attr == nullptr && BindType == BindingType::NotAssigned && @@ -68,8 +66,8 @@ struct DeclBindingInfo { // assigments. class ResourceBindings { public: - DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass, - int Size); + DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass); DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass); bool hasBindingInfoForDecl(const VarDecl *VD); @@ -157,10 +155,10 @@ class SemaHLSL : public SemaBase { ResourceBindings Bindings; private: - void FindResourcesOnVarDecl(VarDecl *D); - void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT, - int Size); - void ProcessExplicitBindingsOnDecl(VarDecl *D); + void collectResourcesOnVarDecl(VarDecl *D); + void collectResourcesOnUserRecordDecl(const VarDecl *VD, + const RecordType *RT); + void processExplicitBindingsOnDecl(VarDecl *D); }; } // namespace clang diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 668d3ad9ecd6ba..a58c4281eeb375 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -56,28 +56,37 @@ static RegisterType getRegisterType(ResourceClass RC) { llvm_unreachable("unexpected ResourceClass value"); } -static RegisterType getRegisterType(StringRef Slot) { +// Converts the first letter of string Slot to RegisterType. +// Returns false if the letter does not correspond to a valid register type. +static bool convertToRegisterType(StringRef Slot, RegisterType *RT) { + assert(RT != nullptr); switch (Slot[0]) { case 't': case 'T': - return RegisterType::SRV; + *RT = RegisterType::SRV; + return true; case 'u': case 'U': - return RegisterType::UAV; + *RT = RegisterType::UAV; + return true; case 'b': case 'B': - return RegisterType::CBuffer; + *RT = RegisterType::CBuffer; + return true; case 's': case 'S': - return RegisterType::Sampler; + *RT = RegisterType::Sampler; + return true; case 'c': case 'C': - return RegisterType::C; + *RT = RegisterType::C; + return true; case 'i': case 'I': - return RegisterType::I; + *RT = RegisterType::I; + return true; default: - return RegisterType::Invalid; + return false; } } @@ -91,19 +100,18 @@ static ResourceClass getResourceClass(RegisterType RT) { return ResourceClass::CBuffer; case RegisterType::Sampler: return ResourceClass::Sampler; - default: + case RegisterType::C: + case RegisterType::I: llvm_unreachable("unexpected RegisterType value"); } } DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD, - ResourceClass ResClass, - int Size) { + ResourceClass ResClass) { assert(getDeclBindingInfo(VD, ResClass) == nullptr && "DeclBindingInfo already added"); - if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end()) - DeclToBindingListIndex[VD] = BindingsList.size(); - return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size)); + DeclToBindingListIndex.try_emplace(VD, BindingsList.size()); + return &BindingsList.emplace_back(VD, ResClass); } DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD, @@ -1050,19 +1058,18 @@ FindHandleTypeOnResource(const VarDecl *VD) { // Walks though the global variable declaration, collects all resource binding // requirements and adds them to Bindings -void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD, - const RecordType *RT, int Size) { +void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD, + const RecordType *RT) { const RecordDecl *RD = RT->getDecl(); for (FieldDecl *FD : RD->fields()) { const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); - // Calculate array size and unwrap - int ArraySize = 1; + // Unwrap arrays + // FIXME: Calculate array size while unwrapping assert(!Ty->isIncompleteArrayType() && "incomplete arrays inside user defined types are not supported"); while (Ty->isConstantArrayType()) { const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); - ArraySize *= CAT->getSize().getSExtValue(); Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); } @@ -1074,23 +1081,26 @@ void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD, FindHandleTypeOnResource(Ty)) { ResourceClass RC = AttrResType->getAttrs().ResourceClass; - // Add a new DeclBindingInfo to Bindings. Update the binding size if - // a binding info already exists (there are multiple resources of same - // resource class in this user decl) - if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC)) - DBI->Size += Size * ArraySize; - else - Bindings.addDeclBindingInfo(VD, RC, Size); + // Add a new DeclBindingInfo to Bindings if it does not already exist + DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC); + if (!DBI) + Bindings.addDeclBindingInfo(VD, RC); } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) { // Recursively scan embedded struct or class; it would be nice to do this - // without recursion, but tricky to corrently calculate the size. - // Hopefully nesting of structs in structs too many levels is unlikely. - FindResourcesOnUserRecordDecl(VD, RT, Size); + // without recursion, but tricky to correctly calculate the size of the + // binding, which is something we are probably going to need to do later + // on. Hopefully nesting of structs in structs too many levels is + // unlikely. + collectResourcesOnUserRecordDecl(VD, RT); } } } -// return false if the register binding is not valid +// Diagnore localized register binding errors for a single binding; does not +// diagnose resource binding on user record types, that will be done later +// in processResourceBindingOnDecl based on the information collected in +// collectResourcesOnVarDecl. +// Returns false if the register binding is not valid. static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace) { @@ -1155,7 +1165,7 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, return false; } if (Ty->isRecordType()) - // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl + // RecordTypes will be diagnosed in processResourceBindingOnDecl // that is called from ActOnVariableDeclarator return true; @@ -1175,7 +1185,7 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, if (HLSLResourceBindingAttr *attr = dyn_cast<HLSLResourceBindingAttr>(*it)) { - RegisterType otherRegType = getRegisterType(attr->getSlot()); + RegisterType otherRegType = attr->getRegisterType(); if (RegisterTypesDetected[static_cast<int>(otherRegType)]) { int otherRegTypeNum = static_cast<int>(otherRegType); S.Diag(TheDecl->getLocation(), @@ -1250,13 +1260,12 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { // Validate. if (!Slot.empty()) { - RegType = getRegisterType(Slot); - if (RegType == RegisterType::I) { - Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); + if (!convertToRegisterType(Slot, &RegType)) { + Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); return; } - if (RegType == RegisterType::Invalid) { - Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); + if (RegType == RegisterType::I) { + Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); return; } @@ -2294,60 +2303,51 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { // find all resources on decl if (IsIntangibleType(VD->getType())) - FindResourcesOnVarDecl(VD); + collectResourcesOnVarDecl(VD); // process explicit bindings - ProcessExplicitBindingsOnDecl(VD); + processExplicitBindingsOnDecl(VD); } } // Walks though the global variable declaration, collects all resource binding // requirements and adds them to Bindings -void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) { +void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) { assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) && "expected global variable that contains HLSL resource"); // Cbuffers and Tbuffers are HLSLBufferDecl types if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) { - Bindings.addDeclBindingInfo(VD, - CBufferOrTBuffer->isCBuffer() - ? ResourceClass::CBuffer - : ResourceClass::SRV, - 1); + Bindings.addDeclBindingInfo(VD, CBufferOrTBuffer->isCBuffer() + ? ResourceClass::CBuffer + : ResourceClass::SRV); return; } - // Calculate size of array and unwrap - int Size = 1; + // Unwrap arrays + // FIXME: Calculate array size while unwrapping const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); - if (Ty->isIncompleteArrayType()) - Size = -1; while (Ty->isConstantArrayType()) { const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); - Size *= CAT->getSize().getSExtValue(); Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); } // Resource (or array of resources) if (const HLSLAttributedResourceType *AttrResType = FindHandleTypeOnResource(Ty)) { - Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass, - Size); + Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass); return; } - assert(Size != -1 && - "unbounded arrays of user defined types are not supported"); - // User defined record type if (const RecordType *RT = dyn_cast<RecordType>(Ty)) - FindResourcesOnUserRecordDecl(VD, RT, Size); + collectResourcesOnUserRecordDecl(VD, RT); } // Walks though the explicit resource binding attributes on the declaration, // and makes sure there is a resource that matched the binding and updates // DeclBindingInfoLists -void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) { +void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) { assert(VD->hasGlobalStorage() && "expected global variable"); for (Attr *A : VD->attrs()) { @@ -2356,11 +2356,9 @@ void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) { continue; RegisterType RT = RBA->getRegisterType(); - assert(RT != RegisterType::I && RT != RegisterType::Invalid && - "invalid or obsolete register type should never have an attribute " - "created"); + assert(RT != RegisterType::I && "invalid or obsolete register type should " + "never have an attribute created"); - // These were already diagnosed earlier if (RT == RegisterType::C) { if (Bindings.hasBindingInfoForDecl(VD)) SemaRef.Diag(VD->getLocation(), @@ -2375,7 +2373,6 @@ void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) { ResourceClass RC = getResourceClass(RT); if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) { // update binding info - RBA->setSize(BI->Size); BI->setBindingAttribute(RBA, BindingType::Explicit); } else { SemaRef.Diag(VD->getLocation(), >From a7fbeaebf2a5ac502ab1c00787e0c51ee807210e Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Wed, 16 Oct 2024 10:14:57 -0700 Subject: [PATCH 6/7] More cleanup - rename function to lowerCase and remove one overload, remove Size argument, remove comment --- clang/include/clang/Sema/SemaHLSL.h | 2 +- clang/lib/Sema/SemaHLSL.cpp | 19 +++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 5eda4d544a5ae5..31b4c5b748e189 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -48,7 +48,7 @@ struct DeclBindingInfo { const HLSLResourceBindingAttr *Attr; BindingType BindType; - DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0, + DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, BindingType BindType = BindingType::NotAssigned, const HLSLResourceBindingAttr *Attr = nullptr) : Decl(Decl), ResClass(ResClass), Attr(Attr), BindType(BindType) {} diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 98c0100afbc5c9..de7951aa1b9088 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1036,7 +1036,7 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { // Returns handle type of a resource, if the type is a resource static const HLSLAttributedResourceType * -FindHandleTypeOnResource(const Type *Ty) { +findHandleTypeOnResource(const Type *Ty) { // If Ty is a resource class, the first field must // be the resource handle of type HLSLAttributedResourceType if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { @@ -1049,13 +1049,6 @@ FindHandleTypeOnResource(const Type *Ty) { return nullptr; } -// Returns handle type of a resource, if the VarDecl is a resource -static const HLSLAttributedResourceType * -FindHandleTypeOnResource(const VarDecl *VD) { - assert(VD != nullptr && "expected VarDecl"); - return FindHandleTypeOnResource(VD->getType().getTypePtr()); -} - // Walks though the global variable declaration, collects all resource binding // requirements and adds them to Bindings void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD, @@ -1076,12 +1069,10 @@ void SemaHLSL::collectResourcesOnUserRecordDecl(const VarDecl *VD, if (!Ty->isRecordType()) continue; - // Field is a resource or array of resources if (const HLSLAttributedResourceType *AttrResType = - FindHandleTypeOnResource(Ty)) { - ResourceClass RC = AttrResType->getAttrs().ResourceClass; - + findHandleTypeOnResource(Ty)) { // Add a new DeclBindingInfo to Bindings if it does not already exist + ResourceClass RC = AttrResType->getAttrs().ResourceClass; DeclBindingInfo *DBI = Bindings.getDeclBindingInfo(VD, RC); if (!DBI) Bindings.addDeclBindingInfo(VD, RC); @@ -1130,7 +1121,7 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, // Resource if (const HLSLAttributedResourceType *AttrResType = - FindHandleTypeOnResource(VD)) { + findHandleTypeOnResource(VD->getType().getTypePtr())) { if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass)) return true; @@ -2373,7 +2364,7 @@ void SemaHLSL::collectResourcesOnVarDecl(VarDecl *VD) { // Resource (or array of resources) if (const HLSLAttributedResourceType *AttrResType = - FindHandleTypeOnResource(Ty)) { + findHandleTypeOnResource(Ty)) { Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass); return; } >From 1e79dcfab082c00b09cd7245f22bf3442ffdd31a Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Wed, 16 Oct 2024 13:41:50 -0700 Subject: [PATCH 7/7] cr feedback - add const, rearrange loop, add debug-only checks --- clang/include/clang/Sema/SemaHLSL.h | 2 +- clang/lib/Sema/SemaHLSL.cpp | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 31b4c5b748e189..4f1fc9a31404c6 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -70,7 +70,7 @@ class ResourceBindings { ResourceClass ResClass); DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass); - bool hasBindingInfoForDecl(const VarDecl *VD); + bool hasBindingInfoForDecl(const VarDecl *VD) const; private: // List of all resource bindings required by the shader. diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index de7951aa1b9088..0d23c4935e9196 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -110,6 +110,19 @@ DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass) { assert(getDeclBindingInfo(VD, ResClass) == nullptr && "DeclBindingInfo already added"); +#ifndef NDEBUG + // Verify that existing bindings for this decl are stored sequentially + // and at the end of the BindingsList + auto I = DeclToBindingListIndex.find(VD); + if (I != DeclToBindingListIndex.end()) { + for (unsigned Index = I->getSecond(); Index < BindingsList.size(); ++Index) + assert(BindingsList[Index].Decl == VD); + } +#endif + // VarDecl may have multiple entries for different resource classes. + // DeclToBindingListIndex stores the index of the first binding we saw + // for this decl. If there are any additional ones then that index + // shouldn't be updated. DeclToBindingListIndex.try_emplace(VD, BindingsList.size()); return &BindingsList.emplace_back(VD, ResClass); } @@ -118,17 +131,17 @@ DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass) { auto Entry = DeclToBindingListIndex.find(VD); if (Entry != DeclToBindingListIndex.end()) { - unsigned Index = Entry->getSecond(); - while (Index < BindingsList.size() && BindingsList[Index].Decl == VD) { + for (unsigned Index = Entry->getSecond(); + Index < BindingsList.size() && BindingsList[Index].Decl == VD; + ++Index) { if (BindingsList[Index].ResClass == ResClass) return &BindingsList[Index]; - Index++; } } return nullptr; } -bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) { +bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const { return DeclToBindingListIndex.contains(VD); } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits