================
@@ -98,21 +110,98 @@ void applySPIRVDistance(MachineInstr &MI, 
MachineRegisterInfo &MRI,
 
   SPIRVGlobalRegistry *GR =
       MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
-  auto RemoveAllUses = [&](Register Reg) {
-    SmallVector<MachineInstr *, 4> UsesToErase(
-        llvm::make_pointer_range(MRI.use_instructions(Reg)));
-
-    // calling eraseFromParent to early invalidates the iterator.
-    for (auto *MIToErase : UsesToErase) {
-      GR->invalidateMachineInstr(MIToErase);
-      MIToErase->eraseFromParent();
-    }
-  };
-  RemoveAllUses(SubDestReg);   // remove all uses of FSUB Result
+  removeAllUses(SubDestReg, MRI, GR); // remove all uses of FSUB Result
   GR->invalidateMachineInstr(SubInstr);
   SubInstr->eraseFromParent(); // remove FSUB instruction
 }
 
+/// This match is part of a combine that
+/// rewrites select(fcmp(dot(I, Ng), 0), N, 0 - N) to faceforward(N, I, Ng)
+///   (vXf32 (g_select
+///             (g_fcmp
+///                (g_intrinsic dot(vXf32 I) (vXf32 Ng)
+///                 0)
+///             (vXf32 N)
+///             (vXf32 g_fsub (0) (vXf32 N))))
+/// ->
+///   (vXf32 (g_intrinsic faceforward
+///             (vXf32 N) (vXf32 I) (vXf32 Ng)))
+///
+bool matchSelectToFaceForward(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  if (MI.getOpcode() != TargetOpcode::G_SELECT)
+    return false;
+
+  // Check if select's condition is a comparison between a dot product and 0.
+  Register CondReg = MI.getOperand(1).getReg();
+  MachineInstr *CondInstr = MRI.getVRegDef(CondReg);
+  if (!CondInstr || CondInstr->getOpcode() != TargetOpcode::G_FCMP)
+    return false;
+
+  Register DotReg = CondInstr->getOperand(2).getReg();
+  MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
+  if (DotInstr->getOpcode() != TargetOpcode::G_FMUL &&
+      (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC ||
+       cast<GIntrinsic>(DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot))
+    return false;
+
+  Register CondZeroReg = CondInstr->getOperand(3).getReg();
+  MachineInstr *CondZeroInstr = MRI.getVRegDef(CondZeroReg);
+  if (CondZeroInstr->getOpcode() != TargetOpcode::G_FCONSTANT ||
+      !CondZeroInstr->getOperand(1).getFPImm()->isZero())
+    return false;
+
+  // Check if select's false operand is the negation of the true operand.
+  Register TrueReg = MI.getOperand(2).getReg();
+  Register FalseReg = MI.getOperand(3).getReg();
+  MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
+  if (FalseInstr->getOpcode() != TargetOpcode::G_FNEG)
+    return false;
+  if (TrueReg != FalseInstr->getOperand(1).getReg())
+    return false;
----------------
s-perron wrote:

Is it possible that the operands be constants like `1` and `-1`? In that case, 
this would fail because the `False` is not a negation. The are many other cases 
where the optimizer could simplify `N` and `0-N` breaking this pattern matching.

I don't know llvm well enough. Is there a generic way to check if one value is 
the negation of the other? In other compilers I worked on, you could create a 
temporary express (a + b) and check if the simplifies folds it to 0.

https://github.com/llvm/llvm-project/pull/139959
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to