Author: Sanjay Patel Date: 2021-01-23T11:17:20-05:00 New Revision: a6f02212764a76935ec5fb704fe86a1a76f65745
URL: https://github.com/llvm/llvm-project/commit/a6f02212764a76935ec5fb704fe86a1a76f65745 DIFF: https://github.com/llvm/llvm-project/commit/a6f02212764a76935ec5fb704fe86a1a76f65745.diff LOG: [SLP] fix fast-math-flag propagation on FP reductions As shown in the test diffs, we could miscompile by propagating flags that did not exist in the original code. The flags required for fmin/fmax reductions will be fixed in a follow-up patch. Added: Modified: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll Removed: ################################################################################ diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 78ce4870588c..6c2b10e5b9fa 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6820,12 +6820,18 @@ class HorizontalReduction { if (NumReducedVals < 4) return false; - // FIXME: Fast-math-flags should be set based on the instructions in the - // reduction (not all of 'fast' are required). + // Intersect the fast-math-flags from all reduction operations. + FastMathFlags RdxFMF; + RdxFMF.set(); + for (ReductionOpsType &RdxOp : ReductionOps) { + for (Value *RdxVal : RdxOp) { + if (auto *FPMO = dyn_cast<FPMathOperator>(RdxVal)) + RdxFMF &= FPMO->getFastMathFlags(); + } + } + IRBuilder<> Builder(cast<Instruction>(ReductionRoot)); - FastMathFlags Unsafe; - Unsafe.setFast(); - Builder.setFastMathFlags(Unsafe); + Builder.setFastMathFlags(RdxFMF); BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; // The same extra argument may be used several times, so log each attempt @@ -7071,9 +7077,6 @@ class HorizontalReduction { assert(isPowerOf2_32(ReduxWidth) && "We only handle power-of-two reductions for now"); - // FIXME: The builder should use an FMF guard. It should not be hard-coded - // to 'fast'. - assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF"); return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind, ReductionOps.back()); } diff --git a/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll index 38d36c676fa7..03ec04cb8cbe 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll @@ -1766,7 +1766,6 @@ bb.1: ret void } -; FIXME: This is a miscompile. ; The FMF on the reduction should match the incoming insts. define float @fadd_v4f32_fmf(float* %p) { @@ -1776,7 +1775,7 @@ define float @fadd_v4f32_fmf(float* %p) { ; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3 ; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>* ; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 -; CHECK-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) ; CHECK-NEXT: ret float [[TMP3]] ; ; STORE-LABEL: @fadd_v4f32_fmf( @@ -1785,7 +1784,7 @@ define float @fadd_v4f32_fmf(float* %p) { ; STORE-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3 ; STORE-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>* ; STORE-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 -; STORE-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) +; STORE-NEXT: [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) ; STORE-NEXT: ret float [[TMP3]] ; %p1 = getelementptr inbounds float, float* %p, i64 1 @@ -1801,6 +1800,10 @@ define float @fadd_v4f32_fmf(float* %p) { ret float %add3 } +; The minimal FMF for fadd reduction are "reassoc nsz". +; Only the common FMF of all operations in the reduction propagate to the result. +; In this example, "contract nnan arcp" are dropped, but "ninf" transfers with the required flags. + define float @fadd_v4f32_fmf_intersect(float* %p) { ; CHECK-LABEL: @fadd_v4f32_fmf_intersect( ; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds float, float* [[P:%.*]], i64 1 @@ -1808,7 +1811,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) { ; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3 ; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>* ; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 -; CHECK-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) ; CHECK-NEXT: ret float [[TMP3]] ; ; STORE-LABEL: @fadd_v4f32_fmf_intersect( @@ -1817,7 +1820,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) { ; STORE-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3 ; STORE-NEXT: [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>* ; STORE-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 -; STORE-NEXT: [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) +; STORE-NEXT: [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]]) ; STORE-NEXT: ret float [[TMP3]] ; %p1 = getelementptr inbounds float, float* %p, i64 1 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits