Author: Thomas Raoux Date: 2020-12-23T11:25:01-08:00 New Revision: 74186880ba99b37c0375e9d87df818beee8b4ff2
URL: https://github.com/llvm/llvm-project/commit/74186880ba99b37c0375e9d87df818beee8b4ff2 DIFF: https://github.com/llvm/llvm-project/commit/74186880ba99b37c0375e9d87df818beee8b4ff2.diff LOG: [mlir][vector] Add more vector Ops canonicalization Add canonicalization for BroadcastOp, ExtractStrideSlicesOp and ShapeCastOp Differential Revision: https://reviews.llvm.org/D93120 Added: Modified: mlir/include/mlir/Dialect/Vector/VectorOps.td mlir/lib/Dialect/Vector/VectorOps.cpp mlir/test/Dialect/Vector/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 13aba2076ee9..e031f87cfb8e 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -271,6 +271,7 @@ def Vector_BroadcastOp : }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Vector_ShuffleOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index a3ad355d30b2..539e00d58dbf 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1110,6 +1110,36 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { return {}; } +namespace { + +// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In +// the degenerated case where the broadcast only adds dimensions of size 1 it +// can be replaced by a ShapeCastOp. This canonicalization checks if the total +// number of elements is the same before and after the broadcast to detect if +// the only change in the vector type are new dimensions of size 1. +class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> { +public: + using OpRewritePattern<BroadcastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>(); + if (!srcVecType || broadcastOp.getVectorType().getNumElements() != + srcVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp<ShapeCastOp>( + broadcastOp, broadcastOp.getVectorType(), broadcastOp.source()); + return success(); + } +}; + +} // namespace + +void BroadcastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<BroadcastToShapeCast>(context); +} + //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// @@ -1768,7 +1798,8 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { namespace { -// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp. +// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to +// ConstantMaskOp. class StridedSliceConstantMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: @@ -1847,14 +1878,70 @@ class StridedSliceConstantFolder final } }; +// Helper that returns a subset of `arrayAttr` as a vector of int64_t. +static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront = 0, + unsigned dropBack = 0) { + assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); + auto range = arrayAttr.getAsRange<IntegerAttr>(); + SmallVector<int64_t, 4> res; + res.reserve(arrayAttr.size() - dropFront - dropBack); + for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; + it != eit; ++it) + res.push_back((*it).getValue().getSExtValue()); + return res; +} + +// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to +// BroadcastOp(ExtractStrideSliceOp). +class StridedSliceBroadcast final + : public OpRewritePattern<ExtractStridedSliceOp> { +public: + using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto broadcast = op.vector().getDefiningOp<BroadcastOp>(); + if (!broadcast) + return failure(); + auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>(); + unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; + auto dstVecType = op.getType().cast<VectorType>(); + unsigned dstRank = dstVecType.getRank(); + unsigned rankDiff = dstRank - srcRrank; + // Check if the most inner dimensions of the source of the broacast are the + // same as the destination of the extract. If this is the case we can just + // use a broadcast as the original dimensions are untouched. + bool lowerDimMatch = true; + for (unsigned i = 0; i < srcRrank; i++) { + if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { + lowerDimMatch = false; + break; + } + } + Value source = broadcast.source(); + if (!lowerDimMatch) { + // The inner dimensions don't match, it means we need to extract from the + // source of the orignal broadcast and then broadcast the extracted value. + source = rewriter.create<ExtractStridedSliceOp>( + op->getLoc(), source, + getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), + getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), + getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); + } + rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); + return success(); + } +}; + } // end anonymous namespace void ExtractStridedSliceOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>( - context); + results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder, + StridedSliceBroadcast>(context); } //===----------------------------------------------------------------------===// @@ -2652,10 +2739,12 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { return source(); // Canceling shape casts. - if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) + if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) { if (result().getType() == otherOp.source().getType()) return otherOp.source(); - + setOperand(otherOp.source()); + return getResult(); + } return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index f07285d7d98c..f94c3bcce5be 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -613,4 +613,51 @@ func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) { return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32> } +// ----- + +// CHECK-LABEL: extract_strided_broadcast +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf16> to vector<2x4xf16> +// CHECK-NEXT: return %[[B]] : vector<2x4xf16> +func @extract_strided_broadcast(%arg0: vector<4xf16>) -> vector<2x4xf16> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : + vector<16x4xf16> to vector<2x4xf16> + return %1 : vector<2x4xf16> +} + +// ----- + +// CHECK-LABEL: extract_strided_broadcast2 +// CHECK: %[[E:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf16> to vector<2xf16> +// CHECK-NEXT: %[[B:.*]] = vector.broadcast %[[E]] : vector<2xf16> to vector<2x2xf16> +// CHECK-NEXT: return %[[B]] : vector<2x2xf16> +func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : + vector<16x4xf16> to vector<2x2xf16> + return %1 : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: consecutive_shape_cast +// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16> +// CHECK-NEXT: return %[[C]] : vector<4x4xf16> +func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> { + %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16> + %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16> + return %1 : vector<4x4xf16> +} + +// ----- + +// CHECK-LABEL: broadcast_to_shapecast +// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16> +// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16> +func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> { + %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16> + return %0 : vector<1x4x4xf16> +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits