llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

This pr implements the following validations:

1. Check that descriptor tables don't mix Sample and non-Sampler resources
2. Ensure that descriptor ranges don't append onto an unbounded range
3. Ensure that descriptor ranges don't overflow
4. Adds a missing validation to ensure that only a single `RootFlags` parameter 
is provided

Resolves: https://github.com/llvm/llvm-project/issues/153868.

---
Full diff: https://github.com/llvm/llvm-project/pull/156754.diff


9 Files Affected:

- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3) 
- (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+10) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+37-2) 
- (modified) clang/test/SemaHLSL/RootSignature-err.hlsl (+6-2) 
- (modified) clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl (+25) 
- (modified) clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl (+3) 
- (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+82-11) 
- (modified) llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h (+4) 
- (modified) llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp (+15) 


``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td 
b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index c934fed2c7462..8bb47e3a4d46d 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13184,6 +13184,9 @@ def err_hlsl_vk_literal_must_contain_constant: 
Error<"the argument to vk::Litera
 
 def err_hlsl_invalid_rootsig_value : Error<"value must be in the range [%0, 
%1]">;
 def err_hlsl_invalid_rootsig_flag : Error< "invalid flags for version 1.%0">;
+def err_hlsl_invalid_mixed_resources: Error< "sampler and non-sampler resource 
mixed in descriptor table">;
+def err_hlsl_appending_onto_unbound: Error<"offset appends to unbounded 
descriptor range">;
+def err_hlsl_offset_overflow: Error<"descriptor range offset overflows [%0, 
%1]">;
 
 def subst_hlsl_format_ranges: TextSubstitution<
   "%select{t|u|b|s}0[%1;%select{%3]|unbounded)}2">;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp 
b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 1af72f8b1c934..7dd0c3e90886b 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -37,8 +37,18 @@ bool RootSignatureParser::parse() {
   // Iterate as many RootSignatureElements as possible, until we hit the
   // end of the stream
   bool HadError = false;
+  bool HasRootFlags = false;
   while (!peekExpectedToken(TokenKind::end_of_stream)) {
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      if (HasRootFlags) {
+        reportDiag(diag::err_hlsl_rootsig_repeat_param)
+            << TokenKind::kw_RootFlags;
+        HadError = true;
+        skipUntilExpectedToken(RootElementKeywords);
+        continue;
+      }
+      HasRootFlags = true;
+
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Flags = parseRootFlags();
       if (!Flags.has_value()) {
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1e5ec952c1ecf..4cf08eac6d171 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1331,12 +1331,47 @@ bool SemaHLSL::handleRootSignatureElements(
                    std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
       assert(UnboundClauses.size() == Table->NumClauses &&
              "Number of unbound elements must match the number of clauses");
+      bool HasSampler = false;
+      bool HasNonSampler = false;
+      uint32_t Offset = 0;
       for (const auto &[Clause, ClauseElem] : UnboundClauses) {
-        uint32_t LowerBound(Clause->Reg.Number);
+        SourceLocation Loc = RootSigElem.getLocation();
+        if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
+          HasSampler = true;
+        else
+          HasNonSampler = true;
+
+        if (HasSampler && HasNonSampler)
+          Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
+
         // Relevant error will have already been reported above and needs to be
-        // fixed before we can conduct range analysis, so shortcut error return
+        // fixed before we can conduct further analysis, so shortcut error
+        // return
         if (Clause->NumDescriptors == 0)
           return true;
+
+        if (Clause->Offset !=
+            llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
+          // Manually specified the offset
+          Offset = Clause->Offset;
+        }
+
+        uint64_t NextOffset =
+            llvm::hlsl::rootsig::nextOffset(Offset, Clause->NumDescriptors);
+
+        if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
+          // Trying to append onto unbound offset
+          Diag(Loc, diag::err_hlsl_appending_onto_unbound);
+        } else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(NextOffset -
+                                                                  1)) {
+          // Upper bound overflows maximum offset
+          Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << NextOffset - 
1;
+        }
+
+        Offset = uint32_t(NextOffset);
+
+        // Compute the register bounds and track resource binding
+        uint32_t LowerBound(Clause->Reg.Number);
         uint32_t UpperBound = Clause->NumDescriptors == ~0u
                                   ? ~0u
                                   : LowerBound + Clause->NumDescriptors - 1;
diff --git a/clang/test/SemaHLSL/RootSignature-err.hlsl 
b/clang/test/SemaHLSL/RootSignature-err.hlsl
index ccfa093baeb87..89c684cd8d11f 100644
--- a/clang/test/SemaHLSL/RootSignature-err.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-err.hlsl
@@ -179,7 +179,7 @@ void basic_validation_3() {}
 
 // expected-error@+2 {{value must be in the range [1, 4294967294]}}
 // expected-error@+1 {{value must be in the range [1, 4294967294]}}
-[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0), Sampler(s0, 
numDescriptors = 0))")]
+[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0)), 
DescriptorTable(Sampler(s0, numDescriptors = 0))")]
 void basic_validation_4() {}
 
 // expected-error@+2 {{value must be in the range [0, 16]}}
@@ -189,4 +189,8 @@ void basic_validation_5() {}
 
 // expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
 [RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
-void basic_validation_6() {}
\ No newline at end of file
+void basic_validation_6() {}
+
+// expected-error@+1 {{sampler and non-sampler resource mixed in descriptor 
table}}
+[RootSignature("DescriptorTable(Sampler(s0), CBV(b0))")]
+void mixed_resource_table() {}
diff --git a/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl 
b/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
index fd098b01cc723..2d025d0e6e5ce 100644
--- a/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
@@ -117,3 +117,28 @@ void bad_root_signature_14() {}
 // expected-note@+1 {{overlapping resource range here}}
 [RootSignature(DuplicatesRootSignature)]
 void valid_root_signature_15() {}
+
+#define AppendingToUnbound \
+  "DescriptorTable(CBV(b1, numDescriptors = unbounded), CBV(b0))"
+
+// expected-error@+1 {{offset appends to unbounded descriptor range}}
+[RootSignature(AppendingToUnbound)]
+void append_to_unbound_signature() {}
+
+#define DirectOffsetOverflow \
+  "DescriptorTable(CBV(b0, offset = 4294967294 , numDescriptors = 6))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967294, 
4294967299]}}
+[RootSignature(DirectOffsetOverflow)]
+void direct_offset_overflow_signature() {}
+
+#define AppendOffsetOverflow \
+  "DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 7))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967293, 
4294967299]}}
+[RootSignature(AppendOffsetOverflow)]
+void append_offset_overflow_signature() {}
+
+// expected-error@+1 {{descriptor range offset overflows [4294967292, 
4294967296]}}
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292, numDescriptors = 
5))")]
+void offset_() {}
diff --git a/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl 
b/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
index 09a1110b0fbc1..10e7215eccf6e 100644
--- a/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
@@ -22,3 +22,6 @@ void valid_root_signature_5() {}
 
 [RootSignature("DescriptorTable(SRV(t5), UAV(u5, numDescriptors=2))")]
 void valid_root_signature_6() {}
+
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, 
numDescriptors = 3))")]
+void valid_root_signature_7() {}
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp 
b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 44c0978a243bc..9b9f5dd8a63bb 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -501,8 +501,6 @@ TEST_F(ParseHLSLRootSignatureTest, 
ValidParseRootConsantsTest) {
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
   using llvm::dxbc::RootFlags;
   const llvm::StringLiteral Source = R"cc(
-    RootFlags(),
-    RootFlags(0),
     RootFlags(
       deny_domain_shader_root_access |
       deny_pixel_shader_root_access |
@@ -533,18 +531,10 @@ TEST_F(ParseHLSLRootSignatureTest, 
ValidParseRootFlagsTest) {
   ASSERT_FALSE(Parser.parse());
 
   auto Elements = Parser.getElements();
-  ASSERT_EQ(Elements.size(), 3u);
+  ASSERT_EQ(Elements.size(), 1u);
 
   RootElement Elem = Elements[0].getElement();
   ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
-  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
-  Elem = Elements[1].getElement();
-  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
-  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
-  Elem = Elements[2].getElement();
-  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
   auto ValidRootFlags = RootFlags::AllowInputAssemblerInputLayout |
                         RootFlags::DenyVertexShaderRootAccess |
                         RootFlags::DenyHullShaderRootAccess |
@@ -562,6 +552,64 @@ TEST_F(ParseHLSLRootSignatureTest, 
ValidParseRootFlagsTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyRootFlagsTest) {
+  using llvm::dxbc::RootFlags;
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(),
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  auto Elements = Parser.getElements();
+  ASSERT_EQ(Elements.size(), 1u);
+
+  RootElement Elem = Elements[0].getElement();
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseZeroRootFlagsTest) {
+  using llvm::dxbc::RootFlags;
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(0),
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  auto Elements = Parser.getElements();
+  ASSERT_EQ(Elements.size(), 1u);
+
+  RootElement Elem = Elements[0].getElement();
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
   using llvm::dxbc::RootDescriptorFlags;
   const llvm::StringLiteral Source = R"cc(
@@ -1658,4 +1706,27 @@ TEST_F(ParseHLSLRootSignatureTest, 
InvalidDescriptorRangeFlagsValueTest) {
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidMultipleRootFlagsTest) {
+  // This test will check that an error is produced when there are multiple
+  // root flags provided
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(DENY_VERTEX_SHADER_ROOT_ACCESS),
+    RootFlags(DENY_PIXEL_SHADER_ROOT_ACCESS)
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 } // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h 
b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
index fde32a1fff591..5ffd31ecb2650 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
@@ -41,6 +41,10 @@ LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc);
 LLVM_ABI bool verifyBorderColor(uint32_t BorderColor);
 LLVM_ABI bool verifyLOD(float LOD);
 
+LLVM_ABI bool verifyBoundOffset(uint32_t Offset);
+LLVM_ABI bool verifyNoOverflowedOffset(uint64_t Offset);
+LLVM_ABI uint64_t nextOffset(uint32_t Offset, uint32_t Size);
+
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp 
b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 72308a3de5fd4..f19354ceb6072 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -180,6 +180,21 @@ bool verifyBorderColor(uint32_t BorderColor) {
 
 bool verifyLOD(float LOD) { return !std::isnan(LOD); }
 
+bool verifyBoundOffset(uint32_t Offset) {
+  return Offset != NumDescriptorsUnbounded;
+}
+
+bool verifyNoOverflowedOffset(uint64_t Offset) {
+  return Offset <= std::numeric_limits<uint32_t>::max();
+}
+
+uint64_t nextOffset(uint32_t Offset, uint32_t Size) {
+  if (Size == NumDescriptorsUnbounded)
+    return NumDescriptorsUnbounded;
+
+  return uint64_t(Offset) + uint64_t(Size);
+}
+
 } // namespace rootsig
 } // namespace hlsl
 } // namespace llvm

``````````

</details>


https://github.com/llvm/llvm-project/pull/156754
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to