Author: Murali Vijayaraghavan Date: 2022-11-17T22:26:02Z New Revision: dddf6ab27212d9813a360eb95440c61e81a308be
URL: https://github.com/llvm/llvm-project/commit/dddf6ab27212d9813a360eb95440c61e81a308be DIFF: https://github.com/llvm/llvm-project/commit/dddf6ab27212d9813a360eb95440c61e81a308be.diff LOG: Simplifying the SplitReduction logic that uses the control to get the dimension where the extra parallel dimension is inserted Currently, the innerParallel and non innerParallel strategies use two different ways to fix for where the extra loop is inserted and where the extra dimension for the intermediate result is inserted - innerParallel adds the extra (parallel) loop right after the pre-existing reduction loop, whereas non innerParallel adds the reduction loop in the successor to the index supplied by control, and the parallel loop in the index supplied by the control. The semantics of the index supplied by the control is supposed to only control where the extra tensor dimension is inserted in the intermediate tensor. Conflating this index with where the reduction (and parallel) loops are inserted leads to more complex (and confusing) logic overall. This differential removes conflating the two uses of the index, and keeps the reduction and parallel loops in the same vicinity and uses the supplied index to only determine the position of the extra tensor dimension. It also simplifies the code by merging the two strategies in a lot more places. Differential Revision: https://reviews.llvm.org/D137478 Added: Modified: mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 2fb550b27ec0c..26a49b91db1ed 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -34,7 +34,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( SplitReductionOptions control = controlSplitReductionFn(op); int64_t ratio = control.ratio; - unsigned insertSplitDimension = control.index; + unsigned insertSplitIndex = control.index; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); @@ -45,10 +45,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDim]; if (reductionDimSize == ShapedType::kDynamicSize || - reductionDimSize % ratio != 0 || - insertSplitDimension >= loopRanges.size()) + reductionDimSize % ratio != 0) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); + if (op.getNumDpsInits() != 1) + return b.notifyMatchFailure(op, "More than one output in split reduction"); + if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size()) + return b.notifyMatchFailure(op, "Insert dimension position too large " + "compared to intermediate tensor size"); SmallVector<Operation *, 4> combinerOps; if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || @@ -80,25 +84,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( newShape.push_back(ratio); newShape.push_back(op.getShape(operand)[idx] / ratio); } + exprs.push_back(b.getAffineDimExpr(reductionDim)); + exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); reassociation.push_back({index++, index++}); - if (control.innerParallel) { - exprs.push_back(b.getAffineDimExpr(reductionDim)); - exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); - } else { - exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); - exprs.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); - } continue; } newShape.push_back(op.getShape(operand)[idx]); - if (control.innerParallel) { - exprs.push_back( - b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1)); - } else { - exprs.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); - } + exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); reassociation.push_back({index++}); } newMaps.push_back( @@ -122,26 +114,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0)); ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0)); SmallVector<AffineExpr> outputExpr; - for (unsigned idx : - llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { - if (idx == insertSplitDimension) { + for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) { + if (insertSplitIndex == idx) { newOutputShape.push_back(ratio); if (control.innerParallel) { outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1)); } else { - outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); + outputExpr.push_back(b.getAffineDimExpr(reductionDim)); } - continue; } - unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1; - newOutputShape.push_back(oldShape[oldIdx]); - unsigned dim = oldOutputMap.getDimPosition(oldIdx); - if (control.innerParallel) { - outputExpr.push_back( - b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1)); - } else { + if (idx < oldShape.size()) { + newOutputShape.push_back(oldShape[idx]); + unsigned dim = oldOutputMap.getDimPosition(idx); outputExpr.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); + b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); } } Value emptyOrAllocTensor; @@ -164,10 +150,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( op.getContext())); SmallVector<utils::IteratorType> newIteratorTypes; for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { - if (insertSplitDimension == it.index() && !control.innerParallel) + if (reductionDim == it.index() && !control.innerParallel) newIteratorTypes.push_back(utils::IteratorType::parallel); newIteratorTypes.push_back(it.value()); - if (insertSplitDimension == it.index() && control.innerParallel) + if (reductionDim == it.index() && control.innerParallel) newIteratorTypes.push_back(utils::IteratorType::parallel); } // Create the new op matching the original op with an extra parallel @@ -185,7 +171,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( SmallVector<utils::IteratorType> reductionIteratorTypes; SmallVector<AffineExpr> exprs; for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { - if (insertSplitDimension == i) { + if (insertSplitIndex == i) { reductionIteratorTypes.push_back(utils::IteratorType::reduction); } else { exprs.push_back(b.getAffineDimExpr(i)); diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir index cb7a92198c04d..eb035d9ffe092 100644 --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -106,9 +106,9 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> return %0 : tensor<5x2xf32> } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d @@ -117,7 +117,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32> // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> -// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { // CHECK: arith.addf // CHECK: arith.maxf _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits