llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-directx Author: None (joaosaffran) <details> <summary>Changes</summary> --- Full diff: https://github.com/llvm/llvm-project/pull/137284.diff 4 Files Affected: - (modified) llvm/include/llvm/MC/DXContainerRootSignature.h (+73-7) - (modified) llvm/lib/MC/DXContainerRootSignature.cpp (+42-35) - (modified) llvm/lib/ObjectYAML/DXContainerEmitter.cpp (+22-12) - (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+27-21) ``````````diff diff --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h index 44e26c81eedc1..c8af613a57094 100644 --- a/llvm/include/llvm/MC/DXContainerRootSignature.h +++ b/llvm/include/llvm/MC/DXContainerRootSignature.h @@ -6,21 +6,87 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLForwardCompat.h" #include "llvm/BinaryFormat/DXContainer.h" +#include <cstddef> #include <cstdint> -#include <limits> +#include <optional> +#include <utility> +#include <variant> namespace llvm { class raw_ostream; namespace mcdxbc { -struct RootParameter { +struct RootParameterInfo { dxbc::RootParameterHeader Header; - union { - dxbc::RootConstants Constants; - dxbc::RST0::v1::RootDescriptor Descriptor; - }; + size_t Location; + + RootParameterInfo() = default; + + RootParameterInfo(dxbc::RootParameterHeader H, size_t L) + : Header(H), Location(L) {} +}; + +using RootDescriptor = std::variant<dxbc::RST0::v0::RootDescriptor, + dxbc::RST0::v1::RootDescriptor>; +using ParametersView = std::variant<const dxbc::RootConstants *, + const dxbc::RST0::v0::RootDescriptor *, + const dxbc::RST0::v1::RootDescriptor *>; +struct RootParametersContainer { + SmallVector<RootParameterInfo> ParametersInfo; + + SmallVector<dxbc::RootConstants> Constants; + SmallVector<RootDescriptor> Descriptors; + + void addInfo(dxbc::RootParameterHeader H, size_t L) { + ParametersInfo.push_back(RootParameterInfo(H, L)); + } + + void addParameter(dxbc::RootParameterHeader H, dxbc::RootConstants C) { + addInfo(H, Constants.size()); + Constants.push_back(C); + } + + void addParameter(dxbc::RootParameterHeader H, + dxbc::RST0::v0::RootDescriptor D) { + addInfo(H, Descriptors.size()); + Descriptors.push_back(D); + } + + void addParameter(dxbc::RootParameterHeader H, + dxbc::RST0::v1::RootDescriptor D) { + addInfo(H, Descriptors.size()); + Descriptors.push_back(D); + } + + std::optional<ParametersView> getParameter(const RootParameterInfo *H) const { + switch (H->Header.ParameterType) { + case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): + return &Constants[H->Location]; + case llvm::to_underlying(dxbc::RootParameterType::CBV): + case llvm::to_underlying(dxbc::RootParameterType::SRV): + case llvm::to_underlying(dxbc::RootParameterType::UAV): + const RootDescriptor &VersionedParam = Descriptors[H->Location]; + if (std::holds_alternative<dxbc::RST0::v0::RootDescriptor>( + VersionedParam)) { + return &std::get<dxbc::RST0::v0::RootDescriptor>(VersionedParam); + } + return &std::get<dxbc::RST0::v1::RootDescriptor>(VersionedParam); + } + + return std::nullopt; + } + + size_t size() const { return ParametersInfo.size(); } + + SmallVector<RootParameterInfo>::const_iterator begin() const { + return ParametersInfo.begin(); + } + SmallVector<RootParameterInfo>::const_iterator end() const { + return ParametersInfo.end(); + } }; struct RootSignatureDesc { @@ -29,7 +95,7 @@ struct RootSignatureDesc { uint32_t RootParameterOffset = 0U; uint32_t StaticSamplersOffset = 0u; uint32_t NumStaticSamplers = 0u; - SmallVector<mcdxbc::RootParameter> Parameters; + mcdxbc::RootParametersContainer ParametersContainer; void write(raw_ostream &OS) const; diff --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp index 2693cb9943d5e..641c2f5fa1b1b 100644 --- a/llvm/lib/MC/DXContainerRootSignature.cpp +++ b/llvm/lib/MC/DXContainerRootSignature.cpp @@ -8,7 +8,9 @@ #include "llvm/MC/DXContainerRootSignature.h" #include "llvm/ADT/SmallString.h" +#include "llvm/BinaryFormat/DXContainer.h" #include "llvm/Support/EndianStream.h" +#include <variant> using namespace llvm; using namespace llvm::mcdxbc; @@ -30,24 +32,20 @@ static void rewriteOffsetToCurrentByte(raw_svector_ostream &Stream, size_t RootSignatureDesc::getSize() const { size_t Size = sizeof(dxbc::RootSignatureHeader) + - Parameters.size() * sizeof(dxbc::RootParameterHeader); + ParametersContainer.size() * sizeof(dxbc::RootParameterHeader); - for (const mcdxbc::RootParameter &P : Parameters) { - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - Size += sizeof(dxbc::RootConstants); - break; - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - if (Version == 1) - Size += sizeof(dxbc::RST0::v0::RootDescriptor); - else - Size += sizeof(dxbc::RST0::v1::RootDescriptor); - - break; - } + for (const auto &I : ParametersContainer) { + std::optional<ParametersView> P = ParametersContainer.getParameter(&I); + if (!P) + continue; + std::visit( + [&Size](auto &Value) -> void { + using T = std::decay_t<decltype(*Value)>; + Size += sizeof(T); + }, + *P); } + return Size; } @@ -56,7 +54,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const { raw_svector_ostream BOS(Storage); BOS.reserveExtraSpace(getSize()); - const uint32_t NumParameters = Parameters.size(); + const uint32_t NumParameters = ParametersContainer.size(); support::endian::write(BOS, Version, llvm::endianness::little); support::endian::write(BOS, NumParameters, llvm::endianness::little); @@ -66,7 +64,7 @@ void RootSignatureDesc::write(raw_ostream &OS) const { support::endian::write(BOS, Flags, llvm::endianness::little); SmallVector<uint32_t> ParamsOffsets; - for (const mcdxbc::RootParameter &P : Parameters) { + for (const auto &P : ParametersContainer) { support::endian::write(BOS, P.Header.ParameterType, llvm::endianness::little); support::endian::write(BOS, P.Header.ShaderVisibility, @@ -76,29 +74,38 @@ void RootSignatureDesc::write(raw_ostream &OS) const { } assert(NumParameters == ParamsOffsets.size()); - for (size_t I = 0; I < NumParameters; ++I) { + const RootParameterInfo *H = ParametersContainer.begin(); + for (size_t I = 0; I < NumParameters; ++I, H++) { rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]); - const mcdxbc::RootParameter &P = Parameters[I]; - - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - support::endian::write(BOS, P.Constants.ShaderRegister, + auto P = ParametersContainer.getParameter(H); + if (!P) + continue; + if (std::holds_alternative<const dxbc::RootConstants *>(P.value())) { + auto *Constants = std::get<const dxbc::RootConstants *>(P.value()); + support::endian::write(BOS, Constants->ShaderRegister, llvm::endianness::little); - support::endian::write(BOS, P.Constants.RegisterSpace, + support::endian::write(BOS, Constants->RegisterSpace, llvm::endianness::little); - support::endian::write(BOS, P.Constants.Num32BitValues, + support::endian::write(BOS, Constants->Num32BitValues, llvm::endianness::little); - break; - case llvm::to_underlying(dxbc::RootParameterType::CBV): - case llvm::to_underlying(dxbc::RootParameterType::SRV): - case llvm::to_underlying(dxbc::RootParameterType::UAV): - support::endian::write(BOS, P.Descriptor.ShaderRegister, + } else if (std::holds_alternative<const dxbc::RST0::v0::RootDescriptor *>( + *P)) { + auto *Descriptor = + std::get<const dxbc::RST0::v0::RootDescriptor *>(P.value()); + support::endian::write(BOS, Descriptor->ShaderRegister, + llvm::endianness::little); + support::endian::write(BOS, Descriptor->RegisterSpace, + llvm::endianness::little); + } else if (std::holds_alternative<const dxbc::RST0::v1::RootDescriptor *>( + *P)) { + auto *Descriptor = + std::get<const dxbc::RST0::v1::RootDescriptor *>(P.value()); + + support::endian::write(BOS, Descriptor->ShaderRegister, llvm::endianness::little); - support::endian::write(BOS, P.Descriptor.RegisterSpace, + support::endian::write(BOS, Descriptor->RegisterSpace, llvm::endianness::little); - if (Version > 1) - support::endian::write(BOS, P.Descriptor.Flags, - llvm::endianness::little); + support::endian::write(BOS, Descriptor->Flags, llvm::endianness::little); } } assert(Storage.size() == getSize()); diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp index 239ee9e3de9b1..b8ea1b048edfe 100644 --- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp +++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp @@ -274,27 +274,37 @@ void DXContainerWriter::writeParts(raw_ostream &OS) { RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset; for (const auto &Param : P.RootSignature->Parameters) { - mcdxbc::RootParameter NewParam; - NewParam.Header = dxbc::RootParameterHeader{ - Param.Type, Param.Visibility, Param.Offset}; + auto Header = dxbc::RootParameterHeader{Param.Type, Param.Visibility, + Param.Offset}; switch (Param.Type) { case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): - NewParam.Constants.Num32BitValues = Param.Constants.Num32BitValues; - NewParam.Constants.RegisterSpace = Param.Constants.RegisterSpace; - NewParam.Constants.ShaderRegister = Param.Constants.ShaderRegister; + dxbc::RootConstants Constants; + Constants.Num32BitValues = Param.Constants.Num32BitValues; + Constants.RegisterSpace = Param.Constants.RegisterSpace; + Constants.ShaderRegister = Param.Constants.ShaderRegister; + RS.ParametersContainer.addParameter(Header, Constants); break; case llvm::to_underlying(dxbc::RootParameterType::SRV): case llvm::to_underlying(dxbc::RootParameterType::UAV): case llvm::to_underlying(dxbc::RootParameterType::CBV): - NewParam.Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; - NewParam.Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; - if (P.RootSignature->Version > 1) - NewParam.Descriptor.Flags = Param.Descriptor.getEncodedFlags(); + if (RS.Version == 1) { + dxbc::RST0::v0::RootDescriptor Descriptor; + Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; + Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; + RS.ParametersContainer.addParameter(Header, Descriptor); + } else { + dxbc::RST0::v1::RootDescriptor Descriptor; + Descriptor.RegisterSpace = Param.Descriptor.RegisterSpace; + Descriptor.ShaderRegister = Param.Descriptor.ShaderRegister; + Descriptor.Flags = Param.Descriptor.getEncodedFlags(); + RS.ParametersContainer.addParameter(Header, Descriptor); + } break; + default: + // Handling invalid parameter type edge case + RS.ParametersContainer.addInfo(Header, -1); } - - RS.Parameters.push_back(NewParam); } RS.write(OS); diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index ef299c17baf76..30ca4d8f7c8ed 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -30,6 +30,7 @@ #include <cstdint> #include <optional> #include <utility> +#include <variant> using namespace llvm; using namespace llvm::dxil; @@ -75,31 +76,32 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (RootConstantNode->getNumOperands() != 5) return reportError(Ctx, "Invalid format for RootConstants Element"); - mcdxbc::RootParameter NewParameter; - NewParameter.Header.ParameterType = + dxbc::RootParameterHeader Header; + Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::Constants32Bit); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) - NewParameter.Header.ShaderVisibility = *Val; + Header.ShaderVisibility = *Val; else return reportError(Ctx, "Invalid value for ShaderVisibility"); + dxbc::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) - NewParameter.Constants.ShaderRegister = *Val; + Constants.ShaderRegister = *Val; else return reportError(Ctx, "Invalid value for ShaderRegister"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) - NewParameter.Constants.RegisterSpace = *Val; + Constants.RegisterSpace = *Val; else return reportError(Ctx, "Invalid value for RegisterSpace"); if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) - NewParameter.Constants.Num32BitValues = *Val; + Constants.Num32BitValues = *Val; else return reportError(Ctx, "Invalid value for Num32BitValues"); - RSD.Parameters.push_back(NewParameter); + RSD.ParametersContainer.addParameter(Header, Constants); return false; } @@ -164,12 +166,12 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { return reportValueError(Ctx, "RootFlags", RSD.Flags); } - for (const mcdxbc::RootParameter &P : RSD.Parameters) { - if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility)) + for (const llvm::mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { + if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) return reportValueError(Ctx, "ShaderVisibility", - P.Header.ShaderVisibility); + Info.Header.ShaderVisibility); - assert(dxbc::isValidParameterType(P.Header.ParameterType) && + assert(dxbc::isValidParameterType(Info.Header.ParameterType) && "Invalid value for ParameterType"); } @@ -287,22 +289,26 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M, OS << indent(Space) << "Version: " << RS.Version << "\n"; OS << indent(Space) << "RootParametersOffset: " << RS.RootParameterOffset << "\n"; - OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n"; + OS << indent(Space) << "NumParameters: " << RS.ParametersContainer.size() + << "\n"; Space++; - for (auto const &P : RS.Parameters) { - OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType + for (auto const &Info : RS.ParametersContainer) { + OS << indent(Space) << "- Parameter Type: " << Info.Header.ParameterType << "\n"; OS << indent(Space + 2) - << "Shader Visibility: " << P.Header.ShaderVisibility << "\n"; - switch (P.Header.ParameterType) { - case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): + << "Shader Visibility: " << Info.Header.ShaderVisibility << "\n"; + std::optional<mcdxbc::ParametersView> P = + RS.ParametersContainer.getParameter(&Info); + if (!P) + continue; + if (std::holds_alternative<const dxbc::RootConstants *>(*P)) { + auto *Constants = std::get<const dxbc::RootConstants *>(*P); OS << indent(Space + 2) - << "Register Space: " << P.Constants.RegisterSpace << "\n"; + << "Register Space: " << Constants->RegisterSpace << "\n"; OS << indent(Space + 2) - << "Shader Register: " << P.Constants.ShaderRegister << "\n"; + << "Shader Register: " << Constants->ShaderRegister << "\n"; OS << indent(Space + 2) - << "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n"; - break; + << "Num 32 Bit Values: " << Constants->Num32BitValues << "\n"; } } Space--; `````````` </details> https://github.com/llvm/llvm-project/pull/137284 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits