Author: Sanjay Patel Date: 2021-01-20T11:14:48-05:00 New Revision: 1c54112a5762ebab2c14a90c55f27d00bfced7f8
URL: https://github.com/llvm/llvm-project/commit/1c54112a5762ebab2c14a90c55f27d00bfced7f8 DIFF: https://github.com/llvm/llvm-project/commit/1c54112a5762ebab2c14a90c55f27d00bfced7f8.diff LOG: [SLP] refactor more reduction functions; NFC We were able to remove almost all of the state from OperationData, so these don't make sense as members of that class - just pass the RecurKind in as a param. More streamlining is possible, but I'm trying to avoid logic/typo bugs while fixing this. Eventually, we should not need the `OperationData` class. Added: Modified: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Removed: ################################################################################ diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 3d657b0b898c..3192d7959f70 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6427,76 +6427,6 @@ class HorizontalReduction { return IsLeafValue || Kind != RecurKind::None; } - /// Return true if this operation is a cmp+select idiom. - bool isCmpSel() const { - assert(Kind != RecurKind::None && "Expected reduction operation."); - return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind); - } - - /// Get the index of the first operand. - unsigned getFirstOperandIndex() const { - assert(!!*this && "The opcode is not set."); - // We allow calling this before 'Kind' is set, so handle that specially. - if (Kind == RecurKind::None) - return 0; - return isCmpSel() ? 1 : 0; - } - - /// Total number of operands in the reduction operation. - unsigned getNumberOfOperands() const { - assert(Kind != RecurKind::None && !!*this && - "Expected reduction operation."); - return isCmpSel() ? 3 : 2; - } - - /// Checks if the instruction is in basic block \p BB. - /// For a min/max reduction check that both compare and select are in \p BB. - bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const { - assert(Kind != RecurKind::None && !!*this && - "Expected reduction operation."); - if (IsRedOp && isCmpSel()) { - auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition()); - return I->getParent() == BB && Cmp && Cmp->getParent() == BB; - } - return I->getParent() == BB; - } - - /// Expected number of uses for reduction operations/reduced values. - bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const { - assert(Kind != RecurKind::None && !!*this && - "Expected reduction operation."); - // SelectInst must be used twice while the condition op must have single - // use only. - if (isCmpSel()) - return I->hasNUses(2) && - (!IsReductionOp || - cast<SelectInst>(I)->getCondition()->hasOneUse()); - - // Arithmetic reduction operation must be used once only. - return I->hasOneUse(); - } - - /// Initializes the list of reduction operations. - void initReductionOps(ReductionOpsListType &ReductionOps) { - assert(Kind != RecurKind::None && !!*this && - "Expected reduction operation."); - if (isCmpSel()) - ReductionOps.assign(2, ReductionOpsType()); - else - ReductionOps.assign(1, ReductionOpsType()); - } - - /// Add all reduction operations for the reduction instruction \p I. - void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) { - assert(Kind != RecurKind::None && "Expected reduction operation."); - if (isCmpSel()) { - ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition()); - ReductionOps[1].emplace_back(I); - } else { - ReductionOps[0].emplace_back(I); - } - } - /// Checks if instruction is associative and can be vectorized. bool isAssociative(Instruction *I) const { assert(Kind != RecurKind::None && "Expected reduction operation."); @@ -6529,16 +6459,6 @@ class HorizontalReduction { /// Get kind of reduction data. RecurKind getKind() const { return Kind; } - Value *getLHS(Instruction *I) const { - if (Kind == RecurKind::None) - return nullptr; - return I->getOperand(getFirstOperandIndex()); - } - Value *getRHS(Instruction *I) const { - if (Kind == RecurKind::None) - return nullptr; - return I->getOperand(getFirstOperandIndex() + 1); - } }; WeakTrackingVH ReductionRoot; @@ -6559,7 +6479,7 @@ class HorizontalReduction { // Do not perform analysis of remaining operands of ParentStackElem.first // instruction, this whole instruction is an extra argument. OperationData OpData = getOperationData(ParentStackElem.first); - ParentStackElem.second = OpData.getNumberOfOperands(); + ParentStackElem.second = getNumberOfOperands(OpData.getKind()); } else { // We ran into something like: // ParentStackElem.first += ... + ExtraArg + ... @@ -6730,6 +6650,81 @@ class HorizontalReduction { return OperationData(*I); } + /// Return true if this operation is a cmp+select idiom. + static bool isCmpSel(RecurKind Kind) { + return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind); + } + + /// Get the index of the first operand. + static unsigned getFirstOperandIndex(RecurKind Kind) { + // We allow calling this before 'Kind' is set, so handle that specially. + if (Kind == RecurKind::None) + return 0; + return isCmpSel(Kind) ? 1 : 0; + } + + /// Total number of operands in the reduction operation. + static unsigned getNumberOfOperands(RecurKind Kind) { + return isCmpSel(Kind) ? 3 : 2; + } + + /// Checks if the instruction is in basic block \p BB. + /// For a min/max reduction check that both compare and select are in \p BB. + static bool hasSameParent(RecurKind Kind, Instruction *I, BasicBlock *BB, + bool IsRedOp) { + if (IsRedOp && isCmpSel(Kind)) { + auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition()); + return I->getParent() == BB && Cmp && Cmp->getParent() == BB; + } + return I->getParent() == BB; + } + + /// Expected number of uses for reduction operations/reduced values. + static bool hasRequiredNumberOfUses(RecurKind Kind, Instruction *I, + bool IsReductionOp) { + // SelectInst must be used twice while the condition op must have single + // use only. + if (isCmpSel(Kind)) + return I->hasNUses(2) && + (!IsReductionOp || + cast<SelectInst>(I)->getCondition()->hasOneUse()); + + // Arithmetic reduction operation must be used once only. + return I->hasOneUse(); + } + + /// Initializes the list of reduction operations. + static void initReductionOps(RecurKind Kind, + ReductionOpsListType &ReductionOps) { + if (isCmpSel(Kind)) + ReductionOps.assign(2, ReductionOpsType()); + else + ReductionOps.assign(1, ReductionOpsType()); + } + + /// Add all reduction operations for the reduction instruction \p I. + static void addReductionOps(RecurKind Kind, Instruction *I, + ReductionOpsListType &ReductionOps) { + assert(Kind != RecurKind::None && "Expected reduction operation."); + if (isCmpSel(Kind)) { + ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition()); + ReductionOps[1].emplace_back(I); + } else { + ReductionOps[0].emplace_back(I); + } + } + + static Value *getLHS(RecurKind Kind, Instruction *I) { + if (Kind == RecurKind::None) + return nullptr; + return I->getOperand(getFirstOperandIndex(Kind)); + } + static Value *getRHS(RecurKind Kind, Instruction *I) { + if (Kind == RecurKind::None) + return nullptr; + return I->getOperand(getFirstOperandIndex(Kind) + 1); + } + public: HorizontalReduction() = default; @@ -6744,13 +6739,13 @@ class HorizontalReduction { // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (RdxTreeInst.getLHS(B) == Phi) { + if (getLHS(RdxTreeInst.getKind(), B) == Phi) { Phi = nullptr; - B = dyn_cast<Instruction>(RdxTreeInst.getRHS(B)); + B = dyn_cast<Instruction>(getRHS(RdxTreeInst.getKind(), B)); RdxTreeInst = getOperationData(B); - } else if (RdxTreeInst.getRHS(B) == Phi) { + } else if (getRHS(RdxTreeInst.getKind(), B) == Phi) { Phi = nullptr; - B = dyn_cast<Instruction>(RdxTreeInst.getLHS(B)); + B = dyn_cast<Instruction>(getLHS(RdxTreeInst.getKind(), B)); RdxTreeInst = getOperationData(B); } } @@ -6775,8 +6770,9 @@ class HorizontalReduction { // Post order traverse the reduction tree starting at B. We only handle true // trees containing only binary operators. SmallVector<std::pair<Instruction *, unsigned>, 32> Stack; - Stack.push_back(std::make_pair(B, RdxTreeInst.getFirstOperandIndex())); - RdxTreeInst.initReductionOps(ReductionOps); + Stack.push_back( + std::make_pair(B, getFirstOperandIndex(RdxTreeInst.getKind()))); + initReductionOps(RdxTreeInst.getKind(), ReductionOps); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVisit = Stack.back().second++; @@ -6784,7 +6780,8 @@ class HorizontalReduction { bool IsReducedValue = OpData != RdxTreeInst; // Postorder visit. - if (IsReducedValue || EdgeToVisit == OpData.getNumberOfOperands()) { + if (IsReducedValue || + EdgeToVisit == getNumberOfOperands(OpData.getKind())) { if (IsReducedValue) ReducedVals.push_back(TreeN); else { @@ -6802,7 +6799,7 @@ class HorizontalReduction { markExtraArg(Stack[Stack.size() - 2], TreeN); ExtraArgs.erase(TreeN); } else - RdxTreeInst.addReductionOps(TreeN, ReductionOps); + addReductionOps(RdxTreeInst.getKind(), TreeN, ReductionOps); } // Retract. Stack.pop_back(); @@ -6822,8 +6819,8 @@ class HorizontalReduction { // ultimate reduction. const bool IsRdxInst = EdgeOpData == RdxTreeInst; if (I && I != Phi && I != B && - RdxTreeInst.hasSameParent(I, B->getParent(), IsRdxInst) && - RdxTreeInst.hasRequiredNumberOfUses(I, IsRdxInst) && + hasSameParent(RdxTreeInst.getKind(), I, B->getParent(), IsRdxInst) && + hasRequiredNumberOfUses(RdxTreeInst.getKind(), I, IsRdxInst) && (!LeafOpcode || LeafOpcode == I->getOpcode() || IsRdxInst)) { if (IsRdxInst) { // We need to be able to reassociate the reduction operations. @@ -6835,7 +6832,8 @@ class HorizontalReduction { } else if (!LeafOpcode) { LeafOpcode = I->getOpcode(); } - Stack.push_back(std::make_pair(I, EdgeOpData.getFirstOperandIndex())); + Stack.push_back( + std::make_pair(I, getFirstOperandIndex(EdgeOpData.getKind()))); continue; } // NextV is an extra argument for TreeN (its parent operation). @@ -6976,7 +6974,7 @@ class HorizontalReduction { // Emit a reduction. If the root is a select (min/max idiom), the insert // point is the compare condition of that select. Instruction *RdxRootInst = cast<Instruction>(ReductionRoot); - if (RdxTreeInst.isCmpSel()) + if (isCmpSel(RdxTreeInst.getKind())) Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst)); else Builder.SetInsertPoint(RdxRootInst); @@ -7019,7 +7017,7 @@ class HorizontalReduction { // select, we also have to RAUW for the compare instruction feeding the // reduction root. That's because the original compare may have extra uses // besides the final select of the reduction. - if (RdxTreeInst.isCmpSel()) { + if (isCmpSel(RdxTreeInst.getKind())) { if (auto *VecSelect = dyn_cast<SelectInst>(VectorizedTree)) { Instruction *ScalarCmp = getCmpForMinMaxReduction(cast<Instruction>(ReductionRoot)); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits