Author: Sanjay Patel Date: 2021-01-20T11:14:48-05:00 New Revision: 8590d245434dd4205c89f0a05b4c22feccb7421c
URL: https://github.com/llvm/llvm-project/commit/8590d245434dd4205c89f0a05b4c22feccb7421c DIFF: https://github.com/llvm/llvm-project/commit/8590d245434dd4205c89f0a05b4c22feccb7421c.diff LOG: [SLP] move reduction createOp functions; NFC We were able to remove almost all of the state from OperationData, so these don't make sense as members of that class - just pass the RecurKind in as a param. Added: Modified: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Removed: ################################################################################ diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 24885e4d8257..3d657b0b898c 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6397,7 +6397,7 @@ namespace { class HorizontalReduction { using ReductionOpsType = SmallVector<Value *, 16>; using ReductionOpsListType = SmallVector<ReductionOpsType, 2>; - ReductionOpsListType ReductionOps; + ReductionOpsListType ReductionOps; SmallVector<Value *, 32> ReducedVals; // Use map vector to make stable output. MapVector<Instruction *, Value *> ExtraArgs; @@ -6412,47 +6412,6 @@ class HorizontalReduction { /// Checks if the reduction operation can be vectorized. bool isVectorizable() const { return Kind != RecurKind::None; } - /// Creates reduction operation with the current opcode. - Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS, - const Twine &Name) const { - assert(isVectorizable() && "Unhandled reduction operation."); - unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); - switch (Kind) { - case RecurKind::Add: - case RecurKind::Mul: - case RecurKind::Or: - case RecurKind::And: - case RecurKind::Xor: - case RecurKind::FAdd: - case RecurKind::FMul: - return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, - Name); - case RecurKind::FMax: - return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS); - case RecurKind::FMin: - return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS); - - case RecurKind::SMax: { - Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name); - return Builder.CreateSelect(Cmp, LHS, RHS, Name); - } - case RecurKind::SMin: { - Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name); - return Builder.CreateSelect(Cmp, LHS, RHS, Name); - } - case RecurKind::UMax: { - Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name); - return Builder.CreateSelect(Cmp, LHS, RHS, Name); - } - case RecurKind::UMin: { - Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name); - return Builder.CreateSelect(Cmp, LHS, RHS, Name); - } - default: - llvm_unreachable("Unknown reduction operation."); - } - } - public: explicit OperationData() = default; @@ -6580,40 +6539,6 @@ class HorizontalReduction { return nullptr; return I->getOperand(getFirstOperandIndex() + 1); } - - /// Creates reduction operation with the current opcode with the IR flags - /// from \p ReductionOps. - Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS, - const Twine &Name, - const ReductionOpsListType &ReductionOps) const { - assert(isVectorizable() && - "Expected add|fadd or min/max reduction operation."); - Value *Op = createOp(Builder, LHS, RHS, Name); - if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind)) { - if (auto *Sel = dyn_cast<SelectInst>(Op)) - propagateIRFlags(Sel->getCondition(), ReductionOps[0]); - propagateIRFlags(Op, ReductionOps[1]); - return Op; - } - propagateIRFlags(Op, ReductionOps[0]); - return Op; - } - /// Creates reduction operation with the current opcode with the IR flags - /// from \p I. - Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS, - const Twine &Name, Instruction *I) const { - assert(isVectorizable() && - "Expected add|fadd or min/max reduction operation."); - Value *Op = createOp(Builder, LHS, RHS, Name); - if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind)) { - if (auto *Sel = dyn_cast<SelectInst>(Op)) { - propagateIRFlags(Sel->getCondition(), - cast<SelectInst>(I)->getCondition()); - } - } - propagateIRFlags(Op, I); - return Op; - } }; WeakTrackingVH ReductionRoot; @@ -6642,6 +6567,76 @@ class HorizontalReduction { } } + /// Creates reduction operation with the current opcode. + static Value *createOp(IRBuilder<> &Builder, RecurKind Kind, Value *LHS, + Value *RHS, const Twine &Name) { + unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); + switch (Kind) { + case RecurKind::Add: + case RecurKind::Mul: + case RecurKind::Or: + case RecurKind::And: + case RecurKind::Xor: + case RecurKind::FAdd: + case RecurKind::FMul: + return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, + Name); + case RecurKind::FMax: + return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS); + case RecurKind::FMin: + return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS); + + case RecurKind::SMax: { + Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name); + return Builder.CreateSelect(Cmp, LHS, RHS, Name); + } + case RecurKind::SMin: { + Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name); + return Builder.CreateSelect(Cmp, LHS, RHS, Name); + } + case RecurKind::UMax: { + Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name); + return Builder.CreateSelect(Cmp, LHS, RHS, Name); + } + case RecurKind::UMin: { + Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name); + return Builder.CreateSelect(Cmp, LHS, RHS, Name); + } + default: + llvm_unreachable("Unknown reduction operation."); + } + } + + /// Creates reduction operation with the current opcode with the IR flags + /// from \p ReductionOps. + static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, + Value *RHS, const Twine &Name, + const ReductionOpsListType &ReductionOps) { + Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name); + if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) { + if (auto *Sel = dyn_cast<SelectInst>(Op)) + propagateIRFlags(Sel->getCondition(), ReductionOps[0]); + propagateIRFlags(Op, ReductionOps[1]); + return Op; + } + propagateIRFlags(Op, ReductionOps[0]); + return Op; + } + /// Creates reduction operation with the current opcode with the IR flags + /// from \p I. + static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, + Value *RHS, const Twine &Name, Instruction *I) { + Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name); + if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) { + if (auto *Sel = dyn_cast<SelectInst>(Op)) { + propagateIRFlags(Sel->getCondition(), + cast<SelectInst>(I)->getCondition()); + } + } + propagateIRFlags(Op, I); + return Op; + } + static OperationData getOperationData(Instruction *I) { if (!I) return OperationData(); @@ -6995,8 +6990,9 @@ class HorizontalReduction { } else { // Update the final value in the reduction. Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = RdxTreeInst.createOp( - Builder, VectorizedTree, ReducedSubTree, "op.rdx", ReductionOps); + VectorizedTree = + createOp(Builder, RdxTreeInst.getKind(), VectorizedTree, + ReducedSubTree, "op.rdx", ReductionOps); } i += ReduxWidth; ReduxWidth = PowerOf2Floor(NumReducedVals - i); @@ -7007,15 +7003,15 @@ class HorizontalReduction { for (; i < NumReducedVals; ++i) { auto *I = cast<Instruction>(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = RdxTreeInst.createOp(Builder, VectorizedTree, I, "", - ReductionOps); + VectorizedTree = createOp(Builder, RdxTreeInst.getKind(), + VectorizedTree, I, "", ReductionOps); } for (auto &Pair : ExternallyUsedValues) { // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = RdxTreeInst.createOp(Builder, VectorizedTree, - Pair.first, "op.extra", I); + VectorizedTree = createOp(Builder, RdxTreeInst.getKind(), + VectorizedTree, Pair.first, "op.extra", I); } } @@ -7039,9 +7035,7 @@ class HorizontalReduction { return VectorizedTree != nullptr; } - unsigned numReductionValues() const { - return ReducedVals.size(); - } + unsigned numReductionValues() const { return ReducedVals.size(); } private: /// Calculate the cost of a reduction. @@ -7062,7 +7056,7 @@ class HorizontalReduction { case RecurKind::FMul: { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, - /*IsPairwiseForm=*/false); + /*IsPairwiseForm=*/false); ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); break; } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits