Author: Helena Kotas Date: 2024-06-07T21:30:04-07:00 New Revision: 5d87ba1c1f584dfbd5afaf187099b43681b2206d
URL: https://github.com/llvm/llvm-project/commit/5d87ba1c1f584dfbd5afaf187099b43681b2206d DIFF: https://github.com/llvm/llvm-project/commit/5d87ba1c1f584dfbd5afaf187099b43681b2206d.diff LOG: [HLSL] Use llvm::Triple::EnvironmentType instead of HLSLShaderAttr::ShaderType (#93847) `HLSLShaderAttr::ShaderType` enum is a subset of `llvm::Triple::EnvironmentType`. We can use `llvm::Triple::EnvironmentType` directly and avoid converting one enum to another. Added: Modified: clang/include/clang/Basic/Attr.td clang/include/clang/Sema/SemaHLSL.h clang/lib/CodeGen/CGHLSLRuntime.cpp clang/lib/Sema/SemaHLSL.cpp Removed: ################################################################################ diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 17d9a710d948b..b70b0c8b836a5 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4470,37 +4470,20 @@ def HLSLShader : InheritableAttr { let Subjects = SubjectList<[HLSLEntry]>; let LangOpts = [HLSL]; let Args = [ - EnumArgument<"Type", "ShaderType", /*is_string=*/true, + EnumArgument<"Type", "llvm::Triple::EnvironmentType", /*is_string=*/true, ["pixel", "vertex", "geometry", "hull", "domain", "compute", "raygeneration", "intersection", "anyhit", "closesthit", "miss", "callable", "mesh", "amplification"], ["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute", "RayGeneration", "Intersection", "AnyHit", "ClosestHit", - "Miss", "Callable", "Mesh", "Amplification"]> + "Miss", "Callable", "Mesh", "Amplification"], + /*opt=*/0, /*fake=*/0, /*isExternalType=*/1> ]; let Documentation = [HLSLSV_ShaderTypeAttrDocs]; let AdditionalMembers = [{ - static const unsigned ShaderTypeMaxValue = (unsigned)HLSLShaderAttr::Amplification; - - static llvm::Triple::EnvironmentType getTypeAsEnvironment(HLSLShaderAttr::ShaderType ShaderType) { - switch (ShaderType) { - case HLSLShaderAttr::Pixel: return llvm::Triple::Pixel; - case HLSLShaderAttr::Vertex: return llvm::Triple::Vertex; - case HLSLShaderAttr::Geometry: return llvm::Triple::Geometry; - case HLSLShaderAttr::Hull: return llvm::Triple::Hull; - case HLSLShaderAttr::Domain: return llvm::Triple::Domain; - case HLSLShaderAttr::Compute: return llvm::Triple::Compute; - case HLSLShaderAttr::RayGeneration: return llvm::Triple::RayGeneration; - case HLSLShaderAttr::Intersection: return llvm::Triple::Intersection; - case HLSLShaderAttr::AnyHit: return llvm::Triple::AnyHit; - case HLSLShaderAttr::ClosestHit: return llvm::Triple::ClosestHit; - case HLSLShaderAttr::Miss: return llvm::Triple::Miss; - case HLSLShaderAttr::Callable: return llvm::Triple::Callable; - case HLSLShaderAttr::Mesh: return llvm::Triple::Mesh; - case HLSLShaderAttr::Amplification: return llvm::Triple::Amplification; - } - llvm_unreachable("unknown enumeration value"); + static bool isValidShaderType(llvm::Triple::EnvironmentType ShaderType) { + return ShaderType >= llvm::Triple::Pixel && ShaderType <= llvm::Triple::Amplification; } }]; } diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index e145f5e7f43f8..0e41a72e444ef 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -39,7 +39,7 @@ class SemaHLSL : public SemaBase { const AttributeCommonInfo &AL, int X, int Y, int Z); HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, - HLSLShaderAttr::ShaderType ShaderType); + llvm::Triple::EnvironmentType ShaderType); HLSLParamModifierAttr * mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling); @@ -48,8 +48,8 @@ class SemaHLSL : public SemaBase { void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param, const HLSLAnnotationAttr *AnnotationAttr); void DiagnoseAttrStageMismatch( - const Attr *A, HLSLShaderAttr::ShaderType Stage, - std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages); + const Attr *A, llvm::Triple::EnvironmentType Stage, + std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages); void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 5e6a3dd4878f4..55ba21ae2ba69 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -313,7 +313,7 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); const StringRef ShaderAttrKindStr = "hlsl.shader"; Fn->addFnAttr(ShaderAttrKindStr, - ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); + llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType())); if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { const StringRef NumThreadsKindStr = "hlsl.numthreads"; std::string NumThreadsStr = diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 0a2face7afe65..144cdcc0d98ef 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -146,7 +146,7 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, HLSLShaderAttr * SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, - HLSLShaderAttr::ShaderType ShaderType) { + llvm::Triple::EnvironmentType ShaderType) { if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { if (NT->getType() != ShaderType) { Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; @@ -184,13 +184,12 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) return; - StringRef Env = TargetInfo.getTriple().getEnvironmentName(); - HLSLShaderAttr::ShaderType ShaderType; - if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) { + llvm::Triple::EnvironmentType Env = TargetInfo.getTriple().getEnvironment(); + if (HLSLShaderAttr::isValidShaderType(Env) && Env != llvm::Triple::Library) { if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { // The entry point is already annotated - check that it matches the // triple. - if (Shader->getType() != ShaderType) { + if (Shader->getType() != Env) { Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) << Shader; FD->setInvalidDecl(); @@ -198,11 +197,11 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { } else { // Implicitly add the shader attribute if the entry function isn't // explicitly annotated. - FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType, + FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), Env, FD->getBeginLoc())); } } else { - switch (TargetInfo.getTriple().getEnvironment()) { + switch (Env) { case llvm::Triple::UnknownEnvironment: case llvm::Triple::Library: break; @@ -215,38 +214,40 @@ void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); assert(ShaderAttr && "Entry point has no shader attribute"); - HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); + llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); switch (ST) { - case HLSLShaderAttr::Pixel: - case HLSLShaderAttr::Vertex: - case HLSLShaderAttr::Geometry: - case HLSLShaderAttr::Hull: - case HLSLShaderAttr::Domain: - case HLSLShaderAttr::RayGeneration: - case HLSLShaderAttr::Intersection: - case HLSLShaderAttr::AnyHit: - case HLSLShaderAttr::ClosestHit: - case HLSLShaderAttr::Miss: - case HLSLShaderAttr::Callable: + case llvm::Triple::Pixel: + case llvm::Triple::Vertex: + case llvm::Triple::Geometry: + case llvm::Triple::Hull: + case llvm::Triple::Domain: + case llvm::Triple::RayGeneration: + case llvm::Triple::Intersection: + case llvm::Triple::AnyHit: + case llvm::Triple::ClosestHit: + case llvm::Triple::Miss: + case llvm::Triple::Callable: if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { DiagnoseAttrStageMismatch(NT, ST, - {HLSLShaderAttr::Compute, - HLSLShaderAttr::Amplification, - HLSLShaderAttr::Mesh}); + {llvm::Triple::Compute, + llvm::Triple::Amplification, + llvm::Triple::Mesh}); FD->setInvalidDecl(); } break; - case HLSLShaderAttr::Compute: - case HLSLShaderAttr::Amplification: - case HLSLShaderAttr::Mesh: + case llvm::Triple::Compute: + case llvm::Triple::Amplification: + case llvm::Triple::Mesh: if (!FD->hasAttr<HLSLNumThreadsAttr>()) { Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) - << HLSLShaderAttr::ConvertShaderTypeToStr(ST); + << llvm::Triple::getEnvironmentTypeName(ST); FD->setInvalidDecl(); } break; + default: + llvm_unreachable("Unhandled environment in triple"); } for (ParmVarDecl *Param : FD->parameters()) { @@ -268,14 +269,14 @@ void SemaHLSL::CheckSemanticAnnotation( const HLSLAnnotationAttr *AnnotationAttr) { auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); assert(ShaderAttr && "Entry point has no shader attribute"); - HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); + llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: - if (ST == HLSLShaderAttr::Compute) + if (ST == llvm::Triple::Compute) return; - DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute}); + DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute}); break; default: llvm_unreachable("Unknown HLSLAnnotationAttr"); @@ -283,16 +284,16 @@ void SemaHLSL::CheckSemanticAnnotation( } void SemaHLSL::DiagnoseAttrStageMismatch( - const Attr *A, HLSLShaderAttr::ShaderType Stage, - std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) { + const Attr *A, llvm::Triple::EnvironmentType Stage, + std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages) { SmallVector<StringRef, 8> StageStrings; llvm::transform(AllowedStages, std::back_inserter(StageStrings), - [](HLSLShaderAttr::ShaderType ST) { + [](llvm::Triple::EnvironmentType ST) { return StringRef( - HLSLShaderAttr::ConvertShaderTypeToStr(ST)); + HLSLShaderAttr::ConvertEnvironmentTypeToStr(ST)); }); Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) - << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage) + << A << llvm::Triple::getEnvironmentTypeName(Stage) << (AllowedStages.size() != 1) << join(StageStrings, ", "); } @@ -430,8 +431,8 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc)) return; - HLSLShaderAttr::ShaderType ShaderType; - if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) { + llvm::Triple::EnvironmentType ShaderType; + if (!HLSLShaderAttr::ConvertStrToEnvironmentType(Str, ShaderType)) { Diag(AL.getLoc(), diag::warn_attribute_type_not_supported) << AL << Str << ArgLoc; return; @@ -549,16 +550,22 @@ class DiagnoseHLSLAvailability // // Maps FunctionDecl to an unsigned number that represents the set of shader // environments the function has been scanned for. - // Since HLSLShaderAttr::ShaderType enum is generated from Attr.td and is - // defined without any assigned values, it is guaranteed to be numbered - // sequentially from 0 up and we can use it to 'index' individual bits - // in the set. + // The llvm::Triple::EnvironmentType enum values for shader stages guaranteed + // to be numbered from llvm::Triple::Pixel to llvm::Triple::Amplification + // (verified by static_asserts in Triple.cpp), we can use it to index + // individual bits in the set, as long as we shift the values to start with 0 + // by subtracting the value of llvm::Triple::Pixel first. + // // The N'th bit in the set will be set if the function has been scanned - // in shader environment whose ShaderType integer value equals N. + // in shader environment whose llvm::Triple::EnvironmentType integer value + // equals (llvm::Triple::Pixel + N). + // // For example, if a function has been scanned in compute and pixel stage - // environment, the value will be 0x21 (100001 binary) because - // (int)HLSLShaderAttr::ShaderType::Pixel == 1 and - // (int)HLSLShaderAttr::ShaderType::Compute == 5. + // environment, the value will be 0x21 (100001 binary) because: + // + // (int)(llvm::Triple::Pixel - llvm::Triple::Pixel) == 0 + // (int)(llvm::Triple::Compute - llvm::Triple::Pixel) == 5 + // // A FunctionDecl is mapped to 0 (or not included in the map) if it has not // been scanned in any environment. llvm::DenseMap<const FunctionDecl *, unsigned> ScannedDecls; @@ -574,12 +581,16 @@ class DiagnoseHLSLAvailability bool ReportOnlyShaderStageIssues; // Helper methods for dealing with current stage context / environment - void SetShaderStageContext(HLSLShaderAttr::ShaderType ShaderType) { + void SetShaderStageContext(llvm::Triple::EnvironmentType ShaderType) { static_assert(sizeof(unsigned) >= 4); - assert((unsigned)ShaderType < 31); // 31 is reserved for "unknown" - - CurrentShaderEnvironment = HLSLShaderAttr::getTypeAsEnvironment(ShaderType); - CurrentShaderStageBit = (1 << ShaderType); + assert(HLSLShaderAttr::isValidShaderType(ShaderType)); + assert((unsigned)(ShaderType - llvm::Triple::Pixel) < 31 && + "ShaderType is too big for this bitmap"); // 31 is reserved for + // "unknown" + + unsigned bitmapIndex = ShaderType - llvm::Triple::Pixel; + CurrentShaderEnvironment = ShaderType; + CurrentShaderStageBit = (1 << bitmapIndex); } void SetUnknownShaderStageContext() { _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits