================ @@ -459,7 +467,412 @@ void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc)); } -void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) { +struct RegisterBindingFlags { + bool Resource = false; + bool UDT = false; + bool Other = false; + bool Basic = false; + + bool SRV = false; + bool UAV = false; + bool CBV = false; + bool Sampler = false; + + bool ContainsNumeric = false; + bool DefaultGlobals = false; +}; + +static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) { + if (!TheDecl) + return false; + + // Traverse up the parent contexts + const DeclContext *context = TheDecl->getDeclContext(); + if (isa<HLSLBufferDecl>(context)) { + return true; + } + + return false; +} + +// get the record decl from a var decl that we expect +// represents a resource +static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { + const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); + assert(Ty && "Resource must have an element type."); + + if (const auto *TheBuiltinTy = dyn_cast<BuiltinType>(Ty)) + return nullptr; + + CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); + assert(TheRecordDecl && "Resource should have a resource type declaration."); + return TheRecordDecl; +} + +static void setResourceClassFlagsFromDeclResourceClass( + RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) { + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: + Flags.SRV = true; + break; + case llvm::hlsl::ResourceClass::UAV: + Flags.UAV = true; + break; + case llvm::hlsl::ResourceClass::CBuffer: + Flags.CBV = true; + break; + case llvm::hlsl::ResourceClass::Sampler: + Flags.Sampler = true; + break; + } +} + +template <typename T> +static const T * +getSpecifiedHLSLAttrFromVarDeclOrRecordDecl(VarDecl *VD, + RecordDecl *TheRecordDecl) { + if (VD) { + TheRecordDecl = getRecordDeclFromVarDecl(VD); + if (!TheRecordDecl) + return nullptr; + } + + // make a lambda that checks if the decl has the specified attr, + // and if not, loops over the field members and checks for the + // specified attribute + auto f = [](RecordDecl *TheRecordDecl) -> const T * { + for (auto *FD : TheRecordDecl->fields()) { + const T *Attr = FD->getAttr<T>(); + if (Attr) + return Attr; + } + return nullptr; + }; + + if (TheRecordDecl) { + // if the member's base type is a ClassTemplateSpecializationDecl, + // check if it has a member handle with a resource class attr + // this is necessary while resources like RWBuffer are defined externally + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) { + auto TheCXXRecordDecl = + TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheCXXRecordDecl = TheCXXRecordDecl->getCanonicalDecl(); + + return f(TheCXXRecordDecl); + } + + return f(TheRecordDecl); + } + llvm_unreachable("TheRecordDecl should not be null"); + return nullptr; +} + +static void setFlagsFromType(QualType TheQualTy, RegisterBindingFlags &Flags) { + // if the member's type is a numeric type, set the ContainsNumeric flag + if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) { + Flags.ContainsNumeric = true; + return; + } + + // otherwise, if the member's base type is not a record type, return + const clang::Type *TheBaseType = TheQualTy.getTypePtr(); + while (TheBaseType->isArrayType()) + TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); + + const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); + if (!TheRecordTy) + return; + + RecordDecl *SubRecordDecl = TheRecordTy->getDecl(); + bool resClassSet = false; + const HLSLResourceClassAttr *Attr = + getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>( + nullptr, SubRecordDecl); + // find the attr if it's on the member (the handle) of the resource + if (Attr) { + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); + resClassSet = true; + } + // otherwise, check if the member of the UDT itself has a resource class attr + else if (const auto *Attr = SubRecordDecl->getAttr<HLSLResourceClassAttr>()) { + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); + resClassSet = true; + } + // recurse if there are more fields to analyze + if (!resClassSet) { + for (auto Field : SubRecordDecl->fields()) { + setFlagsFromType(Field->getType(), Flags); + } + } +} + +static void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags, + const RecordDecl *RD) { + if (!RD) + return; + + if (RD->isCompleteDefinition()) { + for (auto Field : RD->fields()) { + QualType T = Field->getType(); + setFlagsFromType(T, Flags); + } + } +} + +static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, + Decl *TheDecl) { + + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl); + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl); + + assert(((TheVarDecl && !CBufferOrTBuffer) || + (!TheVarDecl && CBufferOrTBuffer)) && + "either TheVarDecl or CBufferOrTBuffer should be set"); + + RegisterBindingFlags Flags; + + // check if the decl type is groupshared + if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { + Flags.Other = true; + return Flags; + } + + if (!isDeclaredWithinCOrTBuffer(TheDecl)) { + // make sure the type is a basic / numeric type + if (TheVarDecl) { + QualType TheQualTy = TheVarDecl->getType(); + // a numeric variable or an array of numeric variables + // will inevitably end up in $Globals buffer + const clang::Type *TheBaseType = TheQualTy.getTypePtr(); + while (TheBaseType->isArrayType()) + TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); + if (TheBaseType->isIntegralType(S.getASTContext()) || + TheBaseType->isFloatingType()) + Flags.DefaultGlobals = true; + } + } + + if (CBufferOrTBuffer) { + Flags.Resource = true; + if (CBufferOrTBuffer->isCBuffer()) + Flags.CBV = true; + else + Flags.SRV = true; + } else if (TheVarDecl) { + const HLSLResourceClassAttr *resClassAttr = + getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>( + TheVarDecl, nullptr); + const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); + while (TheBaseType->isArrayType()) + TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); ---------------- hekota wrote:
`TheBaseType` is not used until line `679: if (TheBaseType->isArithmeticType())`. Can you move this closer to where it is used? https://github.com/llvm/llvm-project/pull/97103 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits