python3kgae created this revision.
python3kgae added reviewers: aaron.ballman, rnk, jdoerfert, MaskRay, rsmith.
Herald added a subscriber: StephenFan.
Herald added a project: All.
python3kgae requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Shader attribute is for shader library identify entry functions.
Here's an example,

[shader("pixel")]
float ps_main() : SV_Target {

  return 1;

}

When compile this shader to library target like -E lib_6_3, compiler needs to 
know ps_main is an entry function for pixel shader. Shader attribute is to 
offer the information.

A new attribute HLSLShader is added to support shader attribute. It has an 
EnumArgument which included all possible shader stages.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D123907

Files:
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/AttrDocs.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/Sema/SemaDecl.cpp
  clang/lib/Sema/SemaDeclAttr.cpp
  clang/test/SemaHLSL/shader_attr.hlsl

Index: clang/test/SemaHLSL/shader_attr.hlsl
===================================================================
--- /dev/null
+++ clang/test/SemaHLSL/shader_attr.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -ast-dump -o - %s -DFAIL -verify
+
+
+#ifdef FAIL
+
+// expected-warning@+1 {{'shader' attribute only applies to global functions}}
+[shader("compute")]
+struct Fido {
+  // expected-warning@+1 {{'shader' attribute only applies to global functions}}
+  [shader("pixel")]
+  void wag() {}
+
+  // expected-warning@+1 {{'shader' attribute only applies to global functions}}
+  [shader("vertex")]
+  static void oops() {}
+};
+
+// expected-warning@+1 {{'shader' attribute only applies to global functions}}
+  [shader("vertex")]
+static void oops() {}
+
+namespace spec {
+// expected-warning@+1 {{'shader' attribute only applies to global functions}}
+  [shader("vertex")]
+static void oops() {}
+}
+
+// expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}}
+[shader("compute")]
+// expected-note@+1 {{conflicting attribute is here}}
+[shader("vertex")]
+int doubledUp() {
+  return 1;
+}
+
+// expected-note@+1 {{conflicting attribute is here}}
+[shader("vertex")]
+int forwardDecl();
+
+// expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}}
+[shader("compute")]
+int forwardDecl() {
+  return 1;
+}
+
+
+// expected-error@+1 {{'shader' attribute takes one argument}}
+[shader()]
+// expected-error@+1 {{'shader' attribute takes one argument}}
+[shader(1,2)]
+// expected-error@+1 {{'shader' attribute requires a string}}
+[shader(1)]
+// expected-error@+1 {{'shader' attribute argument not supported: 'cs'}}
+[shader("cs")]
+
+#endif  // END of FAIL
+
+// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} <line:60:2, col:18> Compute
+[shader("compute")]
+int entry() {
+ return 1;
+}
+
+// Because these two attributes match, they should both appear in the AST
+[shader("compute")]
+// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} <line:66:2, col:18> Compute
+int secondFn();
+
+[shader("compute")]
+// CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} <line:70:2, col:18> Compute
+int secondFn() {
+  return 1;
+}
+
+
Index: clang/lib/Sema/SemaDeclAttr.cpp
===================================================================
--- clang/lib/Sema/SemaDeclAttr.cpp
+++ clang/lib/Sema/SemaDeclAttr.cpp
@@ -6940,6 +6940,44 @@
   D->addAttr(::new (S.Context) HLSLSV_GroupIndexAttr(S.Context, AL));
 }
 
