================ @@ -612,57 +588,61 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, return Flags; } - if (!isDeclaredWithinCOrTBuffer(TheDecl)) { - // make sure the type is a basic / numeric type - if (TheVarDecl) { - QualType TheQualTy = TheVarDecl->getType(); - // a numeric variable or an array of numeric variables - // will inevitably end up in $Globals buffer - const clang::Type *TheBaseType = TheQualTy.getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - if (TheBaseType->isIntegralType(S.getASTContext()) || - TheBaseType->isFloatingType()) - Flags.DefaultGlobals = true; - } - } - - if (CBufferOrTBuffer) { + // Cbuffers and Tbuffers are HLSLBufferDecl types + if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) { Flags.Resource = true; - if (CBufferOrTBuffer->isCBuffer()) - Flags.CBV = true; - else - Flags.SRV = true; - } else if (TheVarDecl) { + Flags.ResourceClass = CBufferOrTBuffer->isCBuffer() + ? llvm::dxil::ResourceClass::CBuffer + : llvm::dxil::ResourceClass::SRV; + } + // Samplers, UAVs, and SRVs are VarDecl types + else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) { const HLSLResourceClassAttr *resClassAttr = getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl); - if (resClassAttr) { - llvm::hlsl::ResourceClass DeclResourceClass = - resClassAttr->getResourceClass(); Flags.Resource = true; - updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); + Flags.ResourceClass = resClassAttr->getResourceClass(); } else { const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); while (TheBaseType->isArrayType()) TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - if (TheBaseType->isArithmeticType()) + + if (TheBaseType->isArithmeticType()) { Flags.Basic = true; - else if (TheBaseType->isRecordType()) { + if (!isDeclaredWithinCOrTBuffer(TheDecl) && + (TheBaseType->isIntegralType(S.getASTContext()) || + TheBaseType->isFloatingType())) + Flags.DefaultGlobals = true; + } else if (TheBaseType->isRecordType()) { Flags.UDT = true; const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); - assert(TheRecordTy && "The Qual Type should be Record Type"); - const RecordDecl *TheRecordDecl = TheRecordTy->getDecl(); - // recurse through members, set appropriate resource class flags. - updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl); + updateResourceClassFlagsFromRecordType(Flags, TheRecordTy); } else Flags.Other = true; } + } else { + llvm_unreachable("expected be VarDecl or HLSLBufferDecl"); } return Flags; } -enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; +enum class RegisterType { + SRV = static_cast<int>(llvm::dxil::ResourceClass::SRV), + UAV = static_cast<int>(llvm::dxil::ResourceClass::UAV), + CBuffer = static_cast<int>(llvm::dxil::ResourceClass::CBuffer), + Sampler = static_cast<int>(llvm::dxil::ResourceClass::Sampler), + C, + I, + Invalid ---------------- hekota wrote:
I've also renamed it to `getRegisterType` to match similar function next to it. https://github.com/llvm/llvm-project/pull/106657 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits