================
@@ -8701,95 +8734,341 @@ SDValue SystemZTargetLowering::combineSETCC(
   return SDValue();
 }
 
-static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+  auto *N = Val.getNode();
+  if (!N)
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  switch (N->getOpcode()) {
+  default:
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  case SystemZISD::IPM:
+    if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
+        N->getOperand(0).getOpcode() == SystemZ::CLST ||
+        N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
+      return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
+    return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:
+    return findCCUse(N->getOperand(0));
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    if (!CCValid || !Op4CCNode)
+      return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+    int CCValidVal = CCValid->getZExtValue();
+    if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
+        Op4CCNode->getOpcode() == SystemZISD::TM) {
+      auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
+      if (OpCC != SDValue())
+        return std::make_pair(OpCC, OpCCValid);
+    }
+    auto [OpCC, OpCCValid] = findCCUse(Op4CCReg);
+    return OpCC != SDValue() ? std::make_pair(OpCC, OpCCValid)
+                             : std::make_pair(Op4CCReg, CCValidVal);
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+    auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+    if (Op0CC != SDValue())
+      return std::make_pair(Op0CC, Op0CCValid);
+    return findCCUse(N->getOperand(1));
+  }
+}
+
+SmallVector<SDValue, 4>
+SystemZTargetLowering::simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
+                                             DAGCombinerInfo &DCI) const {
+  const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0,
+                                         SDValue &Op1, unsigned &Opcode) {
+    auto *N = Op.getNode();
+    if (!N)
+      return false;
+    Opcode = N->getOpcode();
+    if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR &&
+        Opcode != ISD::XOR)
+      return false;
+    Op0 = N->getOperand(0);
+    Op1 = N->getOperand(1);
+    return true;
+  };
+  if (isa<ConstantSDNode>(Val)) {
+    return {Val, Val, Val, Val};
+  }
+  auto *N = Val.getNode(), *CCNode = CC.getNode();
+  if (!N || !CCNode)
+    return {};
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  if (N->getOpcode() == SystemZISD::IPM) {
+    SmallVector<SDValue, 4> ShiftedCCVals;
+    for (auto CC : {0, 1, 2, 3}) {
+      SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32);
+      ShiftedCCVals.emplace_back(
+          DAG.getNode(ISD::SHL, DL, MVT::i32, CCVal,
+                      DAG.getConstant(SystemZ::IPM_CC, DL, MVT::i32)));
+    }
+    return ShiftedCCVals;
+  }
+  if (N->getOpcode() == ISD::SRL) {
+    SDValue Op0 = N->getOperand(0);
+    auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    if (!SRLCount)
+      return {};
+    auto SRLCountVal = SRLCount->getZExtValue();
+    const auto &&SDVals = simplifyAssumingCCVal(Op0, CC, DCI);
+    if (SDVals.empty())
+      return SDVals;
+    SmallVector<SDValue, 4> ShiftedVals;
+    for (const auto &SDVal : SDVals)
+      ShiftedVals.emplace_back(
+          DAG.getNode(ISD::SRL, DL, MVT::i32, SDVal,
+                      DAG.getConstant(SRLCountVal, DL, MVT::i32)));
+    return ShiftedVals;
+  }
+  if (N->getOpcode() == ISD::SRA) {
+    // Keep SRA and SHL opcode together and check for shift amount the same as
+    // in original code.
+    auto *SRACount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    if (!SRACount || SRACount->getZExtValue() != 30)
+      return {};
+    auto *SHL = N->getOperand(0).getNode();
+    if (SHL->getOpcode() != ISD::SHL)
+      return {};
+    auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1));
+    if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC)
+      return {};
+    // Avoid introducing CC spills (because SRA would clobber CC).
+    if (!N->hasOneUse())
+      return {};
+    SDValue IPM = SHL->getOperand(0);
+    const auto &&SDVals = simplifyAssumingCCVal(IPM, CC, DCI);
+    if (SDVals.empty())
+      return SDVals;
+    auto SRAShift = SRACount->getZExtValue();
+    auto SHLShift = SHLCount->getZExtValue();
+    SmallVector<SDValue, 4> ShiftedVals;
+    for (const auto &SDVal : SDVals) {
+      SDValue SRAVal = DAG.getNode(ISD::SHL, DL, MVT::i32, SDVal,
+                                   DAG.getConstant(SHLShift, DL, MVT::i32));
+      ShiftedVals.emplace_back(
+          DAG.getNode(ISD::SRA, DL, MVT::i32, SRAVal,
+                      DAG.getConstant(SRAShift, DL, MVT::i32)));
+    }
+    return ShiftedVals;
+  }
+  if (N->getOpcode() == SystemZISD::SELECT_CCMASK) {
+    SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+    auto *TrueOp = TrueVal.getNode();
+    auto *FalseOp = FalseVal.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+    if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+      return {};
+
+    int CCValidVal = CCValid->getZExtValue();
+    int CCMaskVal = CCMask->getZExtValue();
+    const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DCI);
+    const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DCI);
+    if (TrueSDVals.empty() || FalseSDVals.empty())
+      return {};
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    if (Op4CCNode && Op4CCNode != CCNode)
+      combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DCI);
+    Op4CCNode = Op4CCReg.getNode();
+    if (!Op4CCNode || Op4CCNode != CCNode)
+      return {};
+    SmallVector<SDValue, 4> MergedSDVals;
+    for (auto &CCVal : {0, 1, 2, 3})
+      MergedSDVals.emplace_back((((CCMaskVal & (1 << (3 - CCVal))) != 0) &&
+                                 ((CCValidVal & (1 << (3 - CCVal))) != 0))
+                                    ? TrueSDVals[CCVal]
+                                    : FalseSDVals[CCVal]);
+    return MergedSDVals;
+  }
+  SDValue Op0, Op1;
+  unsigned Opcode;
+  if (isValidBinaryOperation(Val, Op0, Op1, Opcode)) {
+    const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DCI);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DCI);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
+      return {};
+    SmallVector<SDValue, 4> BinaryOpSDVals;
+    for (auto CCVal : {0, 1, 2, 3})
+      BinaryOpSDVals.emplace_back(DAG.getNode(
+          Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal]));
+    return BinaryOpSDVals;
+  }
+  return {};
+}
+
+bool SystemZTargetLowering::combineCCMask(SDValue &CCReg, int &CCValid,
+                                          int &CCMask,
+                                          DAGCombinerInfo &DCI) const {
   // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code
   // set by the CCReg instruction using the CCValid / CCMask masks,
-  // If the CCReg instruction is itself a ICMP testing the condition
+  // If the CCReg instruction is itself a ICMP / TM  testing the condition
   // code set by some other instruction, see whether we can directly
   // use that condition code.
-
-  // Verify that we have an ICMP against some constant.
-  if (CCValid != SystemZ::CCMASK_ICMP)
-    return false;
-  auto *ICmp = CCReg.getNode();
-  if (ICmp->getOpcode() != SystemZISD::ICMP)
-    return false;
-  auto *CompareLHS = ICmp->getOperand(0).getNode();
-  auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1));
-  if (!CompareRHS)
+  auto *CCNode = CCReg.getNode();
+  if (!CCNode)
     return false;
+  const auto getConstFromConstSDVals = [](const SmallVector<SDValue, 4> &Vals) 
{
+    SmallVector<int, 4> CCVals;
+    for (const auto &Val : Vals)
+      if (auto *ConstNode = dyn_cast<ConstantSDNode>(Val.getNode()))
+        CCVals.emplace_back(ConstNode->getZExtValue());
+      else
+        return SmallVector<int, 4>();
+    return CCVals;
+  };
+  const auto getMSBPosSet = [](unsigned int Mask) {
+    int NumBits = std::numeric_limits<unsigned int>::digits;
+    int count = 0;
+    // Keep target search space to the left.
+    while (NumBits > 0) {
+      NumBits /= 2;
+      // Upper half zeros.
+      if (!(Mask >> NumBits)) {
+        count += NumBits;
+        // Search lower half.
+        Mask <<= NumBits;
+      }
+    }
+    return count;
+  };
 
-  // Optimize the case where CompareLHS is a SELECT_CCMASK.
-  if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) {
-    // Verify that we have an appropriate mask for a EQ or NE comparison.
-    bool Invert = false;
-    if (CCMask == SystemZ::CCMASK_CMP_NE)
-      Invert = !Invert;
-    else if (CCMask != SystemZ::CCMASK_CMP_EQ)
+  if (CCNode->getOpcode() == SystemZISD::TM) {
+    if (CCValid != SystemZ::CCMASK_TM)
       return false;
-
-    // Verify that the ICMP compares against one of select values.
-    auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0));
-    if (!TrueVal)
+    const auto emulateTMCCMask = [&](int CCVal, int Mask) {
+      if (!Mask)
+        return std::numeric_limits<unsigned int>::digits;
+      int Result = CCVal & Mask;
+      bool AllOnes = Result == Mask;
+      bool AllZeros = Result == 0;
+      bool MixedZerosOnes = (!AllOnes && !AllZeros);
+      int MSBPos = getMSBPosSet(static_cast<unsigned int>(Mask));
+      bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0;
+      return AllOnes                                ? 3
+             : AllZeros                             ? 0
+             : (MixedZerosOnes && IsLeftMostBitSet) ? 2
+                                                    : 1;
+    };
+    SDValue Op0 = CCNode->getOperand(0);
+    SDValue Op1 = CCNode->getOperand(1);
+    auto [Op0CC, Op0CCValid] = findCCUse(Op0);
+    if (Op0CC == SDValue())
       return false;
-    auto *FalseVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
-    if (!FalseVal)
+    const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, Op0CC, DCI);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, Op0CC, DCI);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
       return false;
-    if (CompareRHS->getAPIntValue() == FalseVal->getAPIntValue())
-      Invert = !Invert;
-    else if (CompareRHS->getAPIntValue() != TrueVal->getAPIntValue())
+    auto &&Op0CCVals = getConstFromConstSDVals(Op0SDVals);
+    const auto &&Op1CCVals = getConstFromConstSDVals(Op1SDVals);
+    if (Op0CCVals.empty() || Op1CCVals.empty())
       return false;
-
-    // Compute the effective CC mask for the new branch or select.
-    auto *NewCCValid = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(2));
-    auto *NewCCMask = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(3));
-    if (!NewCCValid || !NewCCMask)
+    std::transform(Op0CCVals.begin(), Op0CCVals.end(), Op1CCVals.begin(),
+                   Op0CCVals.begin(), emulateTMCCMask);
+    if (std::any_of(Op0CCVals.begin(), Op0CCVals.end(),
+                    [](const auto &CC) { return CC < 0 || CC > 3; }))
       return false;
-    CCValid = NewCCValid->getZExtValue();
-    CCMask = NewCCMask->getZExtValue();
-    if (Invert)
-      CCMask ^= CCValid;
-
-    // Return the updated CCReg link.
-    CCReg = CompareLHS->getOperand(4);
+    int NewCCMask = 0;
+    for (auto CC : Op0CCVals) {
+      NewCCMask <<= 1;
+      NewCCMask |= (CCMask & (1 << (3 - CC))) != 0;
+    }
+    CCReg = Op0CC;
+    CCMask = NewCCMask;
     return true;
   }
+  if (CCNode->getOpcode() != SystemZISD::ICMP ||
+      CCValid != SystemZ::CCMASK_ICMP)
+    return false;
 
-  // Optimize the case where CompareRHS is (SRA (SHL (IPM))).
-  if (CompareLHS->getOpcode() == ISD::SRA) {
-    auto *SRACount = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(1));
-    if (!SRACount || SRACount->getZExtValue() != 30)
-      return false;
-    auto *SHL = CompareLHS->getOperand(0).getNode();
-    if (SHL->getOpcode() != ISD::SHL)
-      return false;
-    auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1));
-    if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC)
-      return false;
-    auto *IPM = SHL->getOperand(0).getNode();
-    if (IPM->getOpcode() != SystemZISD::IPM)
-      return false;
-
-    // Avoid introducing CC spills (because SRA would clobber CC).
-    if (!CompareLHS->hasOneUse())
+  SDValue CmpOp0 = CCNode->getOperand(0);
+  SDValue CmpOp1 = CCNode->getOperand(1);
+  SDValue CmpOp2 = CCNode->getOperand(2);
+  auto [Op0CC, Op0CCValid] = findCCUse(CmpOp0);
+  if (Op0CC != SDValue()) {
+    const auto &&Op0SDVals = simplifyAssumingCCVal(CmpOp0, Op0CC, DCI);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(CmpOp1, Op0CC, DCI);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
       return false;
-    // Verify that the ICMP compares against zero.
-    if (CompareRHS->getZExtValue() != 0)
+    auto &&Op0CCVals = getConstFromConstSDVals(Op0SDVals);
+    const auto &&Op1CCVals = getConstFromConstSDVals(Op1SDVals);
----------------
uweigand wrote:

Again with the `int` truncation ... this also should operate on the full 
`APInt`s.

https://github.com/llvm/llvm-project/pull/125970
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to