+static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
+  if (AL.getNumArgs() != 1) {
+    S.Diag(AL.getLoc(), diag::err_attribute_wrong_number_arguments) << AL << 1;
+    return;
+  }
+
+  StringRef Str;
+  SourceLocation ArgLoc;
+  if (!S.checkStringLiteralArgumentAttr(AL, 0, Str, &ArgLoc))
+    return;
+
+  HLSLShaderAttr::ShaderStage Stage;
+  if (!HLSLShaderAttr::ConvertStrToShaderStage(Str, Stage)) {
+    S.Diag(AL.getLoc(), diag::err_hlsl_invalid_attribute_argument)
+        << AL << "'" + std::string(Str) + "'";
+    return;
+  }
+
+  // TODO: check function match the shader stage.
+
+  HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, Stage);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
+HLSLShaderAttr *Sema::mergeHLSLShaderAttr(Decl *D,
+                                          const AttributeCommonInfo &AL,
+                                          HLSLShaderAttr::ShaderStage Stage) {
+  if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
+    if (NT->getStage() != Stage) {
+      Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+  return HLSLShaderAttr::Create(Context, Stage, AL);
+}
+
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.LangOpts.CPlusPlus) {
     S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
@@ -8815,6 +8853,9 @@
   case ParsedAttr::AT_HLSLSV_GroupIndex:
     handleHLSLSVGroupIndexAttr(S, D, AL);
     break;
+  case ParsedAttr::AT_HLSLShader:
+    handleHLSLShaderAttr(S, D, AL);
+    break;
 
   case ParsedAttr::AT_AbiTag:
     handleAbiTagAttr(S, D, AL);
Index: clang/lib/Sema/SemaDecl.cpp
===================================================================
--- clang/lib/Sema/SemaDecl.cpp
+++ clang/lib/Sema/SemaDecl.cpp
@@ -2806,6 +2806,8 @@
   else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
     NewAttr =
         S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
+  else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
+    NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getStage());
   else if (Attr->shouldInheritEvenIfAlreadyPresent() || !DeclHasAttr(D, Attr))
     NewAttr = cast<InheritableAttr>(Attr->clone(S.Context));
 
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -3489,6 +3489,8 @@
   HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
                                               const AttributeCommonInfo &AL,
                                               int X, int Y, int Z);
+  HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
+                                      HLSLShaderAttr::ShaderStage Stage);
 
   void mergeDeclAttributes(NamedDecl *New, Decl *Old,
                            AvailabilityMergeKind AMK = AMK_Redeclaration);
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11595,6 +11595,8 @@
 def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">;
 def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">;
 def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">;
+def err_hlsl_invalid_attribute_argument : Error<
+  "%0 attribute argument not supported: %1">;
 
 def err_hlsl_pointers_unsupported : Error<
   "%select{pointers|references}0 are unsupported in HLSL">;
Index: clang/include/clang/Basic/AttrDocs.td
===================================================================
--- clang/include/clang/Basic/AttrDocs.td
+++ clang/include/clang/Basic/AttrDocs.td
@@ -6380,6 +6380,14 @@
   }];
 }
 
+def HLSLSV_ShaderAttrDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``shader`` attribute applies to HLSL shader entry functions to identify the
+ shader stage for the entry function.
+  }];
+}
+
 def ClangRandomizeLayoutDocs : Documentation {
   let Category = DocCatDecl;
   let Heading = "randomize_layout, no_randomize_layout";
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -3968,6 +3968,23 @@
   let Documentation = [HLSLSV_GroupIndexDocs];
 }
 
+def HLSLShader : InheritableAttr {
+  let Spellings = [Microsoft<"shader">];
+  let Subjects = SubjectList<[HLSLEntry]>;
+  let LangOpts = [HLSL];
+  let Args = [EnumArgument<"Stage", "ShaderStage",
+                           ["pixel", "vertex", "geometry", "hull", "domain",
+                            "compute", "raygeneration", "intersection",
+                            "anyhit", "closestHit", "miss", "callable", "mesh",
+                            "amplification"],
+                           ["Pixel", "Vertex", "Geometry", "Hull", "Domain",
+                            "Compute", "RayGeneration", "Intersection",
+                            "AnyHit", "ClosestHit", "Miss", "Callable", "Mesh",
+                            "Amplification"]
+                           >];
+  let Documentation = [HLSLSV_ShaderAttrDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to