llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-directx Author: None (joaosaffran) <details> <summary>Changes</summary> This patch enhances error handling in the DirectX backend's root signature parsing, specifically in DXILRootSignature.cpp. The changes include: 1. Modify error handling to accumulate errors: - Replace early returns with error accumulation using HasError - Allow validation to continue after encountering an invalid type - Maintain original error reporting functionality while collecting multiple errors 2. Fix root flag parsing: - Use boolean accumulator for multiple validation errors - Improve invalid type reporting for root flag nodes - Maintain consistency with existing error reporting patterns Before this change, the parser would stop at the first error encountered. Now it continues validation, collecting all errors before returning. This provides a better developer experience by showing all issues that need to be fixed at once. Example of changes: ```cpp bool HasError = false; if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else HasError = HasError || reportInvalidTypeError<ConstantInt>( Ctx, "RootFlagNode", RootFlagNode, 1); return HasError; ``` Testing: - All existing DirectX backend tests pass - Verified error accumulation with multiple validation failures - Root signature parsing continues to work as expected --- Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144465.diff 2 Files Affected: - (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+171-110) - (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error-Accumulation.ll (+23) ``````````diff diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp index 57d5ee8ac467c..eea46e714b756 100644 --- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp +++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp @@ -141,14 +141,15 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (RootFlagNode->getNumOperands() != 2) return reportError(Ctx, "Invalid format for RootFlag Element"); - + bool HasError = false; if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1)) RSD.Flags = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode", - RootFlagNode, 1); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode", + RootFlagNode, 1) || + HasError; - return false; + return HasError; } static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, @@ -157,6 +158,7 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (RootConstantNode->getNumOperands() != 5) return reportError(Ctx, "Invalid format for RootConstants Element"); + bool HasError = false; dxbc::RTS0::v1::RootParameterHeader Header; // The parameter offset doesn't matter here - we recalculate it during // serialization Header.ParameterOffset = 0; @@ -166,31 +168,35 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1)) Header.ShaderVisibility = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", - RootConstantNode, 1); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", + RootConstantNode, 1) || + HasError; dxbc::RTS0::v1::RootConstants Constants; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2)) Constants.ShaderRegister = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", - RootConstantNode, 2); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", + RootConstantNode, 2) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3)) Constants.RegisterSpace = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", - RootConstantNode, 3); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", + RootConstantNode, 3) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4)) Constants.Num32BitValues = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", - RootConstantNode, 4); - - RSD.ParametersContainer.addParameter(Header, Constants); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode", + RootConstantNode, 4) || + HasError; + if (!HasError) + RSD.ParametersContainer.addParameter(Header, Constants); - return false; + return HasError; } static bool parseRootDescriptors(LLVMContext *Ctx, @@ -205,6 +211,7 @@ static bool parseRootDescriptors(LLVMContext *Ctx, if (RootDescriptorNode->getNumOperands() != 5) return reportError(Ctx, "Invalid format for Root Descriptor Element"); + bool HasError = false; dxbc::RTS0::v1::RootParameterHeader Header; switch (ElementKind) { case RootSignatureElementKind::SRV: @@ -224,36 +231,41 @@ static bool parseRootDescriptors(LLVMContext *Ctx, if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1)) Header.ShaderVisibility = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", - RootDescriptorNode, 1); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", + RootDescriptorNode, 1) || + HasError; dxbc::RTS0::v2::RootDescriptor Descriptor; if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2)) Descriptor.ShaderRegister = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", - RootDescriptorNode, 2); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", + RootDescriptorNode, 2) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3)) Descriptor.RegisterSpace = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", - RootDescriptorNode, 3); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", + RootDescriptorNode, 3) || + HasError; if (RSD.Version == 1) { - RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + if (!HasError) + RSD.ParametersContainer.addParameter(Header, Descriptor); + return HasError; } assert(RSD.Version > 1); if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4)) Descriptor.Flags = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", - RootDescriptorNode, 4); - - RSD.ParametersContainer.addParameter(Header, Descriptor); - return false; + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode", + RootDescriptorNode, 4) || + HasError; + if (!HasError) + RSD.ParametersContainer.addParameter(Header, Descriptor); + return HasError; } static bool parseDescriptorRange(LLVMContext *Ctx, @@ -264,14 +276,16 @@ static bool parseDescriptorRange(LLVMContext *Ctx, if (RangeDescriptorNode->getNumOperands() != 6) return reportError(Ctx, "Invalid format for Descriptor Range"); + bool HasError = false; dxbc::RTS0::v2::DescriptorRange Range; std::optional<StringRef> ElementText = extractMdStringValue(RangeDescriptorNode, 0); if (!ElementText.has_value()) - return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 0); + HasError = reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 0) || + HasError; Range.RangeType = StringSwitch<uint32_t>(*ElementText) @@ -283,40 +297,47 @@ static bool parseDescriptorRange(LLVMContext *Ctx, .Default(-1u); if (Range.RangeType == -1u) - return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText); + HasError = + reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1)) Range.NumDescriptors = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 1); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 1) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2)) Range.BaseShaderRegister = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 2); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 2) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3)) Range.RegisterSpace = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 3); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 3) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4)) Range.OffsetInDescriptorsFromTableStart = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 4); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 4) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5)) Range.Flags = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", - RangeDescriptorNode, 5); - - Table.Ranges.push_back(Range); - return false; + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode", + RangeDescriptorNode, 5) || + HasError; + if (!HasError) + Table.Ranges.push_back(Range); + return HasError; } static bool parseDescriptorTable(LLVMContext *Ctx, @@ -325,13 +346,14 @@ static bool parseDescriptorTable(LLVMContext *Ctx, const unsigned int NumOperands = DescriptorTableNode->getNumOperands(); if (NumOperands < 2) return reportError(Ctx, "Invalid format for Descriptor Table"); - + bool HasError = false; dxbc::RTS0::v1::RootParameterHeader Header; if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1)) Header.ShaderVisibility = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode", - DescriptorTableNode, 1); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode", + DescriptorTableNode, 1) || + HasError; mcdxbc::DescriptorTable Table; Header.ParameterType = @@ -340,15 +362,16 @@ static bool parseDescriptorTable(LLVMContext *Ctx, for (unsigned int I = 2; I < NumOperands; I++) { MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I)); if (Element == nullptr) - return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode", - DescriptorTableNode, I); + HasError = reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode", + DescriptorTableNode, I) || + HasError; if (parseDescriptorRange(Ctx, RSD, Table, Element)) - return true; + HasError = true || HasError; } - - RSD.ParametersContainer.addParameter(Header, Table); - return false; + if (!HasError) + RSD.ParametersContainer.addParameter(Header, Table); + return HasError; } static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, @@ -356,87 +379,101 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (StaticSamplerNode->getNumOperands() != 14) return reportError(Ctx, "Invalid format for Static Sampler"); + bool HasError = false; dxbc::RTS0::v1::StaticSampler Sampler; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1)) Sampler.Filter = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 1); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 1) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2)) Sampler.AddressU = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 2); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 2) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3)) Sampler.AddressV = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 3); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 3) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4)) Sampler.AddressW = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 4); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 4) || + HasError; if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5)) Sampler.MipLODBias = Val->convertToFloat(); else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 5); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 5) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6)) Sampler.MaxAnisotropy = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 6); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 6) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7)) Sampler.ComparisonFunc = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 7); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 7) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8)) Sampler.BorderColor = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 8); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 8) || + HasError; if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9)) Sampler.MinLOD = Val->convertToFloat(); else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 9); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 9) || + HasError; if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10)) Sampler.MaxLOD = Val->convertToFloat(); else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 10); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 10) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11)) Sampler.ShaderRegister = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 11); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 11) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12)) Sampler.RegisterSpace = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 12); + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 12) || + HasError; if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13)) Sampler.ShaderVisibility = *Val; else - return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", - StaticSamplerNode, 13); - - RSD.StaticSamplers.push_back(Sampler); - return false; + HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode", + StaticSamplerNode, 13) || + HasError; + if (!HasError) + RSD.StaticSamplers.push_back(Sampler); + return HasError; } static bool parseRootSignatureElement(LLVMContext *Ctx, @@ -488,7 +525,7 @@ static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD, if (Element == nullptr) return reportError(Ctx, "Missing Root Element Metadata Node."); - HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element); + HasError = parseRootSignatureElement(Ctx, RSD, Element) || HasError; } return HasError; @@ -699,19 +736,20 @@ static bool verifyBorderColor(uint32_t BorderColor) { static bool verifyLOD(float LOD) { return !std::isnan(LOD); } static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { - + bool HasError = false; if (!verifyVersion(RSD.Version)) { - return reportValueError(Ctx, "Version", RSD.Version); + HasError = reportValueError(Ctx, "Version", RSD.Version) || HasError; } if (!verifyRootFlag(RSD.Flags)) { - return reportValueError(Ctx, "RootFlags", RSD.Flags); + HasError = reportValueError(Ctx, "RootFlags", RSD.Flags) || HasError; } for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) { if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility)) - return reportValueError(Ctx, "ShaderVisibility", - Info.Header.ShaderVisibility); + HasError = reportValueError(Ctx, "ShaderVisibility", + Info.Header.ShaderVisibility) || + HasError; assert(dxbc::isValidParameterType(Info.Header.ParameterType) && "Invalid value for ParameterType"); @@ -724,15 +762,20 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) { const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location); if (!verifyRegisterValue(Descriptor.ShaderRegister)) - return reportValueError(Ctx, "ShaderRegister", - Descriptor.ShaderRegister); + HasError = reportValueError(Ctx, "ShaderRegister", + Descriptor.ShaderRegister) || + HasError; if (!verifyRegisterSpace(Descriptor.RegisterSpace)) - return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace); + HasError = + reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace) || + HasError; if (RSD.Version > 1)... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/144465 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits