https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/146785
>From a49aa19297811e5800ffce364d8d6a225109d93f Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Thu, 26 Jun 2025 19:28:01 +0000 Subject: [PATCH 1/7] refactoring --- .../lib/Target/DirectX/DXContainerGlobals.cpp | 4 ++- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 14 +++----- llvm/lib/Target/DirectX/DXILRootSignature.h | 33 +++++++++---------- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 6c8ae8eaaea77..e076283b65193 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -160,11 +160,13 @@ void DXContainerGlobals::addRootSignature(Module &M, assert(MMI.EntryPropertyVec.size() == 1); + auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry; const auto &RS = RSA.getDescForFunction(EntryFunction); + const auto &RS = RSA.getDescForFunction(EntryFunction); - if (!RS) + if (!RS ) return; SmallString<256> Data; diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 5a53ea8a3631b..4094df160ef6f 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -554,12 +554,9 @@ analyzeModule(Module &M) { AnalysisKey RootSignatureAnalysis::Key; -RootSignatureAnalysis::Result -RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) { - if (!AnalysisResult) - AnalysisResult = std::make_unique<RootSignatureBindingInfo>( - RootSignatureBindingInfo(analyzeModule(M))); - return *AnalysisResult; +RootSignatureBindingInfo RootSignatureAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + return RootSignatureBindingInfo(analyzeModule(M)); } //===----------------------------------------------------------------------===// @@ -638,9 +635,8 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, //===----------------------------------------------------------------------===// bool RootSignatureAnalysisWrapper::runOnModule(Module &M) { - if (!FuncToRsMap) - FuncToRsMap = std::make_unique<RootSignatureBindingInfo>( - RootSignatureBindingInfo(analyzeModule(M))); + FuncToRsMap = std::make_unique<RootSignatureBindingInfo>( + RootSignatureBindingInfo(analyzeModule(M))); return false; } diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index 3832182277050..24b1a8d3d2abe 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -37,30 +37,28 @@ enum class RootSignatureElementKind { }; class RootSignatureBindingInfo { -private: - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; + private: + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; -public: + public: using iterator = - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; - RootSignatureBindingInfo() = default; - RootSignatureBindingInfo( - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) - : FuncToRsMap(Map) {}; + RootSignatureBindingInfo () = default; + RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {}; iterator find(const Function *F) { return FuncToRsMap.find(F); } iterator end() { return FuncToRsMap.end(); } - std::optional<mcdxbc::RootSignatureDesc> - getDescForFunction(const Function *F) { + std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function* F) { const auto FuncRs = find(F); if (FuncRs == end()) return std::nullopt; return FuncRs->second; } + }; class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { @@ -68,14 +66,13 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { static AnalysisKey Key; public: - RootSignatureAnalysis() = default; - - using Result = RootSignatureBindingInfo; - Result run(Module &M, ModuleAnalysisManager &AM); +RootSignatureAnalysis() = default; -private: - std::unique_ptr<RootSignatureBindingInfo> AnalysisResult; + using Result = RootSignatureBindingInfo; + + RootSignatureBindingInfo + run(Module &M, ModuleAnalysisManager &AM); }; /// Wrapper pass for the legacy pass manager. @@ -92,8 +89,8 @@ class RootSignatureAnalysisWrapper : public ModulePass { RootSignatureAnalysisWrapper() : ModulePass(ID) {} - RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; } - + RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;} + bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override; >From d90676feb6bfc0ca8bbdaee5c347ecc49e396b5b Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Thu, 26 Jun 2025 21:37:11 +0000 Subject: [PATCH 2/7] init refactoring --- .../SemaHLSL/RootSignature-Validation.hlsl | 42 +++++++++++++++++ .../lib/Target/DirectX/DXContainerGlobals.cpp | 2 +- .../DXILPostOptimizationValidation.cpp | 47 +++++++++++++++++-- llvm/lib/Target/DirectX/DXILRootSignature.h | 30 ++++++------ 4 files changed, 102 insertions(+), 19 deletions(-) create mode 100644 clang/test/SemaHLSL/RootSignature-Validation.hlsl diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl new file mode 100644 index 0000000000000..8a4a97f87cb65 --- /dev/null +++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl @@ -0,0 +1,42 @@ +// RUN: %clang_dxc -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -verify + +#define ROOT_SIGNATURE \ + "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \ + "CBV(b0, visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(SRV(t0, numDescriptors=3), visibility=SHADER_VISIBILITY_PIXEL), " \ + "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_PIXEL), " \ + "DescriptorTable(UAV(u0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL)" + +cbuffer CB : register(b3, space2) { + float a; +} + +StructuredBuffer<int> In : register(t0); +RWStructuredBuffer<int> Out : register(u0); + +RWBuffer<float> UAV : register(u3); + +RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4); + +RWBuffer<float> UAV3 : register(space5); + +float f : register(c5); + +int4 intv : register(c2); + +double dar[5] : register(c3); + +struct S { + int a; +}; + +S s : register(c10); + +// Compute Shader for UAV testing +[numthreads(8, 8, 1)] +[RootSignature(ROOT_SIGNATURE)] +void CSMain(uint3 id : SV_DispatchThreadID) +{ + In[0] = id; + Out[0] = In[0]; +} diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index e076283b65193..5c763c24a210a 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -166,7 +166,7 @@ void DXContainerGlobals::addRootSignature(Module &M, const auto &RS = RSA.getDescForFunction(EntryFunction); const auto &RS = RSA.getDescForFunction(EntryFunction); - if (!RS ) + if (!RS) return; SmallString<256> Data; diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index 398dcbb8d1737..daf53fefe5f17 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -7,11 +7,14 @@ //===----------------------------------------------------------------------===// #include "DXILPostOptimizationValidation.h" +#include "DXILRootSignature.h" #include "DXILShaderFlags.h" #include "DirectX.h" +#include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallString.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" +#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsDirectX.h" @@ -85,7 +88,9 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) { } static void reportErrors(Module &M, DXILResourceMap &DRM, - DXILResourceBindingInfo &DRBI) { + DXILResourceBindingInfo &DRBI, + RootSignatureBindingInfo &RSBI, + dxil::ModuleMetadataInfo &MMI) { if (DRM.hasInvalidCounterDirection()) reportInvalidDirection(M, DRM); @@ -94,6 +99,30 @@ static void reportErrors(Module &M, DXILResourceMap &DRM, assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in " "DXILResourceImplicitBinding pass"); + // Assuming this is used to validate only the root signature assigned to the + // entry function. + std::optional<mcdxbc::RootSignatureDesc> RootSigDesc = + RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry); + if (!RootSigDesc) + return; + + for (const mcdxbc::RootParameterInfo &Info : + RootSigDesc->ParametersContainer) { + const auto &[Type, Loc] = + RootSigDesc->ParametersContainer.getTypeAndLocForParameter( + Info.Location); + switch (Type) { + case llvm::to_underlying(dxbc::RootParameterType::CBV): + dxbc::RTS0::v2::RootDescriptor Desc = + RootSigDesc->ParametersContainer.getRootDescriptor(Loc); + + llvm::dxil::ResourceInfo::ResourceBinding Binding; + Binding.LowerBound = Desc.ShaderRegister; + Binding.Space = Desc.RegisterSpace; + Binding.Size = 1; + break; + } + } } } // namespace @@ -101,7 +130,10 @@ PreservedAnalyses DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) { DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M); DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M); - reportErrors(M, DRM, DRBI); + RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M); + ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M); + + reportErrors(M, DRM, DRBI, RSBI, MMI); return PreservedAnalyses::all(); } @@ -113,7 +145,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass { getAnalysis<DXILResourceWrapperPass>().getResourceMap(); DXILResourceBindingInfo &DRBI = getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo(); - reportErrors(M, DRM, DRBI); + + RootSignatureBindingInfo &RSBI = + getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); + dxil::ModuleMetadataInfo &MMI = + getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); + + reportErrors(M, DRM, DRBI, RSBI, MMI); return false; } StringRef getPassName() const override { @@ -125,10 +163,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass { void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { AU.addRequired<DXILResourceWrapperPass>(); AU.addRequired<DXILResourceBindingWrapperPass>(); + AU.addRequired<RootSignatureAnalysisWrapper>(); + AU.addRequired<DXILMetadataAnalysisWrapperPass>(); AU.addPreserved<DXILResourceWrapperPass>(); AU.addPreserved<DXILResourceBindingWrapperPass>(); AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); AU.addPreserved<ShaderFlagsAnalysisWrapper>(); + AU.addPreserved<RootSignatureAnalysisWrapper>(); } }; char DXILPostOptimizationValidationLegacy::ID = 0; diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index 24b1a8d3d2abe..ecfc577d1b97d 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -37,28 +37,30 @@ enum class RootSignatureElementKind { }; class RootSignatureBindingInfo { - private: - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; +private: + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; - public: +public: using iterator = - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; - RootSignatureBindingInfo () = default; - RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {}; + RootSignatureBindingInfo() = default; + RootSignatureBindingInfo( + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) + : FuncToRsMap(Map){}; iterator find(const Function *F) { return FuncToRsMap.find(F); } iterator end() { return FuncToRsMap.end(); } - std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function* F) { + std::optional<mcdxbc::RootSignatureDesc> + getDescForFunction(const Function *F) { const auto FuncRs = find(F); if (FuncRs == end()) return std::nullopt; return FuncRs->second; } - }; class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { @@ -66,13 +68,11 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { static AnalysisKey Key; public: - -RootSignatureAnalysis() = default; + RootSignatureAnalysis() = default; using Result = RootSignatureBindingInfo; - - RootSignatureBindingInfo - run(Module &M, ModuleAnalysisManager &AM); + + RootSignatureBindingInfo run(Module &M, ModuleAnalysisManager &AM); }; /// Wrapper pass for the legacy pass manager. @@ -89,8 +89,8 @@ class RootSignatureAnalysisWrapper : public ModulePass { RootSignatureAnalysisWrapper() : ModulePass(ID) {} - RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;} - + RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; } + bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override; >From a04eb9ff37d20499f05c7b1cc0ab3187f729609b Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Wed, 2 Jul 2025 17:58:56 +0000 Subject: [PATCH 3/7] adding validation --- .../SemaHLSL/RootSignature-Validation.hlsl | 28 ++++--------- .../DXILPostOptimizationValidation.cpp | 42 +++++++++++++++---- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl index 8a4a97f87cb65..62ba704b95c7d 100644 --- a/clang/test/SemaHLSL/RootSignature-Validation.hlsl +++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl @@ -1,42 +1,30 @@ -// RUN: %clang_dxc -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -verify #define ROOT_SIGNATURE \ "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \ - "CBV(b0, visibility=SHADER_VISIBILITY_ALL), " \ - "DescriptorTable(SRV(t0, numDescriptors=3), visibility=SHADER_VISIBILITY_PIXEL), " \ - "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_PIXEL), " \ - "DescriptorTable(UAV(u0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL)" + "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)" cbuffer CB : register(b3, space2) { float a; } -StructuredBuffer<int> In : register(t0); +StructuredBuffer<int> In : register(t0, space0); RWStructuredBuffer<int> Out : register(u0); RWBuffer<float> UAV : register(u3); RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4); -RWBuffer<float> UAV3 : register(space5); +RWBuffer<float> UAV3 : register(space0); -float f : register(c5); -int4 intv : register(c2); - -double dar[5] : register(c3); - -struct S { - int a; -}; - -S s : register(c10); // Compute Shader for UAV testing [numthreads(8, 8, 1)] [RootSignature(ROOT_SIGNATURE)] -void CSMain(uint3 id : SV_DispatchThreadID) +void CSMain(uint id : SV_GroupID) { - In[0] = id; - Out[0] = In[0]; + Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0]; } diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index daf53fefe5f17..3e542e502c2d5 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -10,6 +10,7 @@ #include "DXILRootSignature.h" #include "DXILShaderFlags.h" #include "DirectX.h" +#include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallString.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" @@ -86,7 +87,9 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) { } } } - + uint64_t combine_uint32_to_uint64(uint32_t high, uint32_t low) { + return (static_cast<uint64_t>(high) << 32) | low; + } static void reportErrors(Module &M, DXILResourceMap &DRM, DXILResourceBindingInfo &DRBI, RootSignatureBindingInfo &RSBI, @@ -101,18 +104,24 @@ static void reportErrors(Module &M, DXILResourceMap &DRM, "DXILResourceImplicitBinding pass"); // Assuming this is used to validate only the root signature assigned to the // entry function. + //Start test stuff + if(MMI.EntryPropertyVec.size() == 0) + return; + std::optional<mcdxbc::RootSignatureDesc> RootSigDesc = RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry); if (!RootSigDesc) return; - for (const mcdxbc::RootParameterInfo &Info : - RootSigDesc->ParametersContainer) { + using MapT = llvm::IntervalMap<uint64_t, llvm::dxil::ResourceInfo::ResourceBinding, sizeof(llvm::dxil::ResourceInfo::ResourceBinding), llvm::IntervalMapInfo<uint64_t>>; + MapT::Allocator Allocator; + MapT BindingsMap(Allocator); + auto RSD = *RootSigDesc; + for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) { const auto &[Type, Loc] = - RootSigDesc->ParametersContainer.getTypeAndLocForParameter( - Info.Location); + RootSigDesc->ParametersContainer.getTypeAndLocForParameter(I); switch (Type) { - case llvm::to_underlying(dxbc::RootParameterType::CBV): + case llvm::to_underlying(dxbc::RootParameterType::CBV):{ dxbc::RTS0::v2::RootDescriptor Desc = RootSigDesc->ParametersContainer.getRootDescriptor(Loc); @@ -120,8 +129,27 @@ static void reportErrors(Module &M, DXILResourceMap &DRM, Binding.LowerBound = Desc.ShaderRegister; Binding.Space = Desc.RegisterSpace; Binding.Size = 1; + + BindingsMap.insert(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1), Binding); break; } + // case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):{ + // mcdxbc::DescriptorTable Table = + // RootSigDesc->ParametersContainer.getDescriptorTable(Loc); + // for (const dxbc::RTS0::v2::DescriptorRange &Range : Table){ + // Range. + // } + + // break; + // } + } + + } + + for(const auto &CBuf : DRM.cbuffers()) { + auto Binding = CBuf.getBinding(); + if(!BindingsMap.overlaps(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1))) + auto X = 1; } } } // namespace @@ -146,7 +174,7 @@ class DXILPostOptimizationValidationLegacy : public ModulePass { DXILResourceBindingInfo &DRBI = getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo(); - RootSignatureBindingInfo &RSBI = + RootSignatureBindingInfo& RSBI = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); dxil::ModuleMetadataInfo &MMI = getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); >From 5994b8f8f4ea24115a66c0046c8fc344905b41d4 Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Wed, 2 Jul 2025 21:19:37 +0000 Subject: [PATCH 4/7] clean --- .../DXILPostOptimizationValidation.cpp | 6 +---- .../DirectX/DXILPostOptimizationValidation.h | 3 +++ llvm/lib/Target/DirectX/DXILRootSignature.h | 24 +++++++++---------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index 3e542e502c2d5..4c29b56304391 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -10,12 +10,9 @@ #include "DXILRootSignature.h" #include "DXILShaderFlags.h" #include "DirectX.h" -#include "llvm/ADT/IntervalMap.h" -#include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/SmallString.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" -#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicsDirectX.h" @@ -173,8 +170,7 @@ class DXILPostOptimizationValidationLegacy : public ModulePass { getAnalysis<DXILResourceWrapperPass>().getResourceMap(); DXILResourceBindingInfo &DRBI = getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo(); - - RootSignatureBindingInfo& RSBI = + RootSignatureBindingInfo &RSBI = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); dxil::ModuleMetadataInfo &MMI = getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h index cb5e624514272..151843daf068d 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h @@ -14,6 +14,9 @@ #ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H #define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H +#include "DXILRootSignature.h" +#include "llvm/ADT/IntervalMap.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/PassManager.h" namespace llvm { diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index ecfc577d1b97d..d0d5c7785bda3 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -37,30 +37,28 @@ enum class RootSignatureElementKind { }; class RootSignatureBindingInfo { -private: - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; + private: + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; -public: + public: using iterator = - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; - RootSignatureBindingInfo() = default; - RootSignatureBindingInfo( - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) - : FuncToRsMap(Map){}; +RootSignatureBindingInfo () = default; + RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {}; iterator find(const Function *F) { return FuncToRsMap.find(F); } iterator end() { return FuncToRsMap.end(); } - std::optional<mcdxbc::RootSignatureDesc> - getDescForFunction(const Function *F) { + std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function *F) { const auto FuncRs = find(F); if (FuncRs == end()) return std::nullopt; return FuncRs->second; } + }; class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { @@ -68,7 +66,7 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { static AnalysisKey Key; public: - RootSignatureAnalysis() = default; +RootSignatureAnalysis() = default; using Result = RootSignatureBindingInfo; @@ -88,8 +86,8 @@ class RootSignatureAnalysisWrapper : public ModulePass { using Result = RootSignatureBindingInfo; RootSignatureAnalysisWrapper() : ModulePass(ID) {} - - RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; } + + RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;} bool runOnModule(Module &M) override; >From e8b14bf32e47cf8c059d2f492e57a602375ceeaa Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Fri, 4 Jul 2025 02:03:26 +0000 Subject: [PATCH 5/7] implementing --- .../RootSignature-Validation-Fail.hlsl | 35 ++++ .../SemaHLSL/RootSignature-Validation.hlsl | 11 +- .../DXILPostOptimizationValidation.cpp | 166 +++++++++++++----- .../DirectX/DXILPostOptimizationValidation.h | 88 ++++++++++ llvm/lib/Target/DirectX/DXILRootSignature.h | 24 +-- .../RootSignature-DescriptorTable.ll | 4 +- 6 files changed, 271 insertions(+), 57 deletions(-) create mode 100644 clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl diff --git a/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl new file mode 100644 index 0000000000000..b590ed67e7085 --- /dev/null +++ b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl @@ -0,0 +1,35 @@ +// RUN: not %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 | FileCheck %s + +// CHECK: error: register cbuffer (space=665, register=3) is not defined in Root Signature +// CHECK: error: register srv (space=0, register=0) is not defined in Root Signature +// CHECK: error: register uav (space=0, register=4294967295) is not defined in Root Signature + + +#define ROOT_SIGNATURE \ + "CBV(b3, space=666, visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_VERTEX), " \ + "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)" + +cbuffer CB : register(b3, space665) { + float a; +} + +StructuredBuffer<int> In : register(t0, space0); +RWStructuredBuffer<int> Out : register(u0); + +RWBuffer<float> UAV : register(u4294967295); + +RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4); + +RWBuffer<float> UAV3 : register(space0); + + + +// Compute Shader for UAV testing +[numthreads(8, 8, 1)] +[RootSignature(ROOT_SIGNATURE)] +void CSMain(uint id : SV_GroupID) +{ + Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0]; +} diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl index 62ba704b95c7d..5a7f5baf00619 100644 --- a/clang/test/SemaHLSL/RootSignature-Validation.hlsl +++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl @@ -1,19 +1,22 @@ +// RUN: %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 + +// expected-no-diagnostics + #define ROOT_SIGNATURE \ - "RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT), " \ "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \ "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \ - "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \ + "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_VERTEX), " \ "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)" -cbuffer CB : register(b3, space2) { +cbuffer CB : register(b3, space1) { float a; } StructuredBuffer<int> In : register(t0, space0); RWStructuredBuffer<int> Out : register(u0); -RWBuffer<float> UAV : register(u3); +RWBuffer<float> UAV : register(u4294967294); RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4); diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index 4c29b56304391..23bb5d1a7f651 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -84,9 +84,57 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) { } } } - uint64_t combine_uint32_to_uint64(uint32_t high, uint32_t low) { - return (static_cast<uint64_t>(high) << 32) | low; + +static void reportRegNotBound(Module &M, Twine Type, + ResourceInfo::ResourceBinding Binding) { + SmallString<128> Message; + raw_svector_ostream OS(Message); + OS << "register " << Type << " (space=" << Binding.Space + << ", register=" << Binding.LowerBound << ")" + << " is not defined in Root Signature"; + M.getContext().diagnose(DiagnosticInfoGeneric(Message)); +} + +static dxbc::ShaderVisibility +tripleToVisibility(llvm::Triple::EnvironmentType ET) { + assert((ET == Triple::Pixel || ET == Triple::Vertex || + ET == Triple::Geometry || ET == Triple::Hull || + ET == Triple::Domain || ET == Triple::Mesh || + ET == Triple::Compute) && + "Invalid Triple to shader stage conversion"); + + switch (ET) { + case Triple::Pixel: + return dxbc::ShaderVisibility::Pixel; + case Triple::Vertex: + return dxbc::ShaderVisibility::Vertex; + case Triple::Geometry: + return dxbc::ShaderVisibility::Geometry; + case Triple::Hull: + return dxbc::ShaderVisibility::Hull; + case Triple::Domain: + return dxbc::ShaderVisibility::Domain; + case Triple::Mesh: + return dxbc::ShaderVisibility::Mesh; + case Triple::Compute: + return dxbc::ShaderVisibility::All; + default: + llvm_unreachable("Invalid triple to shader stage conversion"); } +} + +std::optional<mcdxbc::RootSignatureDesc> +getRootSignature(RootSignatureBindingInfo &RSBI, + dxil::ModuleMetadataInfo &MMI) { + if (MMI.EntryPropertyVec.size() == 0) + return std::nullopt; + std::optional<mcdxbc::RootSignatureDesc> RootSigDesc = + RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry); + if (!RootSigDesc) + return std::nullopt; + return RootSigDesc; +} + static void reportErrors(Module &M, DXILResourceMap &DRM, DXILResourceBindingInfo &DRBI, RootSignatureBindingInfo &RSBI, @@ -99,57 +147,95 @@ static void reportErrors(Module &M, DXILResourceMap &DRM, assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in " "DXILResourceImplicitBinding pass"); - // Assuming this is used to validate only the root signature assigned to the - // entry function. - //Start test stuff - if(MMI.EntryPropertyVec.size() == 0) - return; - std::optional<mcdxbc::RootSignatureDesc> RootSigDesc = - RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry); - if (!RootSigDesc) - return; + if (auto RSD = getRootSignature(RSBI, MMI)) { + + RootSignatureBindingValidation Validation; + Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile)); + + for (const auto &CBuf : DRM.cbuffers()) { + ResourceInfo::ResourceBinding Binding = CBuf.getBinding(); + if (!Validation.checkCregBinding(Binding)) + reportRegNotBound(M, "cbuffer", Binding); + } + + for (const auto &CBuf : DRM.srvs()) { + ResourceInfo::ResourceBinding Binding = CBuf.getBinding(); + if (!Validation.checkTRegBinding(Binding)) + reportRegNotBound(M, "srv", Binding); + } - using MapT = llvm::IntervalMap<uint64_t, llvm::dxil::ResourceInfo::ResourceBinding, sizeof(llvm::dxil::ResourceInfo::ResourceBinding), llvm::IntervalMapInfo<uint64_t>>; - MapT::Allocator Allocator; - MapT BindingsMap(Allocator); - auto RSD = *RootSigDesc; - for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) { + for (const auto &CBuf : DRM.uavs()) { + ResourceInfo::ResourceBinding Binding = CBuf.getBinding(); + if (!Validation.checkURegBinding(Binding)) + reportRegNotBound(M, "uav", Binding); + } + } +} +} // namespace + +void RootSignatureBindingValidation::addRsBindingInfo( + mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility) { + for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) { const auto &[Type, Loc] = - RootSigDesc->ParametersContainer.getTypeAndLocForParameter(I); + RSD.ParametersContainer.getTypeAndLocForParameter(I); + + const auto &Header = RSD.ParametersContainer.getHeader(I); switch (Type) { - case llvm::to_underlying(dxbc::RootParameterType::CBV):{ + case llvm::to_underlying(dxbc::RootParameterType::SRV): + case llvm::to_underlying(dxbc::RootParameterType::UAV): + case llvm::to_underlying(dxbc::RootParameterType::CBV): { dxbc::RTS0::v2::RootDescriptor Desc = - RootSigDesc->ParametersContainer.getRootDescriptor(Loc); + RSD.ParametersContainer.getRootDescriptor(Loc); - llvm::dxil::ResourceInfo::ResourceBinding Binding; - Binding.LowerBound = Desc.ShaderRegister; - Binding.Space = Desc.RegisterSpace; - Binding.Size = 1; + if (Header.ShaderVisibility == + llvm::to_underlying(dxbc::ShaderVisibility::All) || + Header.ShaderVisibility == llvm::to_underlying(Visibility)) + addRange(Desc, Type); + break; + } + case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): { + const mcdxbc::DescriptorTable &Table = + RSD.ParametersContainer.getDescriptorTable(Loc); - BindingsMap.insert(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1), Binding); + for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) { + if (Range.RangeType == + llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) + continue; + + if (Header.ShaderVisibility == + llvm::to_underlying(dxbc::ShaderVisibility::All) || + Header.ShaderVisibility == llvm::to_underlying(Visibility)) + addRange(Range); + } break; } - // case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):{ - // mcdxbc::DescriptorTable Table = - // RootSigDesc->ParametersContainer.getDescriptorTable(Loc); - // for (const dxbc::RTS0::v2::DescriptorRange &Range : Table){ - // Range. - // } - - // break; - // } } - } +} - for(const auto &CBuf : DRM.cbuffers()) { - auto Binding = CBuf.getBinding(); - if(!BindingsMap.overlaps(combine_uint32_to_uint64(Binding.Space, Binding.LowerBound), combine_uint32_to_uint64(Binding.Space, Binding.LowerBound + Binding.Size -1))) - auto X = 1; - } +bool RootSignatureBindingValidation::checkCregBinding( + ResourceInfo::ResourceBinding Binding) { + return CRegBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, + Binding.LowerBound + Binding.Size - 1)); +} + +bool RootSignatureBindingValidation::checkTRegBinding( + ResourceInfo::ResourceBinding Binding) { + return TRegBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, Binding.LowerBound + Binding.Size)); +} + +bool RootSignatureBindingValidation::checkURegBinding( + ResourceInfo::ResourceBinding Binding) { + return URegBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, + Binding.LowerBound + Binding.Size - 1)); } -} // namespace PreservedAnalyses DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) { diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h index 151843daf068d..58113bf9f93c7 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h @@ -21,6 +21,94 @@ namespace llvm { +static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) { + return (static_cast<uint64_t>(High) << 32) | Low; +} + +class RootSignatureBindingValidation { + using MapT = + llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding, + sizeof(llvm::dxil::ResourceInfo::ResourceBinding), + llvm::IntervalMapInfo<uint64_t>>; + +private: + MapT::Allocator Allocator; + MapT CRegBindingsMap; + MapT TRegBindingsMap; + MapT URegBindingsMap; + + void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) { + assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) || + Type == llvm::to_underlying(dxbc::RootParameterType::SRV) || + Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) && + "Invalid Type"); + + llvm::dxil::ResourceInfo::ResourceBinding Binding; + Binding.LowerBound = Desc.ShaderRegister; + Binding.Space = Desc.RegisterSpace; + Binding.Size = 1; + + uint64_t LowRange = + combineUint32ToUint64(Binding.Space, Binding.LowerBound); + uint64_t HighRange = combineUint32ToUint64( + Binding.Space, Binding.LowerBound + Binding.Size - 1); + + switch (Type) { + + case llvm::to_underlying(dxbc::RootParameterType::CBV): + CRegBindingsMap.insert(LowRange, HighRange, Binding); + return; + case llvm::to_underlying(dxbc::RootParameterType::SRV): + TRegBindingsMap.insert(LowRange, HighRange, Binding); + return; + case llvm::to_underlying(dxbc::RootParameterType::UAV): + URegBindingsMap.insert(LowRange, HighRange, Binding); + return; + } + llvm_unreachable("Invalid Type in add Range Method"); + } + + void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) { + + llvm::dxil::ResourceInfo::ResourceBinding Binding; + Binding.LowerBound = Range.BaseShaderRegister; + Binding.Space = Range.RegisterSpace; + Binding.Size = Range.NumDescriptors; + + uint64_t LowRange = + combineUint32ToUint64(Binding.Space, Binding.LowerBound); + uint64_t HighRange = combineUint32ToUint64( + Binding.Space, Binding.LowerBound + Binding.Size - 1); + + switch (Range.RangeType) { + case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): + CRegBindingsMap.insert(LowRange, HighRange, Binding); + return; + case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): + TRegBindingsMap.insert(LowRange, HighRange, Binding); + return; + case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): + URegBindingsMap.insert(LowRange, HighRange, Binding); + return; + } + llvm_unreachable("Invalid Type in add Range Method"); + } + +public: + RootSignatureBindingValidation() + : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator), + URegBindingsMap(Allocator) {} + + void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD, + dxbc::ShaderVisibility Visibility); + + bool checkCregBinding(dxil::ResourceInfo::ResourceBinding Binding); + + bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding); + + bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding); +}; + class DXILPostOptimizationValidation : public PassInfoMixin<DXILPostOptimizationValidation> { public: diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h index d0d5c7785bda3..ecfc577d1b97d 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.h +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -37,28 +37,30 @@ enum class RootSignatureElementKind { }; class RootSignatureBindingInfo { - private: - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; +private: + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap; - public: +public: using iterator = - SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator; -RootSignatureBindingInfo () = default; - RootSignatureBindingInfo(SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) : FuncToRsMap(Map) {}; + RootSignatureBindingInfo() = default; + RootSignatureBindingInfo( + SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> Map) + : FuncToRsMap(Map){}; iterator find(const Function *F) { return FuncToRsMap.find(F); } iterator end() { return FuncToRsMap.end(); } - std::optional<mcdxbc::RootSignatureDesc> getDescForFunction(const Function *F) { + std::optional<mcdxbc::RootSignatureDesc> + getDescForFunction(const Function *F) { const auto FuncRs = find(F); if (FuncRs == end()) return std::nullopt; return FuncRs->second; } - }; class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { @@ -66,7 +68,7 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { static AnalysisKey Key; public: -RootSignatureAnalysis() = default; + RootSignatureAnalysis() = default; using Result = RootSignatureBindingInfo; @@ -86,8 +88,8 @@ class RootSignatureAnalysisWrapper : public ModulePass { using Result = RootSignatureBindingInfo; RootSignatureAnalysisWrapper() : ModulePass(ID) {} - - RootSignatureBindingInfo& getRSInfo() {return *FuncToRsMap;} + + RootSignatureBindingInfo &getRSInfo() { return *FuncToRsMap; } bool runOnModule(Module &M) override; diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll index b516d66180247..8e9b4b43b11a6 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll @@ -16,7 +16,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } !2 = !{ ptr @main, !3, i32 2 } ; function, root signature !3 = !{ !5 } ; list of root signature elements !5 = !{ !"DescriptorTable", i32 0, !6, !7 } -!6 = !{ !"SRV", i32 0, i32 1, i32 0, i32 -1, i32 4 } +!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 4 } !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 } ; DXC: - Name: RTS0 @@ -35,7 +35,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } ; DXC-NEXT: RangesOffset: 44 ; DXC-NEXT: Ranges: ; DXC-NEXT: - RangeType: 0 -; DXC-NEXT: NumDescriptors: 0 +; DXC-NEXT: NumDescriptors: 1 ; DXC-NEXT: BaseShaderRegister: 1 ; DXC-NEXT: RegisterSpace: 0 ; DXC-NEXT: OffsetInDescriptorsFromTableStart: 4294967295 >From 8f40e83ab0db147e90070f15708d0a0f4e1a9d1f Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Fri, 4 Jul 2025 19:24:25 +0000 Subject: [PATCH 6/7] finish implementing && fix tests --- .../DXILPostOptimizationValidation.cpp | 45 +++++----------- .../DirectX/DXILPostOptimizationValidation.h | 54 ++++++++++++++----- llvm/lib/Target/DirectX/DXILRootSignature.cpp | 5 +- ...criptorTable-AllValidFlagCombinationsV1.ll | 4 +- llvm/test/CodeGen/DirectX/llc-pipeline.ll | 1 + 5 files changed, 59 insertions(+), 50 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp index 23bb5d1a7f651..a52a04323514c 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp @@ -153,23 +153,29 @@ static void reportErrors(Module &M, DXILResourceMap &DRM, RootSignatureBindingValidation Validation; Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile)); - for (const auto &CBuf : DRM.cbuffers()) { + for (const ResourceInfo &CBuf : DRM.cbuffers()) { ResourceInfo::ResourceBinding Binding = CBuf.getBinding(); - if (!Validation.checkCregBinding(Binding)) + if (!Validation.checkCRegBinding(Binding)) reportRegNotBound(M, "cbuffer", Binding); } - for (const auto &CBuf : DRM.srvs()) { - ResourceInfo::ResourceBinding Binding = CBuf.getBinding(); + for (const ResourceInfo &SRV : DRM.srvs()) { + ResourceInfo::ResourceBinding Binding = SRV.getBinding(); if (!Validation.checkTRegBinding(Binding)) reportRegNotBound(M, "srv", Binding); } - for (const auto &CBuf : DRM.uavs()) { - ResourceInfo::ResourceBinding Binding = CBuf.getBinding(); + for (const ResourceInfo &UAV : DRM.uavs()) { + ResourceInfo::ResourceBinding Binding = UAV.getBinding(); if (!Validation.checkURegBinding(Binding)) reportRegNotBound(M, "uav", Binding); } + + for (const ResourceInfo &Sampler : DRM.samplers()) { + ResourceInfo::ResourceBinding Binding = Sampler.getBinding(); + if (!Validation.checkSamplerBinding(Binding)) + reportRegNotBound(M, "sampler", Binding); + } } } } // namespace @@ -199,10 +205,6 @@ void RootSignatureBindingValidation::addRsBindingInfo( RSD.ParametersContainer.getDescriptorTable(Loc); for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) { - if (Range.RangeType == - llvm::to_underlying(dxbc::DescriptorRangeType::Sampler)) - continue; - if (Header.ShaderVisibility == llvm::to_underlying(dxbc::ShaderVisibility::All) || Header.ShaderVisibility == llvm::to_underlying(Visibility)) @@ -214,29 +216,6 @@ void RootSignatureBindingValidation::addRsBindingInfo( } } -bool RootSignatureBindingValidation::checkCregBinding( - ResourceInfo::ResourceBinding Binding) { - return CRegBindingsMap.overlaps( - combineUint32ToUint64(Binding.Space, Binding.LowerBound), - combineUint32ToUint64(Binding.Space, - Binding.LowerBound + Binding.Size - 1)); -} - -bool RootSignatureBindingValidation::checkTRegBinding( - ResourceInfo::ResourceBinding Binding) { - return TRegBindingsMap.overlaps( - combineUint32ToUint64(Binding.Space, Binding.LowerBound), - combineUint32ToUint64(Binding.Space, Binding.LowerBound + Binding.Size)); -} - -bool RootSignatureBindingValidation::checkURegBinding( - ResourceInfo::ResourceBinding Binding) { - return URegBindingsMap.overlaps( - combineUint32ToUint64(Binding.Space, Binding.LowerBound), - combineUint32ToUint64(Binding.Space, - Binding.LowerBound + Binding.Size - 1)); -} - PreservedAnalyses DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) { DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M); diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h index 58113bf9f93c7..0fa0285425d7e 100644 --- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h +++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h @@ -36,12 +36,13 @@ class RootSignatureBindingValidation { MapT CRegBindingsMap; MapT TRegBindingsMap; MapT URegBindingsMap; + MapT SamplersBindingsMap; void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) { assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) || Type == llvm::to_underlying(dxbc::RootParameterType::SRV) || Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) && - "Invalid Type"); + "Invalid Type in add Range Method"); llvm::dxil::ResourceInfo::ResourceBinding Binding; Binding.LowerBound = Desc.ShaderRegister; @@ -53,19 +54,20 @@ class RootSignatureBindingValidation { uint64_t HighRange = combineUint32ToUint64( Binding.Space, Binding.LowerBound + Binding.Size - 1); + assert(LowRange <= HighRange && "Invalid range configuration"); + switch (Type) { case llvm::to_underlying(dxbc::RootParameterType::CBV): CRegBindingsMap.insert(LowRange, HighRange, Binding); - return; + break; case llvm::to_underlying(dxbc::RootParameterType::SRV): TRegBindingsMap.insert(LowRange, HighRange, Binding); - return; + break; case llvm::to_underlying(dxbc::RootParameterType::UAV): URegBindingsMap.insert(LowRange, HighRange, Binding); - return; + break; } - llvm_unreachable("Invalid Type in add Range Method"); } void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) { @@ -80,33 +82,59 @@ class RootSignatureBindingValidation { uint64_t HighRange = combineUint32ToUint64( Binding.Space, Binding.LowerBound + Binding.Size - 1); + assert(LowRange <= HighRange && "Invalid range configuration"); + switch (Range.RangeType) { case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): CRegBindingsMap.insert(LowRange, HighRange, Binding); - return; + break; case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): TRegBindingsMap.insert(LowRange, HighRange, Binding); - return; + break; case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): URegBindingsMap.insert(LowRange, HighRange, Binding); - return; + break; + case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler): + SamplersBindingsMap.insert(LowRange, HighRange, Binding); + break; } - llvm_unreachable("Invalid Type in add Range Method"); } public: RootSignatureBindingValidation() : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator), - URegBindingsMap(Allocator) {} + URegBindingsMap(Allocator), SamplersBindingsMap(Allocator) {} void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility); - bool checkCregBinding(dxil::ResourceInfo::ResourceBinding Binding); + bool checkCRegBinding(dxil::ResourceInfo::ResourceBinding Binding) { + return CRegBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, + Binding.LowerBound + Binding.Size - 1)); + } - bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding); + bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding) { + return TRegBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, + Binding.LowerBound + Binding.Size - 1)); + } - bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding); + bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding) { + return URegBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, + Binding.LowerBound + Binding.Size - 1)); + } + + bool checkSamplerBinding(dxil::ResourceInfo::ResourceBinding Binding) { + return SamplersBindingsMap.overlaps( + combineUint32ToUint64(Binding.Space, Binding.LowerBound), + combineUint32ToUint64(Binding.Space, + Binding.LowerBound + Binding.Size - 1)); + } }; class DXILPostOptimizationValidation diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 4094df160ef6f..2a68a4c324a09 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -635,8 +635,9 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, //===----------------------------------------------------------------------===// bool RootSignatureAnalysisWrapper::runOnModule(Module &M) { - FuncToRsMap = std::make_unique<RootSignatureBindingInfo>( - RootSignatureBindingInfo(analyzeModule(M))); + if (!FuncToRsMap) + FuncToRsMap = std::make_unique<RootSignatureBindingInfo>( + RootSignatureBindingInfo(analyzeModule(M))); return false; } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll index 9d89dbdd9107b..053721de1eb1f 100644 --- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll @@ -13,7 +13,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } !2 = !{ ptr @main, !3, i32 1 } ; function, root signature !3 = !{ !5 } ; list of root signature elements !5 = !{ !"DescriptorTable", i32 0, !6, !7 } -!6 = !{ !"Sampler", i32 0, i32 1, i32 0, i32 -1, i32 1 } +!6 = !{ !"Sampler", i32 1, i32 1, i32 0, i32 -1, i32 1 } !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 3 } @@ -33,7 +33,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } ; DXC-NEXT: RangesOffset: 44 ; DXC-NEXT: Ranges: ; DXC-NEXT: - RangeType: 3 -; DXC-NEXT: NumDescriptors: 0 +; DXC-NEXT: NumDescriptors: 1 ; DXC-NEXT: BaseShaderRegister: 1 ; DXC-NEXT: RegisterSpace: 0 ; DXC-NEXT: OffsetInDescriptorsFromTableStart: 4294967295 diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index 2b29fd30a7a56..8d75249dc6ecb 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -31,6 +31,7 @@ ; CHECK-NEXT: DXIL Module Metadata analysis ; CHECK-NEXT: DXIL Shader Flag Analysis ; CHECK-NEXT: DXIL Translate Metadata +; CHECK-NEXT: DXIL Root Signature Analysis ; CHECK-NEXT: DXIL Post Optimization Validation ; CHECK-NEXT: DXIL Op Lowering ; CHECK-NEXT: DXIL Prepare Module >From 28350b2dfe2a896b2199260953c1d061550badba Mon Sep 17 00:00:00 2001 From: joaosaffran <joao.saff...@microsoft.com> Date: Sat, 5 Jul 2025 00:35:07 +0000 Subject: [PATCH 7/7] fix issue --- llvm/lib/Target/DirectX/DXContainerGlobals.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 5c763c24a210a..6c8ae8eaaea77 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -160,11 +160,9 @@ void DXContainerGlobals::addRootSignature(Module &M, assert(MMI.EntryPropertyVec.size() == 1); - auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry; const auto &RS = RSA.getDescForFunction(EntryFunction); - const auto &RS = RSA.getDescForFunction(EntryFunction); if (!RS) return; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits