llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-directx Author: None (joaosaffran) <details> <summary>Changes</summary> Adding support for Root Signature Flags Element extraction and writing to DXContainer. - Adding an analysis to deal with RootSignature metadata definition - Adding validation for Flag - writing RootSignature blob into DXIL --- Full diff: https://github.com/llvm/llvm-project/pull/123147.diff 7 Files Affected: - (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1-1) - (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+22) - (added) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+147) - (added) llvm/lib/Target/DirectX/DXILRootSignature.h (+74) - (modified) llvm/lib/Target/DirectX/DirectX.h (+3) - (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+1) - (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll (+38) ``````````diff diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index 26315db891b577..89fe494dea71cc 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -33,7 +33,7 @@ add_llvm_target(DirectXCodeGen DXILResourceAccess.cpp DXILShaderFlags.cpp DXILTranslateMetadata.cpp - + DXILRootSignature.cpp LINK_COMPONENTS Analysis AsmPrinter diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 7a0bd6a7c88692..ac70bd3771dadf 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "DXILRootSignature.h" #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" @@ -26,6 +27,7 @@ #include "llvm/Pass.h" #include "llvm/Support/MD5.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include <optional> using namespace llvm; using namespace llvm::dxil; @@ -41,6 +43,7 @@ class DXContainerGlobals : public llvm::ModulePass { GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name, StringRef SectionName); void addSignature(Module &M, SmallVector<GlobalValue *> &Globals); + void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals); void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV); void addPipelineStateValidationInfo(Module &M, SmallVector<GlobalValue *> &Globals); @@ -60,6 +63,7 @@ class DXContainerGlobals : public llvm::ModulePass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired<ShaderFlagsAnalysisWrapper>(); + AU.addRequired<RootSignatureAnalysisWrapper>(); AU.addRequired<DXILMetadataAnalysisWrapperPass>(); AU.addRequired<DXILResourceTypeWrapperPass>(); AU.addRequired<DXILResourceBindingWrapperPass>(); @@ -73,6 +77,7 @@ bool DXContainerGlobals::runOnModule(Module &M) { Globals.push_back(getFeatureFlags(M)); Globals.push_back(computeShaderHash(M)); addSignature(M, Globals); + addRootSignature(M, Globals); addPipelineStateValidationInfo(M, Globals); appendToCompilerUsed(M, Globals); return true; @@ -144,6 +149,23 @@ void DXContainerGlobals::addSignature(Module &M, Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1")); } +void DXContainerGlobals::addRootSignature(Module &M, + SmallVector<GlobalValue *> &Globals) { + + std::optional<ModuleRootSignature> MRS = + getAnalysis<RootSignatureAnalysisWrapper>().getRootSignature(); + if (!MRS.has_value()) + return; + + SmallString<256> Data; + raw_svector_ostream OS(Data); + MRS->write(OS); + + Constant *Constant = + ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); + Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0")); +} + void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { const DXILBindingMap &DBM = getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap(); diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp new file mode 100644 index 00000000000000..cabaec3671078e --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -0,0 +1,147 @@ +//===- DXILRootSignature.cpp - DXIL Root Signature helper objects +//---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helper objects and APIs for working with DXIL +/// Root Signatures. +/// +//===----------------------------------------------------------------------===// +#include "DXILRootSignature.h" +#include "DirectX.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/BinaryFormat/DXContainer.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include <cassert> + +using namespace llvm; +using namespace llvm::dxil; + +static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) { + + assert(RootFlagNode->getNumOperands() == 2 && + "Invalid format for RootFlag Element"); + auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1)); + auto Value = Flag->getZExtValue(); + + // Root Element validation, as specified: + // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#validations-during-dxil-generation + assert((Value & ~0x80000fff) != 0 && "Invalid flag for RootFlag Element"); + + MRS->Flags = Value; + return false; +} + +static bool parseRootSignatureElement(ModuleRootSignature *MRS, + MDNode *Element) { + MDString *ElementText = cast<MDString>(Element->getOperand(0)); + assert(ElementText != nullptr && + "First preoperty of element is not a string"); + + RootSignatureElementKind ElementKind = + StringSwitch<RootSignatureElementKind>(ElementText->getString()) + .Case("RootFlags", RootSignatureElementKind::RootFlags) + .Case("RootConstants", RootSignatureElementKind::RootConstants) + .Case("RootCBV", RootSignatureElementKind::RootDescriptor) + .Case("RootSRV", RootSignatureElementKind::RootDescriptor) + .Case("RootUAV", RootSignatureElementKind::RootDescriptor) + .Case("Sampler", RootSignatureElementKind::RootDescriptor) + .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable) + .Case("StaticSampler", RootSignatureElementKind::StaticSampler) + .Default(RootSignatureElementKind::None); + + switch (ElementKind) { + + case RootSignatureElementKind::RootFlags: { + return parseRootFlags(MRS, Element); + break; + } + + case RootSignatureElementKind::RootConstants: + case RootSignatureElementKind::RootDescriptor: + case RootSignatureElementKind::DescriptorTable: + case RootSignatureElementKind::StaticSampler: + case RootSignatureElementKind::None: + llvm_unreachable("Not Implemented yet"); + break; + } + + return true; +} + +bool ModuleRootSignature::parse(int32_t Version, NamedMDNode *Root) { + this->Version = Version; + bool HasError = false; + + for (unsigned int Sid = 0; Sid < Root->getNumOperands(); Sid++) { + // This should be an if, for error handling + MDNode *Node = cast<MDNode>(Root->getOperand(Sid)); + + // Not sure what use this for... + // Metadata *Func = Node->getOperand(0).get(); + + MDNode *Elements = cast<MDNode>(Node->getOperand(1).get()); + assert(Elements && "Invalid Metadata type on root signature"); + + for (unsigned int Eid = 0; Eid < Elements->getNumOperands(); Eid++) { + MDNode *Element = cast<MDNode>(Elements->getOperand(Eid)); + assert(Element && "Invalid Metadata type on root element"); + + HasError = HasError || parseRootSignatureElement(this, Element); + } + } + return HasError; +} + +void ModuleRootSignature::write(raw_ostream &OS) { + dxbc::RootSignatureDesc Out{this->Version, this->Flags}; + + if (sys::IsBigEndianHost) { + Out.swapBytes(); + } + + OS.write(reinterpret_cast<const char *>(&Out), + sizeof(dxbc::RootSignatureDesc)); +} + +AnalysisKey RootSignatureAnalysis::Key; + +ModuleRootSignature RootSignatureAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + ModuleRootSignature MRSI; + + NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures"); + if (RootSignatureNode) { + MRSI.parse(1, RootSignatureNode); + } + + return MRSI; +} + +//===----------------------------------------------------------------------===// +bool RootSignatureAnalysisWrapper::runOnModule(Module &M) { + ModuleRootSignature MRS; + + NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures"); + if (RootSignatureNode) { + MRS.parse(1, RootSignatureNode); + this->MRS = MRS; + } + + return false; +} + +void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); +} + +char RootSignatureAnalysisWrapper::ID = 0; + +INITIALIZE_PASS(RootSignatureAnalysisWrapper, "dx-root-signature-analysis", + "DXIL Root Signature Analysis", true, true) diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h new file mode 100644 index 00000000000000..de82afcdc8c467 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILRootSignature.h @@ -0,0 +1,74 @@ +//===- DXILRootSignature.h - DXIL Root Signature helper objects +//---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helper objects and APIs for working with DXIL +/// Root Signatures. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include <optional> + +namespace llvm { +namespace dxil { + +enum class RootSignatureElementKind { + None = 0, + RootFlags = 1, + RootConstants = 2, + RootDescriptor = 3, + DescriptorTable = 4, + StaticSampler = 5 +}; + +struct ModuleRootSignature { + uint32_t Version; + uint32_t Flags; + + ModuleRootSignature() = default; + + bool parse(int32_t Version, NamedMDNode *Root); + void write(raw_ostream &OS); +}; + +class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> { + friend AnalysisInfoMixin<RootSignatureAnalysis>; + static AnalysisKey Key; + +public: + RootSignatureAnalysis() = default; + + using Result = ModuleRootSignature; + + ModuleRootSignature run(Module &M, ModuleAnalysisManager &AM); +}; + +/// Wrapper pass for the legacy pass manager. +/// +/// This is required because the passes that will depend on this are codegen +/// passes which run through the legacy pass manager. +class RootSignatureAnalysisWrapper : public ModulePass { + std::optional<ModuleRootSignature> MRS; + +public: + static char ID; + + RootSignatureAnalysisWrapper() : ModulePass(ID) {} + + const std::optional<ModuleRootSignature> &getRootSignature() { return MRS; } + + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; +}; + +} // namespace dxil +} // namespace llvm diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h index add23587de7d58..953ac3eb820987 100644 --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -77,6 +77,9 @@ void initializeDXILPrettyPrinterLegacyPass(PassRegistry &); /// Initializer for dxil::ShaderFlagsAnalysisWrapper pass. void initializeShaderFlagsAnalysisWrapperPass(PassRegistry &); +/// Initializer for dxil::RootSignatureAnalysisWrapper pass. +void initializeRootSignatureAnalysisWrapperPass(PassRegistry &); + /// Initializer for DXContainerGlobals pass. void initializeDXContainerGlobalsPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index ecb1bf775f8578..93745d7a5cb0d2 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -61,6 +61,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { initializeDXILTranslateMetadataLegacyPass(*PR); initializeDXILResourceMDWrapperPass(*PR); initializeShaderFlagsAnalysisWrapperPass(*PR); + initializeRootSignatureAnalysisWrapperPass(*PR); initializeDXILFinalizeLinkageLegacyPass(*PR); } diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll new file mode 100644 index 00000000000000..ffbf5e9ffd1d32 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll @@ -0,0 +1,38 @@ +; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s +; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC + +target triple = "dxil-unknown-shadermodel6.0-compute" + +; CHECK: @dx.rts0 = private constant [8 x i8] c"{{.*}}", section "RTS0", align 4 + + +define void @main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } + + +!dx.rootsignatures = !{!2} ; list of function/root signature pairs +!2 = !{ ptr @main, !3 } ; function, root signature +!3 = !{ !4 } ; list of root signature elements +!4 = !{ !"RootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout + + +; DXC: - Name: RTS0 +; DXC-NEXT: Size: 8 +; DXC-NEXT: RootSignature: +; DXC-NEXT: Version: 1 +; DXC-NEXT: AllowInputAssemblerInputLayout: true +; DXC-NEXT: DenyVertexShaderRootAccess: false +; DXC-NEXT: DenyHullShaderRootAccess: false +; DXC-NEXT: DenyDomainShaderRootAccess: false +; DXC-NEXT: DenyGeometryShaderRootAccess: false +; DXC-NEXT: DenyPixelShaderRootAccess: false +; DXC-NEXT: AllowStreamOutput: false +; DXC-NEXT: LocalRootSignature: false +; DXC-NEXT: DenyAmplificationShaderRootAccess: false +; DXC-NEXT: DenyMeshShaderRootAccess: false +; DXC-NEXT: CBVSRVUAVHeapDirectlyIndexed: false +; DXC-NEXT: SamplerHeapDirectlyIndexed: false `````````` </details> https://github.com/llvm/llvm-project/pull/123147 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits