Author: Sanjay Patel Date: 2021-01-18T09:32:57-05:00 New Revision: d1c4e859ce42c35c61a0db2f1eb8a4209be4503d
URL: https://github.com/llvm/llvm-project/commit/d1c4e859ce42c35c61a0db2f1eb8a4209be4503d DIFF: https://github.com/llvm/llvm-project/commit/d1c4e859ce42c35c61a0db2f1eb8a4209be4503d.diff LOG: [SLP] reduce opcode API dependency in reduction cost calc; NFC The icmp opcode is now hard-coded in the cost model call. This will make it easier to eventually remove all opcode queries for min/max patterns as we transition to intrinsics. Added: Modified: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Removed: ################################################################################ diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 8dd318a880fc..bf8ef208ccf9 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7058,12 +7058,10 @@ class HorizontalReduction { int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal, unsigned ReduxWidth) { Type *ScalarTy = FirstReducedVal->getType(); - auto *VecTy = FixedVectorType::get(ScalarTy, ReduxWidth); + FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); RecurKind Kind = RdxTreeInst.getKind(); - unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); - int SplittingRdxCost; - int ScalarReduxCost; + int VectorCost, ScalarCost; switch (Kind) { case RecurKind::Add: case RecurKind::Mul: @@ -7071,22 +7069,24 @@ class HorizontalReduction { case RecurKind::And: case RecurKind::Xor: case RecurKind::FAdd: - case RecurKind::FMul: - SplittingRdxCost = TTI->getArithmeticReductionCost( - RdxOpcode, VecTy, /*IsPairwiseForm=*/false); - ScalarReduxCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); + case RecurKind::FMul: { + unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); + VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, + /*IsPairwiseForm=*/false); + ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); break; + } case RecurKind::SMax: case RecurKind::SMin: case RecurKind::UMax: case RecurKind::UMin: { - auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VecTy)); + auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); bool IsUnsigned = Kind == RecurKind::UMax || Kind == RecurKind::UMin; - SplittingRdxCost = - TTI->getMinMaxReductionCost(VecTy, VecCondTy, + VectorCost = + TTI->getMinMaxReductionCost(VectorTy, VecCondTy, /*IsPairwiseForm=*/false, IsUnsigned); - ScalarReduxCost = - TTI->getCmpSelInstrCost(RdxOpcode, ScalarTy) + + ScalarCost = + TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy) + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, CmpInst::makeCmpResultType(ScalarTy)); break; @@ -7095,12 +7095,12 @@ class HorizontalReduction { llvm_unreachable("Expected arithmetic or min/max reduction operation"); } - ScalarReduxCost *= (ReduxWidth - 1); - LLVM_DEBUG(dbgs() << "SLP: Adding cost " - << SplittingRdxCost - ScalarReduxCost + // Scalar cost is repeated for N-1 elements. + ScalarCost *= (ReduxWidth - 1); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VectorCost - ScalarCost << " for reduction that starts with " << *FirstReducedVal << " (It is a splitting reduction)\n"); - return SplittingRdxCost - ScalarReduxCost; + return VectorCost - ScalarCost; } /// Emit a horizontal reduction of the vectorized value. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits