================
@@ -8701,95 +8734,266 @@ SDValue SystemZTargetLowering::combineSETCC(
   return SDValue();
 }
 
-static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+  auto *N = Val.getNode();
+  if (!N)
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  switch (N->getOpcode()) {
+  default:
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  case SystemZISD::IPM:
+    if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
+        N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
+      return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
+    return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    if (!CCValid || !Op4CCNode)
+      return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+    int CCValidVal = CCValid->getZExtValue();
+    if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
+        Op4CCNode->getOpcode() == SystemZISD::TM) {
+      auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
+      if (OpCC != SDValue())
+        return std::make_pair(OpCC, OpCCValid);
+    }
+    return std::make_pair(Op4CCReg, CCValidVal);
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:
+    auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+    if (Op0CC != SDValue())
+      return std::make_pair(Op0CC, Op0CCValid);
+    return findCCUse(N->getOperand(1));
+  }
+}
+
+static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
+                          SelectionDAG &DAG);
+
+SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
+                                                     SelectionDAG &DAG) {
+  auto *N = Val.getNode(), *CCNode = CC.getNode();
+  if (!N || !CCNode)
+    return {};
+  SDLoc DL(N);
+  auto Opcode = N->getOpcode();
+  switch (Opcode) {
+  default:
+    return {};
+  case ISD::Constant:
+    return {Val, Val, Val, Val};
+  case SystemZISD::IPM: {
+    auto *IPMOp0Node = N->getOperand(0).getNode();
+    if (!IPMOp0Node || IPMOp0Node != CCNode)
+      return {};
+    SmallVector<SDValue, 4> ShiftedCCVals;
+    for (auto CC : {0, 1, 2, 3})
+      ShiftedCCVals.emplace_back(
+          DAG.getConstant((CC << SystemZ::IPM_CC), DL, MVT::i32));
+    return ShiftedCCVals;
+  }
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+    auto *TrueOp = TrueVal.getNode();
+    auto *FalseOp = FalseVal.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+    if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+      return {};
+
+    int CCValidVal = CCValid->getZExtValue();
+    int CCMaskVal = CCMask->getZExtValue();
+    const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DAG);
+    const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DAG);
+    if (TrueSDVals.empty() || FalseSDVals.empty())
+      return {};
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    if (Op4CCNode && Op4CCNode != CCNode)
+      combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DAG);
+    Op4CCNode = Op4CCReg.getNode();
+    if (!Op4CCNode || Op4CCNode != CCNode)
+      return {};
+    SmallVector<SDValue, 4> MergedSDVals;
+    for (auto &CCVal : {0, 1, 2, 3})
+      MergedSDVals.emplace_back(((CCMaskVal & (1 << (3 - CCVal))) != 0)
+                                    ? TrueSDVals[CCVal]
+                                    : FalseSDVals[CCVal]);
+    return MergedSDVals;
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+  case ISD::SRA:
+    if (!N->hasOneUse())
+      return {};
+    [[fallthrough]];
+  case ISD::SHL:
+  case ISD::SRL:
+    SDValue Op0 = N->getOperand(0), Op1 = N->getOperand(1);
+    const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DAG);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DAG);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
+      return {};
+    SmallVector<SDValue, 4> BinaryOpSDVals;
+    for (auto CCVal : {0, 1, 2, 3})
+      BinaryOpSDVals.emplace_back(DAG.getNode(
+          Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal]));
+    return BinaryOpSDVals;
+  }
+}
+
+static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
+                          SelectionDAG &DAG) {
   // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code
   // set by the CCReg instruction using the CCValid / CCMask masks,
-  // If the CCReg instruction is itself a ICMP testing the condition
+  // If the CCReg instruction is itself a ICMP / TM  testing the condition
   // code set by some other instruction, see whether we can directly
   // use that condition code.
-
-  // Verify that we have an ICMP against some constant.
-  if (CCValid != SystemZ::CCMASK_ICMP)
-    return false;
-  auto *ICmp = CCReg.getNode();
-  if (ICmp->getOpcode() != SystemZISD::ICMP)
-    return false;
-  auto *CompareLHS = ICmp->getOperand(0).getNode();
-  auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1));
-  if (!CompareRHS)
+  auto *CCNode = CCReg.getNode();
+  if (!CCNode)
     return false;
+  const auto getAPIntSDVals = [](const SmallVector<SDValue, 4> &Vals) {
+    SmallVector<APInt, 4> APIntVals;
+    for (const auto &Val : Vals) {
+      auto *ConstValNode = dyn_cast<ConstantSDNode>(Val.getNode());
+      if (!ConstValNode)
+        return SmallVector<APInt, 4>();
+      APIntVals.emplace_back(ConstValNode->getAPIntValue());
+    }
+    return APIntVals;
+  };
 
-  // Optimize the case where CompareLHS is a SELECT_CCMASK.
-  if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) {
-    // Verify that we have an appropriate mask for a EQ or NE comparison.
-    bool Invert = false;
-    if (CCMask == SystemZ::CCMASK_CMP_NE)
-      Invert = !Invert;
-    else if (CCMask != SystemZ::CCMASK_CMP_EQ)
+  if (CCNode->getOpcode() == SystemZISD::TM) {
+    if (CCValid != SystemZ::CCMASK_TM)
       return false;
-
-    // Verify that the ICMP compares against one of select values.
-    auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0));
-    if (!TrueVal)
+    auto emulateTMCCMask = [](const APInt &Op0Val, const APInt &Op1Val) {
+      auto Result = Op0Val & Op1Val;
+      bool AllOnes = Result == Op1Val;
+      bool AllZeros = Result == 0;
+      int MSBPos = Op1Val.countl_zero();
+      bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0;
+      return AllOnes ? 3 : AllZeros ? 0 : IsLeftMostBitSet ? 2 : 1;
+    };
+    SDValue Op0 = CCNode->getOperand(0);
+    SDValue Op1 = CCNode->getOperand(1);
+    auto [Op0CC, Op0CCValid] = findCCUse(Op0);
+    if (Op0CC == SDValue())
       return false;
-    auto *FalseVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
-    if (!FalseVal)
+    const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, Op0CC, DAG);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, Op0CC, DAG);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
       return false;
-    if (CompareRHS->getAPIntValue() == FalseVal->getAPIntValue())
-      Invert = !Invert;
-    else if (CompareRHS->getAPIntValue() != TrueVal->getAPIntValue())
+    auto &&Op0APInts = getAPIntSDVals(Op0SDVals);
+    const auto &&Op1APInts = getAPIntSDVals(Op1SDVals);
+    if (Op0APInts.empty() || Op1APInts.empty())
       return false;
-
-    // Compute the effective CC mask for the new branch or select.
-    auto *NewCCValid = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(2));
-    auto *NewCCMask = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(3));
-    if (!NewCCValid || !NewCCMask)
+    SmallVector<int, 4> CCVals;
+    std::transform(Op0APInts.begin(), Op0APInts.end(), Op1APInts.begin(),
+                   std::back_inserter(CCVals), emulateTMCCMask);
+    if (CCVals.empty())
       return false;
-    CCValid = NewCCValid->getZExtValue();
-    CCMask = NewCCMask->getZExtValue();
-    if (Invert)
-      CCMask ^= CCValid;
-
-    // Return the updated CCReg link.
-    CCReg = CompareLHS->getOperand(4);
+    int NewCCMask = 0;
+    for (auto CC : CCVals) {
+      NewCCMask <<= 1;
+      NewCCMask |= (CCMask & (1 << (3 - CC))) != 0;
+    }
+    CCReg = Op0CC;
+    CCMask = NewCCMask;
+    CCValid = Op0CCValid;
     return true;
   }
+  if (CCNode->getOpcode() != SystemZISD::ICMP ||
+      CCValid != SystemZ::CCMASK_ICMP)
+    return false;
 
-  // Optimize the case where CompareRHS is (SRA (SHL (IPM))).
-  if (CompareLHS->getOpcode() == ISD::SRA) {
-    auto *SRACount = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
-    if (!SRACount || SRACount->getZExtValue() != 30)
+  SDValue CmpOp0 = CCNode->getOperand(0);
+  SDValue CmpOp1 = CCNode->getOperand(1);
+  SDValue CmpOp2 = CCNode->getOperand(2);
+  auto [Op0CC, Op0CCValid] = findCCUse(CmpOp0);
+  if (Op0CC != SDValue()) {
+    const auto &&Op0SDVals = simplifyAssumingCCVal(CmpOp0, Op0CC, DAG);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(CmpOp1, Op0CC, DAG);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
       return false;
-    auto *SHL = CompareLHS->getOperand(0).getNode();
-    if (SHL->getOpcode() != ISD::SHL)
-      return false;
-    auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1));
-    if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC)
-      return false;
-    auto *IPM = SHL->getOperand(0).getNode();
-    if (IPM->getOpcode() != SystemZISD::IPM)
+    auto &&Op0APInts = getAPIntSDVals(Op0SDVals);
+    const auto &&Op1APInts = getAPIntSDVals(Op1SDVals);
+    if (Op0APInts.empty() || Op1APInts.empty())
       return false;
 
-    // Avoid introducing CC spills (because SRA would clobber CC).
-    if (!CompareLHS->hasOneUse())
-      return false;
-    // Verify that the ICMP compares against zero.
-    if (CompareRHS->getZExtValue() != 0)
+    auto *CmpType = dyn_cast<ConstantSDNode>(CmpOp2);
+    auto CmpTypeVal = CmpType->getZExtValue();
+    const auto compareCCSigned = [&CmpTypeVal](const APInt &Op0Val,
+                                               const APInt &Op1Val) {
+      if (CmpTypeVal == SystemZICMP::SignedOnly)
+        return Op0Val == Op1Val ? 0 : Op0Val.slt(Op1Val) ? 1 : 2;
+      return Op0Val == Op1Val ? 0 : Op0Val.ult(Op1Val) ? 1 : 2;
+    };
+    SmallVector<int, 4> CCVals;
+    std::transform(Op0APInts.begin(), Op0APInts.end(), Op1APInts.begin(),
+                   std::back_inserter(CCVals), compareCCSigned);
+    if (CCVals.empty())
       return false;
----------------
uweigand wrote:

As above, I don't think this can ever be empty.

https://github.com/llvm/llvm-project/pull/125970
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to