================
@@ -8701,95 +8734,266 @@ SDValue SystemZTargetLowering::combineSETCC(
   return SDValue();
 }
 
-static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+  auto *N = Val.getNode();
+  if (!N)
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  switch (N->getOpcode()) {
+  default:
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  case SystemZISD::IPM:
+    if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
+        N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
+      return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
+    return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    if (!CCValid || !Op4CCNode)
+      return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+    int CCValidVal = CCValid->getZExtValue();
+    if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
+        Op4CCNode->getOpcode() == SystemZISD::TM) {
+      auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
+      if (OpCC != SDValue())
+        return std::make_pair(OpCC, OpCCValid);
+    }
+    return std::make_pair(Op4CCReg, CCValidVal);
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:
+    auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+    if (Op0CC != SDValue())
+      return std::make_pair(Op0CC, Op0CCValid);
+    return findCCUse(N->getOperand(1));
+  }
+}
+
+static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
+                          SelectionDAG &DAG);
+
+SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
+                                                     SelectionDAG &DAG) {
+  auto *N = Val.getNode(), *CCNode = CC.getNode();
+  if (!N || !CCNode)
+    return {};
+  SDLoc DL(N);
+  auto Opcode = N->getOpcode();
+  switch (Opcode) {
+  default:
+    return {};
+  case ISD::Constant:
+    return {Val, Val, Val, Val};
+  case SystemZISD::IPM: {
+    auto *IPMOp0Node = N->getOperand(0).getNode();
+    if (!IPMOp0Node || IPMOp0Node != CCNode)
+      return {};
+    SmallVector<SDValue, 4> ShiftedCCVals;
+    for (auto CC : {0, 1, 2, 3})
+      ShiftedCCVals.emplace_back(
+          DAG.getConstant((CC << SystemZ::IPM_CC), DL, MVT::i32));
+    return ShiftedCCVals;
+  }
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+    auto *TrueOp = TrueVal.getNode();
+    auto *FalseOp = FalseVal.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+    if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+      return {};
+
+    int CCValidVal = CCValid->getZExtValue();
+    int CCMaskVal = CCMask->getZExtValue();
+    const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DAG);
+    const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DAG);
+    if (TrueSDVals.empty() || FalseSDVals.empty())
+      return {};
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    if (Op4CCNode && Op4CCNode != CCNode)
+      combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DAG);
+    Op4CCNode = Op4CCReg.getNode();
+    if (!Op4CCNode || Op4CCNode != CCNode)
+      return {};
+    SmallVector<SDValue, 4> MergedSDVals;
+    for (auto &CCVal : {0, 1, 2, 3})
+      MergedSDVals.emplace_back(((CCMaskVal & (1 << (3 - CCVal))) != 0)
+                                    ? TrueSDVals[CCVal]
+                                    : FalseSDVals[CCVal]);
+    return MergedSDVals;
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+  case ISD::SRA:
+    if (!N->hasOneUse())
+      return {};
+    [[fallthrough]];
+  case ISD::SHL:
+  case ISD::SRL:
+    SDValue Op0 = N->getOperand(0), Op1 = N->getOperand(1);
+    const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DAG);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DAG);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
+      return {};
+    SmallVector<SDValue, 4> BinaryOpSDVals;
+    for (auto CCVal : {0, 1, 2, 3})
+      BinaryOpSDVals.emplace_back(DAG.getNode(
+          Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal]));
+    return BinaryOpSDVals;
+  }
+}
+
+static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
+                          SelectionDAG &DAG) {
   // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code
   // set by the CCReg instruction using the CCValid / CCMask masks,
-  // If the CCReg instruction is itself a ICMP testing the condition
+  // If the CCReg instruction is itself a ICMP / TM  testing the condition
   // code set by some other instruction, see whether we can directly
   // use that condition code.
-
-  // Verify that we have an ICMP against some constant.
-  if (CCValid != SystemZ::CCMASK_ICMP)
-    return false;
-  auto *ICmp = CCReg.getNode();
-  if (ICmp->getOpcode() != SystemZISD::ICMP)
-    return false;
-  auto *CompareLHS = ICmp->getOperand(0).getNode();
-  auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1));
-  if (!CompareRHS)
+  auto *CCNode = CCReg.getNode();
+  if (!CCNode)
     return false;
+  const auto getAPIntSDVals = [](const SmallVector<SDValue, 4> &Vals) {
+    SmallVector<APInt, 4> APIntVals;
+    for (const auto &Val : Vals) {
+      auto *ConstValNode = dyn_cast<ConstantSDNode>(Val.getNode());
+      if (!ConstValNode)
+        return SmallVector<APInt, 4>();
+      APIntVals.emplace_back(ConstValNode->getAPIntValue());
+    }
+    return APIntVals;
+  };
 
-  // Optimize the case where CompareLHS is a SELECT_CCMASK.
-  if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) {
-    // Verify that we have an appropriate mask for a EQ or NE comparison.
-    bool Invert = false;
-    if (CCMask == SystemZ::CCMASK_CMP_NE)
-      Invert = !Invert;
-    else if (CCMask != SystemZ::CCMASK_CMP_EQ)
+  if (CCNode->getOpcode() == SystemZISD::TM) {
+    if (CCValid != SystemZ::CCMASK_TM)
       return false;
-
-    // Verify that the ICMP compares against one of select values.
-    auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0));
-    if (!TrueVal)
+    auto emulateTMCCMask = [](const APInt &Op0Val, const APInt &Op1Val) {
+      auto Result = Op0Val & Op1Val;
+      bool AllOnes = Result == Op1Val;
+      bool AllZeros = Result == 0;
+      int MSBPos = Op1Val.countl_zero();
+      bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0;
+      return AllOnes ? 3 : AllZeros ? 0 : IsLeftMostBitSet ? 2 : 1;
----------------
uweigand wrote:

This gets the case wrong where the mask (Op1) is zero.  In this case the 
instruction will set CC0, but your logic would return CC3.  To fix this, you 
can simply test for `AllZeros` *before* the test for `AllOnes`.

https://github.com/llvm/llvm-project/pull/125970
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to