================
@@ -115,126 +117,233 @@ class AMDGPURegBankLegalizeCombiner {
         VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)),
         VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {};
 
-  bool isLaneMask(Register Reg) {
-    const RegisterBank *RB = MRI.getRegBankOrNull(Reg);
-    if (RB && RB->getID() == AMDGPU::VCCRegBankID)
-      return true;
+  bool isLaneMask(Register Reg);
+  std::pair<MachineInstr *, Register> tryMatch(Register Src, unsigned Opcode);
+  std::pair<GUnmerge *, int> tryMatchRALFromUnmerge(Register Src);
+  Register getReadAnyLaneSrc(Register Src);
+  void replaceRegWithOrBuildCopy(Register Dst, Register Src);
 
-    const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg);
-    return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1);
-  }
+  bool tryEliminateReadAnyLane(MachineInstr &Copy);
+  void tryCombineCopy(MachineInstr &MI);
+  void tryCombineS1AnyExt(MachineInstr &MI);
+};
 
-  void cleanUpAfterCombine(MachineInstr &MI, MachineInstr *Optional0) {
-    MI.eraseFromParent();
-    if (Optional0 && isTriviallyDead(*Optional0, MRI))
-      Optional0->eraseFromParent();
-  }
+bool AMDGPURegBankLegalizeCombiner::isLaneMask(Register Reg) {
+  const RegisterBank *RB = MRI.getRegBankOrNull(Reg);
+  if (RB && RB->getID() == AMDGPU::VCCRegBankID)
+    return true;
 
-  std::pair<MachineInstr *, Register> tryMatch(Register Src, unsigned Opcode) {
-    MachineInstr *MatchMI = MRI.getVRegDef(Src);
-    if (MatchMI->getOpcode() != Opcode)
-      return {nullptr, Register()};
-    return {MatchMI, MatchMI->getOperand(1).getReg()};
-  }
+  const TargetRegisterClass *RC = MRI.getRegClassOrNull(Reg);
+  return RC && TRI.isSGPRClass(RC) && MRI.getType(Reg) == LLT::scalar(1);
+}
 
-  void tryCombineCopy(MachineInstr &MI) {
-    Register Dst = MI.getOperand(0).getReg();
-    Register Src = MI.getOperand(1).getReg();
-    // Skip copies of physical registers.
-    if (!Dst.isVirtual() || !Src.isVirtual())
-      return;
-
-    // This is a cross bank copy, sgpr S1 to lane mask.
-    //
-    // %Src:sgpr(s1) = G_TRUNC %TruncS32Src:sgpr(s32)
-    // %Dst:lane-mask(s1) = COPY %Src:sgpr(s1)
-    // ->
-    // %Dst:lane-mask(s1) = G_AMDGPU_COPY_VCC_SCC %TruncS32Src:sgpr(s32)
-    if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) {
-      auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC);
-      assert(Trunc && MRI.getType(TruncS32Src) == S32 &&
-             "sgpr S1 must be result of G_TRUNC of sgpr S32");
-
-      B.setInstr(MI);
-      // Ensure that truncated bits in BoolSrc are 0.
-      auto One = B.buildConstant({SgprRB, S32}, 1);
-      auto BoolSrc = B.buildAnd({SgprRB, S32}, TruncS32Src, One);
-      B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {Dst}, {BoolSrc});
-      cleanUpAfterCombine(MI, Trunc);
-      return;
-    }
+std::pair<MachineInstr *, Register>
+AMDGPURegBankLegalizeCombiner::tryMatch(Register Src, unsigned Opcode) {
+  MachineInstr *MatchMI = MRI.getVRegDef(Src);
+  if (MatchMI->getOpcode() != Opcode)
+    return {nullptr, Register()};
+  return {MatchMI, MatchMI->getOperand(1).getReg()};
+}
+
+std::pair<GUnmerge *, int>
+AMDGPURegBankLegalizeCombiner::tryMatchRALFromUnmerge(Register Src) {
+  MachineInstr *ReadAnyLane = MRI.getVRegDef(Src);
+  if (ReadAnyLane->getOpcode() != AMDGPU::G_AMDGPU_READANYLANE)
+    return {nullptr, -1};
+
+  Register RALSrc = ReadAnyLane->getOperand(1).getReg();
+  if (auto *UnMerge = getOpcodeDef<GUnmerge>(RALSrc, MRI))
+    return {UnMerge, UnMerge->findRegisterDefOperandIdx(RALSrc, nullptr)};
 
-    // Src = G_AMDGPU_READANYLANE RALSrc
-    // Dst = COPY Src
-    // ->
-    // Dst = RALSrc
-    if (MRI.getRegBankOrNull(Dst) == VgprRB &&
-        MRI.getRegBankOrNull(Src) == SgprRB) {
-      auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE);
-      if (!RAL)
-        return;
-
-      assert(MRI.getRegBank(RALSrc) == VgprRB);
-      MRI.replaceRegWith(Dst, RALSrc);
-      cleanUpAfterCombine(MI, RAL);
-      return;
+  return {nullptr, -1};
+}
+
+Register AMDGPURegBankLegalizeCombiner::getReadAnyLaneSrc(Register Src) {
+  // Src = G_AMDGPU_READANYLANE RALSrc
+  auto [RAL, RALSrc] = tryMatch(Src, AMDGPU::G_AMDGPU_READANYLANE);
+  if (RAL)
+    return RALSrc;
+
+  // LoVgpr, HiVgpr = G_UNMERGE_VALUES UnmergeSrc
+  // LoSgpr = G_AMDGPU_READANYLANE LoVgpr
+  // HiSgpr = G_AMDGPU_READANYLANE HiVgpr
+  // Src G_MERGE_VALUES LoSgpr, HiSgpr
+  auto *Merge = getOpcodeDef<GMergeLikeInstr>(Src, MRI);
+  if (Merge) {
+    unsigned NumElts = Merge->getNumSources();
+    auto [Unmerge, Idx] = tryMatchRALFromUnmerge(Merge->getSourceReg(0));
+    if (!Unmerge || Unmerge->getNumDefs() != NumElts || Idx != 0)
+      return {};
+
+    // Check if all elements are from same unmerge and there is no shuffling.
+    for (unsigned i = 1; i < NumElts; ++i) {
+      auto [UnmergeI, IdxI] = tryMatchRALFromUnmerge(Merge->getSourceReg(i));
+      if (UnmergeI != Unmerge || (unsigned)IdxI != i)
+        return {};
     }
+    return Unmerge->getSourceReg();
   }
 
-  void tryCombineS1AnyExt(MachineInstr &MI) {
-    // %Src:sgpr(S1) = G_TRUNC %TruncSrc
-    // %Dst = G_ANYEXT %Src:sgpr(S1)
-    // ->
-    // %Dst = G_... %TruncSrc
-    Register Dst = MI.getOperand(0).getReg();
-    Register Src = MI.getOperand(1).getReg();
-    if (MRI.getType(Src) != S1)
-      return;
-
-    auto [Trunc, TruncSrc] = tryMatch(Src, AMDGPU::G_TRUNC);
-    if (!Trunc)
-      return;
-
-    LLT DstTy = MRI.getType(Dst);
-    LLT TruncSrcTy = MRI.getType(TruncSrc);
-
-    if (DstTy == TruncSrcTy) {
-      MRI.replaceRegWith(Dst, TruncSrc);
-      cleanUpAfterCombine(MI, Trunc);
-      return;
-    }
+  // ..., VgprI, ... = G_UNMERGE_VALUES VgprLarge
+  // SgprI = G_AMDGPU_READANYLANE VgprI
+  // SgprLarge G_MERGE_VALUES ..., SgprI, ...
+  // ..., Src, ... = G_UNMERGE_VALUES SgprLarge
+  auto *UnMerge = getOpcodeDef<GUnmerge>(Src, MRI);
+  if (!UnMerge)
+    return {};
+
+  int Idx = UnMerge->findRegisterDefOperandIdx(Src, nullptr);
+  Merge = getOpcodeDef<GMergeLikeInstr>(UnMerge->getSourceReg(), MRI);
+  if (!Merge || UnMerge->getNumDefs() != Merge->getNumSources())
+    return {};
+
+  Register SrcRegIdx = Merge->getSourceReg(Idx);
+  if (MRI.getType(Src) != MRI.getType(SrcRegIdx))
+    return {};
+
+  auto [RALEl, RALElSrc] = tryMatch(SrcRegIdx, AMDGPU::G_AMDGPU_READANYLANE);
+  if (RALEl)
+    return RALElSrc;
+
+  return {};
+}
+
+void AMDGPURegBankLegalizeCombiner::replaceRegWithOrBuildCopy(Register Dst,
+                                                              Register Src) {
+  if (Dst.isVirtual())
+    MRI.replaceRegWith(Dst, Src);
+  else
+    B.buildCopy(Dst, Src);
+}
+
+bool AMDGPURegBankLegalizeCombiner::tryEliminateReadAnyLane(
+    MachineInstr &Copy) {
+  Register Dst = Copy.getOperand(0).getReg();
+  Register Src = Copy.getOperand(1).getReg();
+
+  // Skip non-vgpr Dst
+  if (Dst.isVirtual() ? (MRI.getRegBankOrNull(Dst) != VgprRB)
+                      : !TRI.isVGPR(MRI, Dst))
+    return false;
+
+  // Skip physical source registers and source registers with register class
+  if (!Src.isVirtual() || MRI.getRegClassOrNull(Src))
+    return false;
+
+  Register RALDst = Src;
+  MachineInstr &SrcMI = *MRI.getVRegDef(Src);
+  if (SrcMI.getOpcode() == AMDGPU::G_BITCAST)
+    RALDst = SrcMI.getOperand(1).getReg();
+
+  Register RALSrc = getReadAnyLaneSrc(RALDst);
+  if (!RALSrc)
+    return false;
+
+  B.setInstr(Copy);
+  if (SrcMI.getOpcode() != AMDGPU::G_BITCAST) {
+    // Src = READANYLANE RALSrc     Src = READANYLANE RALSrc
+    // Dst = Copy Src               $Dst = Copy Src
+    // ->                           ->
+    // Dst = RALSrc                 $Dst = Copy RALSrc
+    replaceRegWithOrBuildCopy(Dst, RALSrc);
+  } else {
+    // RALDst = READANYLANE RALSrc  RALDst = READANYLANE RALSrc
+    // Src = G_BITCAST RALDst       Src = G_BITCAST RALDst
+    // Dst = Copy Src               Dst = Copy Src
+    // ->                          ->
+    // NewVgpr = G_BITCAST RALDst   NewVgpr = G_BITCAST RALDst
+    // Dst = NewVgpr                $Dst = Copy NewVgpr
+    auto Bitcast = B.buildBitcast({VgprRB, MRI.getType(Src)}, RALSrc);
+    replaceRegWithOrBuildCopy(Dst, Bitcast.getReg(0));
+  }
+
+  eraseInstr(Copy, MRI);
+  return true;
+}
+
+void AMDGPURegBankLegalizeCombiner::tryCombineCopy(MachineInstr &MI) {
+  if (tryEliminateReadAnyLane(MI))
+    return;
+
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+  // Skip copies of physical registers.
+  if (!Dst.isVirtual() || !Src.isVirtual())
+    return;
+
+  // This is a cross bank copy, sgpr S1 to lane mask.
+  //
+  // %Src:sgpr(s1) = G_TRUNC %TruncS32Src:sgpr(s32)
+  // %Dst:lane-mask(s1) = COPY %Src:sgpr(s1)
+  // ->
+  // %Dst:lane-mask(s1) = G_AMDGPU_COPY_VCC_SCC %TruncS32Src:sgpr(s32)
+  if (isLaneMask(Dst) && MRI.getRegBankOrNull(Src) == SgprRB) {
+    auto [Trunc, TruncS32Src] = tryMatch(Src, AMDGPU::G_TRUNC);
+    assert(Trunc && MRI.getType(TruncS32Src) == S32 &&
+           "sgpr S1 must be result of G_TRUNC of sgpr S32");
----------------
petar-avramovic wrote:

You are correct, not guaranteed. Was trying to make one with trunc from i64 to 
i1 but did not find one yet. Here is one test that actually hits that assert.
```
define amdgpu_cs void @loop_with_1break(ptr addrspace(1) %x, ptr addrspace(1) 
%a) {
entry:
  br label %A

A:
  %counter = phi i32 [ %counter.plus.1, %loop.body ], [ 0, %entry ]
  %a.plus.counter = getelementptr inbounds i32, ptr addrspace(1) %a, i32 
%counter
  %a.val = load i32, ptr addrspace(1) %a.plus.counter
  %a.cond = icmp eq i32 %a.val, 0
  br i1 %a.cond, label %exit, label %loop.body

loop.body:
  %x.plus.counter = getelementptr inbounds i32, ptr addrspace(1) %x, i32 
%counter
  %x.val = load i32, ptr addrspace(1) %x.plus.counter
  %x.val.plus.1 = add i32 %x.val, 1
  store i32 %x.val.plus.1, ptr addrspace(1) %x.plus.counter
  %counter.plus.1 = add i32 %counter, 1
  %x.cond = trunc i32 %counter to i1
  br i1 %x.cond, label %exit, label %A

exit:
  ret void
}
```
Original idea where that assert comes from trunc created regbanklegalize. Legal 
i1 used by something that is lowered by divergence lowering. Uniform i1 is 
lowered as sgpr S32 that is truncated to S1.
As this pass is still partially implemented can we deal with this in later 
patch?

https://github.com/llvm/llvm-project/pull/145911
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to