================
@@ -8701,95 +8734,341 @@ SDValue SystemZTargetLowering::combineSETCC(
   return SDValue();
 }
 
-static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+  auto *N = Val.getNode();
+  if (!N)
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  switch (N->getOpcode()) {
+  default:
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  case SystemZISD::IPM:
+    if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
+        N->getOperand(0).getOpcode() == SystemZ::CLST ||
+        N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
+      return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
+    return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:
+    return findCCUse(N->getOperand(0));
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    if (!CCValid || !Op4CCNode)
+      return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+    int CCValidVal = CCValid->getZExtValue();
+    if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
+        Op4CCNode->getOpcode() == SystemZISD::TM) {
+      auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
+      if (OpCC != SDValue())
+        return std::make_pair(OpCC, OpCCValid);
+    }
+    auto [OpCC, OpCCValid] = findCCUse(Op4CCReg);
+    return OpCC != SDValue() ? std::make_pair(OpCC, OpCCValid)
+                             : std::make_pair(Op4CCReg, CCValidVal);
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+    auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+    if (Op0CC != SDValue())
+      return std::make_pair(Op0CC, Op0CCValid);
+    return findCCUse(N->getOperand(1));
+  }
+}
+
+SmallVector<SDValue, 4>
+SystemZTargetLowering::simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
+                                             DAGCombinerInfo &DCI) const {
+  const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0,
+                                         SDValue &Op1, unsigned &Opcode) {
+    auto *N = Op.getNode();
+    if (!N)
+      return false;
+    Opcode = N->getOpcode();
+    if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR &&
+        Opcode != ISD::XOR)
+      return false;
+    Op0 = N->getOperand(0);
+    Op1 = N->getOperand(1);
+    return true;
+  };
+  if (isa<ConstantSDNode>(Val)) {
+    return {Val, Val, Val, Val};
+  }
+  auto *N = Val.getNode(), *CCNode = CC.getNode();
+  if (!N || !CCNode)
+    return {};
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  if (N->getOpcode() == SystemZISD::IPM) {
+    SmallVector<SDValue, 4> ShiftedCCVals;
+    for (auto CC : {0, 1, 2, 3}) {
+      SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32);
+      ShiftedCCVals.emplace_back(
+          DAG.getNode(ISD::SHL, DL, MVT::i32, CCVal,
+                      DAG.getConstant(SystemZ::IPM_CC, DL, MVT::i32)));
+    }
+    return ShiftedCCVals;
+  }
+  if (N->getOpcode() == ISD::SRL) {
+    SDValue Op0 = N->getOperand(0);
+    auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    if (!SRLCount)
+      return {};
+    auto SRLCountVal = SRLCount->getZExtValue();
+    const auto &&SDVals = simplifyAssumingCCVal(Op0, CC, DCI);
+    if (SDVals.empty())
+      return SDVals;
+    SmallVector<SDValue, 4> ShiftedVals;
+    for (const auto &SDVal : SDVals)
+      ShiftedVals.emplace_back(
+          DAG.getNode(ISD::SRL, DL, MVT::i32, SDVal,
+                      DAG.getConstant(SRLCountVal, DL, MVT::i32)));
+    return ShiftedVals;
+  }
+  if (N->getOpcode() == ISD::SRA) {
+    // Keep SRA and SHL opcode together and check for shift amount the same as
+    // in original code.
+    auto *SRACount = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    if (!SRACount || SRACount->getZExtValue() != 30)
+      return {};
+    auto *SHL = N->getOperand(0).getNode();
+    if (SHL->getOpcode() != ISD::SHL)
+      return {};
+    auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1));
+    if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC)
+      return {};
+    // Avoid introducing CC spills (because SRA would clobber CC).
+    if (!N->hasOneUse())
+      return {};
+    SDValue IPM = SHL->getOperand(0);
+    const auto &&SDVals = simplifyAssumingCCVal(IPM, CC, DCI);
+    if (SDVals.empty())
+      return SDVals;
+    auto SRAShift = SRACount->getZExtValue();
+    auto SHLShift = SHLCount->getZExtValue();
+    SmallVector<SDValue, 4> ShiftedVals;
+    for (const auto &SDVal : SDVals) {
+      SDValue SRAVal = DAG.getNode(ISD::SHL, DL, MVT::i32, SDVal,
+                                   DAG.getConstant(SHLShift, DL, MVT::i32));
+      ShiftedVals.emplace_back(
+          DAG.getNode(ISD::SRA, DL, MVT::i32, SRAVal,
+                      DAG.getConstant(SRAShift, DL, MVT::i32)));
+    }
+    return ShiftedVals;
+  }
+  if (N->getOpcode() == SystemZISD::SELECT_CCMASK) {
+    SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+    auto *TrueOp = TrueVal.getNode();
+    auto *FalseOp = FalseVal.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+    if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+      return {};
+
+    int CCValidVal = CCValid->getZExtValue();
+    int CCMaskVal = CCMask->getZExtValue();
+    const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DCI);
+    const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DCI);
+    if (TrueSDVals.empty() || FalseSDVals.empty())
+      return {};
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    if (Op4CCNode && Op4CCNode != CCNode)
+      combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DCI);
+    Op4CCNode = Op4CCReg.getNode();
+    if (!Op4CCNode || Op4CCNode != CCNode)
+      return {};
+    SmallVector<SDValue, 4> MergedSDVals;
+    for (auto &CCVal : {0, 1, 2, 3})
+      MergedSDVals.emplace_back((((CCMaskVal & (1 << (3 - CCVal))) != 0) &&
+                                 ((CCValidVal & (1 << (3 - CCVal))) != 0))
+                                    ? TrueSDVals[CCVal]
+                                    : FalseSDVals[CCVal]);
+    return MergedSDVals;
+  }
+  SDValue Op0, Op1;
+  unsigned Opcode;
+  if (isValidBinaryOperation(Val, Op0, Op1, Opcode)) {
+    const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DCI);
+    const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DCI);
+    if (Op0SDVals.empty() || Op1SDVals.empty())
+      return {};
+    SmallVector<SDValue, 4> BinaryOpSDVals;
+    for (auto CCVal : {0, 1, 2, 3})
+      BinaryOpSDVals.emplace_back(DAG.getNode(
+          Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal]));
+    return BinaryOpSDVals;
+  }
+  return {};
+}
+
+bool SystemZTargetLowering::combineCCMask(SDValue &CCReg, int &CCValid,
+                                          int &CCMask,
+                                          DAGCombinerInfo &DCI) const {
   // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code
   // set by the CCReg instruction using the CCValid / CCMask masks,
-  // If the CCReg instruction is itself a ICMP testing the condition
+  // If the CCReg instruction is itself a ICMP / TM  testing the condition
   // code set by some other instruction, see whether we can directly
   // use that condition code.
-
-  // Verify that we have an ICMP against some constant.
-  if (CCValid != SystemZ::CCMASK_ICMP)
-    return false;
-  auto *ICmp = CCReg.getNode();
-  if (ICmp->getOpcode() != SystemZISD::ICMP)
-    return false;
-  auto *CompareLHS = ICmp->getOperand(0).getNode();
-  auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1));
-  if (!CompareRHS)
+  auto *CCNode = CCReg.getNode();
+  if (!CCNode)
     return false;
