Author: QingShan Zhang Date: 2021-01-25T04:02:44Z New Revision: ffc3e800c65ee58166255ff897f8b7e6d850ddda
URL: https://github.com/llvm/llvm-project/commit/ffc3e800c65ee58166255ff897f8b7e6d850ddda DIFF: https://github.com/llvm/llvm-project/commit/ffc3e800c65ee58166255ff897f8b7e6d850ddda.diff LOG: [NFC] [DAGCombine] Correct the result for sqrt even the iteration is zero For now, we correct the result for sqrt if iteration > 0. This doesn't make sense as they are not strict relative. Reviewed By: dmgreen, spatel, RKSimon Differential Revision: https://reviews.llvm.org/D94480 Added: Modified: llvm/include/llvm/CodeGen/TargetLowering.h llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.h llvm/lib/Target/PowerPC/PPCISelLowering.cpp Removed: ################################################################################ diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 5a237074a5a3..1bc5377e6863 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4287,9 +4287,7 @@ class TargetLowering : public TargetLoweringBase { /// comparison may check if the operand is NAN, INF, zero, normal, etc. The /// result should be used as the condition operand for a select or branch. virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG, - const DenormalMode &Mode) const { - return SDValue(); - } + const DenormalMode &Mode) const; /// Return a target-dependent result if the input operand is not suitable for /// use with a square root estimate calculation. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 2ebf7c6ba0f3..cb273a6f299c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -22275,43 +22275,21 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags, Reciprocal)) { AddToWorklist(Est.getNode()); - if (Iterations) { + if (Iterations) Est = UseOneConstNR ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal) : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); - - if (!Reciprocal) { - SDLoc DL(Op); - EVT CCVT = getSetCCResultType(VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - DenormalMode DenormMode = DAG.getDenormalMode(VT); - // Try the target specific test first. - SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode); - if (!Test) { - // If no test provided by target, testing it with denormal inputs to - // avoid wrong estimate. - if (DenormMode.Input == DenormalMode::IEEE) { - // This is specifically a check for the handling of denormal inputs, - // not the result. - - // Test = fabs(X) < SmallestNormal - const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); - APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); - SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); - SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); - Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); - } else - // Test = X == 0.0 - Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); - } - - // The estimate is now completely wrong if the input was exactly 0.0 or - // possibly a denormal. Force the answer to 0.0 or value provided by - // target for those cases. - Est = DAG.getNode( - Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, - Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est); - } + if (!Reciprocal) { + SDLoc DL(Op); + // Try the target specific test first. + SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT)); + + // The estimate is now completely wrong if the input was exactly 0.0 or + // possibly a denormal. Force the answer to 0.0 or value provided by + // target for those cases. + Est = DAG.getNode( + Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, + Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est); } return Est; } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 80b745e0354a..7858bc6c43e4 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5841,6 +5841,28 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const { return false; } +SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + // Testing it with denormal inputs to avoid wrong estimate. + if (Mode.Input == DenormalMode::IEEE) { + // This is specifically a check for the handling of denormal inputs, + // not the result. + + // Test = fabs(X) < SmallestNormal + const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); + APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); + SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); + SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); + return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); + } + // Test = X == 0.0 + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, bool OptForSize, NegatibleCost &Cost, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index f485490199fd..1be09186dc0a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7471,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode, return SDValue(); } +SDValue +AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + +SDValue +AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op, + SelectionDAG &DAG) const { + return Op; +} + SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, @@ -7494,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags); Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags); } - if (!Reciprocal) { - EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), - VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ); - + if (!Reciprocal) Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags); - // Correct the result if the operand is 0.0. - Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, - VT, Eq, Operand, Estimate); - } ExtraSteps = 0; return Estimate; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 71e59f8f6ed7..9550197159e6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -961,6 +961,10 @@ class AArch64TargetLowering : public TargetLowering { bool Reciprocal) const override; SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps) const override; + SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG, + const DenormalMode &Mode) const override; + SDValue getSqrtResultForDenormInput(SDValue Operand, + SelectionDAG &DAG) const override; unsigned combineRepeatedFPDivisors() const override; ConstraintType getConstraintType(StringRef Constraint) const override; diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp index a6382bd17a9f..9215c17cb94b 100644 --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -12133,7 +12133,7 @@ SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, if (!isTypeLegal(MVT::i1) || (VT != MVT::f64 && ((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX()))) - return SDValue(); + return TargetLowering::getSqrtInputTest(Op, DAG, Mode); SDLoc DL(Op); // The output register of FTSQRT is CR field. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits