================ @@ -8701,95 +8734,341 @@ SDValue SystemZTargetLowering::combineSETCC( return SDValue(); } -static bool combineCCMask(SDValue &CCReg, int &CCValid, int &CCMask) { +static std::pair<SDValue, int> findCCUse(const SDValue &Val) { + auto *N = Val.getNode(); + if (!N) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + switch (N->getOpcode()) { + default: + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + case SystemZISD::IPM: + if (N->getOperand(0).getOpcode() == SystemZISD::CLC || + N->getOperand(0).getOpcode() == SystemZ::CLST || + N->getOperand(0).getOpcode() == SystemZISD::STRCMP) + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ICMP); + return std::make_pair(N->getOperand(0), SystemZ::CCMASK_ANY); + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: + return findCCUse(N->getOperand(0)); + case SystemZISD::SELECT_CCMASK: { + SDValue Op4CCReg = N->getOperand(4); + auto *Op4CCNode = Op4CCReg.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + if (!CCValid || !Op4CCNode) + return std::make_pair(SDValue(), SystemZ::CCMASK_NONE); + int CCValidVal = CCValid->getZExtValue(); + if (Op4CCNode->getOpcode() == SystemZISD::ICMP || + Op4CCNode->getOpcode() == SystemZISD::TM) { + auto [OpCC, OpCCValid] = findCCUse(Op4CCNode->getOperand(0)); + if (OpCC != SDValue()) + return std::make_pair(OpCC, OpCCValid); + } + auto [OpCC, OpCCValid] = findCCUse(Op4CCReg); + return OpCC != SDValue() ? std::make_pair(OpCC, OpCCValid) + : std::make_pair(Op4CCReg, CCValidVal); + } + case ISD::ADD: + case ISD::AND: + case ISD::OR: + case ISD::XOR: + auto [Op0CC, Op0CCValid] = findCCUse(N->getOperand(0)); + if (Op0CC != SDValue()) + return std::make_pair(Op0CC, Op0CCValid); + return findCCUse(N->getOperand(1)); + } +} + +SmallVector<SDValue, 4> +SystemZTargetLowering::simplifyAssumingCCVal(SDValue &Val, SDValue &CC, + DAGCombinerInfo &DCI) const { + const auto isValidBinaryOperation = [](const SDValue &Op, SDValue &Op0, + SDValue &Op1, unsigned &Opcode) { + auto *N = Op.getNode(); + if (!N) + return false; + Opcode = N->getOpcode(); + if (Opcode != ISD::ADD && Opcode != ISD::AND && Opcode != ISD::OR && + Opcode != ISD::XOR) + return false; + Op0 = N->getOperand(0); + Op1 = N->getOperand(1); + return true; + }; + if (isa<ConstantSDNode>(Val)) { + return {Val, Val, Val, Val}; + } + auto *N = Val.getNode(), *CCNode = CC.getNode(); + if (!N || !CCNode) + return {}; + SelectionDAG &DAG = DCI.DAG; + SDLoc DL(N); + if (N->getOpcode() == SystemZISD::IPM) { + SmallVector<SDValue, 4> ShiftedCCVals; + for (auto CC : {0, 1, 2, 3}) { + SDValue CCVal = DAG.getConstant(CC, DL, MVT::i32); + ShiftedCCVals.emplace_back( + DAG.getNode(ISD::SHL, DL, MVT::i32, CCVal, + DAG.getConstant(SystemZ::IPM_CC, DL, MVT::i32))); + } + return ShiftedCCVals; + } + if (N->getOpcode() == ISD::SRL) { + SDValue Op0 = N->getOperand(0); + auto *SRLCount = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!SRLCount) + return {}; + auto SRLCountVal = SRLCount->getZExtValue(); + const auto &&SDVals = simplifyAssumingCCVal(Op0, CC, DCI); + if (SDVals.empty()) + return SDVals; + SmallVector<SDValue, 4> ShiftedVals; + for (const auto &SDVal : SDVals) + ShiftedVals.emplace_back( + DAG.getNode(ISD::SRL, DL, MVT::i32, SDVal, + DAG.getConstant(SRLCountVal, DL, MVT::i32))); + return ShiftedVals; + } + if (N->getOpcode() == ISD::SRA) { + // Keep SRA and SHL opcode together and check for shift amount the same as + // in original code. + auto *SRACount = dyn_cast<ConstantSDNode>(N->getOperand(1)); + if (!SRACount || SRACount->getZExtValue() != 30) + return {}; + auto *SHL = N->getOperand(0).getNode(); + if (SHL->getOpcode() != ISD::SHL) + return {}; + auto *SHLCount = dyn_cast<ConstantSDNode>(SHL->getOperand(1)); + if (!SHLCount || SHLCount->getZExtValue() != 30 - SystemZ::IPM_CC) + return {}; + // Avoid introducing CC spills (because SRA would clobber CC). + if (!N->hasOneUse()) + return {}; + SDValue IPM = SHL->getOperand(0); + const auto &&SDVals = simplifyAssumingCCVal(IPM, CC, DCI); + if (SDVals.empty()) + return SDVals; + auto SRAShift = SRACount->getZExtValue(); + auto SHLShift = SHLCount->getZExtValue(); + SmallVector<SDValue, 4> ShiftedVals; + for (const auto &SDVal : SDVals) { + SDValue SRAVal = DAG.getNode(ISD::SHL, DL, MVT::i32, SDVal, + DAG.getConstant(SHLShift, DL, MVT::i32)); + ShiftedVals.emplace_back( + DAG.getNode(ISD::SRA, DL, MVT::i32, SRAVal, + DAG.getConstant(SRAShift, DL, MVT::i32))); + } + return ShiftedVals; + } + if (N->getOpcode() == SystemZISD::SELECT_CCMASK) { + SDValue TrueVal = N->getOperand(0), FalseVal = N->getOperand(1); + auto *TrueOp = TrueVal.getNode(); + auto *FalseOp = FalseVal.getNode(); + auto *CCValid = dyn_cast<ConstantSDNode>(N->getOperand(2)); + auto *CCMask = dyn_cast<ConstantSDNode>(N->getOperand(3)); + if (!TrueOp || !FalseOp || !CCValid || !CCMask) + return {}; + + int CCValidVal = CCValid->getZExtValue(); + int CCMaskVal = CCMask->getZExtValue(); + const auto &&TrueSDVals = simplifyAssumingCCVal(TrueVal, CC, DCI); + const auto &&FalseSDVals = simplifyAssumingCCVal(FalseVal, CC, DCI); + if (TrueSDVals.empty() || FalseSDVals.empty()) + return {}; + SDValue Op4CCReg = N->getOperand(4); + auto *Op4CCNode = Op4CCReg.getNode(); + if (Op4CCNode && Op4CCNode != CCNode) + combineCCMask(Op4CCReg, CCValidVal, CCMaskVal, DCI); + Op4CCNode = Op4CCReg.getNode(); + if (!Op4CCNode || Op4CCNode != CCNode) + return {}; + SmallVector<SDValue, 4> MergedSDVals; + for (auto &CCVal : {0, 1, 2, 3}) + MergedSDVals.emplace_back((((CCMaskVal & (1 << (3 - CCVal))) != 0) && + ((CCValidVal & (1 << (3 - CCVal))) != 0)) + ? TrueSDVals[CCVal] + : FalseSDVals[CCVal]); + return MergedSDVals; + } + SDValue Op0, Op1; + unsigned Opcode; + if (isValidBinaryOperation(Val, Op0, Op1, Opcode)) { + const auto &&Op0SDVals = simplifyAssumingCCVal(Op0, CC, DCI); + const auto &&Op1SDVals = simplifyAssumingCCVal(Op1, CC, DCI); + if (Op0SDVals.empty() || Op1SDVals.empty()) + return {}; + SmallVector<SDValue, 4> BinaryOpSDVals; + for (auto CCVal : {0, 1, 2, 3}) + BinaryOpSDVals.emplace_back(DAG.getNode( + Opcode, DL, Val.getValueType(), Op0SDVals[CCVal], Op1SDVals[CCVal])); + return BinaryOpSDVals; + } + return {}; +} + +bool SystemZTargetLowering::combineCCMask(SDValue &CCReg, int &CCValid, + int &CCMask, + DAGCombinerInfo &DCI) const { // We have a SELECT_CCMASK or BR_CCMASK comparing the condition code // set by the CCReg instruction using the CCValid / CCMask masks, - // If the CCReg instruction is itself a ICMP testing the condition + // If the CCReg instruction is itself a ICMP / TM testing the condition // code set by some other instruction, see whether we can directly // use that condition code. - - // Verify that we have an ICMP against some constant. - if (CCValid != SystemZ::CCMASK_ICMP) - return false; - auto *ICmp = CCReg.getNode(); - if (ICmp->getOpcode() != SystemZISD::ICMP) - return false; - auto *CompareLHS = ICmp->getOperand(0).getNode(); - auto *CompareRHS = dyn_cast<ConstantSDNode>(ICmp->getOperand(1)); - if (!CompareRHS) + auto *CCNode = CCReg.getNode(); + if (!CCNode) return false; + const auto getConstFromConstSDVals = [](const SmallVector<SDValue, 4> &Vals) { + SmallVector<int, 4> CCVals; + for (const auto &Val : Vals) + if (auto *ConstNode = dyn_cast<ConstantSDNode>(Val.getNode())) + CCVals.emplace_back(ConstNode->getZExtValue()); ---------------- uweigand wrote:
This is wrong, it will truncate again. The whole point of using `SDValue` is to avoid that truncation - this was all for nothing if you truncate here again. You should leave everything as (constant) SDValue and then just operate on those (possibly via the embedded `APInt`). https://github.com/llvm/llvm-project/pull/125970 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits