Author: Finn Plummer
Date: 2025-07-10T10:52:20-07:00
New Revision: d60da27400cd96855542cd992d326c10a34dd0f7

URL: 
https://github.com/llvm/llvm-project/commit/d60da27400cd96855542cd992d326c10a34dd0f7
DIFF: 
https://github.com/llvm/llvm-project/commit/d60da27400cd96855542cd992d326c10a34dd0f7.diff

LOG: [HLSL][RootSignature] Implement diagnostic for missed comma (#147350)

This pr fixes a bug that allows parameters to be specified without an
intermediate comma.

After this pr, we will correctly produce a diagnostic for (eg):
```
RootFlags(0) CBV(b0)
```

This pr updates the problematic code pattern containing a chain of 'if'
statements to a chain of 'else if' statements, to prevent parsing of an
element before checking for a comma.

This pr also does 2 small updates, while in the region:
1. Simplify the `do` loop that these `if` statements are contained in.
This helps code readability and makes it easier to improve the
diagnostics further
2. Moves the `consumeExpectedToken` function calls to be right after the
`parse.*Params` invocation. This will ensure that the comma or invalid
token error is presented before a "missed mandatory param" diagnostic.

- Updates all occurrences of the if chains with an else-if chain
- Simplifies the surrounding `do` loop to be an easier to understand
`while` loop
- Moves the `consumeExpectedToken` diagnostic right after the loop so
that the missing comma diagnostic is produce before checking for any
missed mandatory arguments
- Adds unit tests for this scenario
- Small fix to the diagnostic of `RootDescriptors` to use their
respective `Token` instead of `RootConstants`

Resolves: https://github.com/llvm/llvm-project/issues/147337

Added: 
    clang/test/SemaHLSL/RootSignature.hlsl

Modified: 
    clang/lib/Parse/ParseHLSLRootSignature.cpp
    clang/test/SemaHLSL/RootSignature-err.hlsl
    clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp 
b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index cf86c62f3b671..dc5f6faefbab4 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,44 +25,41 @@ RootSignatureParser::RootSignatureParser(
       Lexer(Signature->getString()), PP(PP), CurToken(0) {}
 
 bool RootSignatureParser::parse() {
-  // Iterate as many RootElements as possible
-  do {
+  // Iterate as many RootSignatureElements as possible, until we hit the
+  // end of the stream
+  while (!peekExpectedToken(TokenKind::end_of_stream)) {
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
       auto Flags = parseRootFlags();
       if (!Flags.has_value())
         return true;
       Elements.push_back(*Flags);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
       auto Constants = parseRootConstants();
       if (!Constants.has_value())
         return true;
       Elements.push_back(*Constants);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
       auto Table = parseDescriptorTable();
       if (!Table.has_value())
         return true;
       Elements.push_back(*Table);
-    }
-
-    if (tryConsumeExpectedToken(
-            {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+    } else if (tryConsumeExpectedToken(
+                   {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) 
{
       auto Descriptor = parseRootDescriptor();
       if (!Descriptor.has_value())
         return true;
       Elements.push_back(*Descriptor);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
       auto Sampler = parseStaticSampler();
       if (!Sampler.has_value())
         return true;
       Elements.push_back(*Sampler);
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at end of stream
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return consumeExpectedToken(TokenKind::end_of_stream,
                               diag::err_hlsl_unexpected_end_of_params,
@@ -139,6 +136,11 @@ std::optional<RootConstants> 
RootSignatureParser::parseRootConstants() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_RootConstants))
+    return std::nullopt;
+
   // Check mandatory parameters where provided
   if (!Params->Num32BitConstants.has_value()) {
     reportDiag(diag::err_hlsl_rootsig_missing_param)
@@ -162,11 +164,6 @@ std::optional<RootConstants> 
RootSignatureParser::parseRootConstants() {
   if (Params->Space.has_value())
     Constants.Space = Params->Space.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Constants;
 }
 
@@ -206,6 +203,11 @@ std::optional<RootDescriptor> 
RootSignatureParser::parseRootDescriptor() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/DescriptorKind))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
     reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
@@ -224,11 +226,6 @@ std::optional<RootDescriptor> 
RootSignatureParser::parseRootDescriptor() {
   if (Params->Flags.has_value())
     Descriptor.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Descriptor;
 }
 
@@ -243,18 +240,18 @@ std::optional<DescriptorTable> 
RootSignatureParser::parseDescriptorTable() {
   DescriptorTable Table;
   std::optional<llvm::dxbc::ShaderVisibility> Visibility;
 
-  // Iterate as many Clauses as possible
-  do {
+  // Iterate as many Clauses as possible, until we hit ')'
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+      // DescriptorTableClause - CBV, SRV, UAV, or Sampler
       auto Clause = parseDescriptorTableClause();
       if (!Clause.has_value())
         return std::nullopt;
       Elements.push_back(*Clause);
       Table.NumClauses++;
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // visibility = SHADER_VISIBILITY
       if (Visibility.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -267,17 +264,21 @@ std::optional<DescriptorTable> 
RootSignatureParser::parseDescriptorTable() {
       if (!Visibility.has_value())
         return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
 
-  // Fill in optional visibility
-  if (Visibility.has_value())
-    Table.Visibility = Visibility.value();
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/TokenKind::kw_DescriptorTable))
     return std::nullopt;
 
+  // Fill in optional visibility
+  if (Visibility.has_value())
+    Table.Visibility = Visibility.value();
+
   return Table;
 }
 
@@ -323,6 +324,11 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/ParamKind))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
     reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
@@ -344,11 +350,6 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Flags.has_value())
     Clause.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/ParamKind))
-    return std::nullopt;
-
   return Clause;
 }
 
@@ -366,6 +367,11 @@ std::optional<StaticSampler> 
RootSignatureParser::parseStaticSampler() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren,
+                           diag::err_hlsl_unexpected_end_of_params,
+                           /*param of=*/TokenKind::kw_StaticSampler))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
     reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
@@ -411,11 +417,6 @@ std::optional<StaticSampler> 
RootSignatureParser::parseStaticSampler() {
   if (Params->Visibility.has_value())
     Sampler.Visibility = Params->Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_StaticSampler))
-    return std::nullopt;
-
   return Sampler;
 }
 
@@ -428,9 +429,9 @@ RootSignatureParser::parseRootConstantParams() {
          "Expects to only be invoked starting at given token");
 
   ParsedConstantParams Params;
-  do {
-    // `num32BitConstants` `=` POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      // `num32BitConstants` `=` POS_INT
       if (Params.Num32BitConstants.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -443,10 +444,8 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Num32BitConstants.has_value())
         return std::nullopt;
       Params.Num32BitConstants = Num32BitConstants;
-    }
-
-    // `b` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+    } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      // `b` POS_INT
       if (Params.Reg.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -455,10 +454,8 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -471,10 +468,8 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -488,7 +483,11 @@ RootSignatureParser::parseRootConstantParams() {
         return std::nullopt;
       Params.Visibility = Visibility;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
@@ -499,9 +498,9 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind 
RegType) {
          "Expects to only be invoked starting at given token");
 
   ParsedRootDescriptorParams Params;
-  do {
-    // ( `b` | `t` | `u`) POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(RegType)) {
+      // ( `b` | `t` | `u`) POS_INT
       if (Params.Reg.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -510,10 +509,8 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind 
RegType) {
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -526,10 +523,8 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind 
RegType) {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -542,10 +537,8 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind 
RegType) {
       if (!Visibility.has_value())
         return std::nullopt;
       Params.Visibility = Visibility;
-    }
-
-    // `flags` `=` ROOT_DESCRIPTOR_FLAGS
-    if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+      // `flags` `=` ROOT_DESCRIPTOR_FLAGS
       if (Params.Flags.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -559,7 +552,11 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind 
RegType) {
         return std::nullopt;
       Params.Flags = Flags;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
@@ -570,9 +567,9 @@ 
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
          "Expects to only be invoked starting at given token");
 
   ParsedClauseParams Params;
-  do {
-    // ( `b` | `t` | `u` | `s`) POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(RegType)) {
+      // ( `b` | `t` | `u` | `s`) POS_INT
       if (Params.Reg.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -581,10 +578,8 @@ 
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `numDescriptors` `=` POS_INT | unbounded
-    if (tryConsumeExpectedToken(TokenKind::kw_numDescriptors)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_numDescriptors)) {
+      // `numDescriptors` `=` POS_INT | unbounded
       if (Params.NumDescriptors.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -603,10 +598,8 @@ 
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
       }
 
       Params.NumDescriptors = NumDescriptors;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -619,10 +612,8 @@ 
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `offset` `=` POS_INT | DESCRIPTOR_RANGE_OFFSET_APPEND
-    if (tryConsumeExpectedToken(TokenKind::kw_offset)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_offset)) {
+      // `offset` `=` POS_INT | DESCRIPTOR_RANGE_OFFSET_APPEND
       if (Params.Offset.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -641,10 +632,8 @@ 
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
       }
 
       Params.Offset = Offset;
-    }
-
-    // `flags` `=` DESCRIPTOR_RANGE_FLAGS
-    if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
+      // `flags` `=` DESCRIPTOR_RANGE_FLAGS
       if (Params.Flags.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -659,7 +648,10 @@ 
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
       Params.Flags = Flags;
     }
 
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
@@ -670,9 +662,9 @@ RootSignatureParser::parseStaticSamplerParams() {
          "Expects to only be invoked starting at given token");
 
   ParsedStaticSamplerParams Params;
-  do {
-    // `s` POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(TokenKind::sReg)) {
+      // `s` POS_INT
       if (Params.Reg.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -681,10 +673,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `filter` `=` FILTER
-    if (tryConsumeExpectedToken(TokenKind::kw_filter)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_filter)) {
+      // `filter` `=` FILTER
       if (Params.Filter.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -697,10 +687,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!Filter.has_value())
         return std::nullopt;
       Params.Filter = Filter;
-    }
-
-    // `addressU` `=` TEXTURE_ADDRESS
-    if (tryConsumeExpectedToken(TokenKind::kw_addressU)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_addressU)) {
+      // `addressU` `=` TEXTURE_ADDRESS
       if (Params.AddressU.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -713,10 +701,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!AddressU.has_value())
         return std::nullopt;
       Params.AddressU = AddressU;
-    }
-
-    // `addressV` `=` TEXTURE_ADDRESS
-    if (tryConsumeExpectedToken(TokenKind::kw_addressV)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_addressV)) {
+      // `addressV` `=` TEXTURE_ADDRESS
       if (Params.AddressV.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -729,10 +715,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!AddressV.has_value())
         return std::nullopt;
       Params.AddressV = AddressV;
-    }
-
-    // `addressW` `=` TEXTURE_ADDRESS
-    if (tryConsumeExpectedToken(TokenKind::kw_addressW)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_addressW)) {
+      // `addressW` `=` TEXTURE_ADDRESS
       if (Params.AddressW.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -745,10 +729,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!AddressW.has_value())
         return std::nullopt;
       Params.AddressW = AddressW;
-    }
-
-    // `mipLODBias` `=` NUMBER
-    if (tryConsumeExpectedToken(TokenKind::kw_mipLODBias)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_mipLODBias)) {
+      // `mipLODBias` `=` NUMBER
       if (Params.MipLODBias.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -761,10 +743,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!MipLODBias.has_value())
         return std::nullopt;
       Params.MipLODBias = MipLODBias;
-    }
-
-    // `maxAnisotropy` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_maxAnisotropy)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_maxAnisotropy)) {
+      // `maxAnisotropy` `=` POS_INT
       if (Params.MaxAnisotropy.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -777,10 +757,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!MaxAnisotropy.has_value())
         return std::nullopt;
       Params.MaxAnisotropy = MaxAnisotropy;
-    }
-
-    // `comparisonFunc` `=` COMPARISON_FUNC
-    if (tryConsumeExpectedToken(TokenKind::kw_comparisonFunc)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_comparisonFunc)) {
+      // `comparisonFunc` `=` COMPARISON_FUNC
       if (Params.CompFunc.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -793,10 +771,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!CompFunc.has_value())
         return std::nullopt;
       Params.CompFunc = CompFunc;
-    }
-
-    // `borderColor` `=` STATIC_BORDER_COLOR
-    if (tryConsumeExpectedToken(TokenKind::kw_borderColor)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_borderColor)) {
+      // `borderColor` `=` STATIC_BORDER_COLOR
       if (Params.BorderColor.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -809,10 +785,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!BorderColor.has_value())
         return std::nullopt;
       Params.BorderColor = BorderColor;
-    }
-
-    // `minLOD` `=` NUMBER
-    if (tryConsumeExpectedToken(TokenKind::kw_minLOD)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_minLOD)) {
+      // `minLOD` `=` NUMBER
       if (Params.MinLOD.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -825,10 +799,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!MinLOD.has_value())
         return std::nullopt;
       Params.MinLOD = MinLOD;
-    }
-
-    // `maxLOD` `=` NUMBER
-    if (tryConsumeExpectedToken(TokenKind::kw_maxLOD)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_maxLOD)) {
+      // `maxLOD` `=` NUMBER
       if (Params.MaxLOD.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -841,10 +813,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!MaxLOD.has_value())
         return std::nullopt;
       Params.MaxLOD = MaxLOD;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -857,10 +827,8 @@ RootSignatureParser::parseStaticSamplerParams() {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
         reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
@@ -874,7 +842,11 @@ RootSignatureParser::parseStaticSamplerParams() {
         return std::nullopt;
       Params.Visibility = Visibility;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }

diff  --git a/clang/test/SemaHLSL/RootSignature-err.hlsl 
b/clang/test/SemaHLSL/RootSignature-err.hlsl
index 118fc38daf3f2..04013974d28b9 100644
--- a/clang/test/SemaHLSL/RootSignature-err.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-err.hlsl
@@ -34,3 +34,7 @@ void bad_root_signature_5() {}
 // expected-error@+1 {{expected ')' to denote end of parameters, or, another 
valid parameter of RootConstants}}
 [RootSignature(MultiLineRootSignature)]
 void bad_root_signature_6() {}
+
+// expected-error@+1 {{expected end of stream to denote end of parameters, or, 
another valid parameter of RootSignature}}
+[RootSignature("RootFlags() RootConstants(b0, num32BitConstants = 1)")]
+void bad_root_signature_7() {}

diff  --git a/clang/test/SemaHLSL/RootSignature.hlsl 
b/clang/test/SemaHLSL/RootSignature.hlsl
new file mode 100644
index 0000000000000..810f81479caab
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature.hlsl
@@ -0,0 +1,13 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl 
-fsyntax-only %s -verify
+
+// expected-no-diagnostics
+
+// Test that we have consistent behaviour for comma parsing. Namely:
+// - a single trailing comma is allowed after any parameter
+// - a trailing comma is not required
+
+[RootSignature("CBV(b0, flags = DATA_VOLATILE,), 
DescriptorTable(Sampler(s0,),),")]
+void maximum_commas() {}
+
+[RootSignature("CBV(b0, flags = DATA_VOLATILE), DescriptorTable(Sampler(s0))")]
+void minimal_commas() {}

diff  --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp 
b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index ff1697f1bbb9a..e82dcadebba3f 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -1238,4 +1238,166 @@ TEST_F(ParseHLSLRootSignatureTest, 
InvalidNonZeroFlagsTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidRootElementMissingCommaTest) {
+  // This test will check that an error is produced when there is a missing
+  // comma between parameters
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags()
+    RootConstants(num32BitConstants = 1, b0)
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
+                                   Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidDescriptorTableMissingCommaTest) {
+  // This test will check that an error is produced when there is a missing
+  // comma between parameters
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      CBV(b0)
+      visibility = SHADER_VISIBILITY_ALL
+    )
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
+                                   Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRootConstantParamsCommaTest) {
+  // This test will check that an error is produced when there is a missing
+  // comma between parameters
+  const llvm::StringLiteral Source = R"cc(
+    RootConstants(
+      num32BitConstants = 1
+      b0
+    )
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
+                                   Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidRootDescriptorParamsCommaTest) {
+  // This test will check that an error is produced when there is a missing
+  // comma between parameters
+  const llvm::StringLiteral Source = R"cc(
+    CBV(
+      b0
+      flags = 0
+    )
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
+                                   Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidDescriptorClauseParamsCommaTest) {
+  // This test will check that an error is produced when there is a missing
+  // comma between parameters
+  const llvm::StringLiteral Source = R"cc(
+    DescriptorTable(
+      UAV(
+        u0
+        flags = 0
+      )
+    )
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
+                                   Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, InvalidStaticSamplerCommaTest) {
+  // This test will check that an error is produced when there is a missing
+  // comma between parameters
+  const llvm::StringLiteral Source = R"cc(
+    StaticSampler(
+      s0
+      maxLOD = 3
+    )
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  SmallVector<RootElement> Elements;
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
+                                   Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 } // anonymous namespace


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to