+  const auto getConstFromConstSDVals = [](const SmallVector<SDValue, 4> &Vals) 
{
+    SmallVector<int, 4> CCVals;
+    for (const auto &Val : Vals)
+      if (auto *ConstNode = dyn_cast<ConstantSDNode>(Val.getNode()))
+        CCVals.emplace_back(ConstNode->getZExtValue());
+      else
+        return SmallVector<int, 4>();
+    return CCVals;
+  };
+  const auto getMSBPosSet = [](unsigned int Mask) {
+    int NumBits = std::numeric_limits<unsigned int>::digits;
+    int count = 0;
+    // Keep target search space to the left.
+    while (NumBits > 0) {
+      NumBits /= 2;
+      // Upper half zeros.
+      if (!(Mask >> NumBits)) {
+        count += NumBits;
+        // Search lower half.
+        Mask <<= NumBits;
+      }
+    }
+    return count;
+  };
 
-  // Optimize the case where CompareLHS is a SELECT_CCMASK.
-  if (CompareLHS->getOpcode() == SystemZISD::SELECT_CCMASK) {
-    // Verify that we have an appropriate mask for a EQ or NE comparison.
-    bool Invert = false;
-    if (CCMask == SystemZ::CCMASK_CMP_NE)
-      Invert = !Invert;
-    else if (CCMask != SystemZ::CCMASK_CMP_EQ)
+  if (CCNode->getOpcode() == SystemZISD::TM) {
+    if (CCValid != SystemZ::CCMASK_TM)
       return false;
-
-    // Verify that the ICMP compares against one of select values.
-    auto *TrueVal = dyn_cast<ConstantSDNode>(CompareLHS->getOperand(0));
-    if (!TrueVal)
+    const auto emulateTMCCMask = [&](int CCVal, int Mask) {
+      if (!Mask)
+        return std::numeric_limits<unsigned int>::digits;
+      int Result = CCVal & Mask;
+      bool AllOnes = Result == Mask;
+      bool AllZeros = Result == 0;
+      bool MixedZerosOnes = (!AllOnes && !AllZeros);
+      int MSBPos = getMSBPosSet(static_cast<unsigned int>(Mask));
+      bool IsLeftMostBitSet = (Result & (1 << MSBPos)) != 0;
+      return AllOnes                                ? 3
+             : AllZeros                             ? 0
+             : (MixedZerosOnes && IsLeftMostBitSet) ? 2
----------------
uweigand wrote:

According to the sequence of comparisons, `MixedZerosOnes` must be true if we 
get here, so it seem redundant to test.

https://github.com/llvm/llvm-project/pull/125970
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to