================ @@ -459,7 +467,506 @@ void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc)); } -void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) { +struct RegisterBindingFlags { + bool Resource = false; + bool UDT = false; + bool Other = false; + bool Basic = false; + + bool SRV = false; + bool UAV = false; + bool CBV = false; + bool Sampler = false; + + bool ContainsNumeric = false; + bool DefaultGlobals = false; +}; + +bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) { + if (!TheDecl) + return false; + + // Traverse up the parent contexts + const DeclContext *context = TheDecl->getDeclContext(); + if (isa<HLSLBufferDecl>(context)) { + return true; + } + + return false; +} + +const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { + const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); + assert(Ty && "Resource class must have an element type."); + + if (const auto *TheBuiltinTy = dyn_cast<BuiltinType>(Ty)) + return nullptr; + + const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); + assert(TheRecordDecl && + "Resource class should have a resource type declaration."); + + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) + TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + return TheRecordDecl; +} + +const HLSLResourceClassAttr * +getHLSLResourceClassAttrFromEitherDecl(VarDecl *VD, + HLSLBufferDecl *CBufferOrTBuffer) { + + if (VD) { + const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD); + if (!TheRecordDecl) + return nullptr; + + // the resource class attr could be on the record decl itself or on one of + // its fields (the resource handle, most commonly) + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceClassAttr>(); + if (!Attr) { + for (auto *FD : TheRecordDecl->fields()) { + Attr = FD->getAttr<HLSLResourceClassAttr>(); + if (Attr) + break; + } + } + return Attr; + } else if (CBufferOrTBuffer) { + const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>(); + return Attr; + } + llvm_unreachable("one of the two conditions should be true."); + return nullptr; +} + +const HLSLResourceAttr * +getHLSLResourceAttrFromEitherDecl(VarDecl *VD, + HLSLBufferDecl *CBufferOrTBuffer) { + + if (VD) { + const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD); + if (!TheRecordDecl) + return nullptr; + + // the resource attr could be on the record decl itself or on one of + // its fields (the resource handle, most commonly) + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + if (!Attr) { + for (auto *FD : TheRecordDecl->fields()) { + Attr = FD->getAttr<HLSLResourceAttr>(); + if (Attr) + break; + } + } + return Attr; + } else if (CBufferOrTBuffer) { + const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>(); + return Attr; + } + llvm_unreachable("one of the two conditions should be true."); + return nullptr; +} + +void traverseType(QualType TheQualTy, RegisterBindingFlags &Flags) { + // if the member's type is a numeric type, set the ContainsNumeric flag + if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) { + Flags.ContainsNumeric = true; + return; + } + + // otherwise, if the member's base type is not a record type, return + const clang::Type *TheBaseType = TheQualTy.getTypePtr(); + while (TheBaseType->isArrayType()) + TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); + + const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); + if (!TheRecordTy) + return; + + RecordDecl *SubRecordDecl = TheRecordTy->getDecl(); + bool resClassSet = false; + // if the member's base type is a ClassTemplateSpecializationDecl, + // check if it has a resource class attr + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRecordDecl)) { + auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceClassAttr>(); + if (!Attr) { + for (auto *FD : TheRecordDecl->fields()) { + Attr = FD->getAttr<HLSLResourceClassAttr>(); + if (Attr) + break; + } + } + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: + Flags.SRV = true; + break; + case llvm::hlsl::ResourceClass::UAV: + Flags.UAV = true; + break; + case llvm::hlsl::ResourceClass::CBuffer: + Flags.CBV = true; + break; + case llvm::hlsl::ResourceClass::Sampler: + Flags.Sampler = true; + break; + } + resClassSet = true; + } + // otherwise, check if the member has a resource class attr + else if (auto *Attr = SubRecordDecl->getAttr<HLSLResourceClassAttr>()) { + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: + Flags.SRV = true; + break; + case llvm::hlsl::ResourceClass::UAV: + Flags.UAV = true; + break; + case llvm::hlsl::ResourceClass::CBuffer: + Flags.CBV = true; + break; + case llvm::hlsl::ResourceClass::Sampler: + Flags.Sampler = true; + break; + } + resClassSet = true; + } + + if (!resClassSet) { + for (auto Field : SubRecordDecl->fields()) { + traverseType(Field->getType(), Flags); + } + } +} + +void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags, + const RecordDecl *RD) { + if (!RD) + return; + + if (RD->isCompleteDefinition()) { + for (auto Field : RD->fields()) { + QualType T = Field->getType(); + traverseType(T, Flags); + } + } +} + +RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *TheDecl) { + + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl); + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl); + + assert(((TheVarDecl && !CBufferOrTBuffer) || + (!TheVarDecl && CBufferOrTBuffer)) && + "either VD or CBufferOrTBuffer should be set"); + + RegisterBindingFlags Flags; + + // check if the decl type is groupshared + if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { + Flags.Other = true; + return Flags; + } + + if (!isDeclaredWithinCOrTBuffer(TheDecl)) { + // make sure the type is a basic / numeric type + if (TheVarDecl) { + QualType TheQualTy = TheVarDecl->getType(); + // a numeric variable or an array of numeric variables + // will inevitably end up in $Globals buffer + const clang::Type *TheBaseType = TheQualTy.getTypePtr(); + while (TheBaseType->isArrayType()) + TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); + if (TheBaseType->isIntegralType(S.getASTContext()) || + TheBaseType->isFloatingType()) + Flags.DefaultGlobals = true; + } + } + + if (CBufferOrTBuffer) { + Flags.Resource = true; + if (CBufferOrTBuffer->isCBuffer()) + Flags.CBV = true; + else + Flags.SRV = true; + } else if (TheVarDecl) { + const HLSLResourceClassAttr *resClassAttr = + getHLSLResourceClassAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer); + const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); + while (TheBaseType->isArrayType()) + TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); + + if (resClassAttr) { + llvm::hlsl::ResourceClass DeclResourceClass = + resClassAttr->getResourceClass(); + Flags.Resource = true; + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: + Flags.SRV = true; + break; + case llvm::hlsl::ResourceClass::UAV: + Flags.UAV = true; + break; + case llvm::hlsl::ResourceClass::CBuffer: + Flags.CBV = true; + break; + case llvm::hlsl::ResourceClass::Sampler: + Flags.Sampler = true; + break; + } + } else { + if (TheBaseType->isArithmeticType()) + Flags.Basic = true; + else if (TheBaseType->isRecordType()) { + Flags.UDT = true; + const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); + assert(TheRecordTy && "The Qual Type should be Record Type"); + const RecordDecl *TheRecordDecl = TheRecordTy->getDecl(); + // recurse through members, set appropriate resource class flags. + setResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl); + } else + Flags.Other = true; + } + } + return Flags; +} + +enum RegisterType { SRV, UAV, CBuffer, Sampler, C, I }; + +int getRegisterTypeIndex(StringRef Slot) { + switch (Slot[0]) { + case 't': + case 'T': + return RegisterType::SRV; + case 'u': + case 'U': + return RegisterType::UAV; + case 'b': + case 'B ': + return RegisterType::CBuffer; + case 's': + case 'S': + return RegisterType::Sampler; + case 'c': + case 'C': + return RegisterType::C; + case 'i': + case 'I': + return RegisterType::I; + default: + llvm_unreachable("invalid register type"); + } +} + +static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, + StringRef &Slot) { + // make sure that there are no tworegister annotations + // applied to the decl with the same register type + bool RegisterTypesDetected[6] = {false}; + RegisterTypesDetected[getRegisterTypeIndex(Slot)] = true; + + for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) { + if (HLSLResourceBindingAttr *attr = + dyn_cast<HLSLResourceBindingAttr>(*it)) { + + int registerTypeIndex = getRegisterTypeIndex(attr->getSlot()); + if (RegisterTypesDetected[registerTypeIndex]) { + S.Diag(TheDecl->getLocation(), + diag::err_hlsl_duplicate_register_annotation) + << registerTypeIndex; + } else { + RegisterTypesDetected[registerTypeIndex] = true; + } + } + } +} + +std::string getHLSLResourceTypeStr(Sema &S, Decl *TheDecl) { + VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl); + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl); + + if (TheVarDecl) { + QualType TheQualTy = TheVarDecl->getType(); + PrintingPolicy PP = S.getPrintingPolicy(); + return QualType::getAsString(TheQualTy.split(), PP); + } else { + return CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer"; + } +} + +static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, + Decl *TheDecl, StringRef &Slot) { + + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl); + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl); + + // exactly one of these two types should be set + assert(((TheVarDecl && !CBufferOrTBuffer) || + (!TheVarDecl && CBufferOrTBuffer)) && + "either TheVarDecl or CBufferOrTBuffer should be set"); + + RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl); + assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic + + (int)Flags.UDT == + 1 && + "only one resource analysis result should be expected"); + + int regType = getRegisterTypeIndex(Slot); + + // first, if "other" is set, emit an error + if (Flags.Other) { + if (regType == RegisterType::I) { + S.Diag(TheDecl->getLocation(), + diag::warn_hlsl_deprecated_register_type_i); + return; + } + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regType; + return; + } + + // next, if multiple register annotations exist, check that none conflict. + ValidateMultipleRegisterAnnotations(S, TheDecl, Slot); + + // next, if resource is set, make sure the register type in the register + // annotation is compatible with the variable's resource type. + if (Flags.Resource) { + if (regType == RegisterType::I) { + S.Diag(TheDecl->getLocation(), + diag::warn_hlsl_deprecated_register_type_i); + return; + } + const HLSLResourceAttr *resAttr = + getHLSLResourceAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer); + const HLSLResourceClassAttr *resClassAttr = + getHLSLResourceClassAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer); + assert(resAttr && resClassAttr && + "any decl that set the resource flag on analysis should " + "have a resource attribute and resource class attribute attached."); + const llvm::hlsl::ResourceClass DeclResourceClass = + resClassAttr->getResourceClass(); + + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: + if (regType != RegisterType::SRV) + S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) + << regType; + break; + case llvm::hlsl::ResourceClass::UAV: + if (regType != RegisterType::UAV) + S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) + << regType; + break; + case llvm::hlsl::ResourceClass::CBuffer: + if (regType != RegisterType::CBuffer) + S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) + << regType; + break; + case llvm::hlsl::ResourceClass::Sampler: + if (regType != RegisterType::Sampler) + S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) + << regType; + break; + } + return; + } + + // next, handle diagnostics for when the "basic" flag is set, + // including the legacy "i" and "b" register types. + if (Flags.Basic) { + if (Flags.DefaultGlobals) { + if (regType == RegisterType::CBuffer) + S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); + else if (regType == RegisterType::I) + S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); + else if (regType != RegisterType::C) + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regType; + return; + } + + if (regType == RegisterType::C) + S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); + else if (regType == RegisterType::I) + S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); + else + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regType; + + return; + } + + // finally, we handle the udt case + if (Flags.UDT) { + if (regType == RegisterType::I) { + S.Diag(TheDecl->getLocation(), + diag::warn_hlsl_deprecated_register_type_i); + return; + } + switch (getRegisterTypeIndex(Slot)) { ---------------- damyanp wrote:
```suggestion switch (regType) { ``` Or am I missing something? https://github.com/llvm/llvm-project/pull/97103 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits