Author: Johannes Reifferscheid Date: 2024-04-15T11:48:12+02:00 New Revision: de88bd7e8925f5df51547e20f6fbd1ef006386ad
URL: https://github.com/llvm/llvm-project/commit/de88bd7e8925f5df51547e20f6fbd1ef006386ad DIFF: https://github.com/llvm/llvm-project/commit/de88bd7e8925f5df51547e20f6fbd1ef006386ad.diff LOG: Revert "Fix rsqrt inaccuracies. (#88691)" This reverts commit 8ddaf750746d7f9b5f7e878870b086edc0f55326. Added: Modified: mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir Removed: ################################################################################ diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 3ebee9baff31bd..49eb575212ffc1 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -27,11 +27,9 @@ using namespace mlir; namespace { -enum class AbsFn { abs, sqrt, rsqrt }; - -// Returns the absolute value, its square root or its reciprocal square root. +// Returns the absolute value or its square root. Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, - ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) { + ImplicitLocOpBuilder &b, bool returnSqrt = false) { Value one = b.create<arith::ConstantOp>(real.getType(), b.getFloatAttr(real.getType(), 1.0)); @@ -45,13 +43,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf, Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf); Value result; - if (fn == AbsFn::rsqrt) { - ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf); - min = b.create<math::RsqrtOp>(min, fmf); - max = b.create<math::RsqrtOp>(max, fmf); - } - - if (fn == AbsFn::sqrt) { + if (returnSqrt) { Value quarter = b.create<arith::ConstantOp>( real.getType(), b.getFloatAttr(real.getType(), 0.25)); // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily. @@ -871,7 +863,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> { Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); - Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt); + Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true); Value argArg = b.create<math::Atan2Op>(imag, real, fmf); Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf); Value cos = b.create<math::CosOp>(sqrtArg, fmf); @@ -1155,74 +1147,18 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> { LogicalResult matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); auto type = cast<ComplexType>(adaptor.getComplex().getType()); auto elementType = cast<FloatType>(type.getElementType()); - arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); - - auto cst = [&](APFloat v) { - return b.create<arith::ConstantOp>(elementType, - b.getFloatAttr(elementType, v)); - }; - const auto &floatSemantics = elementType.getFloatSemantics(); - Value zero = cst(APFloat::getZero(floatSemantics)); - Value inf = cst(APFloat::getInf(floatSemantics)); - Value negHalf = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -0.5)); - Value nan = cst(APFloat::getNaN(floatSemantics)); - - Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex()); - Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex()); - Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt); - Value argArg = b.create<math::Atan2Op>(imag, real, fmf); - Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf); - Value cos = b.create<math::CosOp>(rsqrtArg, fmf); - Value sin = b.create<math::SinOp>(rsqrtArg, fmf); - - Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf); - Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf); - - if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan | - arith::FastMathFlags::ninf)) { - Value negOne = b.create<arith::ConstantOp>( - elementType, b.getFloatAttr(elementType, -1)); - - Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf); - Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf); - Value negImagSignedZero = - b.create<arith::MulFOp>(negOne, imagSignedZero, fmf); + Value c = builder.create<arith::ConstantOp>( + elementType, builder.getFloatAttr(elementType, -0.5)); + Value d = builder.create<arith::ConstantOp>( + elementType, builder.getFloatAttr(elementType, 0)); - Value absReal = b.create<math::AbsFOp>(real, fmf); - Value absImag = b.create<math::AbsFOp>(imag, fmf); - - Value absImagIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf); - Value realIsNan = - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf); - Value realIsInf = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf); - Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan); - - Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf); - - resultReal = - b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal); - resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero, - resultImag); - } - - Value isRealZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf); - Value isImagZero = - b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf); - Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero); - - resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal); - resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag); - - rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, - resultImag); + rewriter.replaceOp(op, + {powOpConversionImpl(builder, type, adaptor.getComplex(), + c, d, op.getFastmath())}); return success(); } }; diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 8b4ea9777f7976..e0e7cdadd317d2 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -837,21 +837,6 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> { return %rsqrt : complex<f32> } -// CHECK-COUNT-5: arith.select -// CHECK-NOT: arith.select - -// ----- - -// CHECK-LABEL: func @complex_rsqrt_nnan_ninf -// CHECK-SAME: %[[ARG:.*]]: complex<f32> -func.func @complex_rsqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> { - %sqrt = complex.rsqrt %arg fastmath<nnan,ninf> : complex<f32> - return %sqrt : complex<f32> -} - -// CHECK-COUNT-3: arith.select -// CHECK-NOT: arith.select - // ----- // CHECK-LABEL: func.func @complex_angle @@ -2118,4 +2103,4 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> { // CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32> // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32 -// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32> +// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32> \ No newline at end of file _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits