================ @@ -8701,95 +8734,341 @@ SDValue SystemZTargetLowering::combineSETCC( return SDValue(); } -static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) { +static std::pair<SDValue, int> findCCUse(const SDValue &Val) { + auto *N = Val.getNode(); + if (!N) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + switch (N->getOpcode()) { + default: + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + case SystemZISD::IPM: + if (N->getOperand(0).getOpcode() == SystemZISD::CLC || + N->getOperand(0).getOpcode() == SystemZ::CLST || + N->getOperand(0).getOpcode() == SystemZISD::STRCMP) + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP); + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY); + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: + return findCCUse(N->getOperand(0)); + case SystemZISD::SELECT_CCMASK: { + SDValue Op4CCReg = N->getOperand(4); + auto *Op4CCNode = Op4CCReg.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + if (!CCValid || !Op4CCNode) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + int CCValidVal = CCValid->getZExtValue(); + if (Op4CCNode->getOpcode() == SystemZISD::ICMP || + Op4CCNode->getOpcode() == SystemZISD::TM) { + auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0)); + if (OpCC != SDValue()) + return std::make_pair(OpCC, OpCCValid); + } + auto [OpCC, OpCCValid] = findCCUse(Op4CCReg); + return OpCC != SDValue() ? std::make_pair(OpCC, OpCCValid) + : std::make_pair(Op4CCReg, CCValidVal); + } + case ISD::ADD: + case ISD::AND: + case ISD::OR: + case ISD::XOR: + auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0)); + if (Op0CC != SDValue()) + return std::make_pair(Op0CC, Op0CCValid); + return findCCUse(N->getOperand(1)); + } +} + +SmallVector<SDValue, 4> +SystemZTargetLowering::simplifyAssumingCCVal(SDValue &Val, SDValue &CC, + DAGCombinerInfo &DCI) const { + const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0, + SDValue &Op1, unsigned &Opcode) { + auto *N = Op.getNode(); + if (!N) + return false; + Opcode = N->getOpcode(); + if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR && + Opcode != ISD::XOR) + return false; + Op0 = N->getOperand(0); + Op1 = N->getOperand(1); + return true; + }; + if (isa<ConstantSDNode>(Val)) { + return {Val, Val, Val, Val}; + } + auto *N = Val.getNode(), *CCNode = CC.getNode(); + if (!N || !CCNode) + return {}; + SelectionDAG &DAG = DCI.DAG; + SDLoc DL(N); + if (N->getOpcode() == SystemZISD::IPM) { + SmallVector<SDValue, 4> ShiftedCCVals; + for (auto CC : {0, 1, 2, 3}) { + SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32); + ShiftedCCVals.emplace_back( + DAG.getNode(ISD::SHL, DL, MVT::i32, CCVal, + DAG.getConstant(SystemZ::IPM_CC, DL, MVT::i32))); + } + return ShiftedCCVals; + } + if (N->getOpcode() == ISD::SRL) { + SDValue Op0 = N->getOperand(0); + auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!SRLCount) + return {}; + auto SRLCountVal = SRLCount->getZExtValue(); + const auto &&SDVals = simplifyAssumingCCVal(Op0, CC, DCI); + if (SDVals.empty()) + return SDVals; + SmallVector<SDValue, 4> ShiftedVals; + for (const auto &SDVal : SDVals) + ShiftedVals.emplace_back( + DAG.getNode(ISD::SRL, DL, MVT::i32, SDVal, + DAG.getConstant(SRLCountVal, DL, MVT::i32))); + return ShiftedVals; + } + if (N->getOpcode() == ISD::SRA) { + // Keep SRA and SHL opcode together and check for shift amount the same as + // in original code. + auto *SRACount = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!SRACount || SRACount->getZExtValue() != 30) + return {}; + auto *SHL = N->getOperand(0).getNode(); + if (SHL->getOpcode() != ISD::SHL) + return {}; + auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1)); + if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC) + return {}; + // Avoid introducing CC spills (because SRA would clobber CC). + if (!N->hasOneUse()) + return {}; + SDValue IPM = SHL->getOperand(0); + const auto &&SDVals = simplifyAssumingCCVal(IPM, CC, DCI); + if (SDVals.empty()) + return SDVals; + auto SRAShift = SRACount->getZExtValue(); + auto SHLShift = SHLCount->getZExtValue(); + SmallVector<SDValue, 4> ShiftedVals; + for (const auto &SDVal : SDVals) { + SDValue SRAVal = DAG.getNode(ISD::SHL, DL, MVT::i32, SDVal, + DAG.getConstant(SHLShift, DL, MVT::i32)); + ShiftedVals.emplace_back( + DAG.getNode(ISD::SRA, DL, MVT::i32, SRAVal, + DAG.getConstant(SRAShift, DL, MVT::i32))); + } + return ShiftedVals; + } + if (N->getOpcode() == SystemZISD::SELECT_CCMASK) { + SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1); + auto *TrueOp = TrueVal.getNode(); + auto *FalseOp = FalseVal.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3)); + if (!TrueOp || !FalseOp || !CCValid || !CCMask) + return {}; + + int CCValidVal = CCValid->getZExtValue(); + int CCMaskVal = CCMask->getZExtValue(); + const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DCI); + const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DCI); + if (TrueSDVals.empty() || FalseSDVals.empty()) + return {}; + SDValue Op4CCReg = N->getOperand(4); + auto *Op4CCNode = Op4CCReg.getNode(); + if (Op4CCNode && Op4CCNode != CCNode) + combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DCI); + Op4CCNode = Op4CCReg.getNode(); + if (!Op4CCNode || Op4CCNode != CCNode) + return {}; + SmallVector<SDValue, 4> MergedSDVals; + for (auto &CCVal : {0, 1, 2, 3}) + MergedSDVals.emplace_back((((CCMaskVal & (1 << (3 - CCVal))) != 0) && + ((CCValidVal & (1 << (3 - CCVal))) != 0)) ---------------- uweigand wrote:
CCMask is guaranteed to be a subset of CCValid, so the second check is unnecessary. https://github.com/llvm/llvm-project/pull/125970 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits