https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/80466
>From f51bb7b15be55e682be76f2289b991ed42ab4d41 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Fri, 2 Feb 2024 11:37:03 -0600 Subject: [PATCH 1/3] [mlir][linalg]Implement canonicalizer for `linalg::BroadCastOp` on tensors --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 751edd0228830..11b6f50032c09 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -531,6 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index e86b9762d8581..cddb0671dd58f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1907,6 +1907,20 @@ void BroadcastOp::getEffects( getDpsInits()); } +LogicalResult BroadcastOp::canonicalize(BroadcastOp op, + PatternRewriter &rewriter) { + // For tensor semantics, if op's input and init are same shape, it is a no op. + // Otherwise, with buffer semantics, the op does a copy and we don't + // canonicalize. + if (op.hasPureTensorSemantics() && + (op.getInput().getType() == op.getInit().getType())) { + rewriter.replaceAllUsesWith(op.getResult(), op.getInput()); + rewriter.eraseOp(op); + return success(); + } + return failure(); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// >From c40476b2b7a186af7237a3fdc1599a129d65f749 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Fri, 2 Feb 2024 11:51:52 -0600 Subject: [PATCH 2/3] Add regression test --- mlir/test/Dialect/Linalg/canonicalize.mlir | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 052dc367ca677..a2777a035320f 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1017,3 +1017,15 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32> return %copy : tensor<?x?xf32> } + +// ----- + +// CHECK-LABEL: func @broadcast_same_shape( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>) +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[ARG0]] : tensor<2x3xf32> +func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> { + %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = [] + return %0 : tensor<2x3xf32> +} \ No newline at end of file >From 458e93a3a6cf1f4b28984ff9d1d7da2e9ce60a30 Mon Sep 17 00:00:00 2001 From: Sam <srcarroll...@gmail.com> Date: Fri, 2 Feb 2024 19:24:07 -0600 Subject: [PATCH 3/3] Refactor EraseIdentityGenericOp to be reused by any LinalgOp --- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 2 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 54 ++++++++----------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 11b6f50032c09..272bc3116c5fd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -531,7 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [ let hasCustomAssemblyFormat = 1; let hasVerifier = 1; - let hasCanonicalizeMethod = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cddb0671dd58f..a0f02f6a7f259 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1087,24 +1087,25 @@ LogicalResult GenericOp::verify() { return success(); } namespace { -/// Remove generic operations (on tensors) that are just copying +/// Remove any linalg operation (on tensors) that are just copying /// the values from inputs to the results. Requirements are /// 1) All iterator types are parallel /// 2) The body contains just a yield operation with the yielded values being /// the arguments corresponding to the operands. -struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { - using OpRewritePattern<GenericOp>::OpRewritePattern; +template <typename OpTy> +struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> { + using OpRewritePattern<OpTy>::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, + LogicalResult matchAndRewrite(OpTy linalgOp, PatternRewriter &rewriter) const override { // Check all indexing maps are identity. - if (llvm::any_of(genericOp.getIndexingMapsArray(), + if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) { return !map.isIdentity(); })) return failure(); // Check that the body of the linalg operation is just a linalg.yield // operation. - Block &body = genericOp.getRegion().front(); + Block &body = linalgOp->getRegion(0).front(); if (!llvm::hasSingleElement(body)) return failure(); auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); @@ -1112,18 +1113,18 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { return failure(); // In the buffer case, we need to check exact buffer equality. - if (genericOp.hasPureBufferSemantics()) { - if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 && - genericOp.getDpsInputOperand(0)->get() == - genericOp.getDpsInitOperand(0)->get()) { - rewriter.eraseOp(genericOp); + if (linalgOp.hasPureBufferSemantics()) { + if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 && + linalgOp.getDpsInputOperand(0)->get() == + linalgOp.getDpsInitOperand(0)->get()) { + rewriter.eraseOp(linalgOp); return success(); } return failure(); } // Mixed semantics is not supported yet. - if (!genericOp.hasPureTensorSemantics()) + if (!linalgOp.hasPureTensorSemantics()) return failure(); // Get the argument number of the returned values. That is the operand @@ -1134,8 +1135,8 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); - Value returnedArg = genericOp->getOperand(argumentNumber); - Type resultType = genericOp->getResult(yieldVal.index()).getType(); + Value returnedArg = linalgOp->getOperand(argumentNumber); + Type resultType = linalgOp->getResult(yieldVal.index()).getType(); // The input can have a different type than the result, e.g. a dynamic // input dimension can be turned into a static output dimension. Type returnType = returnedArg.getType(); @@ -1145,21 +1146,21 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { if (sparse_tensor::getSparseTensorEncoding(returnType) || sparse_tensor::getSparseTensorEncoding(resultType)) returnedArg = rewriter.create<sparse_tensor::ConvertOp>( - genericOp.getLoc(), resultType, returnedArg); + linalgOp.getLoc(), resultType, returnedArg); else { if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), resultType)) return failure(); returnedArg = rewriter.create<tensor::CastOp>( - genericOp.getLoc(), resultType, returnedArg); + linalgOp.getLoc(), resultType, returnedArg); } } returnedArgs.push_back(returnedArg); } - if (returnedArgs.size() != genericOp->getNumResults()) + if (returnedArgs.size() != linalgOp->getNumResults()) return failure(); - rewriter.replaceOp(genericOp, returnedArgs); + rewriter.replaceOp(linalgOp, returnedArgs); return success(); } }; @@ -1168,7 +1169,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> { void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<EraseIdentityGenericOp>(context); + results.add<EraseIdentityLinalgOp<GenericOp>>(context); } LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { @@ -1907,18 +1908,9 @@ void BroadcastOp::getEffects( getDpsInits()); } -LogicalResult BroadcastOp::canonicalize(BroadcastOp op, - PatternRewriter &rewriter) { - // For tensor semantics, if op's input and init are same shape, it is a no op. - // Otherwise, with buffer semantics, the op does a copy and we don't - // canonicalize. - if (op.hasPureTensorSemantics() && - (op.getInput().getType() == op.getInit().getType())) { - rewriter.replaceAllUsesWith(op.getResult(), op.getInput()); - rewriter.eraseOp(op); - return success(); - } - return failure(); +void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<EraseIdentityLinalgOp<BroadcastOp>>(context); } //===----------------------------------------------------------------------===// _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits