Author: Finn Plummer Date: 2025-04-23T11:51:24-07:00 New Revision: b8e420e424b41f67019155055f4f600ba0454189
URL: https://github.com/llvm/llvm-project/commit/b8e420e424b41f67019155055f4f600ba0454189 DIFF: https://github.com/llvm/llvm-project/commit/b8e420e424b41f67019155055f4f600ba0454189.diff LOG: Reland "[HLSL][RootSignature] Implement initial parsing of the descriptor table clause params" (#136740) This pr relands #133800. It addresses the compilation error of using a shadowed name `Register` for both the struct name and the data member holding this type: `Register Register`. It resolves the issues my renaming the data members called `Register` to `Reg`. This issue was not caught as the current pre-merge checks do not include a build of `llvm;clang` using the gcc/g++ compilers and this is not erroneous with clang/clang++. Second part of #126569 --------- Co-authored-by: Finn Plummer <finnplum...@microsoft.com> Added: Modified: clang/include/clang/Basic/DiagnosticParseKinds.td clang/include/clang/Parse/ParseHLSLRootSignature.h clang/lib/Parse/ParseHLSLRootSignature.cpp clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h Removed: ################################################################################ diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td index 9975520f4f9ff..72e765bcb800d 100644 --- a/clang/include/clang/Basic/DiagnosticParseKinds.td +++ b/clang/include/clang/Basic/DiagnosticParseKinds.td @@ -1836,8 +1836,11 @@ def err_hlsl_virtual_function def err_hlsl_virtual_inheritance : Error<"virtual inheritance is unsupported in HLSL">; -// HLSL Root Siganture diagnostic messages +// HLSL Root Signature Parser Diagnostics def err_hlsl_unexpected_end_of_params : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">; +def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">; +def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">; +def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">; } // end of Parser diagnostics diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index a8dd6b02501ae..3eb3f8ea8422d 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -40,26 +40,31 @@ class RootSignatureParser { private: DiagnosticsEngine &getDiags() { return PP.getDiagnostics(); } - // All private Parse.* methods follow a similar pattern: + // All private parse.* methods follow a similar pattern: // - Each method will start with an assert to denote what the CurToken is // expected to be and will parse from that token forward // // - Therefore, it is the callers responsibility to ensure that you are // at the correct CurToken. This should be done with the pattern of: // - // if (TryConsumeExpectedToken(RootSignatureToken::Kind)) - // if (Parse.*()) - // return true; + // if (tryConsumeExpectedToken(RootSignatureToken::Kind)) { + // auto ParsedObject = parse.*(); + // if (!ParsedObject.has_value()) + // return std::nullopt; + // ... + // } // // or, // - // if (ConsumeExpectedToken(RootSignatureToken::Kind, ...)) - // return true; - // if (Parse.*()) - // return true; + // if (consumeExpectedToken(RootSignatureToken::Kind, ...)) + // return std::nullopt; + // auto ParsedObject = parse.*(); + // if (!ParsedObject.has_value()) + // return std::nullopt; + // ... // - // - All methods return true if a parsing error is encountered. It is the - // callers responsibility to propogate this error up, or deal with it + // - All methods return std::nullopt if a parsing error is encountered. It + // is the callers responsibility to propogate this error up, or deal with it // otherwise // // - An error will be raised if the proceeding tokens are not what is @@ -69,6 +74,23 @@ class RootSignatureParser { bool parseDescriptorTable(); bool parseDescriptorTableClause(); + /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any + /// order and only exactly once. `ParsedClauseParams` denotes the current + /// state of parsed params + struct ParsedClauseParams { + std::optional<llvm::hlsl::rootsig::Register> Reg; + std::optional<uint32_t> Space; + }; + std::optional<ParsedClauseParams> + parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType); + + std::optional<uint32_t> parseUIntParam(); + std::optional<llvm::hlsl::rootsig::Register> parseRegister(); + + /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned + /// 32-bit integer + std::optional<uint32_t> handleUIntLiteral(); + /// Invoke the Lexer to consume a token and update CurToken with the result void consumeNextToken() { CurToken = Lexer.consumeToken(); } diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 3513ef454f750..4f8bfccfa2243 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -8,6 +8,8 @@ #include "clang/Parse/ParseHLSLRootSignature.h" +#include "clang/Lex/LiteralSupport.h" + #include "llvm/Support/raw_ostream.h" using namespace llvm::hlsl::rootsig; @@ -41,12 +43,11 @@ bool RootSignatureParser::parse() { break; } - if (!tryConsumeExpectedToken(TokenKind::end_of_stream)) { - getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params) - << /*expected=*/TokenKind::end_of_stream - << /*param of=*/TokenKind::kw_RootSignature; + if (consumeExpectedToken(TokenKind::end_of_stream, + diag::err_hlsl_unexpected_end_of_params, + /*param of=*/TokenKind::kw_RootSignature)) return true; - } + return false; } @@ -72,12 +73,10 @@ bool RootSignatureParser::parseDescriptorTable() { break; } - if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) { - getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params) - << /*expected=*/TokenKind::pu_r_paren - << /*param of=*/TokenKind::kw_DescriptorTable; + if (consumeExpectedToken(TokenKind::pu_r_paren, + diag::err_hlsl_unexpected_end_of_params, + /*param of=*/TokenKind::kw_DescriptorTable)) return true; - } Elements.push_back(Table); return false; @@ -90,36 +89,170 @@ bool RootSignatureParser::parseDescriptorTableClause() { CurToken.TokKind == TokenKind::kw_Sampler) && "Expects to only be invoked starting at given keyword"); + TokenKind ParamKind = CurToken.TokKind; + + if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after, + CurToken.TokKind)) + return true; + DescriptorTableClause Clause; - switch (CurToken.TokKind) { + TokenKind ExpectedReg; + switch (ParamKind) { default: llvm_unreachable("Switch for consumed token was not provided"); case TokenKind::kw_CBV: Clause.Type = ClauseType::CBuffer; + ExpectedReg = TokenKind::bReg; break; case TokenKind::kw_SRV: Clause.Type = ClauseType::SRV; + ExpectedReg = TokenKind::tReg; break; case TokenKind::kw_UAV: Clause.Type = ClauseType::UAV; + ExpectedReg = TokenKind::uReg; break; case TokenKind::kw_Sampler: Clause.Type = ClauseType::Sampler; + ExpectedReg = TokenKind::sReg; break; } - if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after, - CurToken.TokKind)) + auto Params = parseDescriptorTableClauseParams(ExpectedReg); + if (!Params.has_value()) return true; - if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after, - CurToken.TokKind)) + // Check mandatory parameters were provided + if (!Params->Reg.has_value()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param) + << ExpectedReg; + return true; + } + + Clause.Reg = Params->Reg.value(); + + // Fill in optional values + if (Params->Space.has_value()) + Clause.Space = Params->Space.value(); + + if (consumeExpectedToken(TokenKind::pu_r_paren, + diag::err_hlsl_unexpected_end_of_params, + /*param of=*/ParamKind)) return true; Elements.push_back(Clause); return false; } +std::optional<RootSignatureParser::ParsedClauseParams> +RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) { + assert(CurToken.TokKind == TokenKind::pu_l_paren && + "Expects to only be invoked starting at given token"); + + // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any + // order and only exactly once. Parse through as many arguments as possible + // reporting an error if a duplicate is seen. + ParsedClauseParams Params; + do { + // ( `b` | `t` | `u` | `s`) POS_INT + if (tryConsumeExpectedToken(RegType)) { + if (Params.Reg.has_value()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << CurToken.TokKind; + return std::nullopt; + } + auto Reg = parseRegister(); + if (!Reg.has_value()) + return std::nullopt; + Params.Reg = Reg; + } + + // `space` `=` POS_INT + if (tryConsumeExpectedToken(TokenKind::kw_space)) { + if (Params.Space.has_value()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << CurToken.TokKind; + return std::nullopt; + } + + if (consumeExpectedToken(TokenKind::pu_equal)) + return std::nullopt; + + auto Space = parseUIntParam(); + if (!Space.has_value()) + return std::nullopt; + Params.Space = Space; + } + } while (tryConsumeExpectedToken(TokenKind::pu_comma)); + + return Params; +} + +std::optional<uint32_t> RootSignatureParser::parseUIntParam() { + assert(CurToken.TokKind == TokenKind::pu_equal && + "Expects to only be invoked starting at given keyword"); + tryConsumeExpectedToken(TokenKind::pu_plus); + if (consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after, + CurToken.TokKind)) + return std::nullopt; + return handleUIntLiteral(); +} + +std::optional<Register> RootSignatureParser::parseRegister() { + assert((CurToken.TokKind == TokenKind::bReg || + CurToken.TokKind == TokenKind::tReg || + CurToken.TokKind == TokenKind::uReg || + CurToken.TokKind == TokenKind::sReg) && + "Expects to only be invoked starting at given keyword"); + + Register Reg; + switch (CurToken.TokKind) { + default: + llvm_unreachable("Switch for consumed token was not provided"); + case TokenKind::bReg: + Reg.ViewType = RegisterType::BReg; + break; + case TokenKind::tReg: + Reg.ViewType = RegisterType::TReg; + break; + case TokenKind::uReg: + Reg.ViewType = RegisterType::UReg; + break; + case TokenKind::sReg: + Reg.ViewType = RegisterType::SReg; + break; + } + + auto Number = handleUIntLiteral(); + if (!Number.has_value()) + return std::nullopt; // propogate NumericLiteralParser error + + Reg.Number = *Number; + return Reg; +} + +std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() { + // Parse the numeric value and do semantic checks on its specification + clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc, + PP.getSourceManager(), PP.getLangOpts(), + PP.getTargetInfo(), PP.getDiagnostics()); + if (Literal.hadError) + return true; // Error has already been reported so just return + + assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits"); + + llvm::APSInt Val = llvm::APSInt(32, false); + if (Literal.GetIntegerValue(Val)) { + // Report that the value has overflowed + PP.getDiagnostics().Report(CurToken.TokLoc, + diag::err_hlsl_number_literal_overflow) + << 0 << CurToken.NumSpelling; + return std::nullopt; + } + + return Val.getExtValue(); +} + bool RootSignatureParser::peekExpectedToken(TokenKind Expected) { return peekExpectedToken(ArrayRef{Expected}); } @@ -141,6 +274,7 @@ bool RootSignatureParser::consumeExpectedToken(TokenKind Expected, case diag::err_expected: DB << Expected; break; + case diag::err_hlsl_unexpected_end_of_params: case diag::err_expected_either: case diag::err_expected_after: DB << Expected << Context; diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index 19d5b267f310a..e382a1b26d366 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -129,10 +129,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) { TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { const llvm::StringLiteral Source = R"cc( DescriptorTable( - CBV(), - SRV(), - Sampler(), - UAV() + CBV(b0), + SRV(space = 3, t42), + Sampler(s987, space = +2), + UAV(u4294967294) ), DescriptorTable() )cc"; @@ -154,18 +154,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { RootElement Elem = Elements[0]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, + RegisterType::BReg); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 0u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u); Elem = Elements[1]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, + RegisterType::TReg); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 42u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u); Elem = Elements[2]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, + RegisterType::SReg); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 987u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u); Elem = Elements[3]; ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem)); ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.ViewType, + RegisterType::UReg); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Reg.Number, 4294967294u); + ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u); Elem = Elements[4]; ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem)); @@ -175,6 +191,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { Elem = Elements[5]; ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem)); ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + +TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) { + // This test will checks we can handling trailing commas ',' + const llvm::StringLiteral Source = R"cc( + DescriptorTable( + CBV(b0, ), + SRV(t42), + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test no diagnostics produced + Consumer->setNoDiag(); + + ASSERT_FALSE(Parser.parse()); + ASSERT_TRUE(Consumer->isSatisfied()); } @@ -236,6 +278,102 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) { // Test correct diagnostic produced - end of stream Consumer->setExpected(diag::err_expected_after); + + ASSERT_TRUE(Parser.parse()); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + +TEST_F(ParseHLSLRootSignatureTest, InvalidMissingParameterTest) { + // This test will check that the parsing fails due a mandatory + // parameter (register) not being specified + const llvm::StringLiteral Source = R"cc( + DescriptorTable( + CBV() + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test correct diagnostic produced + Consumer->setExpected(diag::err_hlsl_rootsig_missing_param); + ASSERT_TRUE(Parser.parse()); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + +TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryParameterTest) { + // This test will check that the parsing fails due the same mandatory + // parameter being specified multiple times + const llvm::StringLiteral Source = R"cc( + DescriptorTable( + CBV(b32, b84) + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test correct diagnostic produced + Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param); + ASSERT_TRUE(Parser.parse()); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + +TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalParameterTest) { + // This test will check that the parsing fails due the same optional + // parameter being specified multiple times + const llvm::StringLiteral Source = R"cc( + DescriptorTable( + CBV(space = 2, space = 0) + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test correct diagnostic produced + Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param); + ASSERT_TRUE(Parser.parse()); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + +TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) { + // This test will check that the lexing fails due to an integer overflow + const llvm::StringLiteral Source = R"cc( + DescriptorTable( + CBV(b4294967296) + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector<RootElement> Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test correct diagnostic produced + Consumer->setExpected(diag::err_hlsl_number_literal_overflow); ASSERT_TRUE(Parser.parse()); ASSERT_TRUE(Consumer->isSatisfied()); diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index c1b67844c747f..778b0c397f9cf 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -23,6 +23,13 @@ namespace rootsig { // Definitions of the in-memory data layout structures +// Models the diff erent registers: bReg | tReg | uReg | sReg +enum class RegisterType { BReg, TReg, UReg, SReg }; +struct Register { + RegisterType ViewType; + uint32_t Number; +}; + // Models the end of a descriptor table and stores its visibility struct DescriptorTable { uint32_t NumClauses = 0; // The number of clauses in the table @@ -32,6 +39,8 @@ struct DescriptorTable { using ClauseType = llvm::dxil::ResourceClass; struct DescriptorTableClause { ClauseType Type; + Register Reg; + uint32_t Space = 0; }; // Models RootElement : DescriptorTable | DescriptorTableClause _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits