https://github.com/lizhengxing updated https://github.com/llvm/llvm-project/pull/115911
>From 6418461717614d5879688d32a0ab9bf9d9137328 Mon Sep 17 00:00:00 2001 From: Zhengxing Li <zhengxin...@microsoft.com> Date: Tue, 1 Oct 2024 15:13:34 -0700 Subject: [PATCH] [HLSL] Implement SV_GroupID semantic Support SV_GroupID attribute. Translate it into dx.group.id in clang codeGen. Fixes: #70120 --- clang/include/clang/Basic/Attr.td | 7 ++++++ clang/include/clang/Basic/AttrDocs.td | 11 ++++++++ clang/include/clang/Sema/SemaHLSL.h | 1 + clang/lib/CodeGen/CGHLSLRuntime.cpp | 4 +++ clang/lib/Parse/ParseHLSL.cpp | 1 + clang/lib/Sema/SemaDeclAttr.cpp | 3 +++ clang/lib/Sema/SemaHLSL.cpp | 16 ++++++++++-- .../CodeGenHLSL/semantics/SV_GroupID.hlsl | 21 ++++++++++++++++ .../SemaHLSL/Semantics/entry_parameter.hlsl | 11 +++++--- .../Semantics/invalid_entry_parameter.hlsl | 22 ++++++++++++++++ .../Semantics/valid_entry_parameter.hlsl | 25 +++++++++++++++++++ 11 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index a631e81d40aa68..52df814deecc02 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4587,6 +4587,13 @@ def HLSLNumThreads: InheritableAttr { let Documentation = [NumThreadsDocs]; } +def HLSLSV_GroupID: HLSLAnnotationAttr { + let Spellings = [HLSLAnnotation<"SV_GroupID">]; + let Subjects = SubjectList<[ParmVar, Field]>; + let LangOpts = [HLSL]; + let Documentation = [HLSLSV_GroupIDDocs]; +} + def HLSLSV_GroupIndex: HLSLAnnotationAttr { let Spellings = [HLSLAnnotation<"SV_GroupIndex">]; let Subjects = SubjectList<[ParmVar, GlobalVar]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index b64dbef6332e6a..0a20f87d0ddb37 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7816,6 +7816,17 @@ randomized. }]; } +def HLSLSV_GroupIDDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``SV_GroupID`` semantic, when applied to an input parameter, specifies a +data binding to map the group id to the specified parameter. This attribute is +only supported in compute shaders. + +The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid + }]; +} + def HLSLSV_GroupIndexDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 06c541dec08cc8..f36b46d0c7ad43 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); + void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 7ba0d615018181..2c293523fca8ca 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -389,6 +389,10 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, CGM.getIntrinsic(getThreadIdIntrinsic()); return buildVectorInput(B, ThreadIDIntrinsic, Ty); } + if (D.hasAttr<HLSLSV_GroupIDAttr>()) { + llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id); + return buildVectorInput(B, GroupIDIntrinsic, Ty); + } assert(false && "Unhandled parameter attribute"); return nullptr; } diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index b36ea4012c26e1..2f67718f94c68c 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs, case ParsedAttr::UnknownAttribute: Diag(Loc, diag::err_unknown_hlsl_semantic) << II; return; + case ParsedAttr::AT_HLSLSV_GroupID: case ParsedAttr::AT_HLSLSV_GroupIndex: case ParsedAttr::AT_HLSLSV_DispatchThreadID: break; diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index d05d326178e1b8..cea51bd507f288 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6990,6 +6990,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLWaveSize: S.HLSL().handleWaveSizeAttr(D, AL); break; + case ParsedAttr::AT_HLSLSV_GroupID: + S.HLSL().handleSV_GroupIDAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupIndex: handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 65b0d9cd65637f..92847c32b18d12 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation( switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: + case attr::HLSLSV_GroupID: if (ST == llvm::Triple::Compute) return; DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute}); @@ -764,7 +765,7 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } -static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { +static bool isLegalTypeForHLSLSV_ThreadOrGroupID(QualType T) { if (!T->hasUnsignedIntegerRepresentation()) return false; if (const auto *VT = T->getAs<VectorType>()) @@ -774,7 +775,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { auto *VD = cast<ValueDecl>(D); - if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) { + if (!isLegalTypeForHLSLSV_ThreadOrGroupID(VD->getType())) { Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) << AL << "uint/uint2/uint3"; return; @@ -784,6 +785,17 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); } +void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { + auto *VD = cast<ValueDecl>(D); + if (!isLegalTypeForHLSLSV_ThreadOrGroupID(VD->getType())) { + Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) + << AL << "uint/uint2/uint3"; + return; + } + + D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL)); +} + void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) { if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) { Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node) diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl new file mode 100644 index 00000000000000..d3cf2339335246 --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl @@ -0,0 +1,21 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +// Make sure SV_GroupID translated into dx.group.id. + +// CHECK: define void @foo() +// CHECK: %[[#ID:]] = call i32 @llvm.dx.group.id(i32 0) +// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +[shader("compute")] +[numthreads(8,8,1)] +void foo(uint Idx : SV_GroupID) {} + +// CHECK: define void @bar() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +[shader("compute")] +[numthreads(8,8,1)] +void bar(uint2 Idx : SV_GroupID) {} + diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl index 8484259f84692b..13c07038d2e4a4 100644 --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -2,12 +2,15 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s [numthreads(8,8,1)] -// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} -void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) { -// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)' +// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} +void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' // CHECK-NEXT: HLSLSV_GroupIndexAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint' +// CHECK-NEXT: HLSLSV_GroupIDAttr } diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl index bc3cf8bc51daf4..e78c4ce25c49f6 100644 --- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -27,3 +27,25 @@ struct ST2 { static uint X : SV_DispatchThreadID; uint s : SV_DispatchThreadID; }; + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain_GID(float ID : SV_GroupID) { +} + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain2_GID(ST GID : SV_GroupID) { + +} + +void foo_GID() { +// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}} + uint GIS : SV_GroupID; +} + +struct ST2_GID { +// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}} + static uint GID : SV_GroupID; + uint s_gid : SV_GroupID; +}; \ No newline at end of file diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl index 8e79fc4d85ec91..10a5e5dabac87b 100644 --- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl @@ -24,3 +24,28 @@ void CSMain3(uint3 : SV_DispatchThreadID) { // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 'uint3' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr } + +[numthreads(8,8,1)] +void CSMain_GID(uint ID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GID 'void (uint)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:22 ID 'uint' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} +[numthreads(8,8,1)] +void CSMain1_GID(uint2 ID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GID 'void (uint2)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint2' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} +[numthreads(8,8,1)] +void CSMain2_GID(uint3 ID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint3' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} +[numthreads(8,8,1)] +void CSMain3_GID(uint3 : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits