bogner created this revision.
bogner added reviewers: beanz, python3kgae, bob80905, tex3d.
Herald added subscribers: Anastasia, arphaman, mcrosier.
Herald added a reviewer: aaron.ballman.
Herald added a project: All.
bogner requested review of this revision.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

This moves the sema checking of the entrypoint sensitive HLSL
attributes all into one place. This ended up being kind of large for a
couple of reasons:

- I had to move the call to CheckHLSLEntryPoint later in 
ActOnFunctionDeclarator so that we do this after redeclarations and have access 
to all of the attributes.

- We need to transfer the target shader stage onto the specified entry point 
before doing the checking.

- I removed "library" from the HLSLShader attribute value enum and just go 
through a string to convert from the triple - the other way was confusing and 
brittle.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D158803

Files:
  clang/include/clang/Basic/Attr.td
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/Sema/SemaDecl.cpp
  clang/lib/Sema/SemaDeclAttr.cpp
  clang/test/CodeGenHLSL/GlobalDestructors.hlsl
  clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
  clang/test/SemaHLSL/Semantics/groupindex.hlsl
  clang/test/SemaHLSL/entry.hlsl
  clang/test/SemaHLSL/entry_shader.hlsl
  clang/test/SemaHLSL/entry_shader_redecl.hlsl
  clang/test/SemaHLSL/num_threads.hlsl
  clang/test/SemaHLSL/shader_type_attr.hlsl

Index: clang/test/SemaHLSL/shader_type_attr.hlsl
===================================================================
--- clang/test/SemaHLSL/shader_type_attr.hlsl
+++ clang/test/SemaHLSL/shader_type_attr.hlsl
@@ -28,7 +28,7 @@
 } // namespace spec
 
 // expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}}
-[shader("compute")]
+[shader("pixel")]
 // expected-note@+1 {{conflicting attribute is here}}
 [shader("vertex")]
 int doubledUp() {
@@ -40,7 +40,7 @@
 int forwardDecl();
 
 // expected-error@+1 {{'shader' attribute parameters do not match the previous declaration}}
-[shader("compute")]
+[shader("compute"), numthreads(8,1,1)]
 int forwardDecl() {
   return 1;
 }
@@ -58,17 +58,17 @@
 #endif // END of FAIL
 
 // CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} <line:61:2, col:18> Compute
-[shader("compute")]
+[shader("compute"), numthreads(8,1,1)]
 int entry() {
   return 1;
 }
 
 // Because these two attributes match, they should both appear in the AST
-[shader("compute")]
+[shader("compute"), numthreads(8,1,1)]
 // CHECK:HLSLShaderAttr 0x{{[0-9a-fA-F]+}} <line:67:2, col:18> Compute
 int secondFn();
 
-[shader("compute")]
+[shader("compute"), numthreads(8,1,1)]
 // CHECK:HLSLShaderAttr 0x{{[0-9a-fAl-F]+}} <line:71:2, col:18> Compute
 int secondFn() {
   return 1;
Index: clang/test/SemaHLSL/num_threads.hlsl
===================================================================
--- clang/test/SemaHLSL/num_threads.hlsl
+++ clang/test/SemaHLSL/num_threads.hlsl
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - %s | FileCheck %s 
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl -ast-dump -o - %s | FileCheck %s 
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o - %s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-amplification -x hlsl -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -o - %s | FileCheck %s
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-pixel -x hlsl -ast-dump -o - %s -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-vertex -x hlsl -ast-dump -o - %s -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-hull -x hlsl -ast-dump -o - %s -verify
@@ -97,14 +97,20 @@
   return 1;
 }
 
+[numthreads(4,2,1)]
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:100:2, col:18> 4 2 1
+int onlyOnForwardDecl();
+
+// CHECK: HLSLNumThreadsAttr 0x{{[0-9a-fA-F]+}} <line:100:2, col:18> Inherited 4 2 1
+int onlyOnForwardDecl() {
+  return 1;
+}
 
 #else // Vertex and Pixel only beyond here
-// expected-error-re@+1 {{attribute 'numthreads' is unsupported in {{[A-Za-z]+}} shaders, requires Compute, Amplification, Mesh or Library}}
+// expected-error-re@+1 {{attribute 'numthreads' is unsupported in {{[A-Za-z]+}} shaders, requires compute, amplification, or mesh}}
 [numthreads(1,1,1)]
 int main() {
  return 1;
 }
 
 #endif
-
-
Index: clang/test/SemaHLSL/entry_shader_redecl.hlsl
===================================================================
--- /dev/null
+++ clang/test/SemaHLSL/entry_shader_redecl.hlsl
@@ -0,0 +1,75 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry cs1 -o - %s -ast-dump -verify | FileCheck -DSHADERFN=cs1 -check-prefix=CHECK-ENV %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry cs2 -o - %s -ast-dump -verify | FileCheck -DSHADERFN=cs2 -check-prefix=CHECK-ENV %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry cs3 -o - %s -ast-dump -verify | FileCheck -DSHADERFN=cs3 -check-prefix=CHECK-ENV %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -o - %s -ast-dump -verify | FileCheck -check-prefix=CHECK-LIB %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3 -x hlsl -o - %s -ast-dump -verify | FileCheck -check-prefix=CHECK-LIB %s
+
+// expected-no-diagnostics
+
+// CHECK-ENV: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} [[SHADERFN]] 'void ()'
+// CHECK-ENV: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} [[SHADERFN]] 'void ()'
+// CHECK-ENV-NEXT: CompoundStmt 0x
+// CHECK-ENV-NEXT: HLSLNumThreadsAttr 0x
+// CHECK-ENV-NEXT: HLSLShaderAttr 0x{{.*}} Implicit Compute
+void cs1();
+[numthreads(1,1,1)] void cs1() {}
+[numthreads(1,1,1)] void cs2();
+void cs2() {}
+[numthreads(1,1,1)] void cs3();
+[numthreads(1,1,1)] void cs3() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s1 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s1 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+void s1();
+[shader("compute"), numthreads(1,1,1)] void s1() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s2 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s2 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute")] void s2();
+[shader("compute"), numthreads(1,1,1)] void s2() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s3 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s3 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[numthreads(1,1,1)] void s3();
+[shader("compute"), numthreads(1,1,1)] void s3() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s4 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s4 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute"), numthreads(1,1,1)] void s4();
+[shader("compute")][numthreads(1,1,1)] void s4() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s5 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s5 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Inherited Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute"), numthreads(1,1,1)] void s5();
+void s5() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s6 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s6 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Compute
+[shader("compute"), numthreads(1,1,1)] void s6();
+[shader("compute")] void s6() {}
+
+// CHECK-LIB: FunctionDecl [[PROTO:0x[0-9a-f]+]] {{.*}} s7 'void ()'
+// CHECK-LIB: FunctionDecl 0x{{.*}} prev [[PROTO]] {{.*}} s7 'void ()'
+// CHECK-LIB-NEXT: CompoundStmt 0x
+// CHECK-LIB-NEXT: HLSLShaderAttr 0x{{.*}} Inherited Compute
+// CHECK-LIB-NEXT: HLSLNumThreadsAttr 0x
+[shader("compute"), numthreads(1,1,1)] void s7();
+[numthreads(1,1,1)] void s7() {}
Index: clang/test/SemaHLSL/entry_shader.hlsl
===================================================================
--- clang/test/SemaHLSL/entry_shader.hlsl
+++ clang/test/SemaHLSL/entry_shader.hlsl
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo  -o - %s -DSHADER='"anyHit"' -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo  -o - %s -DSHADER='"mesh"' -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -x hlsl -hlsl-entry foo  -o - %s -DSHADER='"compute"'
 
-// expected-error@+1 {{'shader' attribute on entry function does not match the pipeline stage}}
+// expected-error@+1 {{'shader' attribute on entry function does not match the target profile}}
 [numthreads(1,1,1), shader(SHADER)]
 void foo() {
 
Index: clang/test/SemaHLSL/entry.hlsl
===================================================================
--- clang/test/SemaHLSL/entry.hlsl
+++ clang/test/SemaHLSL/entry.hlsl
@@ -4,7 +4,7 @@
 
 // Make sure add HLSLShaderAttr along with HLSLNumThreadsAttr.
 // CHECK:HLSLNumThreadsAttr 0x{{.*}} <line:10:2, col:18> 1 1 1
-// CHECK:HLSLShaderAttr 0x{{.*}} <line:13:1> Compute
+// CHECK:HLSLShaderAttr 0x{{.*}} <line:13:1> Implicit Compute
 
 #ifdef WITH_NUM_THREADS
 [numthreads(1,1,1)]
Index: clang/test/SemaHLSL/Semantics/groupindex.hlsl
===================================================================
--- /dev/null
+++ clang/test/SemaHLSL/Semantics/groupindex.hlsl
@@ -0,0 +1,57 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0 -x hlsl -o - %s -verify
+
+// expected-no-error
+[shader("compute"), numthreads(32,1,1)]
+void compute(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in pixel shaders}}
+[shader("pixel")]
+void pixel(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in vertex shaders}}
+[shader("vertex")]
+void vertex(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in geometry shaders}}
+[shader("geometry")]
+void geometry(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in hull shaders}}
+[shader("hull")]
+void hull(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in domain shaders}}
+[shader("domain")]
+void domain(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in raygeneration shaders}}
+[shader("raygeneration")]
+void raygeneration(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in intersection shaders}}
+[shader("intersection")]
+void intersection(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in anyhit shaders}}
+[shader("anyhit")]
+void anyhit(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in closesthit shaders}}
+[shader("closesthit")]
+void closesthit(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in miss shaders}}
+[shader("miss")]
+void miss(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in callable shaders}}
+[shader("callable")]
+void callable(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in amplification shaders}}
+[shader("amplification"), numthreads(32,1,1)]
+void amplification(int GI : SV_GroupIndex) {}
+
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in mesh shaders}}
+[shader("mesh"), numthreads(32,1,1)]
+void mesh(int GI : SV_GroupIndex) {}
Index: clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
===================================================================
--- clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
+++ clang/test/SemaHLSL/Semantics/entry_parameter.hlsl
@@ -1,9 +1,9 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl  -finclude-default-header  -ast-dump -o - %s | FileCheck %s
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -x hlsl -ast-dump  -finclude-default-header  -verify -o - %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -hlsl-entry CSMain -x hlsl  -finclude-default-header  -ast-dump -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header  -verify -o - %s
 
 [numthreads(8,8,1)]
-// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in Mesh shaders, requires Compute}}
-// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in Mesh shaders, requires Compute}}
+// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in mesh shaders, requires compute}}
+// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in mesh shaders, requires compute}}
 void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) {
 // CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)'
 // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
Index: clang/test/CodeGenHLSL/GlobalDestructors.hlsl
===================================================================
--- clang/test/CodeGenHLSL/GlobalDestructors.hlsl
+++ clang/test/CodeGenHLSL/GlobalDestructors.hlsl
@@ -41,6 +41,7 @@
 int Pupper::Count = 0;
 
 [numthreads(1,1,1)]
+[shader("compute")]
 void main(unsigned GI : SV_GroupIndex) {
   Wag();
 }
Index: clang/lib/Sema/SemaDeclAttr.cpp
===================================================================
--- clang/lib/Sema/SemaDeclAttr.cpp
+++ clang/lib/Sema/SemaDeclAttr.cpp
@@ -7063,20 +7063,8 @@
 }
 
 static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  using llvm::Triple;
-  Triple Target = S.Context.getTargetInfo().getTriple();
-  auto Env = S.Context.getTargetInfo().getTriple().getEnvironment();
-  if (!llvm::is_contained({Triple::Compute, Triple::Mesh, Triple::Amplification,
-                           Triple::Library},
-                          Env)) {
-    uint32_t Pipeline =
-        static_cast<uint32_t>(hlsl::getStageFromEnvironment(Env));
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-        << AL << Pipeline << "Compute, Amplification, Mesh or Library";
-    return;
-  }
-
-  llvm::VersionTuple SMVersion = Target.getOSVersion();
+  llvm::VersionTuple SMVersion =
+      S.Context.getTargetInfo().getTriple().getOSVersion();
   uint32_t ZMax = 1024;
   uint32_t ThreadMax = 1024;
   if (SMVersion.getMajor() <= 4) {
@@ -7135,21 +7123,6 @@
   return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
 }
 
-static void handleHLSLSVGroupIndexAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  using llvm::Triple;
-  auto Env = S.Context.getTargetInfo().getTriple().getEnvironment();
-  if (Env != Triple::Compute && Env != Triple::Library) {
-    // FIXME: it is OK for a compute shader entry and pixel shader entry live in
-    // same HLSL file. Issue https://github.com/llvm/llvm-project/issues/57880.
-    ShaderStage Pipeline = hlsl::getStageFromEnvironment(Env);
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-        << AL << (uint32_t)Pipeline << "Compute";
-    return;
-  }
-
-  D->addAttr(::new (S.Context) HLSLSV_GroupIndexAttr(S.Context, AL));
-}
-
 static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
   if (!T->hasUnsignedIntegerRepresentation())
     return false;
@@ -7160,23 +7133,6 @@
 
 static void handleHLSLSV_DispatchThreadIDAttr(Sema &S, Decl *D,
                                               const ParsedAttr &AL) {
-  using llvm::Triple;
-  Triple Target = S.Context.getTargetInfo().getTriple();
-  // FIXME: it is OK for a compute shader entry and pixel shader entry live in
-  // same HLSL file.Issue https://github.com/llvm/llvm-project/issues/57880.
-  if (Target.getEnvironment() != Triple::Compute &&
-      Target.getEnvironment() != Triple::Library) {
-    uint32_t Pipeline =
-        (uint32_t)S.Context.getTargetInfo().getTriple().getEnvironment() -
-        (uint32_t)llvm::Triple::Pixel;
-    S.Diag(AL.getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
-        << AL << Pipeline << "Compute";
-    return;
-  }
-
-  // FIXME: report warning and ignore semantic when cannot apply on the Decl.
-  // See https://github.com/llvm/llvm-project/issues/57916.
-
   // FIXME: support semantic on field.
   // See https://github.com/llvm/llvm-project/issues/57889.
   if (isa<FieldDecl>(D)) {
@@ -7202,11 +7158,7 @@
     return;
 
   HLSLShaderAttr::ShaderType ShaderType;
-  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType) ||
-      // Library is added to help convert HLSLShaderAttr::ShaderType to
-      // llvm::Triple::EnviromentType. It is not a legal
-      // HLSLShaderAttr::ShaderType.
-      ShaderType == HLSLShaderAttr::Library) {
+  if (!HLSLShaderAttr::ConvertStrToShaderType(Str, ShaderType)) {
     S.Diag(AL.getLoc(), diag::warn_attribute_type_not_supported)
         << AL << Str << ArgLoc;
     return;
@@ -9345,7 +9297,7 @@
     handleHLSLNumThreadsAttr(S, D, AL);
     break;
   case ParsedAttr::AT_HLSLSV_GroupIndex:
-    handleHLSLSVGroupIndexAttr(S, D, AL);
+    handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
     break;
   case ParsedAttr::AT_HLSLSV_DispatchThreadID:
     handleHLSLSV_DispatchThreadIDAttr(S, D, AL);
Index: clang/lib/Sema/SemaDecl.cpp
===================================================================
--- clang/lib/Sema/SemaDecl.cpp
+++ clang/lib/Sema/SemaDecl.cpp
@@ -10325,7 +10325,7 @@
     return NewFD;
   }
 
-  if (getLangOpts().OpenCL) {
+  if (getLangOpts().OpenCL || getLangOpts().HLSL) {
     // OpenCL v1.1 s6.5: Using an address space qualifier in a function return
     // type declaration will generate a compilation error.
     LangAS AddressSpace = NewFD->getReturnType().getAddressSpace();
@@ -10335,40 +10335,6 @@
     }
   }
 
-  if (getLangOpts().HLSL) {
-    auto &TargetInfo = getASTContext().getTargetInfo();
-    // Skip operator overload which not identifier.
-    // Also make sure NewFD is in translation-unit scope.
-    if (!NewFD->isInvalidDecl() && Name.isIdentifier() &&
-        NewFD->getName() == TargetInfo.getTargetOpts().HLSLEntry &&
-        S->getDepth() == 0) {
-      CheckHLSLEntryPoint(NewFD);
-      if (!NewFD->isInvalidDecl()) {
-        auto Env = TargetInfo.getTriple().getEnvironment();
-        HLSLShaderAttr::ShaderType ShaderType =
-            static_cast<HLSLShaderAttr::ShaderType>(
-                hlsl::getStageFromEnvironment(Env));
-        // To share code with HLSLShaderAttr, add HLSLShaderAttr to entry
-        // function.
-        if (HLSLShaderAttr *NT = NewFD->getAttr<HLSLShaderAttr>()) {
-          if (NT->getType() != ShaderType)
-            Diag(NT->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
-                << NT;
-        } else {
-          NewFD->addAttr(HLSLShaderAttr::Create(Context, ShaderType,
-                                                NewFD->getBeginLoc()));
-        }
-      }
-    }
-    // HLSL does not support specifying an address space on a function return
-    // type.
-    LangAS AddressSpace = NewFD->getReturnType().getAddressSpace();
-    if (AddressSpace != LangAS::Default) {
-      Diag(NewFD->getLocation(), diag::err_return_value_with_address_space);
-      NewFD->setInvalidDecl();
-    }
-  }
-
   if (!getLangOpts().CPlusPlus) {
     // Perform semantic checking on the function declaration.
     if (!NewFD->isInvalidDecl() && NewFD->isMain())
@@ -10658,6 +10624,15 @@
     }
   }
 
+  if (getLangOpts().HLSL && D.isFunctionDefinition()) {
+    // Any top level function could potentially be specified as an entry.
+    if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
+      ActOnHLSLTopLevelFunction(NewFD);
+
+    if (NewFD->hasAttr<HLSLShaderAttr>())
+      CheckHLSLEntryPoint(NewFD);
+  }
+
   // If this is the first declaration of a library builtin function, add
   // attributes as appropriate.
   if (!D.isRedeclaration()) {
@@ -12385,24 +12360,83 @@
   }
 }
 
-void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
+void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
   auto &TargetInfo = getASTContext().getTargetInfo();
-  auto const Triple = TargetInfo.getTriple();
-  switch (Triple.getEnvironment()) {
-  default:
-    // FIXME: check all shader profiles.
+
+  if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
+    return;
+
+  StringRef Env = TargetInfo.getTriple().getEnvironmentName();
+  HLSLShaderAttr::ShaderType ShaderType;
+  if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
+    if (HLSLShaderAttr *Shader = FD->getAttr<HLSLShaderAttr>()) {
+      // The entry point is already annotated - check that it matches the
+      // triple.
+      if (Shader->getType() != ShaderType) {
+        Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
+            << Shader;
+        FD->setInvalidDecl();
+      }
+    } else {
+      // Implicitly add the shader attribute if the entry function isn't
+      // explicitly annotated.
+      FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
+                                                 FD->getBeginLoc()));
+    }
+  } else {
+    switch (TargetInfo.getTriple().getEnvironment()) {
+    case llvm::Triple::UnknownEnvironment:
+    case llvm::Triple::Library:
+      break;
+    default:
+      // TODO: This should probably just be llvm_unreachable and we should
+      // reject triples with random ABIs and such when we build the target.
+      // For now, crash.
+      llvm::report_fatal_error("Unhandled environment in triple");
+    }
+  }
+}
+
+void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
+  auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
+  assert(ShaderAttr && "Entry point has no shader attribute");
+  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+  switch (ST) {
+  case HLSLShaderAttr::Pixel:
+  case HLSLShaderAttr::Vertex:
+  case HLSLShaderAttr::Geometry:
+  case HLSLShaderAttr::Hull:
+  case HLSLShaderAttr::Domain:
+  case HLSLShaderAttr::RayGeneration:
+  case HLSLShaderAttr::Intersection:
+  case HLSLShaderAttr::AnyHit:
+  case HLSLShaderAttr::ClosestHit:
+  case HLSLShaderAttr::Miss:
+  case HLSLShaderAttr::Callable:
+    if (auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
+      Diag(NT->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
+          << NT << HLSLShaderAttr::ConvertShaderTypeToStr(ST)
+          << "compute, amplification, or mesh";
+      FD->setInvalidDecl();
+    }
     break;
-  case llvm::Triple::EnvironmentType::Compute:
+
+  case HLSLShaderAttr::Compute:
+  case HLSLShaderAttr::Amplification:
+  case HLSLShaderAttr::Mesh:
     if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
       Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
-          << Triple.getEnvironmentName();
+          << HLSLShaderAttr::ConvertShaderTypeToStr(ST);
       FD->setInvalidDecl();
     }
     break;
   }
 
-  for (const auto *Param : FD->parameters()) {
-    if (!Param->hasAttr<HLSLAnnotationAttr>()) {
+  for (ParmVarDecl *Param : FD->parameters()) {
+    if (auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
+      CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
+    } else {
       // FIXME: Handle struct parameters where annotations are on struct fields.
       // See: https://github.com/llvm/llvm-project/issues/57875
       Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
@@ -12413,6 +12447,26 @@
   // FIXME: Verify return type semantic annotation.
 }
 
+void Sema::CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, Decl *Param,
+                                       HLSLAnnotationAttr *AnnotationAttr) {
+  auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
+  assert(ShaderAttr && "Entry point has no shader attribute");
+  HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();
+
+  switch (AnnotationAttr->getKind()) {
+  case attr::HLSLSV_DispatchThreadID:
+  case attr::HLSLSV_GroupIndex:
+    if (ST == HLSLShaderAttr::Compute)
+      return;
+    Diag(AnnotationAttr->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
+        << AnnotationAttr << HLSLShaderAttr::ConvertShaderTypeToStr(ST)
+        << "compute";
+    break;
+  default:
+    llvm_unreachable("Unknown HLSLAnnotationAttr");
+  }
+}
+
 bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
   // FIXME: Need strict checking.  In C89, we need to check for
   // any assignment, increment, decrement, function-calls, or
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -3005,7 +3005,10 @@
                                       QualType NewT, QualType OldT);
   void CheckMain(FunctionDecl *FD, const DeclSpec &D);
   void CheckMSVCRTEntryPoint(FunctionDecl *FD);
+  void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
   void CheckHLSLEntryPoint(FunctionDecl *FD);
+  void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, Decl *Param,
+                                   HLSLAnnotationAttr *AnnotationAttr);
   Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
                                                    bool IsDefinition);
   void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11864,13 +11864,13 @@
   "'std::source_location::__impl' must be standard-layout and have only two 'const char *' fields '_M_file_name' and '_M_function_name', and two integral fields '_M_line' and '_M_column'">;
 
 // HLSL Diagnostics
-def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %select{Pixel|Vertex|Geometry|Hull|Domain|Compute|Library|RayGeneration|Intersection|AnyHit|ClosestHit|Miss|Callable|Mesh|Amplification|Invalid}1 shaders, requires %2">;
+def err_hlsl_attr_unsupported_in_stage : Error<"attribute %0 is unsupported in %1 shaders, requires %2">;
 def err_hlsl_attr_invalid_type : Error<
    "attribute %0 only applies to a field or parameter of type '%1'">;
 def err_hlsl_attr_invalid_ast_node : Error<
    "attribute %0 only applies to %1">;
 def err_hlsl_entry_shader_attr_mismatch : Error<
-   "%0 attribute on entry function does not match the pipeline stage">;
+   "%0 attribute on entry function does not match the target profile">;
 def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numthreads attribute cannot exceed %1">;
 def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">;
 def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">;
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -4133,24 +4133,14 @@
   let Spellings = [Microsoft<"shader">];
   let Subjects = SubjectList<[HLSLEntry]>;
   let LangOpts = [HLSL];
-  // NOTE:
-  // order for the enum should match order in llvm::Triple::EnvironmentType.
-  // ShaderType will be converted to llvm::Triple::EnvironmentType like
-  //   (llvm::Triple::EnvironmentType)((uint32_t)ShaderType +
-  //      (uint32_t)llvm::Triple::EnvironmentType::Pixel).
-  // This will avoid update code for convert when new shader type is added.
   let Args = [
     EnumArgument<"Type", "ShaderType",
-                 [
-                   "pixel", "vertex", "geometry", "hull", "domain", "compute",
-                   "library", "raygeneration", "intersection", "anyHit",
-                   "closestHit", "miss", "callable", "mesh", "amplification"
-                 ],
-                 [
-                   "Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
-                   "Library", "RayGeneration", "Intersection", "AnyHit",
-                   "ClosestHit", "Miss", "Callable", "Mesh", "Amplification"
-                 ]>
+                 ["pixel", "vertex", "geometry", "hull", "domain", "compute",
+                  "raygeneration", "intersection", "anyhit", "closesthit",
+                  "miss", "callable", "mesh", "amplification"],
+                 ["Pixel", "Vertex", "Geometry", "Hull", "Domain", "Compute",
+                  "RayGeneration", "Intersection", "AnyHit", "ClosestHit",
+                  "Miss", "Callable", "Mesh", "Amplification"]>
   ];
   let Documentation = [HLSLSV_ShaderTypeAttrDocs];
 }
@@ -4232,4 +4222,3 @@
   let Subjects = SubjectList<[TypedefName], ErrorDiag>;
   let Documentation = [Undocumented];
 }
-
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to