https://github.com/lizhengxing created https://github.com/llvm/llvm-project/pull/117781
Support SV_GroupThreadId attribute. Translate it into dx.thread.id.in.group in clang codeGen. Fixes: #70122 >From 9d5ffe00f2a1093ca2c28cce184cad0324f53de2 Mon Sep 17 00:00:00 2001 From: Zhengxing Li <zhengxin...@microsoft.com> Date: Wed, 13 Nov 2024 10:54:16 -0800 Subject: [PATCH] [HLSL] Implement SV_GroupThreadId semantic Support SV_GroupThreadId attribute. Translate it into dx.thread.id.in.group in clang codeGen. Fixes: #70122 --- clang/include/clang/Basic/Attr.td | 7 ++++ clang/include/clang/Basic/AttrDocs.td | 11 +++++++ clang/include/clang/Sema/SemaHLSL.h | 1 + clang/lib/CodeGen/CGHLSLRuntime.cpp | 5 +++ clang/lib/Parse/ParseHLSL.cpp | 1 + clang/lib/Sema/SemaDeclAttr.cpp | 3 ++ clang/lib/Sema/SemaHLSL.cpp | 10 ++++++ .../semantics/SV_GroupThreadID.hlsl | 32 +++++++++++++++++++ .../SemaHLSL/Semantics/entry_parameter.hlsl | 13 +++++--- .../Semantics/invalid_entry_parameter.hlsl | 22 +++++++++++++ .../Semantics/valid_entry_parameter.hlsl | 25 +++++++++++++++ 11 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index b055cbd769bb50..9c8e27c0f34e93 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4621,6 +4621,13 @@ def HLSLNumThreads: InheritableAttr { let Documentation = [NumThreadsDocs]; } +def HLSLSV_GroupThreadID: HLSLAnnotationAttr { + let Spellings = [HLSLAnnotation<"SV_GroupThreadID">]; + let Subjects = SubjectList<[ParmVar, Field]>; + let LangOpts = [HLSL]; + let Documentation = [HLSLSV_GroupThreadIDDocs]; +} + def HLSLSV_GroupID: HLSLAnnotationAttr { let Spellings = [HLSLAnnotation<"SV_GroupID">]; let Subjects = SubjectList<[ParmVar, Field]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index aafd4449e47004..88bf9a020586ea 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7934,6 +7934,17 @@ randomized. }]; } +def HLSLSV_GroupThreadIDDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which +individual thread within a thread group is executing in. This attribute is +only supported in compute shaders. + +The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid + }]; +} + def HLSLSV_GroupIDDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index ee685d95c96154..f4cd11f423a84a 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); + void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL); void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 2c293523fca8ca..19db7faddaeac0 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, CGM.getIntrinsic(getThreadIdIntrinsic()); return buildVectorInput(B, ThreadIDIntrinsic, Ty); } + if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) { + llvm::Function *GroupThreadIDIntrinsic = + CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group); + return buildVectorInput(B, GroupThreadIDIntrinsic, Ty); + } if (D.hasAttr<HLSLSV_GroupIDAttr>()) { llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id); return buildVectorInput(B, GroupIDIntrinsic, Ty); diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index 4de342b63ed802..443bf2b9ec626a 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs, case ParsedAttr::UnknownAttribute: Diag(Loc, diag::err_unknown_hlsl_semantic) << II; return; + case ParsedAttr::AT_HLSLSV_GroupThreadID: case ParsedAttr::AT_HLSLSV_GroupID: case ParsedAttr::AT_HLSLSV_GroupIndex: case ParsedAttr::AT_HLSLSV_DispatchThreadID: diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 53cc8cb6afd7dc..47e946c3ee64bc 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7103,6 +7103,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLWaveSize: S.HLSL().handleWaveSizeAttr(D, AL); break; + case ParsedAttr::AT_HLSLSV_GroupThreadID: + S.HLSL().handleSV_GroupThreadIDAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupID: S.HLSL().handleSV_GroupIDAttr(D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 8b2f24a8e4be0a..7f3c6cb566bcbf 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation( switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: + case attr::HLSLSV_GroupThreadID: case attr::HLSLSV_GroupID: if (ST == llvm::Triple::Compute) return; @@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); } +void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) { + auto *VD = cast<ValueDecl>(D); + if (!diagnoseInputIDType(VD->getType(), AL)) + return; + + D->addAttr(::new (getASTContext()) + HLSLSV_GroupThreadIDAttr(getASTContext(), AL)); +} + void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { auto *VD = cast<ValueDecl>(D); if (!diagnoseInputIDType(VD->getType(), AL)) diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl new file mode 100644 index 00000000000000..3533331c6f091c --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl @@ -0,0 +1,32 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +// Make sure SV_GroupThreadID translated into dx.thread.id.in.group. + +// CHECK: define void @foo() +// CHECK: %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +[shader("compute")] +[numthreads(8,8,1)] +void foo(uint Idx : SV_GroupThreadID) {} + +// CHECK: define void @bar() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +[shader("compute")] +[numthreads(8,8,1)] +void bar(uint2 Idx : SV_GroupThreadID) {} + +// CHECK: define void @test() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2) +// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2 +// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +[shader("compute")] +[numthreads(8,8,1)] +void test(uint3 Idx : SV_GroupThreadID) {} diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl index 13c07038d2e4a4..71d32cd13832e1 100644 --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -2,15 +2,18 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s [numthreads(8,8,1)] -// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} -void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) { -// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)' +// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}} +void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' // CHECK-NEXT: HLSLSV_GroupIndexAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint' // CHECK-NEXT: HLSLSV_GroupIDAttr +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr } diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl index 4e1f88aa2294b5..a24112c8e1bb8f 100644 --- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -49,3 +49,25 @@ struct ST2_GID { static uint GID : SV_GroupID; uint s_gid : SV_GroupID; }; + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain_GThreadID(float ID : SV_GroupThreadID) { +} + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain2_GThreadID(ST GID : SV_GroupThreadID) { + +} + +void foo_GThreadID() { +// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}} + uint GThreadIS : SV_GroupThreadID; +} + +struct ST2_GThreadID { +// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}} + static uint GThreadID : SV_GroupThreadID; + uint s_gthreadid : SV_GroupThreadID; +}; diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl index 10a5e5dabac87b..6781f9241df240 100644 --- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl @@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) { // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3' // CHECK-NEXT: HLSLSV_GroupIDAttr } + +[numthreads(8,8,1)] +void CSMain_GThreadID(uint ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain3_GThreadID(uint3 : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits