https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/142602
>From c69258d78459b8dcc89bec38a8a795763cd3dc80 Mon Sep 17 00:00:00 2001 From: pvanhout <pierre.vanhoutr...@amd.com> Date: Tue, 3 Jun 2025 14:40:38 +0200 Subject: [PATCH] [AMDGPU] New RegBankSelect: Add Ptr32/Ptr64/Ptr128 There's quite a few opcodes that do not care about the exact AS of the pointer, just its size. Adding generic types for these will help reduce duplication in the rule definitions. I also moved the usual B types to use the new `isAnyPtr` helper I added to make sure they're supersets of the `Ptr` cases --- .../AMDGPU/AMDGPURegBankLegalizeHelper.cpp | 42 +++++++++++++++---- .../AMDGPU/AMDGPURegBankLegalizeRules.cpp | 29 +++++++++++-- .../AMDGPU/AMDGPURegBankLegalizeRules.h | 19 +++++++++ 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp index 89af982636590..b2ddc6e88966b 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp @@ -595,17 +595,23 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) { case VgprB32: case UniInVgprB32: if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) || - Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) || - Ty == LLT::pointer(6, 32)) + isAnyPtr(Ty, 32)) return Ty; return LLT(); + case SgprPtr32: + case VgprPtr32: + return isAnyPtr(Ty, 32) ? Ty : LLT(); + case SgprPtr64: + case VgprPtr64: + return isAnyPtr(Ty, 64) ? Ty : LLT(); + case SgprPtr128: + case VgprPtr128: + return isAnyPtr(Ty, 128) ? Ty : LLT(); case SgprB64: case VgprB64: case UniInVgprB64: if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) || - Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) || - Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64) || - (Ty.isPointer() && Ty.getAddressSpace() > AMDGPUAS::MAX_AMDGPU_ADDRESS)) + Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64)) return Ty; return LLT(); case SgprB96: @@ -619,7 +625,7 @@ LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) { case VgprB128: case UniInVgprB128: if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) || - Ty == LLT::fixed_vector(2, 64)) + Ty == LLT::fixed_vector(2, 64) || isAnyPtr(Ty, 128)) return Ty; return LLT(); case SgprB256: @@ -654,6 +660,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) { case SgprP3: case SgprP4: case SgprP5: + case SgprPtr32: + case SgprPtr64: + case SgprPtr128: case SgprV2S16: case SgprV2S32: case SgprV4S32: @@ -688,6 +697,9 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) { case VgprP3: case VgprP4: case VgprP5: + case VgprPtr32: + case VgprPtr64: + case VgprPtr128: case VgprV2S16: case VgprV2S32: case VgprV4S32: @@ -754,12 +766,18 @@ void RegBankLegalizeHelper::applyMappingDst( case SgprB128: case SgprB256: case SgprB512: + case SgprPtr32: + case SgprPtr64: + case SgprPtr128: case VgprB32: case VgprB64: case VgprB96: case VgprB128: case VgprB256: - case VgprB512: { + case VgprB512: + case VgprPtr32: + case VgprPtr64: + case VgprPtr128: { assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty)); assert(RB == getRegBankFromID(MethodIDs[OpIdx])); break; @@ -864,7 +882,10 @@ void RegBankLegalizeHelper::applyMappingSrc( case SgprB96: case SgprB128: case SgprB256: - case SgprB512: { + case SgprB512: + case SgprPtr32: + case SgprPtr64: + case SgprPtr128: { assert(Ty == getBTyFromID(MethodIDs[i], Ty)); assert(RB == getRegBankFromID(MethodIDs[i])); break; @@ -895,7 +916,10 @@ void RegBankLegalizeHelper::applyMappingSrc( case VgprB96: case VgprB128: case VgprB256: - case VgprB512: { + case VgprB512: + case VgprPtr32: + case VgprPtr64: + case VgprPtr128: { assert(Ty == getBTyFromID(MethodIDs[i], Ty)); if (RB != VgprRB) { auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg); diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp index 672fc5b79abc2..5402129e41887 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp @@ -26,6 +26,10 @@ using namespace llvm; using namespace AMDGPU; +bool AMDGPU::isAnyPtr(LLT Ty, unsigned Width) { + return Ty.isPointer() && Ty.getSizeInBits() == Width; +} + RegBankLLTMapping::RegBankLLTMapping( std::initializer_list<RegBankLLTMappingApplyID> DstOpMappingList, std::initializer_list<RegBankLLTMappingApplyID> SrcOpMappingList, @@ -62,6 +66,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, return MRI.getType(Reg) == LLT::pointer(4, 64); case P5: return MRI.getType(Reg) == LLT::pointer(5, 32); + case Ptr32: + return isAnyPtr(MRI.getType(Reg), 32); + case Ptr64: + return isAnyPtr(MRI.getType(Reg), 64); + case Ptr128: + return isAnyPtr(MRI.getType(Reg), 128); case V2S32: return MRI.getType(Reg) == LLT::fixed_vector(2, 32); case V4S32: @@ -98,6 +108,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg); case UniP5: return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg); + case UniPtr32: + return isAnyPtr(MRI.getType(Reg), 32) && MUI.isUniform(Reg); + case UniPtr64: + return isAnyPtr(MRI.getType(Reg), 64) && MUI.isUniform(Reg); + case UniPtr128: + return isAnyPtr(MRI.getType(Reg), 128) && MUI.isUniform(Reg); case UniV2S16: return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg); case UniB32: @@ -132,6 +148,12 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID, return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg); case DivP5: return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg); + case DivPtr32: + return isAnyPtr(MRI.getType(Reg), 32) && MUI.isDivergent(Reg); + case DivPtr64: + return isAnyPtr(MRI.getType(Reg), 64) && MUI.isDivergent(Reg); + case DivPtr128: + return isAnyPtr(MRI.getType(Reg), 128) && MUI.isDivergent(Reg); case DivV2S16: return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg); case DivB32: @@ -205,15 +227,14 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) { UniformityLLTOpPredicateID LLTToBId(LLT Ty) { if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) || - (Ty.isPointer() && Ty.getSizeInBits() == 32)) + isAnyPtr(Ty, 32)) return B32; if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) || - Ty == LLT::fixed_vector(4, 16) || - (Ty.isPointer() && Ty.getSizeInBits() == 64)) + Ty == LLT::fixed_vector(4, 16) || isAnyPtr(Ty, 64)) return B64; if (Ty == LLT::fixed_vector(3, 32)) return B96; - if (Ty == LLT::fixed_vector(4, 32)) + if (Ty == LLT::fixed_vector(4, 32) || isAnyPtr(Ty, 128)) return B128; return _; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h index 30b900d871f3c..7243d75aa830c 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h +++ b/llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h @@ -15,6 +15,7 @@ namespace llvm { +class LLT; class MachineRegisterInfo; class MachineInstr; class GCNSubtarget; @@ -26,6 +27,9 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>; namespace AMDGPU { +/// \returns true if \p Ty is a pointer type with size \p Width. +bool isAnyPtr(LLT Ty, unsigned Width); + // IDs used to build predicate for RegBankLegalizeRule. Predicate can have one // or more IDs and each represents a check for 'uniform or divergent' + LLT or // just LLT on register operand. @@ -59,18 +63,27 @@ enum UniformityLLTOpPredicateID { P3, P4, P5, + Ptr32, + Ptr64, + Ptr128, UniP0, UniP1, UniP3, UniP4, UniP5, + UniPtr32, + UniPtr64, + UniPtr128, DivP0, DivP1, DivP3, DivP4, DivP5, + DivPtr32, + DivPtr64, + DivPtr128, // vectors V2S16, @@ -125,6 +138,9 @@ enum RegBankLLTMappingApplyID { SgprP3, SgprP4, SgprP5, + SgprPtr32, + SgprPtr64, + SgprPtr128, SgprV2S16, SgprV4S32, SgprV2S32, @@ -145,6 +161,9 @@ enum RegBankLLTMappingApplyID { VgprP3, VgprP4, VgprP5, + VgprPtr32, + VgprPtr64, + VgprPtr128, VgprV2S16, VgprV2S32, VgprB32, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits