Author: Diego Caballero Date: 2021-01-21T00:37:23+02:00 New Revision: 735a07f0478566f6f7c60a8a98eb8884db574113
URL: https://github.com/llvm/llvm-project/commit/735a07f0478566f6f7c60a8a98eb8884db574113 DIFF: https://github.com/llvm/llvm-project/commit/735a07f0478566f6f7c60a8a98eb8884db574113.diff LOG: Revert "[mlir][Affine] Add support for multi-store producer fusion" This reverts commit 7dd198852b4db52ae22242dfeda4eccda83aa8b2. ASAN issue. Added: Modified: mlir/include/mlir/Analysis/AffineStructures.h mlir/include/mlir/Analysis/Utils.h mlir/include/mlir/Transforms/LoopFusionUtils.h mlir/include/mlir/Transforms/Passes.td mlir/lib/Analysis/AffineStructures.cpp mlir/lib/Analysis/Utils.cpp mlir/lib/Transforms/LoopFusion.cpp mlir/lib/Transforms/Utils/LoopFusionUtils.cpp mlir/test/Transforms/loop-fusion.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 893d4ea4ff46..fa80db7d4b63 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -234,21 +234,6 @@ class FlatAffineConstraints { // TODO: add support for non-unit strides. LogicalResult addAffineForOpDomain(AffineForOp forOp); - /// Adds constraints (lower and upper bounds) for each loop in the loop nest - /// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice. - /// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in - /// the nest, sorted outer-to-inner. 'operands' contains the bound operands - /// for a single bound map. All the bound maps will use the same bound - /// operands. Note that some loops described by a computation slice might not - /// exist yet in the IR so the Value attached to those dimension identifiers - /// might be empty. For that reason, this method doesn't perform Value - /// look-ups to retrieve the dimension identifier positions. Instead, it - /// assumes the position of the dim identifiers in the constraint system is - /// the same as the position of the loop in the loop nest. - LogicalResult addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, - ArrayRef<AffineMap> ubMaps, - ArrayRef<Value> operands); - /// Adds constraints imposed by the `affine.if` operation. These constraints /// are collected from the IntegerSet attached to the given `affine.if` /// instance argument (`ifOp`). It is asserted that: diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index ee6f8095f25e..30b6272181f5 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -83,25 +83,10 @@ struct ComputationSliceState { // Clears all bounds and operands in slice state. void clearBounds(); - /// Returns true if the computation slice is empty. + /// Return true if the computation slice is empty. bool isEmpty() const { return ivs.empty(); } - /// Returns true if the computation slice encloses all the iterations of the - /// sliced loop nest. Returns false if it does not. Returns llvm::None if it - /// cannot determine if the slice is maximal or not. - // TODO: Cache 'isMaximal' so that we don't recompute it when the slice - // information hasn't changed. - Optional<bool> isMaximal() const; - void dump() const; - -private: - /// Fast check to determine if the computation slice is maximal. Returns true - /// if each slice dimension maps to an existing dst dimension and both the src - /// and the dst loops for those dimensions have the same bounds. Returns false - /// if both the src and the dst loops don't have the same bounds. Returns - /// llvm::None if none of the above can be proven. - Optional<bool> isSliceMaximalFastCheck() const; }; /// Computes the computation slice loop bounds for one loop nest as affine maps diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h index 10d6b83d022f..eade565e0325 100644 --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -50,8 +50,7 @@ struct FusionResult { // TODO: Generalize utilities so that producer-consumer and sibling fusion // strategies can be used without the assumptions made in the AffineLoopFusion // pass. -class FusionStrategy { -public: +struct FusionStrategy { enum StrategyEnum { // Generic loop fusion: Arbitrary loops are considered for fusion. No // assumptions about a specific fusion strategy from AffineLoopFusion pass @@ -70,34 +69,13 @@ class FusionStrategy { // implementation in AffineLoopFusion pass are made. See pass for specific // details. Sibling - }; + } strategy; - /// Construct a generic or producer-consumer fusion strategy. - FusionStrategy(StrategyEnum strategy) : strategy(strategy) { - assert(strategy != Sibling && - "Sibling fusion strategy requires a specific memref"); - } - - /// Construct a sibling fusion strategy targeting 'memref'. This construct - /// should only be used for sibling fusion. - FusionStrategy(Value memref) : strategy(Sibling), memref(memref) {} - - /// Returns the fusion strategy. - StrategyEnum getStrategy() const { return strategy; }; - - /// Returns the memref attached to this sibling fusion strategy. - Value getSiblingFusionMemRef() const { - assert(strategy == Sibling && "Memref is only valid for sibling fusion"); - return memref; - } - -private: - /// Fusion strategy. - StrategyEnum strategy; - - /// Target memref for this fusion transformation. Only used for sibling - /// fusion. + // Target memref for this fusion transformation. Value memref; + + FusionStrategy(StrategyEnum strategy, Value memref) + : strategy(strategy), memref(memref) {} }; /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the @@ -108,10 +86,11 @@ class FusionStrategy { /// NOTE: This function is not feature complete and should only be used in /// testing. /// TODO: Update comments when this function is fully implemented. -FusionResult -canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, - ComputationSliceState *srcSlice, - FusionStrategy fusionStrategy = FusionStrategy::Generic); +FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, + unsigned dstLoopDepth, + ComputationSliceState *srcSlice, + FusionStrategy fusionStrategy = { + FusionStrategy::Generic, Value()}); /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point /// and source slice loop bounds specified in 'srcSlice'. @@ -155,12 +134,6 @@ bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, const ComputationSliceState &slice, int64_t *computeCost); -/// Returns in 'producerConsumerMemrefs' the memrefs involved in a -/// producer-consumer dependence between write ops in 'srcOps' and read ops in -/// 'dstOps'. -void gatherProducerConsumerMemrefs(ArrayRef<Operation *> srcOps, - ArrayRef<Operation *> dstOps, - DenseSet<Value> &producerConsumerMemrefs); } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index a03b439af339..438a468673b5 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -17,111 +17,6 @@ include "mlir/Pass/PassBase.td" def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> { let summary = "Fuse affine loop nests"; - let description = [{ - This pass performs fusion of loop nests using a slicing-based approach. It - combines two fusion strategies: producer-consumer fusion and sibling fusion. - Producer-consumer fusion is aimed at fusing pairs of loops where the first - one writes to a memref that the second reads. Sibling fusion targets pairs - of loops that share no dependences between them but that load from the same - memref. The fused loop nests, when possible, are rewritten to access - significantly smaller local buffers instead of the original memref's, and - the latter are often either completely optimized away or contracted. This - transformation leads to enhanced locality and lower memory footprint through - the elimination or contraction of temporaries/intermediate memref's. These - benefits are sometimes achieved at the expense of redundant computation - through a cost model that evaluates available choices such as the depth at - which a source slice should be materialized in the designation slice. - - Example 1: Producer-consumer fusion. - Input: - ```mlir - func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { - %0 = alloc() : memref<10xf32> - %1 = alloc() : memref<10xf32> - %cst = constant 0.000000e+00 : f32 - affine.for %arg2 = 0 to 10 { - affine.store %cst, %0[%arg2] : memref<10xf32> - affine.store %cst, %1[%arg2] : memref<10xf32> - } - affine.for %arg2 = 0 to 10 { - %2 = affine.load %0[%arg2] : memref<10xf32> - %3 = addf %2, %2 : f32 - affine.store %3, %arg0[%arg2] : memref<10xf32> - } - affine.for %arg2 = 0 to 10 { - %2 = affine.load %1[%arg2] : memref<10xf32> - %3 = mulf %2, %2 : f32 - affine.store %3, %arg1[%arg2] : memref<10xf32> - } - return - } - ``` - Output: - ```mlir - func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { - %0 = alloc() : memref<1xf32> - %1 = alloc() : memref<1xf32> - %cst = constant 0.000000e+00 : f32 - affine.for %arg2 = 0 to 10 { - affine.store %cst, %0[0] : memref<1xf32> - affine.store %cst, %1[0] : memref<1xf32> - %2 = affine.load %1[0] : memref<1xf32> - %3 = mulf %2, %2 : f32 - affine.store %3, %arg1[%arg2] : memref<10xf32> - %4 = affine.load %0[0] : memref<1xf32> - %5 = addf %4, %4 : f32 - affine.store %5, %arg0[%arg2] : memref<10xf32> - } - return - } - ``` - - Example 2: Sibling fusion. - Input: - ```mlir - func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>, - %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>, - %arg4: memref<10x10xf32>) { - affine.for %arg5 = 0 to 3 { - affine.for %arg6 = 0 to 3 { - %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> - %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 - affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> - } - } - affine.for %arg5 = 0 to 3 { - affine.for %arg6 = 0 to 3 { - %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> - %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %2 = addf %0, %1 : f32 - affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32> - } - } - return - } - ``` - Output: - ```mlir - func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>, - %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>, - %arg4: memref<10x10xf32>) { - affine.for %arg5 = 0 to 3 { - affine.for %arg6 = 0 to 3 { - %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> - %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> - %2 = mulf %0, %1 : f32 - affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> - %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> - %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> - %5 = addf %3, %4 : f32 - affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32> - } - } - return - } - ``` - }]; let constructor = "mlir::createLoopFusionPass()"; let options = [ Option<"computeToleranceThreshold", "fusion-compute-tolerance", "double", diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 81dc7855184e..12c90fbcfc54 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -708,70 +708,6 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { /*eq=*/false, /*lower=*/false); } -/// Adds constraints (lower and upper bounds) for each loop in the loop nest -/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice. -/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in -/// the nest, sorted outer-to-inner. 'operands' contains the bound operands -/// for a single bound map. All the bound maps will use the same bound -/// operands. Note that some loops described by a computation slice might not -/// exist yet in the IR so the Value attached to those dimension identifiers -/// might be empty. For that reason, this method doesn't perform Value -/// look-ups to retrieve the dimension identifier positions. Instead, it -/// assumes the position of the dim identifiers in the constraint system is -/// the same as the position of the loop in the loop nest. -LogicalResult -FlatAffineConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps, - ArrayRef<AffineMap> ubMaps, - ArrayRef<Value> operands) { - assert(lbMaps.size() == ubMaps.size()); - assert(lbMaps.size() <= getNumDimIds()); - - for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { - AffineMap lbMap = lbMaps[i]; - AffineMap ubMap = ubMaps[i]; - assert(!lbMap || lbMap.getNumInputs() == operands.size()); - assert(!ubMap || ubMap.getNumInputs() == operands.size()); - - // Check if this slice is just an equality along this dimension. If so, - // retrieve the existing loop it equates to and add it to the system. - if (lbMap && ubMap && lbMap.getNumResults() == 1 && - ubMap.getNumResults() == 1 && - lbMap.getResult(0) + 1 == ubMap.getResult(0) && - // The condition above will be true for maps describing a single - // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). - // Make sure we skip those cases by checking that the lb result is not - // just a constant. - !lbMap.getResult(0).isa<AffineConstantExpr>()) { - // Limited support: we expect the lb result to be just a loop dimension. - // Not supported otherwise for now. - AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>(); - if (!result) - return failure(); - - AffineForOp loop = - getForInductionVarOwner(operands[result.getPosition()]); - if (!loop) - return failure(); - - if (failed(addAffineForOpDomain(loop))) - return failure(); - continue; - } - - // This slice refers to a loop that doesn't exist in the IR yet. Add its - // bounds to the system assuming its dimension identifier position is the - // same as the position of the loop in the loop nest. - if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false, - /*lower=*/true))) - return failure(); - - if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false, - /*lower=*/false))) - return failure(); - } - return success(); -} - void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { // Create the base constraints from the integer set attached to ifOp. FlatAffineConstraints cst(ifOp.getIntegerSet()); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 383a6587bbef..a1e7d1ffe844 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Utils.h" + #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/PresburgerSet.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -127,128 +127,6 @@ void ComputationSliceState::dump() const { } } -/// Fast check to determine if the computation slice is maximal. Returns true if -/// each slice dimension maps to an existing dst dimension and both the src -/// and the dst loops for those dimensions have the same bounds. Returns false -/// if both the src and the dst loops don't have the same bounds. Returns -/// llvm::None if none of the above can be proven. -Optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const { - assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() && - "Unexpected number of lbs, ubs and ivs in slice"); - - for (unsigned i = 0, end = lbs.size(); i < end; ++i) { - AffineMap lbMap = lbs[i]; - AffineMap ubMap = ubs[i]; - - // Check if this slice is just an equality along this dimension. - if (!lbMap || !ubMap || lbMap.getNumResults() != 1 || - ubMap.getNumResults() != 1 || - lbMap.getResult(0) + 1 != ubMap.getResult(0) || - // The condition above will be true for maps describing a single - // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). - // Make sure we skip those cases by checking that the lb result is not - // just a constant. - lbMap.getResult(0).isa<AffineConstantExpr>()) - return llvm::None; - - // Limited support: we expect the lb result to be just a loop dimension for - // now. - AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>(); - if (!result) - return llvm::None; - - // Retrieve dst loop bounds. - AffineForOp dstLoop = - getForInductionVarOwner(lbOperands[i][result.getPosition()]); - if (!dstLoop) - return llvm::None; - AffineMap dstLbMap = dstLoop.getLowerBoundMap(); - AffineMap dstUbMap = dstLoop.getUpperBoundMap(); - - // Retrieve src loop bounds. - AffineForOp srcLoop = getForInductionVarOwner(ivs[i]); - assert(srcLoop && "Expected affine for"); - AffineMap srcLbMap = srcLoop.getLowerBoundMap(); - AffineMap srcUbMap = srcLoop.getUpperBoundMap(); - - // Limited support: we expect simple src and dst loops with a single - // constant component per bound for now. - if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 || - dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1) - return llvm::None; - - AffineExpr srcLbResult = srcLbMap.getResult(0); - AffineExpr dstLbResult = dstLbMap.getResult(0); - AffineExpr srcUbResult = srcUbMap.getResult(0); - AffineExpr dstUbResult = dstUbMap.getResult(0); - if (!srcLbResult.isa<AffineConstantExpr>() || - !srcUbResult.isa<AffineConstantExpr>() || - !dstLbResult.isa<AffineConstantExpr>() || - !dstUbResult.isa<AffineConstantExpr>()) - return llvm::None; - - // Check if src and dst loop bounds are the same. If not, we can guarantee - // that the slice is not maximal. - if (srcLbResult != dstLbResult || srcUbResult != dstUbResult) - return false; - } - - return true; -} - -/// Returns true if the computation slice encloses all the iterations of the -/// sliced loop nest. Returns false if it does not. Returns llvm::None if it -/// cannot determine if the slice is maximal or not. -Optional<bool> ComputationSliceState::isMaximal() const { - // Fast check to determine if the computation slice is maximal. If the result - // is inconclusive, we proceed with a more expensive analysis. - Optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck(); - if (isMaximalFastCheck.hasValue()) - return isMaximalFastCheck; - - // Create constraints for the src loop nest being sliced. - FlatAffineConstraints srcConstraints; - srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, - /*numLocals=*/0, ivs); - for (Value iv : ivs) { - AffineForOp loop = getForInductionVarOwner(iv); - assert(loop && "Expected affine for"); - if (failed(srcConstraints.addAffineForOpDomain(loop))) - return llvm::None; - } - - // Create constraints for the slice using the dst loop nest information. We - // retrieve existing dst loops from the lbOperands. - SmallVector<Value, 8> consumerIVs; - for (Value lbOp : lbOperands[0]) - if (getForInductionVarOwner(lbOp)) - consumerIVs.push_back(lbOp); - - // Add empty IV Values for those new loops that are not equalities and, - // therefore, are not yet materialized in the IR. - for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i) - consumerIVs.push_back(Value()); - - FlatAffineConstraints sliceConstraints; - sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0, - /*numLocals=*/0, consumerIVs); - - if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0]))) - return llvm::None; - - if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds()) - // Constraint dims are diff erent. The integer set diff erence can't be - // computed so we don't know if the slice is maximal. - return llvm::None; - - // Compute the diff erence between the src loop nest and the slice integer - // sets. - PresburgerSet srcSet(srcConstraints); - PresburgerSet sliceSet(sliceConstraints); - PresburgerSet diff Set = srcSet.subtract(sliceSet); - return diff Set.isIntegerEmpty(); -} - unsigned MemRefRegion::getRank() const { return memref.getType().cast<MemRefType>().getRank(); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6c56368ca6e1..6fe112b89baf 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -30,7 +30,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <iomanip> -#include <set> #include <sstream> #define DEBUG_TYPE "affine-loop-fusion" @@ -271,6 +270,64 @@ struct MemRefDependenceGraph { return false; } + // Returns the unique AffineWriteOpInterface in `node` that meets all the + // following: + // *) store is the only one that writes to a function-local memref live out + // of `node`, + // *) store is not the source of a self-dependence on `node`. + // Otherwise, returns a null AffineWriteOpInterface. + AffineWriteOpInterface getUniqueOutgoingStore(Node *node) { + AffineWriteOpInterface uniqueStore; + + // Return null if `node` doesn't have any outgoing edges. + auto outEdgeIt = outEdges.find(node->id); + if (outEdgeIt == outEdges.end()) + return nullptr; + + const auto &nodeOutEdges = outEdgeIt->second; + for (auto *op : node->stores) { + auto storeOp = cast<AffineWriteOpInterface>(op); + auto memref = storeOp.getMemRef(); + // Skip this store if there are no dependences on its memref. This means + // that store either: + // *) writes to a memref that is only read within the same loop nest + // (self-dependence edges are not represented in graph at the moment), + // *) writes to a function live out memref (function parameter), or + // *) is dead. + if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) { + return (edge.value != memref); + })) + continue; + + if (uniqueStore) + // Found multiple stores to function-local live-out memrefs. + return nullptr; + // Found first store to function-local live-out memref. + uniqueStore = storeOp; + } + + return uniqueStore; + } + + // Returns true if node 'id' can be removed from the graph. Returns false + // otherwise. A node can be removed from the graph iff the following + // conditions are met: + // *) The node does not write to any memref which escapes (or is a + // function/block argument). + // *) The node has no successors in the dependence graph. + bool canRemoveNode(unsigned id) { + if (writesToLiveInOrEscapingMemrefs(id)) + return false; + Node *node = getNode(id); + for (auto *storeOpInst : node->stores) { + // Return false if there exist out edges from 'id' on 'memref'. + auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); + if (getOutEdgeCount(id, storeMemref) > 0) + return false; + } + return true; + } + // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. @@ -438,49 +495,42 @@ struct MemRefDependenceGraph { return dstNodeInst; } - // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, - // taking into account that: - // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, - // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a - // private memref. - void updateEdges(unsigned srcId, unsigned dstId, - const DenseSet<Value> &privateMemRefs, bool removeSrcId) { + // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' + // has been replaced in node at 'dstId' by a private memref depending + // on the value of 'createPrivateMemRef'. + void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef, + bool createPrivateMemRef) { // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { - // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. - if (privateMemRefs.count(inEdge.value) == 0) + // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'. + if (inEdge.value != oldMemRef) addEdge(inEdge.id, dstId, inEdge.value); } } // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. - // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. if (outEdges.count(srcId) > 0) { SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; for (auto &outEdge : oldOutEdges) { // Remove any out edges from 'srcId' to 'dstId' across memrefs. if (outEdge.id == dstId) removeEdge(srcId, outEdge.id, outEdge.value); - else if (removeSrcId) { - addEdge(dstId, outEdge.id, outEdge.value); - removeEdge(srcId, outEdge.id, outEdge.value); - } } } // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being // replaced by a private memref). These edges could come from nodes // other than 'srcId' which were removed in the previous step. - if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { + if (inEdges.count(dstId) > 0 && createPrivateMemRef) { SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; for (auto &inEdge : oldInEdges) - if (privateMemRefs.count(inEdge.value) > 0) + if (inEdge.value == oldMemRef) removeEdge(inEdge.id, dstId, inEdge.value); } } // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion - // of sibling node 'sibId' into node 'dstId'. + // of sibling node 'sidId' into node 'dstId'. void updateEdges(unsigned sibId, unsigned dstId) { // For each edge in 'inEdges[sibId]': // *) Add new edge from source node 'inEdge.id' to 'dstNode'. @@ -574,132 +624,6 @@ struct MemRefDependenceGraph { void dump() const { print(llvm::errs()); } }; -/// Returns true if node 'srcId' can be removed after fusing it with node -/// 'dstId'. The node can be removed if any of the following conditions are met: -/// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. -/// 2. 'srcId' has no output dependences after fusion, has escaping memrefs -/// and the fusion slice is maximal. -/// 3. 'srcId' has output dependences after fusion, the fusion slice is -/// maximal and the fusion insertion point dominates all the dependences. -static bool canRemoveSrcNodeAfterFusion( - unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, - Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs, - MemRefDependenceGraph *mdg) { - - Operation *dstNodeOp = mdg->getNode(dstId)->op; - bool hasOutDepsAfterFusion = false; - - for (auto &outEdge : mdg->outEdges[srcId]) { - Operation *depNodeOp = mdg->getNode(outEdge.id)->op; - // Skip dependence with dstOp since it will be removed after fusion. - if (depNodeOp == dstNodeOp) - continue; - - // Only fusion within the same block is supported. Use domination analysis - // when needed. - if (depNodeOp->getBlock() != dstNodeOp->getBlock()) - return false; - - // Check if the insertion point of the fused loop dominates the dependence. - // Otherwise, the src loop can't be removed. - if (fusedLoopInsPoint != depNodeOp && - !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { - LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " - "dominate dependence\n"); - return false; - } - - hasOutDepsAfterFusion = true; - } - - // If src loop has dependences after fusion or it writes to an live-out or - // escaping memref, we can only remove it if the fusion slice is maximal so - // that all the dependences are preserved. - if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { - Optional<bool> isMaximal = fusionSlice.isMaximal(); - if (!isMaximal.hasValue()) { - LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " - "if fusion is maximal\n"); - return false; - } - - if (!isMaximal.getValue()) { - LLVM_DEBUG(llvm::dbgs() - << "Src loop can't be removed: fusion is not maximal\n"); - return false; - } - } - - return true; -} - -/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer -/// 'dstId'. -// TODO: Move this to a loop fusion utility once 'mdg' is also moved. -static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, - DenseSet<unsigned> &srcIdCandidates) { - // Skip if no input edges along which to fuse. - if (mdg->inEdges.count(dstId) == 0) - return; - - // Gather memrefs from loads in 'dstId'. - auto *dstNode = mdg->getNode(dstId); - DenseSet<Value> consumedMemrefs; - for (Operation *load : dstNode->loads) - consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef()); - - // Traverse 'dstId' incoming edges and gather the nodes that contain a store - // to one of the consumed memrefs. - for (auto &srcEdge : mdg->inEdges[dstId]) { - auto *srcNode = mdg->getNode(srcEdge.id); - // Skip if 'srcNode' is not a loop nest. - if (!isa<AffineForOp>(srcNode->op)) - continue; - - if (any_of(srcNode->stores, [&](Operation *op) { - auto storeOp = cast<AffineWriteOpInterface>(op); - return consumedMemrefs.count(storeOp.getMemRef()) > 0; - })) - srcIdCandidates.insert(srcNode->id); - } -} - -/// Returns in 'producerConsumerMemrefs' the memrefs involved in a -/// producer-consumer dependence between 'srcId' and 'dstId'. -static void -gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, - MemRefDependenceGraph *mdg, - DenseSet<Value> &producerConsumerMemrefs) { - auto *dstNode = mdg->getNode(dstId); - auto *srcNode = mdg->getNode(srcId); - gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, - producerConsumerMemrefs); -} - -/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' -/// that escape the function. A memref escapes the function if either: -/// 1. It's a function argument, or -/// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.) -void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, - DenseSet<Value> &escapingMemRefs) { - auto *node = mdg->getNode(id); - for (auto *storeOpInst : node->stores) { - auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef(); - if (escapingMemRefs.count(memref)) - continue; - // Check if 'memref' escapes because it's a block argument. - if (memref.isa<BlockArgument>()) { - escapingMemRefs.insert(memref); - continue; - } - // Check if 'memref' escapes through a non-affine op (e.g., std load/store, - // call op, etc.). - for (Operation *user : memref.getUsers()) - if (!isMemRefDereferencingOp(*user)) - escapingMemRefs.insert(memref); - } -} - } // end anonymous namespace // Initializes the data dependence graph by walking operations in 'f'. @@ -707,7 +631,6 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, // TODO: Add support for taking a Block arg to construct the // dependence graph at a diff erent depth. bool MemRefDependenceGraph::init(FuncOp f) { - LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); DenseMap<Value, SetVector<unsigned>> memrefAccesses; // TODO: support multi-block functions. @@ -763,12 +686,6 @@ bool MemRefDependenceGraph::init(FuncOp f) { } } -#ifndef NDEBUG - for (auto &idAndNode : nodes) - LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" - << *(idAndNode.second.op) << "\n"); -#endif - // Add dependence edges between nodes which produce SSA values and their // users. for (auto &idAndNode : nodes) { @@ -808,6 +725,22 @@ bool MemRefDependenceGraph::init(FuncOp f) { return true; } +// Removes load operations from 'srcLoads' which operate on 'memref', and +// adds them to 'dstLoads'. +static void moveLoadsAccessingMemrefTo(Value memref, + SmallVectorImpl<Operation *> *srcLoads, + SmallVectorImpl<Operation *> *dstLoads) { + dstLoads->clear(); + SmallVector<Operation *, 4> srcLoadsToKeep; + for (auto *load : *srcLoads) { + if (cast<AffineReadOpInterface>(load).getMemRef() == memref) + dstLoads->push_back(load); + else + srcLoadsToKeep.push_back(load); + } + srcLoads->swap(srcLoadsToKeep); +} + // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -999,6 +932,75 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, return false; } +// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' +// may write to multiple memrefs but it is required that only one of them, +// 'srcLiveOutStoreOp', has output edges. +// Returns true if 'dstNode's read/write region to 'memref' is a super set of +// 'srcNode's write region to 'memref' and 'srcId' has only one output edge. +// TODO: Generalize this to handle more live in/out cases. +static bool +canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, + AffineWriteOpInterface srcLiveOutStoreOp, + MemRefDependenceGraph *mdg) { + assert(srcLiveOutStoreOp && "Expected a valid store op"); + auto *dstNode = mdg->getNode(dstId); + Value memref = srcLiveOutStoreOp.getMemRef(); + // Return false if 'srcNode' has more than one output edge on 'memref'. + if (mdg->getOutEdgeCount(srcId, memref) > 1) + return false; + + // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. + MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); + if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for source operation\n."); + return false; + } + SmallVector<int64_t, 4> srcShape; + // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. + // by 'srcStoreOp' at depth 'dstLoopDepth'. + Optional<int64_t> srcNumElements = + srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); + if (!srcNumElements.hasValue()) + return false; + + // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'. + // TODO: Compute 'unionboundingbox' of all write regions (one for + // each store op in 'dstStoreOps'). + SmallVector<Operation *, 2> dstStoreOps; + dstNode->getStoreOpsForMemref(memref, &dstStoreOps); + SmallVector<Operation *, 2> dstLoadOps; + dstNode->getLoadOpsForMemref(memref, &dstLoadOps); + + auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0]; + MemRefRegion dstRegion(dstOpInst->getLoc()); + if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute MemRefRegion for dest operation\n."); + return false; + } + SmallVector<int64_t, 4> dstShape; + // Query 'dstRegion' for 'dstShape' and 'dstNumElements'. + // by 'dstOpInst' at depth 'dstLoopDepth'. + Optional<int64_t> dstNumElements = + dstRegion.getConstantBoundingSizeAndShape(&dstShape); + if (!dstNumElements.hasValue()) + return false; + + // Return false if write region is not a superset of 'srcNodes' write + // region to 'memref'. + // TODO: Check the shape and lower bounds here too. + if (srcNumElements != dstNumElements) + return false; + + // Return false if 'memref' is used by a non-affine operation that is + // between node 'srcId' and node 'dstId'. + if (hasNonAffineUsersOnThePath(srcId, dstId, mdg)) + return false; + + return true; +} + // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // The argument 'srcStoreOpInst' is used to calculate the storage reduction on @@ -1027,6 +1029,9 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, // the largest computation slice at the maximal dst loop depth (closest to // the load) to minimize reuse distance and potentially enable subsequent // load/store forwarding. +// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for +// the same memref as is written by 'srcOpInst', then the union of slice +// loop bounds is used to compute the slice and associated slice cost. // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop // nest, at which the src computation slice is inserted/fused. // NOTE: We attempt to maximize the dst loop depth, but there are cases @@ -1036,18 +1041,18 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. -// TODO: Extend profitability analysis to support scenarios with multiple -// stores. static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, - AffineForOp dstForOp, + ArrayRef<Operation *> dstLoadOpInsts, ArrayRef<ComputationSliceState> depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; - llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; - llvm::dbgs() << dstForOp << "\n"; + llvm::dbgs() << ' ' << *srcOpInst << " and destination op(s)\n"; + for (auto dstOpInst : dstLoadOpInsts) { + llvm::dbgs() << " " << *dstOpInst << "\n"; + }; }); if (maxLegalFusionDepth == 0) { @@ -1065,8 +1070,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, return false; // Compute cost of dst loop nest. + SmallVector<AffineForOp, 4> dstLoopIVs; + getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); + LoopNestStats dstLoopNestStats; - if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) + if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) return false; // Search for min cost value for 'dstLoopDepth'. At each value of @@ -1100,19 +1108,18 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); + uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats); // Evaluate all depth choices for materializing the slice in the destination // loop nest. for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { - const ComputationSliceState &slice = depthSliceUnions[i - 1]; // Skip slice union if it wasn't computed for this depth. - if (slice.isEmpty()) + if (depthSliceUnions[i - 1].isEmpty()) continue; int64_t fusedLoopNestComputeCost; - if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, - dstLoopNestStats, slice, + if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], + dstLoopNestStats, depthSliceUnions[i - 1], &fusedLoopNestComputeCost)) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); continue; @@ -1124,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1; // Determine what the slice write MemRefRegion would be, if the src loop - // nest slice 'slice' were to be inserted into the dst loop nest at loop - // depth 'i'. + // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst + // loop nest at loop depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &slice))) { + &depthSliceUnions[i - 1]))) { LLVM_DEBUG(llvm::dbgs() << "Failed to compute slice write region at loopDepth: " << i << "\n"); @@ -1211,7 +1218,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, << "\n fused loop nest compute cost: " << minFusedLoopNestComputeCost << "\n"); - auto dstMemSize = getMemoryFootprintBytes(dstForOp); + auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); Optional<double> storageReduction = None; @@ -1315,6 +1322,8 @@ struct GreedyFusion { MemRefDependenceGraph *mdg; // Worklist of graph nodes visited during the fusion pass. SmallVector<unsigned, 8> worklist; + // Set of graph nodes which are present on the worklist. + llvm::SmallDenseSet<unsigned, 16> worklistSet; // Parameter for local buffer size threshold. unsigned localBufSizeThreshold; // Parameter for fast memory space. @@ -1335,14 +1344,16 @@ struct GreedyFusion { fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), computeToleranceThreshold(computeToleranceThreshold) {} - /// Initializes 'worklist' with nodes from 'mdg'. + // Initializes 'worklist' with nodes from 'mdg' void init() { // TODO: Add a priority queue for prioritizing nodes by diff erent // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). worklist.clear(); + worklistSet.clear(); for (auto &idAndNode : mdg->nodes) { const Node &node = idAndNode.second; worklist.push_back(node.id); + worklistSet.insert(node.id); } } @@ -1361,11 +1372,11 @@ struct GreedyFusion { } void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { - LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); + worklistSet.erase(dstId); // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) @@ -1375,97 +1386,114 @@ struct GreedyFusion { // Skip if 'dstNode' is not a loop nest. if (!isa<AffineForOp>(dstNode->op)) continue; - - LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); - // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop // depth at which we can fuse a slice of a producer loop nest into a // consumer loop nest. sinkSequentialLoops(dstNode); - auto dstAffineForOp = cast<AffineForOp>(dstNode->op); - - // Try to fuse 'dstNode' with candidate producer loops until a fixed point - // is reached. Fusing two loops may expose new fusion opportunities. - bool dstNodeChanged; - do { - // Gather src loop candidates for 'dstNode' and visit them in "quasi" - // reverse program order to minimize the number of iterations needed to - // reach the fixed point. Note that this is a best effort approach since - // 'getProducerCandidates' does not always guarantee that program order - // in 'srcIdCandidates'. - dstNodeChanged = false; - DenseSet<unsigned> srcIdCandidates; - getProducerCandidates(dstId, mdg, srcIdCandidates); - - /// Visit candidates in reverse node id order. This order corresponds to - /// the reverse program order when the 'mdg' is created. However, - /// reverse program order is not guaranteed and must not be required. - /// Reverse program order won't be held if the 'mdg' is reused from a - /// previous fusion step or if the node creation order changes in the - /// future to support more advance cases. - SmallVector<unsigned, 16> sortedSrcIdCandidates; - sortedSrcIdCandidates.reserve(srcIdCandidates.size()); - sortedSrcIdCandidates.append(srcIdCandidates.begin(), - srcIdCandidates.end()); - llvm::sort(sortedSrcIdCandidates, std::greater<unsigned>()); - - for (unsigned srcId : sortedSrcIdCandidates) { + + SmallVector<Operation *, 4> loads = dstNode->loads; + SmallVector<Operation *, 4> dstLoadOpInsts; + DenseSet<Value> visitedMemrefs; + while (!loads.empty()) { + // Get memref of load on top of the stack. + auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef(); + if (visitedMemrefs.count(memref) > 0) + continue; + visitedMemrefs.insert(memref); + // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. + moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); + // Skip if no input edges along which to fuse. + if (mdg->inEdges.count(dstId) == 0) + continue; + // Iterate through in-edges for 'dstId' and src node id for any + // edges on 'memref'. + SmallVector<unsigned, 2> srcNodeIds; + for (auto &srcEdge : mdg->inEdges[dstId]) { + // Skip 'srcEdge' if not for 'memref'. + if (srcEdge.value != memref) + continue; + srcNodeIds.push_back(srcEdge.id); + } + for (unsigned srcId : srcNodeIds) { + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(srcId) == 0) + continue; // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); - auto srcAffineForOp = cast<AffineForOp>(srcNode->op); - LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId - << " for dst loop " << dstId << "\n"); - - DenseSet<Value> producerConsumerMemrefs; - gatherProducerConsumerMemrefs(srcId, dstId, mdg, - producerConsumerMemrefs); - - // Skip if 'srcNode' out edge count on any memref is greater than - // 'maxSrcUserCount'. - if (any_of(producerConsumerMemrefs, [&](Value memref) { - return mdg->getOutEdgeCount(srcNode->id, memref) > - maxSrcUserCount; - })) + // Skip if 'srcNode' is not a loop nest. + if (!isa<AffineForOp>(srcNode->op)) continue; + // Skip if 'srcNode' has more than one live-out store to a + // function-local memref. + // TODO: Support more generic multi-output src loop nests + // fusion. + auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); + if (!srcStoreOp) { + // Get the src store op at the deepest loop depth. + // We will use 'LoopFusionUtils::canFuseLoops' to check fusion + // feasibility for loops with multiple stores. + unsigned maxLoopDepth = 0; + for (auto *op : srcNode->stores) { + auto storeOp = cast<AffineWriteOpInterface>(op); + if (storeOp.getMemRef() != memref) { + srcStoreOp = nullptr; + break; + } + unsigned loopDepth = getNestingDepth(storeOp); + if (loopDepth > maxLoopDepth) { + maxLoopDepth = loopDepth; + srcStoreOp = storeOp; + } + } + if (!srcStoreOp) + continue; + } - // Gather memrefs in 'srcNode' that are written and escape to the - // function (e.g., memref function arguments, returned memrefs, - // memrefs passed to function calls, etc.). - DenseSet<Value> srcEscapingMemRefs; - gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); - - // Skip if there are non-affine operations in between the 'srcNode' - // and 'dstNode' using their memrefs. If so, we wouldn't be able to - // compute a legal insertion point for now. 'srcNode' and 'dstNode' - // memrefs with non-affine operation users would be considered - // escaping memrefs so we can limit this check to only scenarios with - // escaping memrefs. - if (!srcEscapingMemRefs.empty() && - hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { - LLVM_DEBUG( - llvm::dbgs() - << "Can't fuse: non-affine users in between the loops\n."); + // Unique outgoing store found must write to 'memref' since 'memref' + // is the one that established the producer-consumer relationship + // between 'srcNode' and 'dstNode'. + assert(srcStoreOp.getMemRef() == memref && + "Found store to unexpected memref"); + + // Skip if 'srcNode' writes to any live in or escaping memrefs, + // and cannot be fused. + bool writesToLiveInOrOut = + mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); + if (writesToLiveInOrOut && + !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) continue; + + // Don't create a private memref if 'writesToLiveInOrOut'. + bool createPrivateMemref = !writesToLiveInOrOut; + // Don't create a private memref if 'srcNode' has in edges on + // 'memref', or if 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || + mdg->getOutEdgeCount(dstNode->id, memref) > 0) { + createPrivateMemref = false; } + // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. + if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) + continue; + // Compute an operation list insertion point for the fused loop // nest which preserves dependences. - Operation *fusedLoopInsPoint = + Operation *insertPointInst = mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); - if (fusedLoopInsPoint == nullptr) + if (insertPointInst == nullptr) continue; - // Compute the innermost common loop depth for dstNode - // producer-consumer loads/stores. + auto srcAffineForOp = cast<AffineForOp>(srcNode->op); + auto dstAffineForOp = cast<AffineForOp>(dstNode->op); + + // Compute the innermost common loop depth for dstNode loads/stores. SmallVector<Operation *, 2> dstMemrefOps; for (Operation *op : dstNode->loads) - if (producerConsumerMemrefs.count( - cast<AffineReadOpInterface>(op).getMemRef()) > 0) + if (cast<AffineReadOpInterface>(op).getMemRef() == memref) dstMemrefOps.push_back(op); for (Operation *op : dstNode->stores) - if (producerConsumerMemrefs.count( - cast<AffineWriteOpInterface>(op).getMemRef())) + if (cast<AffineWriteOpInterface>(op).getMemRef() == memref) dstMemrefOps.push_back(op); unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); @@ -1474,7 +1502,7 @@ struct GreedyFusion { unsigned maxLegalFusionDepth = 0; SmallVector<ComputationSliceState, 8> depthSliceUnions; depthSliceUnions.resize(dstLoopDepthTest); - FusionStrategy strategy(FusionStrategy::ProducerConsumer); + FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = mlir::canFuseLoops( srcAffineForOp, dstAffineForOp, @@ -1484,82 +1512,27 @@ struct GreedyFusion { maxLegalFusionDepth = i; } - if (maxLegalFusionDepth == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: fusion is not legal at any depth\n"); + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; - } // Check if fusion would be profitable. We skip profitability analysis // for maximal fusion since we already know the maximal legal depth to // fuse. unsigned bestDstLoopDepth = maxLegalFusionDepth; - if (!maximalFusion) { - // Retrieve producer stores from the src loop. - SmallVector<Operation *, 2> producerStores; - for (Operation *op : srcNode->stores) - if (producerConsumerMemrefs.count( - cast<AffineWriteOpInterface>(op).getMemRef())) - producerStores.push_back(op); - - // TODO: Suppport multiple producer stores in profitability - // analysis. We limit profitability analysis to only scenarios with - // a single producer store for now. Note that some multi-store - // producer scenarios will still go through profitability analysis - // if only one of the stores is involved the producer-consumer - // relationship of the candidate loops. - assert(producerStores.size() > 0 && "Expected producer store"); - if (producerStores.size() > 1) - LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " - "supported for this case\n"); - else if (!isFusionProfitable(producerStores[0], producerStores[0], - dstAffineForOp, depthSliceUnions, - maxLegalFusionDepth, &bestDstLoopDepth, - computeToleranceThreshold)) - continue; - } + if (!maximalFusion && + !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) + continue; assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); - ComputationSliceState &bestSlice = - depthSliceUnions[bestDstLoopDepth - 1]; - assert(!bestSlice.isEmpty() && "Missing slice union for depth"); - - // Determine if 'srcId' can be removed after fusion, taking into - // account remaining dependences, escaping memrefs and the fusion - // insertion point. - bool removeSrcNode = canRemoveSrcNodeAfterFusion( - srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, - mdg); - - DenseSet<Value> privateMemrefs; - for (Value memref : producerConsumerMemrefs) { - // Don't create a private memref if 'srcNode' writes to escaping - // memrefs. - if (srcEscapingMemRefs.count(memref) > 0) - continue; - - // Don't create a private memref if 'srcNode' has in edges on - // 'memref' or 'dstNode' has out edges on 'memref'. - if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || - mdg->getOutEdgeCount(dstId, memref) > 0) - continue; - - // If 'srcNode' will be removed but it has out edges on 'memref' to - // nodes other than 'dstNode', we have to preserve dependences and - // cannot create a private memref. - if (removeSrcNode && - any_of(mdg->outEdges[srcId], [&](const auto &edge) { - return edge.value == memref && edge.id != dstId; - })) - continue; - - // Create a private version of this memref. - privateMemrefs.insert(memref); - } + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Missing slice union for depth"); // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); - dstNodeChanged = true; + fuseLoops(srcAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); LLVM_DEBUG(llvm::dbgs() << "Fused src loop " << srcId << " into dst loop " << dstId @@ -1567,20 +1540,18 @@ struct GreedyFusion { << dstAffineForOp << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - if (fusedLoopInsPoint != dstAffineForOp.getOperation()) - dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); + if (insertPointInst != dstAffineForOp.getOperation()) + dstAffineForOp->moveBefore(insertPointInst); // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, - removeSrcNode); + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); // Collect slice loop stats. LoopNestStateCollector dstForCollector; dstForCollector.collect(dstAffineForOp); - for (Value memref : privateMemrefs) { + if (createPrivateMemref) { // Create private memref for 'memref' in 'dstAffineForOp'. - // TODO: remove storesForMemref and move the code below to the - // loop-if. SmallVector<Operation *, 4> storesForMemref; for (auto *storeOpInst : dstForCollector.storeOpInsts) { if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == @@ -1592,6 +1563,7 @@ struct GreedyFusion { auto newMemRef = createPrivateMemRef( dstAffineForOp, storesForMemref[0], bestDstLoopDepth, fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); // Add edge from 'newMemRef' node to dstNode. @@ -1602,21 +1574,58 @@ struct GreedyFusion { LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstAffineForOp.getOperation()); + // Add new load ops to current Node load op list 'loads' to continue + // fusing based on new operands. + for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { + // NOTE: Change 'loads' to a hash set in case efficiency is an + // issue. We still use a vector since it's expected to be small. + if (!llvm::is_contained(loads, loadOpInst)) + loads.push_back(loadOpInst); + } + // Clear visited memrefs after fusion so that previously visited src + // nodes are considered for fusion again in the context of the new + // fused node. + // TODO: This shouldn't be necessary if we visited candidates in the + // dependence graph in post-order or once we fully support multi-store + // producers. Currently, in a multi-store producer scenario such as + // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing + // edges. However, after fusing B+C, A has a single outgoing edge and + // can be fused if we revisit it in the context of the new fused B+C + // node. + visitedMemrefs.clear(); + // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts); - - if (removeSrcNode) { - LLVM_DEBUG(llvm::dbgs() - << "Removing src loop " << srcId << " after fusion\n"); - // srcNode is no longer valid after it is removed from mdg. - srcAffineForOp.erase(); - mdg->removeNode(srcId); - srcNode = nullptr; + // Remove old src loop nest if it no longer has outgoing dependence + // edges, and if it does not write to a memref which escapes the + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been + // fused into 'dstNode' and write region of 'dstNode' covers the write + // region of 'srcNode', and 'srcNode' has no other users so it is safe + // to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { + mdg->removeNode(srcNode->id); + srcNode->op->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created new + // fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); + } + } + } } } - } while (dstNodeChanged); + } } } @@ -1627,6 +1636,7 @@ struct GreedyFusion { while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); + worklistSet.erase(dstId); // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) @@ -1688,7 +1698,7 @@ struct GreedyFusion { SmallVector<ComputationSliceState, 8> depthSliceUnions; depthSliceUnions.resize(dstLoopDepthTest); unsigned maxLegalFusionDepth = 0; - FusionStrategy strategy(memref); + FusionStrategy strategy(FusionStrategy::Sibling, memref); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = mlir::canFuseLoops( sibAffineForOp, dstAffineForOp, @@ -1702,10 +1712,10 @@ struct GreedyFusion { if (maxLegalFusionDepth == 0) continue; - unsigned bestDstLoopDepth = maxLegalFusionDepth; + unsigned bestDstLoopDepth = dstLoopDepthTest; if (!maximalFusion) { // Check if fusion would be profitable. - if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, depthSliceUnions, maxLegalFusionDepth, &bestDstLoopDepth, computeToleranceThreshold)) continue; diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 9749a8de2351..9759300f2e42 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -191,8 +191,11 @@ gatherLoadsAndStores(AffineForOp forOp, /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. // TODO: Generalize this check for sibling and more generic fusion scenarios. // TODO: Support forward slice fusion. -static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, - ArrayRef<Operation *> dstOps) { +static unsigned getMaxLoopDepth(ArrayRef<Operation *> dstOps, + FusionStrategy fusionStrategy) { + assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer && + "Fusion strategy not supported"); + if (dstOps.empty()) // Expected at least one memory operation. // TODO: Revisit this case with a specific example. @@ -200,14 +203,15 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps, // Filter out ops in 'dstOps' that do not use the producer-consumer memref so // that they are not considered for analysis. - DenseSet<Value> producerConsumerMemrefs; - gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); + // TODO: Currently, we pass the producer-consumer memref through + // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we + // generalize the algorithm. SmallVector<Operation *, 4> targetDstOps; for (Operation *dstOp : dstOps) { auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp); Value memref = loadOp ? loadOp.getMemRef() : cast<AffineWriteOpInterface>(dstOp).getMemRef(); - if (producerConsumerMemrefs.count(memref) > 0) + if (memref == fusionStrategy.memref) targetDstOps.push_back(dstOp); } @@ -304,10 +308,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, // loop dependences. // TODO: Enable this check for sibling and more generic loop fusion // strategies. - if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { + if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) { // TODO: 'getMaxLoopDepth' does not support forward slice fusion. assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); - if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { + if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); return FusionResult::FailFusionDependence; } @@ -320,7 +324,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, // Filter out ops in 'opsA' to compute the slice union based on the // assumptions made by the fusion strategy. SmallVector<Operation *, 4> strategyOpsA; - switch (fusionStrategy.getStrategy()) { + switch (fusionStrategy.strategy) { case FusionStrategy::Generic: // Generic fusion. Take into account all the memory operations to compute // the slice union. @@ -328,9 +332,10 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, break; case FusionStrategy::ProducerConsumer: // Producer-consumer fusion (AffineLoopFusion pass) only takes into - // account stores in 'srcForOp' to compute the slice union. + // account stores to 'memref' in 'srcForOp' to compute the slice union. for (Operation *op : opsA) { - if (isa<AffineWriteOpInterface>(op)) + auto store = dyn_cast<AffineWriteOpInterface>(op); + if (store && store.getMemRef() == fusionStrategy.memref) strategyOpsA.push_back(op); } break; @@ -339,7 +344,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, // to 'memref' in 'srcForOp' to compute the slice union. for (Operation *op : opsA) { auto load = dyn_cast<AffineReadOpInterface>(op); - if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) + if (load && load.getMemRef() == fusionStrategy.memref) strategyOpsA.push_back(op); } break; @@ -623,23 +628,3 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); return true; } - -/// Returns in 'producerConsumerMemrefs' the memrefs involved in a -/// producer-consumer dependence between write ops in 'srcOps' and read ops in -/// 'dstOps'. -void mlir::gatherProducerConsumerMemrefs( - ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps, - DenseSet<Value> &producerConsumerMemrefs) { - // Gather memrefs from stores in 'srcOps'. - DenseSet<Value> srcStoreMemRefs; - for (Operation *op : srcOps) - if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) - srcStoreMemRefs.insert(storeOp.getMemRef()); - - // Compute the intersection between memrefs from stores in 'srcOps' and - // memrefs from loads in 'dstOps'. - for (Operation *op : dstOps) - if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) - if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) - producerConsumerMemrefs.insert(loadOp.getMemRef()); -} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index c1bccea4c9f5..a23f0e2ee430 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -364,8 +364,8 @@ func @should_fuse_and_move_to_preserve_war_dep() { // ----- -// CHECK-LABEL: func @should_fuse_if_top_level_access() { -func @should_fuse_if_top_level_access() { +// CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() { +func @should_fuse_with_private_memref_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -378,45 +378,14 @@ func @should_fuse_if_top_level_access() { %c0 = constant 4 : index %v1 = affine.load %m[%c0] : memref<10xf32> - // Top-level load to '%m' should prevent creating a private memref but - // loop nests should be fused and '%i0' should be removed. - // CHECK: %[[m:.*]] = alloc() : memref<10xf32> - // CHECK-NOT: alloc - - // CHECK: affine.for %[[i1:.*]] = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %[[m]][%[[i1]]] : memref<10xf32> - // CHECK-NEXT: affine.load %[[m]][%[[i1]]] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: affine.load %[[m]][%{{.*}}] : memref<10xf32> - return -} - -// ----- - -// CHECK-LABEL: func @should_fuse_but_not_remove_src() { -func @should_fuse_but_not_remove_src() { - %m = alloc() : memref<100xf32> - %cf7 = constant 7.0 : f32 - - affine.for %i0 = 0 to 100 { - affine.store %cf7, %m[%i0] : memref<100xf32> - } - affine.for %i1 = 0 to 17 { - %v0 = affine.load %m[%i1] : memref<100xf32> - } - %v1 = affine.load %m[99] : memref<100xf32> - - // Loop '%i0' and '%i1' should be fused but '%i0' shouldn't be removed to - // preserve the dependence with the top-level access. - // CHECK: affine.for %{{.*}} = 0 to 100 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<100xf32> + // Top-level load to '%{{.*}}' should prevent fusion. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 17 { + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.load %{{.*}}[99] : memref<100xf32> - // CHECK-NEXT: return return } @@ -1141,8 +1110,8 @@ func @should_fuse_with_private_memrefs_with_ diff _shapes() { // ----- -// CHECK-LABEL: func @should_fuse_live_out_arg_but_preserve_src_loop(%{{.*}}: memref<10xf32>) { -func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) { +// CHECK-LABEL: func @should_not_fuse_live_out_arg(%{{.*}}: memref<10xf32>) { +func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { @@ -1160,7 +1129,6 @@ func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 9 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1192,8 +1160,8 @@ func @should_fuse_live_out_arg(%arg0: memref<10xf32>) { // ----- -// CHECK-LABEL: func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> -func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> { +// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32> +func @should_not_fuse_escaping_memref() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { @@ -1202,21 +1170,19 @@ func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> { affine.for %i1 = 0 to 9 { %v0 = affine.load %m[%i1] : memref<10xf32> } - // This tests that the loop nest '%i0' should not be removed after fusion - // because it writes to memref '%m', which is returned by the function, and - // the '%i1' memory region does not cover '%i0' memory region. - + // This tests that the loop nest '%{{.*}}' should not be removed after fusion + // because it writes to memref '%{{.*}}' which is returned by the function. // CHECK-DAG: alloc() : memref<10xf32> // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 9 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10xf32> return %m : memref<10xf32> } + // ----- // This should fuse with the %in becoming a 1x1x1. @@ -1264,7 +1230,7 @@ func @R3_to_R2_reshape() { // ----- -func @should_fuse_multi_output_producer() { +func @should_not_fuse_multi_output_producer() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> @@ -1280,10 +1246,12 @@ func @should_fuse_multi_output_producer() { } // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1536,8 +1504,8 @@ func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32> // ----- -// CHECK-LABEL: func @should_fuse_only_two_loops_and_remove_producer() { -func @should_fuse_only_two_loops_and_remove_producer() { +// CHECK-LABEL: func @should_fuse_after_private_memref_creation() { +func @should_fuse_after_private_memref_creation() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> @@ -1557,21 +1525,18 @@ func @should_fuse_only_two_loops_and_remove_producer() { // On the first visit to '%i2', the fusion algorithm can not fuse loop nest // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on - // '%i1'. Then, '%i0' is fused into '%i1' and no private memref is created for - // memref '%a' to be able to remove '%i0' and still preserve the depencence on - // '%a' with '%i2'. - // TODO: Alternatively, we could fuse '%i0' into '%i1' with a private memref, - // the dependence between '%i0' and '%i1' on memref '%a' would no longer exist, - // and '%i0' could be fused into '%i2' as well. Note that this approach would - // duplicate the computation in loop nest '%i0' to loop nests '%i1' and '%i2', - // which would limit its profitability. + // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a + // private memref, the dependence between '%i0' and '%i1' on memref '%a' no + // longer exists, so '%i0' can now be fused into '%i2'. + // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -2255,7 +2220,7 @@ func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<10 } } - // CHECK: affine.for %{{.*}} = 0 to 1024 { + // CHECK: affine.for %{{.*}} = 0 to 1024 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 { // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32> @@ -2346,8 +2311,8 @@ func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref } // CHECK: affine.for %[[i0:.*]] = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[i0]]] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -2408,11 +2373,12 @@ func @mul_add_0(%arg0: memref<3x4xf32>, %arg1: memref<4x3xf32>, %arg2: memref<3x // ----- -// Verify that 'fuseProducerConsumerNodes' fuse a producer loop with a store -// that has multiple outgoing edges. +// Verify that 'fuseProducerConsumerNodes' doesn't fuse a producer loop with +// a store that has multiple outgoing edges. Sibling loop fusion should not fuse +// any of these loops due to dependencies on external memref '%a'. -// CHECK-LABEL: func @should_fuse_multi_outgoing_edge_store_producer -func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) { +// CHECK-LABEL: func @should_not_fuse_multi_outgoing_edge_store_producer1 +func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) { %cst = constant 0.000000e+00 : f32 affine.for %arg0 = 0 to 1 { affine.store %cst, %a[%arg0] : memref<1xf32> @@ -2425,12 +2391,9 @@ func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) { affine.for %arg0 = 0 to 1 { %0 = affine.load %a[%arg0] : memref<1xf32> } - // CHECK: affine.for %{{.*}} = 0 to 1 { - // CHECK-NEXT: affine.store - // CHECK-NEXT: affine.load - // CHECK-NEXT: affine.load - // CHECK-NEXT: } - + // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK: affine.for %{{.*}} = 0 to 1 return } @@ -2700,109 +2663,3 @@ func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) { // MAXIMAL: affine.for // MAXIMAL-NEXT: affine.for // MAXIMAL-NOT: affine.for -// MAXIMAL: return - -// ----- - -// CHECK-LABEL: func @should_fuse_multi_store_producer_and_privatize_memfefs -func @should_fuse_multi_store_producer_and_privatize_memfefs() { - %a = alloc() : memref<10xf32> - %b = alloc() : memref<10xf32> - %c = alloc() : memref<10xf32> - %cst = constant 0.000000e+00 : f32 - affine.for %arg0 = 0 to 10 { - affine.store %cst, %a[%arg0] : memref<10xf32> - affine.store %cst, %b[%arg0] : memref<10xf32> - affine.store %cst, %c[%arg0] : memref<10xf32> - %0 = affine.load %c[%arg0] : memref<10xf32> - } - - affine.for %arg0 = 0 to 10 { - %0 = affine.load %a[%arg0] : memref<10xf32> - } - - affine.for %arg0 = 0 to 10 { - %0 = affine.load %b[%arg0] : memref<10xf32> - } - - // All the memrefs should be privatized except '%c', which is not involved in - // the producer-consumer fusion. - // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: } - - return -} - -// ----- - -func @should_fuse_multi_store_producer_with_scaping_memrefs_and_remove_src( - %a : memref<10xf32>, %b : memref<10xf32>) { - %cst = constant 0.000000e+00 : f32 - affine.for %i0 = 0 to 10 { - affine.store %cst, %a[%i0] : memref<10xf32> - affine.store %cst, %b[%i0] : memref<10xf32> - } - - affine.for %i1 = 0 to 10 { - %0 = affine.load %a[%i1] : memref<10xf32> - } - - affine.for %i2 = 0 to 10 { - %0 = affine.load %b[%i2] : memref<10xf32> - } - - // Producer loop '%i0' should be removed after fusion since fusion is maximal. - // No memref should be privatized since they escape the function. - // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NOT: affine.for - - return -} - -// ----- - -func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src( - %a : memref<10xf32>, %b : memref<10xf32>) { - %cst = constant 0.000000e+00 : f32 - affine.for %i0 = 0 to 10 { - affine.store %cst, %a[%i0] : memref<10xf32> - affine.store %cst, %b[%i0] : memref<10xf32> - } - - affine.for %i1 = 0 to 5 { - %0 = affine.load %a[%i1] : memref<10xf32> - } - - affine.for %i2 = 0 to 10 { - %0 = affine.load %b[%i2] : memref<10xf32> - } - - // Loops '%i0' and '%i2' should be fused first and '%i0' should be removed - // since fusion is maximal. Then the fused loop and '%i1' should be fused - // and the fused loop shouldn't be removed since fusion is not maximal. - // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK: affine.for %{{.*}} = 0 to 5 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NOT: affine.for - - return -} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits