================ @@ -0,0 +1,417 @@ +#include "clang/Parse/ParseHLSLRootSignature.h" + +#include "clang/Lex/LiteralSupport.h" + +#include "llvm/Support/raw_ostream.h" + +using namespace llvm::hlsl::rootsig; + +namespace clang { +namespace hlsl { + +static std::string FormatTokenKinds(ArrayRef<TokenKind> Kinds) { + std::string TokenString; + llvm::raw_string_ostream Out(TokenString); + bool First = true; + for (auto Kind : Kinds) { + if (!First) + Out << ", "; + switch (Kind) { +#define TOK(X, SPELLING) \ + case TokenKind::X: \ + Out << SPELLING; \ + break; +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + } + First = false; + } + + return TokenString; +} + +// Parser Definitions + +RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements, + RootSignatureLexer &Lexer, + Preprocessor &PP) + : Elements(Elements), Lexer(Lexer), PP(PP), CurToken(SourceLocation()) {} + +bool RootSignatureParser::Parse() { + // Handle edge-case of empty RootSignature() + if (Lexer.EndOfBuffer()) + return false; + + // Iterate as many RootElements as possible + while (!ParseRootElement()) { + if (Lexer.EndOfBuffer()) + return false; + if (ConsumeExpectedToken(TokenKind::pu_comma)) + return true; + } + + return true; +} + +bool RootSignatureParser::ParseRootElement() { + if (ConsumeExpectedToken(TokenKind::kw_DescriptorTable)) + return true; + + // Dispatch onto the correct parse method + switch (CurToken.Kind) { + case TokenKind::kw_DescriptorTable: + return ParseDescriptorTable(); + default: + break; + } + llvm_unreachable("Switch for an expected token was not provided"); +} + +bool RootSignatureParser::ParseDescriptorTable() { + DescriptorTable Table; + + if (ConsumeExpectedToken(TokenKind::pu_l_paren)) + return true; + + // Empty case: + if (TryConsumeExpectedToken(TokenKind::pu_r_paren)) { + Elements.push_back(Table); + return false; + } + + bool SeenVisibility = false; + // Iterate through all the defined clauses + do { + // Handle the visibility parameter + if (TryConsumeExpectedToken(TokenKind::kw_visibility)) { + if (SeenVisibility) { + Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << FormatTokenKinds(CurToken.Kind); + return true; + } + SeenVisibility = true; + if (ParseParam(&Table.Visibility)) + return true; + continue; + } + + // Otherwise, we expect a clause + if (ParseDescriptorTableClause()) + return true; + Table.NumClauses++; + } while (TryConsumeExpectedToken(TokenKind::pu_comma)); + + if (ConsumeExpectedToken(TokenKind::pu_r_paren)) + return true; + + Elements.push_back(Table); + return false; +} + +bool RootSignatureParser::ParseDescriptorTableClause() { + if (ConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV, + TokenKind::kw_UAV, TokenKind::kw_Sampler})) + return true; + + DescriptorTableClause Clause; + switch (CurToken.Kind) { + case TokenKind::kw_CBV: + Clause.Type = ClauseType::CBuffer; + break; + case TokenKind::kw_SRV: + Clause.Type = ClauseType::SRV; + break; + case TokenKind::kw_UAV: + Clause.Type = ClauseType::UAV; + break; + case TokenKind::kw_Sampler: + Clause.Type = ClauseType::Sampler; + break; + default: + llvm_unreachable("Switch for an expected token was not provided"); + } + Clause.SetDefaultFlags(); + + if (ConsumeExpectedToken(TokenKind::pu_l_paren)) + return true; + + // Consume mandatory Register paramater + if (ParseRegister(&Clause.Register)) + return true; + + // Define optional paramaters + llvm::SmallDenseMap<TokenKind, ParamType> RefMap = { + {TokenKind::kw_numDescriptors, &Clause.NumDescriptors}, + {TokenKind::kw_space, &Clause.Space}, + {TokenKind::kw_offset, &Clause.Offset}, + {TokenKind::kw_flags, &Clause.Flags}, + }; + if (ParseOptionalParams({RefMap})) + return true; + + if (ConsumeExpectedToken(TokenKind::pu_r_paren)) + return true; + + Elements.push_back(Clause); + return false; +} + +// Helper struct so that we can use the overloaded notation of std::visit +template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; }; +template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>; + +bool RootSignatureParser::ParseParam(ParamType Ref) { + if (ConsumeExpectedToken(TokenKind::pu_equal)) + return true; + + bool Error; + std::visit( + ParseMethods{ + [&](uint32_t *X) { Error = ParseUInt(X); }, + [&](DescriptorRangeOffset *X) { + Error = ParseDescriptorRangeOffset(X); + }, + [&](ShaderVisibility *Enum) { Error = ParseShaderVisibility(Enum); }, + [&](DescriptorRangeFlags *Flags) { + Error = ParseDescriptorRangeFlags(Flags); + }, + }, + Ref); + + return Error; +} + +bool RootSignatureParser::ParseOptionalParams( + llvm::SmallDenseMap<TokenKind, ParamType> &RefMap) { + SmallVector<TokenKind> ParamKeywords; + for (auto RefPair : RefMap) + ParamKeywords.push_back(RefPair.first); + + // Keep track of which keywords have been seen to report duplicates + llvm::SmallDenseSet<TokenKind> Seen; + + while (TryConsumeExpectedToken(TokenKind::pu_comma)) { + if (ConsumeExpectedToken(ParamKeywords)) + return true; + + TokenKind ParamKind = CurToken.Kind; + if (Seen.contains(ParamKind)) { + Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << FormatTokenKinds({ParamKind}); + return true; + } + Seen.insert(ParamKind); + + if (ParseParam(RefMap[ParamKind])) + return true; + } + + return false; +} + +bool RootSignatureParser::HandleUIntLiteral(uint32_t &X) { + // Parse the numeric value and do semantic checks on its specification + clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc, + PP.getSourceManager(), PP.getLangOpts(), + PP.getTargetInfo(), PP.getDiagnostics()); + if (Literal.hadError) + return true; // Error has already been reported so just return + + assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits"); + + llvm::APSInt Val = llvm::APSInt(32, false); + if (Literal.GetIntegerValue(Val)) { + // Report that the value has overflowed + PP.getDiagnostics().Report(CurToken.TokLoc, + diag::err_hlsl_number_literal_overflow) + << 0 << CurToken.NumSpelling; + return true; + } + + X = Val.getExtValue(); + return false; +} + +bool RootSignatureParser::ParseRegister(Register *Register) { + if (ConsumeExpectedToken( + {TokenKind::bReg, TokenKind::tReg, TokenKind::uReg, TokenKind::sReg})) + return true; + + switch (CurToken.Kind) { + case TokenKind::bReg: + Register->ViewType = RegisterType::BReg; + break; + case TokenKind::tReg: + Register->ViewType = RegisterType::TReg; + break; + case TokenKind::uReg: + Register->ViewType = RegisterType::UReg; + break; + case TokenKind::sReg: + Register->ViewType = RegisterType::SReg; + break; + default: + llvm_unreachable("Switch for an expected token was not provided"); + } + + if (HandleUIntLiteral(Register->Number)) + return true; // propogate NumericLiteralParser error + + return false; +} + +bool RootSignatureParser::ParseUInt(uint32_t *X) { + // Treat a postively signed integer as though it is unsigned to match DXC + TryConsumeExpectedToken(TokenKind::pu_plus); + if (ConsumeExpectedToken(TokenKind::int_literal)) + return true; + + if (HandleUIntLiteral(*X)) + return true; // propogate NumericLiteralParser error + + return false; +} + +bool RootSignatureParser::ParseDescriptorRangeOffset(DescriptorRangeOffset *X) { + if (ConsumeExpectedToken( + {TokenKind::int_literal, TokenKind::en_DescriptorRangeOffsetAppend})) + return true; + + // Edge case for the offset enum -> static value + if (CurToken.Kind == TokenKind::en_DescriptorRangeOffsetAppend) { + *X = DescriptorTableOffsetAppend; + return false; + } + + uint32_t Temp; + if (HandleUIntLiteral(Temp)) + return true; // propogate NumericLiteralParser error + *X = DescriptorRangeOffset(Temp); + return false; +} + +template <bool AllowZero, typename EnumType> +bool RootSignatureParser::ParseEnum( + llvm::SmallDenseMap<TokenKind, EnumType> &EnumMap, EnumType *Enum) { + SmallVector<TokenKind> EnumToks; + if (AllowZero) + EnumToks.push_back(TokenKind::int_literal); // '0' is a valid flag value + for (auto EnumPair : EnumMap) + EnumToks.push_back(EnumPair.first); + + // If invoked we expect to have an enum + if (ConsumeExpectedToken(EnumToks)) + return true; + + // Handle the edge case when '0' is used to specify None + if (CurToken.Kind == TokenKind::int_literal) { + uint32_t Temp; + if (HandleUIntLiteral(Temp)) + return true; // propogate NumericLiteralParser error + if (Temp != 0) { + Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag); + return true; + } + // Set enum to None equivalent + *Enum = EnumType(0); + return false; + } + + // Effectively a switch statement on the token kinds + for (auto EnumPair : EnumMap) + if (CurToken.Kind == EnumPair.first) { + *Enum = EnumPair.second; + return false; + } + + llvm_unreachable("Switch for an expected token was not provided"); +} + +bool RootSignatureParser::ParseShaderVisibility(ShaderVisibility *Enum) { + // Define the possible flag kinds + llvm::SmallDenseMap<TokenKind, ShaderVisibility> EnumMap = { +#define SHADER_VISIBILITY_ENUM(NAME, LIT) \ + {TokenKind::en_##NAME, ShaderVisibility::NAME}, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + return ParseEnum(EnumMap, Enum); +} + +template <typename FlagType> +bool RootSignatureParser::ParseFlags( + llvm::SmallDenseMap<TokenKind, FlagType> &FlagMap, FlagType *Flags) { + // Override the default value to 0 so that we can correctly 'or' the values + *Flags = FlagType(0); + + do { + FlagType Flag; + if (ParseEnum<true>(FlagMap, &Flag)) + return true; + // Store the 'or' + *Flags |= Flag; + } while (TryConsumeExpectedToken(TokenKind::pu_or)); + + return false; +} + +bool RootSignatureParser::ParseDescriptorRangeFlags( + DescriptorRangeFlags *Flags) { + // Define the possible flag kinds + llvm::SmallDenseMap<TokenKind, DescriptorRangeFlags> FlagMap = { +#define DESCRIPTOR_RANGE_FLAG_ENUM(NAME, LIT, ON) \ + {TokenKind::en_##NAME, DescriptorRangeFlags::NAME}, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + return ParseFlags(FlagMap, Flags); +} + +// Is given token one of the expected kinds +static bool IsExpectedToken(TokenKind Kind, ArrayRef<TokenKind> AnyExpected) { + for (auto Expected : AnyExpected) + if (Kind == Expected) + return true; + return false; +} + +bool RootSignatureParser::PeekExpectedToken(TokenKind Expected) { + return PeekExpectedToken(ArrayRef{Expected}); +} + +bool RootSignatureParser::PeekExpectedToken(ArrayRef<TokenKind> AnyExpected) { + RootSignatureToken Result = Lexer.PeekNextToken(); + return IsExpectedToken(Result.Kind, AnyExpected); +} + +bool RootSignatureParser::ConsumeExpectedToken(TokenKind Expected) { + return ConsumeExpectedToken(ArrayRef{Expected}); +} + +bool RootSignatureParser::ConsumeExpectedToken( + ArrayRef<TokenKind> AnyExpected) { + ConsumeNextToken(); + if (IsExpectedToken(CurToken.Kind, AnyExpected)) + return false; + + // Report unexpected token kind error + Diags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_unexpected_token_kind) ---------------- inbelic wrote:
I think this does effectively using a similar pattern (albeit with different names, which is confusing) as Clang, this function would be equivalent to the `ExpectAndConsume` function in the Clang's implementation. `ExpectAndConsume` does however allow to pass down custom context specific diag messages. I have added a quick prototype commit to allow using custom diag messages to `ConsumeExpectedToken`. Is this what you had intended? Do you think we should also change the function names to align with the Clang parser? https://github.com/llvm/llvm-project/pull/122982 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits