================ @@ -14,10 +14,129 @@ #ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H #define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H +#include "DXILRootSignature.h" +#include "llvm/ADT/IntervalMap.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/PassManager.h" namespace llvm { +static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) { + return (static_cast<uint64_t>(High) << 32) | Low; +} + +class RootSignatureBindingValidation { + using MapT = + llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding, + sizeof(llvm::dxil::ResourceInfo::ResourceBinding), + llvm::IntervalMapInfo<uint64_t>>; + +private: + MapT::Allocator Allocator; + MapT CRegBindingsMap; + MapT TRegBindingsMap; + MapT URegBindingsMap; + MapT SamplersBindingsMap; + + void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) { + assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) || + Type == llvm::to_underlying(dxbc::RootParameterType::SRV) || + Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) && + "Invalid Type in add Range Method"); + + llvm::dxil::ResourceInfo::ResourceBinding Binding; + Binding.LowerBound = Desc.ShaderRegister; + Binding.Space = Desc.RegisterSpace; + Binding.Size = 1; + + uint64_t LowRange = + combineUint32ToUint64(Binding.Space, Binding.LowerBound); + uint64_t HighRange = combineUint32ToUint64( + Binding.Space, Binding.LowerBound + Binding.Size - 1); + + assert(LowRange <= HighRange && "Invalid range configuration"); + + switch (Type) { + + case llvm::to_underlying(dxbc::RootParameterType::CBV): + CRegBindingsMap.insert(LowRange, HighRange, Binding); + break; + case llvm::to_underlying(dxbc::RootParameterType::SRV): + TRegBindingsMap.insert(LowRange, HighRange, Binding); + break; + case llvm::to_underlying(dxbc::RootParameterType::UAV): + URegBindingsMap.insert(LowRange, HighRange, Binding); + break; + } + } + + void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) { + + llvm::dxil::ResourceInfo::ResourceBinding Binding; + Binding.LowerBound = Range.BaseShaderRegister; + Binding.Space = Range.RegisterSpace; + Binding.Size = Range.NumDescriptors; + + uint64_t LowRange = + combineUint32ToUint64(Binding.Space, Binding.LowerBound); + uint64_t HighRange = combineUint32ToUint64( + Binding.Space, Binding.LowerBound + Binding.Size - 1); + + assert(LowRange <= HighRange && "Invalid range configuration"); + + switch (Range.RangeType) { + case llvm::to_underlying(dxbc::DescriptorRangeType::CBV): + CRegBindingsMap.insert(LowRange, HighRange, Binding); + break; + case llvm::to_underlying(dxbc::DescriptorRangeType::SRV): + TRegBindingsMap.insert(LowRange, HighRange, Binding); + break; + case llvm::to_underlying(dxbc::DescriptorRangeType::UAV): + URegBindingsMap.insert(LowRange, HighRange, Binding); + break; + case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler): + SamplersBindingsMap.insert(LowRange, HighRange, Binding); + break; + } + } + +public: + RootSignatureBindingValidation() + : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator), + URegBindingsMap(Allocator), SamplersBindingsMap(Allocator) {} + + void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD, + dxbc::ShaderVisibility Visibility); + + bool checkCRegBinding(dxil::ResourceInfo::ResourceBinding Binding) { + return CRegBindingsMap.overlaps( ---------------- inbelic wrote:
I don't think that this checks exactly what you want. It will simply determine if any part of the interval `[a,b]` overlaps with `[x,y]`, not that the interval `[a,b]` is completely covered. https://github.com/llvm/llvm-project/pull/146785 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits