Author: Sanjay Patel Date: 2021-01-04T17:05:28-05:00 New Revision: 36263a7cccc0d98afc36dea55e7a004d08455811
URL: https://github.com/llvm/llvm-project/commit/36263a7cccc0d98afc36dea55e7a004d08455811 DIFF: https://github.com/llvm/llvm-project/commit/36263a7cccc0d98afc36dea55e7a004d08455811.diff LOG: [LoopUtils] remove redundant opcode parameter; NFC While here, rename the inaccurate getRecurrenceBinOp() because that was also used to get CmpInst opcodes. The recurrence/reduction kind should always refer to the expected opcode for a reduction. SLP appears to be the only direct caller of createSimpleTargetReduction(), and that calling code ideally should not be carrying around both an opcode and a reduction kind. This should allow us to generalize reduction matching to use intrinsics instead of only binops. Added: Modified: llvm/include/llvm/Analysis/IVDescriptors.h llvm/include/llvm/Transforms/Utils/LoopUtils.h llvm/lib/Analysis/IVDescriptors.cpp llvm/lib/Transforms/Utils/LoopUtils.cpp llvm/lib/Transforms/Vectorize/LoopVectorize.cpp llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp llvm/lib/Transforms/Vectorize/VPlan.cpp Removed: ################################################################################ diff --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h index 798eb430df08f..6bb6c4cae0a2c 100644 --- a/llvm/include/llvm/Analysis/IVDescriptors.h +++ b/llvm/include/llvm/Analysis/IVDescriptors.h @@ -139,9 +139,8 @@ class RecurrenceDescriptor { /// Returns identity corresponding to the RecurrenceKind. static Constant *getRecurrenceIdentity(RecurKind K, Type *Tp); - /// Returns the opcode of binary operation corresponding to the - /// RecurrenceKind. - static unsigned getRecurrenceBinOp(RecurKind Kind); + /// Returns the opcode corresponding to the RecurrenceKind. + static unsigned getOpcode(RecurKind Kind); /// Returns true if Phi is a reduction of type Kind and adds it to the /// RecurrenceDescriptor. If either \p DB is non-null or \p AC and \p DT are @@ -178,9 +177,7 @@ class RecurrenceDescriptor { RecurKind getRecurrenceKind() const { return Kind; } - unsigned getRecurrenceBinOp() const { - return getRecurrenceBinOp(getRecurrenceKind()); - } + unsigned getOpcode() const { return getOpcode(getRecurrenceKind()); } FastMathFlags getFastMathFlags() const { return FMF; } diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index b29add4cba0e5..d606fa954f952 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -366,8 +366,7 @@ Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op, /// required to implement the reduction. /// Fast-math-flags are propagated using the IRBuilder's setting. Value *createSimpleTargetReduction(IRBuilderBase &B, - const TargetTransformInfo *TTI, - unsigned Opcode, Value *Src, + const TargetTransformInfo *TTI, Value *Src, RecurKind RdxKind, ArrayRef<Value *> RedOps = None); diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp index 0bd4f98541587..a11faac093db0 100644 --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -800,8 +800,7 @@ Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp) { } } -/// This function translates the recurrence kind to an LLVM binary operator. -unsigned RecurrenceDescriptor::getRecurrenceBinOp(RecurKind Kind) { +unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) { switch (Kind) { case RecurKind::Add: return Instruction::Add; @@ -833,7 +832,7 @@ unsigned RecurrenceDescriptor::getRecurrenceBinOp(RecurKind Kind) { SmallVector<Instruction *, 4> RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const { SmallVector<Instruction *, 4> ReductionOperations; - unsigned RedOp = getRecurrenceBinOp(Kind); + unsigned RedOp = getOpcode(Kind); // Search down from the Phi to the LoopExitInstr, looking for instructions // with a single user of the correct type for the reduction. diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 3245f5f21017f..f2b94d9e78adc 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -979,9 +979,9 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, const TargetTransformInfo *TTI, - unsigned Opcode, Value *Src, - RecurKind RdxKind, + Value *Src, RecurKind RdxKind, ArrayRef<Value *> RedOps) { + unsigned Opcode = RecurrenceDescriptor::getOpcode(RdxKind); TargetTransformInfo::ReductionFlags RdxFlags; RdxFlags.IsMaxOp = RdxKind == RecurKind::SMax || RdxKind == RecurKind::UMax || RdxKind == RecurKind::FMax; @@ -991,42 +991,34 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, return getShuffleReduction(Builder, Src, Opcode, RdxKind, RedOps); auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); - switch (Opcode) { - case Instruction::Add: + switch (RdxKind) { + case RecurKind::Add: return Builder.CreateAddReduce(Src); - case Instruction::Mul: + case RecurKind::Mul: return Builder.CreateMulReduce(Src); - case Instruction::And: + case RecurKind::And: return Builder.CreateAndReduce(Src); - case Instruction::Or: + case RecurKind::Or: return Builder.CreateOrReduce(Src); - case Instruction::Xor: + case RecurKind::Xor: return Builder.CreateXorReduce(Src); - case Instruction::FAdd: + case RecurKind::FAdd: return Builder.CreateFAddReduce(ConstantFP::getNegativeZero(SrcVecEltTy), Src); - case Instruction::FMul: + case RecurKind::FMul: return Builder.CreateFMulReduce(ConstantFP::get(SrcVecEltTy, 1.0), Src); - case Instruction::ICmp: - switch (RdxKind) { - case RecurKind::SMax: - return Builder.CreateIntMaxReduce(Src, true); - case RecurKind::SMin: - return Builder.CreateIntMinReduce(Src, true); - case RecurKind::UMax: - return Builder.CreateIntMaxReduce(Src, false); - case RecurKind::UMin: - return Builder.CreateIntMinReduce(Src, false); - default: - llvm_unreachable("Unexpected min/max reduction type"); - } - case Instruction::FCmp: - assert((RdxKind == RecurKind::FMax || RdxKind == RecurKind::FMin) && - "Unexpected min/max reduction type"); - if (RdxKind == RecurKind::FMax) - return Builder.CreateFPMaxReduce(Src); - else - return Builder.CreateFPMinReduce(Src); + case RecurKind::SMax: + return Builder.CreateIntMaxReduce(Src, true); + case RecurKind::SMin: + return Builder.CreateIntMinReduce(Src, true); + case RecurKind::UMax: + return Builder.CreateIntMaxReduce(Src, false); + case RecurKind::UMin: + return Builder.CreateIntMinReduce(Src, false); + case RecurKind::FMax: + return Builder.CreateFPMaxReduce(Src); + case RecurKind::FMin: + return Builder.CreateFPMinReduce(Src); default: llvm_unreachable("Unhandled opcode"); } @@ -1040,8 +1032,7 @@ Value *llvm::createTargetReduction(IRBuilderBase &B, // descriptor. IRBuilderBase::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(Desc.getFastMathFlags()); - return createSimpleTargetReduction(B, TTI, Desc.getRecurrenceBinOp(), Src, - Desc.getRecurrenceKind()); + return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind()); } void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) { diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index a55efe1c323ac..7f89fd9a13490 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -4254,7 +4254,7 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { RecurrenceDescriptor RdxDesc = Legal->getReductionVars()[Phi]; if (PreferPredicatedReductionSelect || TTI->preferPredicatedReductionSelect( - RdxDesc.getRecurrenceBinOp(), Phi->getType(), + RdxDesc.getOpcode(), Phi->getType(), TargetTransformInfo::ReductionFlags())) { auto *VecRdxPhi = cast<PHINode>(getOrCreateVectorValue(Phi, Part)); VecRdxPhi->setIncomingValueForBlock( @@ -4296,7 +4296,7 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = VectorLoopValueMap.getVectorValue(LoopExitInst, 0); - unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); + unsigned Op = RecurrenceDescriptor::getOpcode(RK); // The middle block terminator has already been assigned a DebugLoc here (the // OrigLoop's single latch terminator). We want the whole middle block to @@ -7325,7 +7325,7 @@ void LoopVectorizationCostModel::collectInLoopReductions() { // If the target would prefer this reduction to happen "in-loop", then we // want to record it as such. - unsigned Opcode = RdxDesc.getRecurrenceBinOp(); + unsigned Opcode = RdxDesc.getOpcode(); if (!PreferInLoopReductions && !TTI.preferInLoopReduction(Opcode, Phi->getType(), TargetTransformInfo::ReductionFlags())) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 2de09089be0c8..a655d3dd91bd1 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7255,9 +7255,9 @@ class HorizontalReduction { // FIXME: The builder should use an FMF guard. It should not be hard-coded // to 'fast'. assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF"); - return createSimpleTargetReduction( - Builder, TTI, RdxTreeInst.getOpcode(), VectorizedValue, - RdxTreeInst.getKind(), ReductionOps.back()); + return createSimpleTargetReduction(Builder, TTI, VectorizedValue, + RdxTreeInst.getKind(), + ReductionOps.back()); } Value *TmpVec = VectorizedValue; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index f5ce1a3ccafb9..c6e44d11e7b38 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -917,7 +917,7 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, printAsOperand(O, SlotTracker); O << " = "; getChainOp()->printAsOperand(O, SlotTracker); - O << " + reduce." << Instruction::getOpcodeName(RdxDesc->getRecurrenceBinOp()) + O << " + reduce." << Instruction::getOpcodeName(RdxDesc->getOpcode()) << " ("; getVecOp()->printAsOperand(O, SlotTracker); if (getCondOp()) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits