================
@@ -8701,95 +8734,266 @@ SDValue SystemZTargetLowering::combineSETCC(
   return SDValue();
 }
 
-static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) {
+static std::pair<SDValue, int> findCCUse(const SDValue &Val) {
+  auto *N = Val.getNode();
+  if (!N)
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  switch (N->getOpcode()) {
+  default:
+    return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+  case SystemZISD::IPM:
+    if (N->getOperand(0).getOpcode() == SystemZISD::CLC ||
+        N->getOperand(0).getOpcode() == SystemZISD::STRCMP)
+      return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP);
+    return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY);
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    if (!CCValid || !Op4CCNode)
+      return std::make_pair(SDValue(), SystemZ::CCMASK_NONE);
+    int CCValidVal = CCValid->getZExtValue();
+    if (Op4CCNode->getOpcode() == SystemZISD::ICMP ||
+        Op4CCNode->getOpcode() == SystemZISD::TM) {
+      auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0));
+      if (OpCC != SDValue())
+        return std::make_pair(OpCC, OpCCValid);
+    }
+    return std::make_pair(Op4CCReg, CCValidVal);
+  }
+  case ISD::ADD:
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR:
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:
+    auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0));
+    if (Op0CC != SDValue())
+      return std::make_pair(Op0CC, Op0CCValid);
+    return findCCUse(N->getOperand(1));
+  }
+}
+
+static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask,
+                          SelectionDAG &DAG);
+
+SmallVector<SDValue, 4> static simplifyAssumingCCVal(SDValue &Val, SDValue &CC,
+                                                     SelectionDAG &DAG) {
+  auto *N = Val.getNode(), *CCNode = CC.getNode();
+  if (!N || !CCNode)
+    return {};
+  SDLoc DL(N);
+  auto Opcode = N->getOpcode();
+  switch (Opcode) {
+  default:
+    return {};
+  case ISD::Constant:
+    return {Val, Val, Val, Val};
+  case SystemZISD::IPM: {
+    auto *IPMOp0Node = N->getOperand(0).getNode();
+    if (!IPMOp0Node || IPMOp0Node != CCNode)
+      return {};
+    SmallVector<SDValue, 4> ShiftedCCVals;
+    for (auto CC : {0, 1, 2, 3})
+      ShiftedCCVals.emplace_back(
+          DAG.getConstant((CC << SystemZ::IPM_CC), DL, MVT::i32));
+    return ShiftedCCVals;
+  }
+  case SystemZISD::SELECT_CCMASK: {
+    SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1);
+    auto *TrueOp = TrueVal.getNode();
+    auto *FalseOp = FalseVal.getNode();
+    auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2));
+    auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3));
+    if (!TrueOp || !FalseOp || !CCValid || !CCMask)
+      return {};
+
+    int CCValidVal = CCValid->getZExtValue();
+    int CCMaskVal = CCMask->getZExtValue();
+    const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DAG);
+    const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DAG);
+    if (TrueSDVals.empty() || FalseSDVals.empty())
+      return {};
+    SDValue Op4CCReg = N->getOperand(4);
+    auto *Op4CCNode = Op4CCReg.getNode();
+    if (Op4CCNode && Op4CCNode != CCNode)
----------------
uweigand wrote:

As above, compare the SDValues.

https://github.com/llvm/llvm-project/pull/125970
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to