https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/146150
None >From 18edab1fc3ab4a137de935abff5bb4716cc9a5fd Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 27 Jun 2025 18:36:38 +0000 Subject: [PATCH 1/5] nfc: introduce wrapper `RootSignatureElement` around `RootElement` to retain clang diag info --- .../clang/Parse/ParseHLSLRootSignature.h | 16 ++- clang/include/clang/Sema/SemaHLSL.h | 10 +- clang/lib/Parse/ParseDeclCXX.cpp | 2 +- clang/lib/Parse/ParseHLSLRootSignature.cpp | 25 ++-- clang/lib/Sema/SemaHLSL.cpp | 6 +- .../Parse/ParseHLSLRootSignatureTest.cpp | 115 +++++++++--------- 6 files changed, 100 insertions(+), 74 deletions(-) diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index 18cd1f379e62c..c51c3950b73df 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -25,9 +25,21 @@ namespace clang { namespace hlsl { +// Introduce a wrapper struct around the underlying RootElement. This structure +// will retain extra clang diagnostic information that is not available in llvm. +struct RootSignatureElement { + RootSignatureElement(llvm::hlsl::rootsig::RootElement Element) + : Element(Element) {} + + const llvm::hlsl::rootsig::RootElement &getElement() const { return Element; } + +private: + llvm::hlsl::rootsig::RootElement Element; +}; + class RootSignatureParser { public: - RootSignatureParser(SmallVector<llvm::hlsl::rootsig::RootElement> &Elements, + RootSignatureParser(SmallVector<RootSignatureElement> &Elements, RootSignatureLexer &Lexer, clang::Preprocessor &PP); /// Consumes tokens from the Lexer and constructs the in-memory @@ -187,7 +199,7 @@ class RootSignatureParser { bool tryConsumeExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected); private: - SmallVector<llvm::hlsl::rootsig::RootElement> &Elements; + SmallVector<RootSignatureElement> &Elements; RootSignatureLexer &Lexer; clang::Preprocessor &PP; diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 7d7eae4db532c..1af706da702c2 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -32,6 +32,10 @@ class ParsedAttr; class Scope; class VarDecl; +namespace hlsl { +struct RootSignatureElement; +} + using llvm::dxil::ResourceClass; // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no @@ -130,9 +134,9 @@ class SemaHLSL : public SemaBase { /// Creates the Root Signature decl of the parsed Root Signature elements /// onto the AST and push it onto current Scope - void ActOnFinishRootSignatureDecl( - SourceLocation Loc, IdentifierInfo *DeclIdent, - SmallVector<llvm::hlsl::rootsig::RootElement> &Elements); + void + ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent, + ArrayRef<hlsl::RootSignatureElement> Elements); // Returns true when D is invalid and a diagnostic was produced bool handleRootSignatureDecl(HLSLRootSignatureDecl *D, SourceLocation Loc); diff --git a/clang/lib/Parse/ParseDeclCXX.cpp b/clang/lib/Parse/ParseDeclCXX.cpp index c1493a5bfd3b3..f561da1801004 100644 --- a/clang/lib/Parse/ParseDeclCXX.cpp +++ b/clang/lib/Parse/ParseDeclCXX.cpp @@ -4956,7 +4956,7 @@ void Parser::ParseHLSLRootSignatureAttributeArgs(ParsedAttributes &Attrs) { StrLiteral.value()->getExprLoc().getLocWithOffset(1); // Invoke the root signature parser to construct the in-memory constructs hlsl::RootSignatureLexer Lexer(Signature, SignatureLoc); - SmallVector<llvm::hlsl::rootsig::RootElement> RootElements; + SmallVector<hlsl::RootSignatureElement> RootElements; hlsl::RootSignatureParser Parser(RootElements, Lexer, PP); if (Parser.parse()) { T.consumeClose(); diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 18d3644114eef..dd0c8d4ac9ca9 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -17,33 +17,34 @@ namespace hlsl { using TokenKind = RootSignatureToken::Kind; -RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements, - RootSignatureLexer &Lexer, - Preprocessor &PP) +RootSignatureParser::RootSignatureParser( + SmallVector<RootSignatureElement> &Elements, RootSignatureLexer &Lexer, + Preprocessor &PP) : Elements(Elements), Lexer(Lexer), PP(PP), CurToken(SourceLocation()) {} bool RootSignatureParser::parse() { - // Iterate as many RootElements as possible + // Iterate as many RootSignatureElements as possible do { + std::optional<RootSignatureElement> Element = std::nullopt; if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) { auto Flags = parseRootFlags(); if (!Flags.has_value()) return true; - Elements.push_back(*Flags); + Element = RootSignatureElement(*Flags); } if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) { auto Constants = parseRootConstants(); if (!Constants.has_value()) return true; - Elements.push_back(*Constants); + Element = RootSignatureElement(*Constants); } if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) { auto Table = parseDescriptorTable(); if (!Table.has_value()) return true; - Elements.push_back(*Table); + Element = RootSignatureElement(*Table); } if (tryConsumeExpectedToken( @@ -51,15 +52,19 @@ bool RootSignatureParser::parse() { auto Descriptor = parseRootDescriptor(); if (!Descriptor.has_value()) return true; - Elements.push_back(*Descriptor); + Element = RootSignatureElement(*Descriptor); } if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) { auto Sampler = parseStaticSampler(); if (!Sampler.has_value()) return true; - Elements.push_back(*Sampler); + Element = RootSignatureElement(*Sampler); } + + if (Element.has_value()) + Elements.push_back(*Element); + } while (tryConsumeExpectedToken(TokenKind::pu_comma)); return consumeExpectedToken(TokenKind::end_of_stream, @@ -250,7 +255,7 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() { auto Clause = parseDescriptorTableClause(); if (!Clause.has_value()) return std::nullopt; - Elements.push_back(*Clause); + Elements.push_back(RootSignatureElement(*Clause)); Table.NumClauses++; } diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index ca66c71370d60..86902f17dab11 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1063,7 +1063,11 @@ SemaHLSL::ActOnStartRootSignatureDecl(StringRef Signature) { void SemaHLSL::ActOnFinishRootSignatureDecl( SourceLocation Loc, IdentifierInfo *DeclIdent, - SmallVector<llvm::hlsl::rootsig::RootElement> &Elements) { + ArrayRef<hlsl::RootSignatureElement> RootElements) { + + SmallVector<llvm::hlsl::rootsig::RootElement> Elements; + for (auto &RootSigElement : RootElements) + Elements.push_back(RootSigElement.getElement()); auto *SignatureDecl = HLSLRootSignatureDecl::Create( SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc, diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index ba42895afce6c..c82667ae84e52 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -25,6 +25,7 @@ #include "gtest/gtest.h" using namespace clang; +using namespace clang::hlsl; using namespace llvm::hlsl::rootsig; namespace { @@ -114,7 +115,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -147,7 +148,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -156,7 +157,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { ASSERT_FALSE(Parser.parse()); // First Descriptor Table with 4 elements - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, @@ -169,7 +170,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, llvm::dxbc::DescriptorRangeFlags::DataStaticWhileSetAtExecute); - Elem = Elements[1]; + Elem = Elements[1].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, @@ -181,7 +182,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, llvm::dxbc::DescriptorRangeFlags::None); - Elem = Elements[2]; + Elem = Elements[2].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, @@ -194,7 +195,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, llvm::dxbc::DescriptorRangeFlags::None); - Elem = Elements[3]; + Elem = Elements[3].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, @@ -209,14 +210,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, ValidDescriptorRangeFlags); - Elem = Elements[4]; + Elem = Elements[4].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem)); ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, (uint32_t)4); ASSERT_EQ(std::get<DescriptorTable>(Elem).Visibility, llvm::dxbc::ShaderVisibility::Pixel); // Empty Descriptor Table - Elem = Elements[5]; + Elem = Elements[5].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem)); ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u); ASSERT_EQ(std::get<DescriptorTable>(Elem).Visibility, @@ -245,7 +246,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseStaticSamplerTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -256,7 +257,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseStaticSamplerTest) { ASSERT_EQ(Elements.size(), 2u); // Check default values are as expected - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_EQ(std::get<StaticSampler>(Elem).Reg.ViewType, RegisterType::SReg); ASSERT_EQ(std::get<StaticSampler>(Elem).Reg.Number, 0u); @@ -281,7 +282,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseStaticSamplerTest) { llvm::dxbc::ShaderVisibility::All); // Check values can be set as expected - Elem = Elements[1]; + Elem = Elements[1].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_EQ(std::get<StaticSampler>(Elem).Reg.ViewType, RegisterType::SReg); ASSERT_EQ(std::get<StaticSampler>(Elem).Reg.Number, 0u); @@ -330,7 +331,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseFloatsTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -338,55 +339,55 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseFloatsTest) { ASSERT_FALSE(Parser.parse()); - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 0.f); - Elem = Elements[1]; + Elem = Elements[1].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 1.f); - Elem = Elements[2]; + Elem = Elements[2].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, -1.f); - Elem = Elements[3]; + Elem = Elements[3].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 42.f); - Elem = Elements[4]; + Elem = Elements[4].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 4.2f); - Elem = Elements[5]; + Elem = Elements[5].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, -.42f); - Elem = Elements[6]; + Elem = Elements[6].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 420.f); - Elem = Elements[7]; + Elem = Elements[7].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 0.000000000042f); - Elem = Elements[8]; + Elem = Elements[8].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 42.f); - Elem = Elements[9]; + Elem = Elements[9].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 4.2f); - Elem = Elements[10]; + Elem = Elements[10].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 420000000000.f); - Elem = Elements[11]; + Elem = Elements[11].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, -2147483648.f); - Elem = Elements[12]; + Elem = Elements[12].getElement(); ASSERT_TRUE(std::holds_alternative<StaticSampler>(Elem)); ASSERT_FLOAT_EQ(std::get<StaticSampler>(Elem).MipLODBias, 2147483648.f); @@ -405,7 +406,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -413,7 +414,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) { ASSERT_FALSE(Parser.parse()); - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Flags, @@ -435,7 +436,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -445,7 +446,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) { ASSERT_EQ(Elements.size(), 2u); - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem)); ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 1u); ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg); @@ -454,7 +455,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) { ASSERT_EQ(std::get<RootConstants>(Elem).Visibility, llvm::dxbc::ShaderVisibility::All); - Elem = Elements[1]; + Elem = Elements[1].getElement(); ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem)); ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 4294967295u); ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg); @@ -491,7 +492,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -501,15 +502,15 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) { ASSERT_EQ(Elements.size(), 3u); - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<llvm::dxbc::RootFlags>(Elem)); ASSERT_EQ(std::get<llvm::dxbc::RootFlags>(Elem), llvm::dxbc::RootFlags::None); - Elem = Elements[1]; + Elem = Elements[1].getElement(); ASSERT_TRUE(std::holds_alternative<llvm::dxbc::RootFlags>(Elem)); ASSERT_EQ(std::get<llvm::dxbc::RootFlags>(Elem), llvm::dxbc::RootFlags::None); - Elem = Elements[2]; + Elem = Elements[2].getElement(); ASSERT_TRUE(std::holds_alternative<llvm::dxbc::RootFlags>(Elem)); auto ValidRootFlags = llvm::dxbc::RootFlags(0xfff); ASSERT_EQ(std::get<llvm::dxbc::RootFlags>(Elem), ValidRootFlags); @@ -532,7 +533,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -542,7 +543,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) { ASSERT_EQ(Elements.size(), 4u); - RootElement Elem = Elements[0]; + RootElement Elem = Elements[0].getElement(); ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem)); ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer); ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::BReg); @@ -553,7 +554,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) { ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, llvm::dxbc::RootDescriptorFlags::DataStaticWhileSetAtExecute); - Elem = Elements[1]; + Elem = Elements[1].getElement(); ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem)); ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::SRV); ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::TReg); @@ -564,7 +565,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) { auto ValidRootDescriptorFlags = llvm::dxbc::RootDescriptorFlags(0xe); ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, ValidRootDescriptorFlags); - Elem = Elements[2]; + Elem = Elements[2].getElement(); ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem)); ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::UAV); ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::UReg); @@ -577,7 +578,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) { ASSERT_EQ(std::get<RootDescriptor>(Elem).Flags, llvm::dxbc::RootDescriptorFlags::DataVolatile); - Elem = Elements[3]; + Elem = Elements[3].getElement(); ASSERT_EQ(std::get<RootDescriptor>(Elem).Type, DescriptorType::CBuffer); ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.ViewType, RegisterType::BReg); ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 0u); @@ -604,7 +605,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test no diagnostics produced @@ -628,7 +629,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedTokenTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -648,7 +649,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseInvalidTokenTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced - invalid token @@ -668,7 +669,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced - end of stream @@ -693,7 +694,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidMissingDTParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -715,7 +716,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidMissingRDParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -737,7 +738,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidMissingRCParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -761,7 +762,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryDTParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -783,7 +784,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryRCParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -807,7 +808,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalDTParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -833,7 +834,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalRCParameterTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -856,7 +857,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -878,7 +879,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseOverflowedNegativeNumberTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -899,7 +900,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedFloatTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -920,7 +921,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexNegOverflowedFloatTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -941,7 +942,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedDoubleTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -962,7 +963,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidLexUnderflowFloatTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced @@ -986,7 +987,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) { auto TokLoc = SourceLocation(); hlsl::RootSignatureLexer Lexer(Source, TokLoc); - SmallVector<RootElement> Elements; + SmallVector<RootSignatureElement> Elements; hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced >From be1cab36738cf32cc0eaddaa846ffca9469fbddf Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 27 Jun 2025 18:53:10 +0000 Subject: [PATCH 2/5] let `RootSignatureElement` retain its source location --- .../clang/Parse/ParseHLSLRootSignature.h | 7 +++++-- clang/lib/Parse/ParseHLSLRootSignature.cpp | 18 ++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index c51c3950b73df..a3b69c014a716 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -28,12 +28,15 @@ namespace hlsl { // Introduce a wrapper struct around the underlying RootElement. This structure // will retain extra clang diagnostic information that is not available in llvm. struct RootSignatureElement { - RootSignatureElement(llvm::hlsl::rootsig::RootElement Element) - : Element(Element) {} + RootSignatureElement(SourceLocation Loc, + llvm::hlsl::rootsig::RootElement Element) + : Loc(Loc), Element(Element) {} const llvm::hlsl::rootsig::RootElement &getElement() const { return Element; } + const SourceLocation &getLocation() const { return Loc; } private: + SourceLocation Loc; llvm::hlsl::rootsig::RootElement Element; }; diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index dd0c8d4ac9ca9..8b529b9de9228 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -27,39 +27,44 @@ bool RootSignatureParser::parse() { do { std::optional<RootSignatureElement> Element = std::nullopt; if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) { + SourceLocation ElementLoc = CurToken.TokLoc; auto Flags = parseRootFlags(); if (!Flags.has_value()) return true; - Element = RootSignatureElement(*Flags); + Element = RootSignatureElement(ElementLoc, *Flags); } if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) { + SourceLocation ElementLoc = CurToken.TokLoc; auto Constants = parseRootConstants(); if (!Constants.has_value()) return true; - Element = RootSignatureElement(*Constants); + Element = RootSignatureElement(ElementLoc, *Constants); } if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) { + SourceLocation ElementLoc = CurToken.TokLoc; auto Table = parseDescriptorTable(); if (!Table.has_value()) return true; - Element = RootSignatureElement(*Table); + Element = RootSignatureElement(ElementLoc, *Table); } if (tryConsumeExpectedToken( {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) { + SourceLocation ElementLoc = CurToken.TokLoc; auto Descriptor = parseRootDescriptor(); if (!Descriptor.has_value()) return true; - Element = RootSignatureElement(*Descriptor); + Element = RootSignatureElement(ElementLoc, *Descriptor); } if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) { + SourceLocation ElementLoc = CurToken.TokLoc; auto Sampler = parseStaticSampler(); if (!Sampler.has_value()) return true; - Element = RootSignatureElement(*Sampler); + Element = RootSignatureElement(ElementLoc, *Sampler); } if (Element.has_value()) @@ -252,10 +257,11 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() { do { if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV, TokenKind::kw_Sampler})) { + SourceLocation ElementLoc = CurToken.TokLoc; auto Clause = parseDescriptorTableClause(); if (!Clause.has_value()) return std::nullopt; - Elements.push_back(RootSignatureElement(*Clause)); + Elements.push_back(RootSignatureElement(ElementLoc, *Clause)); Table.NumClauses++; } >From 77a2194f4085f512c0faf6b67965926f1b2cbf38 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 27 Jun 2025 19:24:18 +0000 Subject: [PATCH 3/5] update resource range analysis to use retained source loc --- clang/include/clang/Sema/SemaHLSL.h | 4 +- clang/lib/Sema/SemaHLSL.cpp | 37 +++++++++++++++---- .../Frontend/HLSL/RootSignatureValidations.h | 3 ++ 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 1af706da702c2..5c944cbbd966b 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -139,7 +139,9 @@ class SemaHLSL : public SemaBase { ArrayRef<hlsl::RootSignatureElement> Elements); // Returns true when D is invalid and a diagnostic was produced - bool handleRootSignatureDecl(HLSLRootSignatureDecl *D, SourceLocation Loc); + bool + handleRootSignatureElements(ArrayRef<hlsl::RootSignatureElement> Elements, + SourceLocation Loc); void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL); void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 86902f17dab11..a7ebce1b0c498 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -28,6 +28,7 @@ #include "clang/Basic/SourceLocation.h" #include "clang/Basic/Specifiers.h" #include "clang/Basic/TargetInfo.h" +#include "clang/Parse/ParseHLSLRootSignature.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/ParsedAttr.h" @@ -1065,6 +1066,9 @@ void SemaHLSL::ActOnFinishRootSignatureDecl( SourceLocation Loc, IdentifierInfo *DeclIdent, ArrayRef<hlsl::RootSignatureElement> RootElements) { + if (handleRootSignatureElements(RootElements, Loc)) + return; + SmallVector<llvm::hlsl::rootsig::RootElement> Elements; for (auto &RootSigElement : RootElements) Elements.push_back(RootSigElement.getElement()); @@ -1073,15 +1077,12 @@ void SemaHLSL::ActOnFinishRootSignatureDecl( SemaRef.getASTContext(), /*DeclContext=*/SemaRef.CurContext, Loc, DeclIdent, SemaRef.getLangOpts().HLSLRootSigVer, Elements); - if (handleRootSignatureDecl(SignatureDecl, Loc)) - return; - SignatureDecl->setImplicit(); SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope()); } -bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, - SourceLocation Loc) { +bool SemaHLSL::handleRootSignatureElements( + ArrayRef<hlsl::RootSignatureElement> Elements, SourceLocation Loc) { // The following conducts analysis on resource ranges to detect and report // any overlaps in resource ranges. // @@ -1106,9 +1107,15 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, using ResourceRange = llvm::hlsl::rootsig::ResourceRange; using GroupT = std::pair<ResourceClass, /*Space*/ uint32_t>; + // Introduce a mapping from the collected RangeInfos back to the + // RootSignatureElement that will retain its diagnostics info + llvm::DenseMap<size_t, const hlsl::RootSignatureElement *> InfoIndexMap; + size_t InfoIndex = 0; + // 1. Collect RangeInfos llvm::SmallVector<RangeInfo> Infos; - for (const llvm::hlsl::rootsig::RootElement &Elem : D->getRootElements()) { + for (const hlsl::RootSignatureElement &RootSigElem : Elements) { + const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement(); if (const auto *Descriptor = std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) { RangeInfo Info; @@ -1119,6 +1126,9 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type)); Info.Space = Descriptor->Space; Info.Visibility = Descriptor->Visibility; + + Info.Index = InfoIndex++; + InfoIndexMap[Info.Index] = &RootSigElem; Infos.push_back(Info); } else if (const auto *Constants = std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) { @@ -1129,6 +1139,9 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, Info.Class = llvm::dxil::ResourceClass::CBuffer; Info.Space = Constants->Space; Info.Visibility = Constants->Visibility; + + Info.Index = InfoIndex++; + InfoIndexMap[Info.Index] = &RootSigElem; Infos.push_back(Info); } else if (const auto *Sampler = std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) { @@ -1139,6 +1152,9 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, Info.Class = llvm::dxil::ResourceClass::Sampler; Info.Space = Sampler->Space; Info.Visibility = Sampler->Visibility; + + Info.Index = InfoIndex++; + InfoIndexMap[Info.Index] = &RootSigElem; Infos.push_back(Info); } else if (const auto *Clause = std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>( @@ -1153,7 +1169,10 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, Info.Class = Clause->Type; Info.Space = Clause->Space; + // Note: Clause does not hold the visibility this will need to + Info.Index = InfoIndex++; + InfoIndexMap[Info.Index] = &RootSigElem; Infos.push_back(Info); } else if (const auto *Table = std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) { @@ -1200,13 +1219,15 @@ bool SemaHLSL::handleRootSignatureDecl(HLSLRootSignatureDecl *D, }; // Helper to report diagnostics - auto ReportOverlap = [this, Loc, &HadOverlap](const RangeInfo *Info, + auto ReportOverlap = [this, InfoIndexMap, &HadOverlap](const RangeInfo *Info, const RangeInfo *OInfo) { HadOverlap = true; auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All ? OInfo->Visibility : Info->Visibility; - this->Diag(Loc, diag::err_hlsl_resource_range_overlap) + const hlsl::RootSignatureElement *Elem = InfoIndexMap.at(Info->Index); + SourceLocation InfoLoc = Elem->getLocation(); + this->Diag(InfoLoc, diag::err_hlsl_resource_range_overlap) << llvm::to_underlying(Info->Class) << Info->LowerBound << /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded) << Info->UpperBound << llvm::to_underlying(OInfo->Class) diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 14eb7c482c08c..6ea07d43ad573 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -32,6 +32,9 @@ struct RangeInfo { llvm::dxil::ResourceClass Class; uint32_t Space; llvm::dxbc::ShaderVisibility Visibility; + + // The index retains its original position before being sorted by group. + size_t Index; }; class ResourceRange { >From fc958bdbf0804db42f6326123bb6dd057c1eeea1 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 27 Jun 2025 19:54:53 +0000 Subject: [PATCH 4/5] move wrapper definition to `SemaHLSL` - this struct needs to be accessible to both `Sema` and `Parse` and since `Parse` depends on `Sema` then we need to have it be included from there, so as to not introduce a circular dependency --- .../clang/Parse/ParseHLSLRootSignature.h | 16 +--------------- clang/include/clang/Sema/SemaHLSL.h | 19 +++++++++++++++++-- clang/lib/Sema/SemaHLSL.cpp | 1 - 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index a3b69c014a716..0ddcb5640ae6d 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -16,6 +16,7 @@ #include "clang/Basic/DiagnosticParse.h" #include "clang/Lex/LexHLSLRootSignature.h" #include "clang/Lex/Preprocessor.h" +#include "clang/Sema/SemaHLSL.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -25,21 +26,6 @@ namespace clang { namespace hlsl { -// Introduce a wrapper struct around the underlying RootElement. This structure -// will retain extra clang diagnostic information that is not available in llvm. -struct RootSignatureElement { - RootSignatureElement(SourceLocation Loc, - llvm::hlsl::rootsig::RootElement Element) - : Loc(Loc), Element(Element) {} - - const llvm::hlsl::rootsig::RootElement &getElement() const { return Element; } - const SourceLocation &getLocation() const { return Loc; } - -private: - SourceLocation Loc; - llvm::hlsl::rootsig::RootElement Element; -}; - class RootSignatureParser { public: RootSignatureParser(SmallVector<RootSignatureElement> &Elements, diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 5c944cbbd966b..910e0e640796b 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -33,8 +33,23 @@ class Scope; class VarDecl; namespace hlsl { -struct RootSignatureElement; -} + +// Introduce a wrapper struct around the underlying RootElement. This structure +// will retain extra clang diagnostic information that is not available in llvm. +struct RootSignatureElement { + RootSignatureElement(SourceLocation Loc, + llvm::hlsl::rootsig::RootElement Element) + : Loc(Loc), Element(Element) {} + + const llvm::hlsl::rootsig::RootElement &getElement() const { return Element; } + const SourceLocation &getLocation() const { return Loc; } + +private: + SourceLocation Loc; + llvm::hlsl::rootsig::RootElement Element; +}; + +} // namespace hlsl using llvm::dxil::ResourceClass; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index a7ebce1b0c498..dfae7c8119478 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -28,7 +28,6 @@ #include "clang/Basic/SourceLocation.h" #include "clang/Basic/Specifiers.h" #include "clang/Basic/TargetInfo.h" -#include "clang/Parse/ParseHLSLRootSignature.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/ParsedAttr.h" >From fcd69bfaff4334169442c84e42dd5c24ff0d6634 Mon Sep 17 00:00:00 2001 From: Finn Plummer <canadienf...@gmail.com> Date: Fri, 27 Jun 2025 19:58:15 +0000 Subject: [PATCH 5/5] nfc: move collection of overlaps to `RootSignatureValidations` --- clang/lib/Sema/SemaHLSL.cpp | 101 ++---------------- .../Frontend/HLSL/RootSignatureValidations.h | 32 ++++++ .../HLSL/RootSignatureValidations.cpp | 73 +++++++++++++ 3 files changed, 115 insertions(+), 91 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index dfae7c8119478..231c80137790d 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1082,29 +1082,8 @@ void SemaHLSL::ActOnFinishRootSignatureDecl( bool SemaHLSL::handleRootSignatureElements( ArrayRef<hlsl::RootSignatureElement> Elements, SourceLocation Loc) { - // The following conducts analysis on resource ranges to detect and report - // any overlaps in resource ranges. - // - // A resource range overlaps with another resource range if they have: - // - equivalent ResourceClass (SRV, UAV, CBuffer, Sampler) - // - equivalent resource space - // - overlapping visbility - // - // The following algorithm is implemented in the following steps: - // - // 1. Collect RangeInfo from relevant RootElements: - // - RangeInfo will retain the interval, ResourceClass, Space and Visibility - // 2. Sort the RangeInfo's such that they are grouped together by - // ResourceClass and Space (GroupT defined below) - // 3. Iterate through the collected RangeInfos by their groups - // - For each group we will have a ResourceRange for each visibility - // - As we iterate through we will: - // A: Insert the current RangeInfo into the corresponding Visibility - // ResourceRange - // B: Check for overlap with any overlapping Visibility ResourceRange using RangeInfo = llvm::hlsl::rootsig::RangeInfo; - using ResourceRange = llvm::hlsl::rootsig::ResourceRange; - using GroupT = std::pair<ResourceClass, /*Space*/ uint32_t>; + using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges; // Introduce a mapping from the collected RangeInfos back to the // RootSignatureElement that will retain its diagnostics info @@ -1187,40 +1166,10 @@ bool SemaHLSL::handleRootSignatureElements( } } - // 2. Sort the RangeInfo's by their GroupT to form groupings - std::sort(Infos.begin(), Infos.end(), [](RangeInfo A, RangeInfo B) { - return std::tie(A.Class, A.Space) < std::tie(B.Class, B.Space); - }); - - // 3. First we will init our state to track: - if (Infos.size() == 0) - return false; // No ranges to overlap - GroupT CurGroup = {Infos[0].Class, Infos[0].Space}; - bool HadOverlap = false; - - // Create a ResourceRange for each Visibility - ResourceRange::MapT::Allocator Allocator; - std::array<ResourceRange, 8> Ranges = { - ResourceRange(Allocator), // All - ResourceRange(Allocator), // Vertex - ResourceRange(Allocator), // Hull - ResourceRange(Allocator), // Domain - ResourceRange(Allocator), // Geometry - ResourceRange(Allocator), // Pixel - ResourceRange(Allocator), // Amplification - ResourceRange(Allocator), // Mesh - }; - - // Reset the ResourceRanges for when we iterate through a new group - auto ClearRanges = [&Ranges]() { - for (ResourceRange &Range : Ranges) - Range.clear(); - }; - // Helper to report diagnostics - auto ReportOverlap = [this, InfoIndexMap, &HadOverlap](const RangeInfo *Info, - const RangeInfo *OInfo) { - HadOverlap = true; + auto ReportOverlap = [this, InfoIndexMap](OverlappingRanges Overlap) { + const RangeInfo *Info = Overlap.A; + const RangeInfo *OInfo = Overlap.B; auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All ? OInfo->Visibility : Info->Visibility; @@ -1235,42 +1184,12 @@ bool SemaHLSL::handleRootSignatureElements( << OInfo->UpperBound << Info->Space << CommonVis; }; - // 3: Iterate through collected RangeInfos - for (const RangeInfo &Info : Infos) { - GroupT InfoGroup = {Info.Class, Info.Space}; - // Reset our ResourceRanges when we enter a new group - if (CurGroup != InfoGroup) { - ClearRanges(); - CurGroup = InfoGroup; - } - - // 3A: Insert range info into corresponding Visibility ResourceRange - ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)]; - if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info)) - ReportOverlap(&Info, Overlapping.value()); - - // 3B: Check for overlap in all overlapping Visibility ResourceRanges - // - // If the range that we are inserting has ShaderVisiblity::All it needs to - // check for an overlap in all other visibility types as well. - // Otherwise, the range that is inserted needs to check that it does not - // overlap with ShaderVisibility::All. - // - // OverlapRanges will be an ArrayRef to all non-all visibility - // ResourceRanges in the former case and it will be an ArrayRef to just the - // all visiblity ResourceRange in the latter case. - ArrayRef<ResourceRange> OverlapRanges = - Info.Visibility == llvm::dxbc::ShaderVisibility::All - ? ArrayRef<ResourceRange>{Ranges}.drop_front() - : ArrayRef<ResourceRange>{Ranges}.take_front(); - - for (const ResourceRange &Range : OverlapRanges) - if (std::optional<const RangeInfo *> Overlapping = - Range.getOverlapping(Info)) - ReportOverlap(&Info, Overlapping.value()); - } - - return HadOverlap; + llvm::SmallVector<OverlappingRanges> Overlaps = + llvm::hlsl::rootsig::findOverlappingRanges(Infos); + for (OverlappingRanges Overlap : Overlaps) + ReportOverlap(Overlap); + + return Overlaps.size() != 0; } void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) { diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h index 6ea07d43ad573..56c3e202519fd 100644 --- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h +++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h @@ -87,6 +87,38 @@ class ResourceRange { std::optional<const RangeInfo *> insert(const RangeInfo &Info); }; +struct OverlappingRanges { + const RangeInfo *A; + const RangeInfo *B; + + OverlappingRanges(const RangeInfo *A, const RangeInfo *B) : A(A), B(B) {} +}; + +/// The following conducts analysis on resource ranges to detect and report +/// any overlaps in resource ranges. +/// +/// A resource range overlaps with another resource range if they have: +/// - equivalent ResourceClass (SRV, UAV, CBuffer, Sampler) +/// - equivalent resource space +/// - overlapping visbility +/// +/// The algorithm is implemented in the following steps: +/// +/// 1. The user will collect RangeInfo from relevant RootElements: +/// - RangeInfo will retain the interval, ResourceClass, Space and Visibility +/// - It will also contain an index so that it can be associated to +/// additional diagnostic information +/// 2. Sort the RangeInfo's such that they are grouped together by +/// ResourceClass and Space +/// 3. Iterate through the collected RangeInfos by their groups +/// - For each group we will have a ResourceRange for each visibility +/// - As we iterate through we will: +/// A: Insert the current RangeInfo into the corresponding Visibility +/// ResourceRange +/// B: Check for overlap with any overlapping Visibility ResourceRange +llvm::SmallVector<OverlappingRanges> +findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos); + } // namespace rootsig } // namespace hlsl } // namespace llvm diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp index 9825946d59690..118b570538f9e 100644 --- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp +++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp @@ -79,6 +79,79 @@ std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) { return Res; } +llvm::SmallVector<OverlappingRanges> +findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos) { + // 1. The user has provided the corresponding range information + llvm::SmallVector<OverlappingRanges> Overlaps; + using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>; + + // 2. Sort the RangeInfo's by their GroupT to form groupings + std::sort(Infos.begin(), Infos.end(), [](RangeInfo A, RangeInfo B) { + return std::tie(A.Class, A.Space) < std::tie(B.Class, B.Space); + }); + + // 3. First we will init our state to track: + if (Infos.size() == 0) + return Overlaps; // No ranges to overlap + GroupT CurGroup = {Infos[0].Class, Infos[0].Space}; + + // Create a ResourceRange for each Visibility + ResourceRange::MapT::Allocator Allocator; + std::array<ResourceRange, 8> Ranges = { + ResourceRange(Allocator), // All + ResourceRange(Allocator), // Vertex + ResourceRange(Allocator), // Hull + ResourceRange(Allocator), // Domain + ResourceRange(Allocator), // Geometry + ResourceRange(Allocator), // Pixel + ResourceRange(Allocator), // Amplification + ResourceRange(Allocator), // Mesh + }; + + // Reset the ResourceRanges for when we iterate through a new group + auto ClearRanges = [&Ranges]() { + for (ResourceRange &Range : Ranges) + Range.clear(); + }; + + // 3: Iterate through collected RangeInfos + for (const RangeInfo &Info : Infos) { + GroupT InfoGroup = {Info.Class, Info.Space}; + // Reset our ResourceRanges when we enter a new group + if (CurGroup != InfoGroup) { + ClearRanges(); + CurGroup = InfoGroup; + } + + // 3A: Insert range info into corresponding Visibility ResourceRange + ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)]; + if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info)) + Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value())); + + // 3B: Check for overlap in all overlapping Visibility ResourceRanges + // + // If the range that we are inserting has ShaderVisiblity::All it needs to + // check for an overlap in all other visibility types as well. + // Otherwise, the range that is inserted needs to check that it does not + // overlap with ShaderVisibility::All. + // + // OverlapRanges will be an ArrayRef to all non-all visibility + // ResourceRanges in the former case and it will be an ArrayRef to just the + // all visiblity ResourceRange in the latter case. + ArrayRef<ResourceRange> OverlapRanges = + Info.Visibility == llvm::dxbc::ShaderVisibility::All + ? ArrayRef<ResourceRange>{Ranges}.drop_front() + : ArrayRef<ResourceRange>{Ranges}.take_front(); + + for (const ResourceRange &Range : OverlapRanges) + if (std::optional<const RangeInfo *> Overlapping = + Range.getOverlapping(Info)) + Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value())); + } + + return Overlaps; +} + } // namespace rootsig } // namespace hlsl } // namespace llvm _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits