llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: None (joaosaffran)

<details>
<summary>Changes</summary>

This patch enhances error handling and validation in the DirectX backend's root 
signature parsing. The changes include:

1. **Improved Error Reporting**:
   - Introduced `reportInvalidTypeError` utility to provide detailed error 
messages for type mismatches.
   - Enhanced diagnostic messages for invalid metadata nodes and values.

2. **Validation Updates**:
   - Added stricter validation for descriptor tables and static samplers.
   - Improved handling of invalid values for filter modes, address modes, and 
LOD parameters.

Example changes:
```cpp
if (Element == nullptr)
  return reportInvalidTypeError&lt;MDNode&gt;(Ctx, "DescriptorTableNode", 
DescriptorTableNode, I);

if (std::optional&lt;uint32_t&gt; Val = extractMdIntValue(StaticSamplerNode, 1))
  Sampler.Filter = *Val;
else
  return reportInvalidTypeError&lt;ConstantInt&gt;(Ctx, "StaticSamplerNode", 
StaticSamplerNode, 1);
  ```
Testing:
- Validation of invalid metadata nodes and values.
- Proper diagnostic messages for type mismatches.
- All existing DirectX backend tests continue to pass.


---
Full diff: https://github.com/llvm/llvm-project/pull/144577.diff


4 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+125-31) 
- (modified) 
llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
 (+1-1) 
- (modified) 
llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
 (+1-1) 
- (modified) 
llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
 (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp 
b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 3aef7d3eb1e69..57d5ee8ac467c 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -48,6 +48,71 @@ static bool reportValueError(LLVMContext *Ctx, Twine 
ParamName,
   return true;
 }
 
+// Template function to get formatted type string based on C++ type
+template <typename T> std::string getTypeFormatted() {
+  if constexpr (std::is_same_v<T, MDString>) {
+    return "string";
+  } else if constexpr (std::is_same_v<T, MDNode *> ||
+                       std::is_same_v<T, const MDNode *>) {
+    return "metadata";
+  } else if constexpr (std::is_same_v<T, ConstantAsMetadata *> ||
+                       std::is_same_v<T, const ConstantAsMetadata *>) {
+    return "constant";
+  } else if constexpr (std::is_same_v<T, ConstantAsMetadata>) {
+    return "constant";
+  } else if constexpr (std::is_same_v<T, ConstantInt *> ||
+                       std::is_same_v<T, const ConstantInt *>) {
+    return "constant int";
+  } else if constexpr (std::is_same_v<T, ConstantInt>) {
+    return "constant int";
+  }
+  return "unknown";
+}
+
+// Helper function to get the actual type of a metadata operand
+std::string getActualMDType(const MDNode *Node, unsigned Index) {
+  if (!Node || Index >= Node->getNumOperands())
+    return "null";
+
+  Metadata *Op = Node->getOperand(Index);
+  if (!Op)
+    return "null";
+
+  if (isa<MDString>(Op))
+    return getTypeFormatted<MDString>();
+
+  if (isa<ConstantAsMetadata>(Op)) {
+    if (auto *CAM = dyn_cast<ConstantAsMetadata>(Op)) {
+      Type *T = CAM->getValue()->getType();
+      if (T->isIntegerTy())
+        return (Twine("i") + Twine(T->getIntegerBitWidth())).str();
+      if (T->isFloatingPointTy())
+        return T->isFloatTy()    ? getTypeFormatted<float>()
+               : T->isDoubleTy() ? getTypeFormatted<double>()
+                                 : "fp";
+
+      return getTypeFormatted<ConstantAsMetadata>();
+    }
+  }
+  if (isa<MDNode>(Op))
+    return getTypeFormatted<MDNode *>();
+
+  return "unknown";
+}
+
+// Helper function to simplify error reporting for invalid metadata values
+template <typename ET>
+auto reportInvalidTypeError(LLVMContext *Ctx, Twine ParamName,
+                            const MDNode *Node, unsigned Index) {
+  std::string ExpectedType = getTypeFormatted<ET>();
+  std::string ActualType = getActualMDType(Node, Index);
+
+  return reportError(Ctx, "Root Signature Node: " + ParamName +
+                              " expected metadata node of type " +
+                              ExpectedType + " at index " + Twine(Index) +
+                              " but got " + ActualType);
+}
+
 static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
                                                  unsigned int OpId) {
   if (auto *CI =
@@ -80,7 +145,8 @@ static bool parseRootFlags(LLVMContext *Ctx, 
mcdxbc::RootSignatureDesc &RSD,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
     RSD.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for RootFlag");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
+                                               RootFlagNode, 1);
 
   return false;
 }
@@ -100,23 +166,27 @@ static bool parseRootConstants(LLVMContext *Ctx, 
mcdxbc::RootSignatureDesc &RSD,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                               RootConstantNode, 1);
 
   dxbc::RTS0::v1::RootConstants Constants;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
     Constants.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                               RootConstantNode, 2);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
     Constants.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                               RootConstantNode, 3);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
     Constants.Num32BitValues = *Val;
   else
-    return reportError(Ctx, "Invalid value for Num32BitValues");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+                                               RootConstantNode, 4);
 
   RSD.ParametersContainer.addParameter(Header, Constants);
 
@@ -154,18 +224,21 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                               RootDescriptorNode, 1);
 
   dxbc::RTS0::v2::RootDescriptor Descriptor;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
     Descriptor.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                               RootDescriptorNode, 2);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
     Descriptor.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                               RootDescriptorNode, 3);
 
   if (RSD.Version == 1) {
     RSD.ParametersContainer.addParameter(Header, Descriptor);
@@ -176,7 +249,8 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
     Descriptor.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+                                               RootDescriptorNode, 4);
 
   RSD.ParametersContainer.addParameter(Header, Descriptor);
   return false;
@@ -196,7 +270,8 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
       extractMdStringValue(RangeDescriptorNode, 0);
 
   if (!ElementText.has_value())
-    return reportError(Ctx, "Descriptor Range, first element is not a 
string.");
+    return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+                                            RangeDescriptorNode, 0);
 
   Range.RangeType =
       StringSwitch<uint32_t>(*ElementText)
@@ -213,28 +288,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
     Range.NumDescriptors = *Val;
   else
-    return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                               RangeDescriptorNode, 1);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
     Range.BaseShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for BaseShaderRegister");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                               RangeDescriptorNode, 2);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
     Range.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                               RangeDescriptorNode, 3);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
     Range.OffsetInDescriptorsFromTableStart = *Val;
   else
-    return reportError(Ctx,
-                       "Invalid value for OffsetInDescriptorsFromTableStart");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                               RangeDescriptorNode, 4);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
     Range.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+                                               RangeDescriptorNode, 5);
 
   Table.Ranges.push_back(Range);
   return false;
@@ -251,7 +330,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
+                                               DescriptorTableNode, 1);
 
   mcdxbc::DescriptorTable Table;
   Header.ParameterType =
@@ -260,7 +340,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
   for (unsigned int I = 2; I < NumOperands; I++) {
     MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
     if (Element == nullptr)
-      return reportError(Ctx, "Missing Root Element Metadata Node.");
+      return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
+                                            DescriptorTableNode, I);
 
     if (parseDescriptorRange(Ctx, RSD, Table, Element))
       return true;
@@ -279,67 +360,80 @@ static bool parseStaticSampler(LLVMContext *Ctx, 
mcdxbc::RootSignatureDesc &RSD,
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
     Sampler.Filter = *Val;
   else
-    return reportError(Ctx, "Invalid value for Filter");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 1);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
     Sampler.AddressU = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressU");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 2);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
     Sampler.AddressV = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressV");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 3);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
     Sampler.AddressW = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressW");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 4);
 
   if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
     Sampler.MipLODBias = Val->convertToFloat();
   else
-    return reportError(Ctx, "Invalid value for MipLODBias");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 5);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
     Sampler.MaxAnisotropy = *Val;
   else
-    return reportError(Ctx, "Invalid value for MaxAnisotropy");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 6);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
     Sampler.ComparisonFunc = *Val;
   else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 7);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
     Sampler.BorderColor = *Val;
   else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 8);
 
   if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
     Sampler.MinLOD = Val->convertToFloat();
   else
-    return reportError(Ctx, "Invalid value for MinLOD");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 9);
 
   if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
     Sampler.MaxLOD = Val->convertToFloat();
   else
-    return reportError(Ctx, "Invalid value for MaxLOD");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 10);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
     Sampler.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 11);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
     Sampler.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 12);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
     Sampler.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+                                               StaticSamplerNode, 13);
 
   RSD.StaticSamplers.push_back(Sampler);
   return false;
diff --git 
a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
 
b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
index 552c128e5ab57..0d5bbdfc097c4 100644
--- 
a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
+++ 
b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-Num32BitValues.ll
@@ -2,7 +2,7 @@
 
 target triple = "dxil-unknown-shadermodel6.0-compute"
 
-; CHECK: error: Invalid value for Num32BitValues
+; CHECK: error: Root Signature Node: RootConstantNode expected metadata node 
of type constant int at index 4 but got string 
 ; CHECK-NOT: Root Signature Definitions
 
 define void @main() {
diff --git 
a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
 
b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
index 1087b414942e2..1384da4baca98 100644
--- 
a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
+++ 
b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-RegisterSpace.ll
@@ -2,7 +2,7 @@
 
 target triple = "dxil-unknown-shadermodel6.0-compute"
 
-; CHECK: error: Invalid value for RegisterSpace
+; CHECK: error: Root Signature Node: RootConstantNode expected metadata node 
of type constant int at index 3 but got string 
 ; CHECK-NOT: Root Signature Definitions
 
 define void @main() #0 {
diff --git 
a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
 
b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
index 53fd924e8f46e..e1fd6a4414609 100644
--- 
a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
+++ 
b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootConstants-Invalid-ShaderRegister.ll
@@ -2,7 +2,7 @@
 
 target triple = "dxil-unknown-shadermodel6.0-compute"
 
-; CHECK: error: Invalid value for ShaderRegister
+; CHECK: error: Root Signature Node: RootConstantNode expected metadata node 
of type constant int at index 2 but got string 
 ; CHECK-NOT: Root Signature Definitions
 
 define void @main() #0 {

``````````

</details>


https://github.com/llvm/llvm-project/pull/144577
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to