https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/145171
This commit adds 1:N support to `ConversionPatternRewriter::replaceUsesOfBlockArgument`. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet. This commit also reuses `replaceUsesOfBlockArgument` in the implementation of `applySignatureConversion`. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the `applySignatureConversion` implementation into a state where it works both with and without rollbacks. To that end, `applySignatureConversion` should not directly access the `mapping`. Depends on #145155. >From b3760c623e4fe8161232533a0b1e65d7bf883d2d Mon Sep 17 00:00:00 2001 From: Matthias Springer <m...@m-sp.org> Date: Sat, 21 Jun 2025 14:29:50 +0000 Subject: [PATCH] [mlir][Transforms] Add 1:N support to `replaceUsesOfBlockArgument` --- .../mlir/Transforms/DialectConversion.h | 5 +- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 2 +- .../Transforms/Utils/DialectConversion.cpp | 40 +++++++++------ mlir/test/Transforms/test-legalizer.mlir | 31 ++++++++--- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 51 +++++++++++-------- 5 files changed, 82 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5a5f116073a9a..81858812d2623 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter { Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); - /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument from, Value to); + /// Replace all the uses of the block argument `from` with `to`. This + /// function supports both 1:1 and 1:N replacements. + void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to); /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 538016927256b..9e8e746507557 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -294,7 +294,7 @@ static void restoreByValRefArgumentType( Type resTy = typeConverter.convertType( cast<TypeAttr>(byValRefAttr->getValue()).getValue()); - auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg); + Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg); rewriter.replaceUsesOfBlockArgument(arg, valueArg); } } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 774d58973eb91..9cb6f2ba1eaae 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// uses. void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); + /// Replace the given block argument with the given values. The specified + /// converter is used to build materializations (if necessary). + void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, + const TypeConverter *converter); + /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( if (!inputMap) { // This block argument was dropped and no replacement value was provided. // Materialize a replacement value "out of thin air". - buildUnresolvedMaterialization( - MaterializationKind::Source, - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), - /*outputTypes=*/origArgType, /*originalType=*/Type(), converter); - appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); + Value mat = + buildUnresolvedMaterialization( + MaterializationKind::Source, + OpBuilder::InsertPoint(newBlock, newBlock->begin()), + origArg.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/origArgType, /*originalType=*/Type(), converter) + .front(); + replaceUsesOfBlockArgument(origArg, mat, converter); continue; } @@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - mapping.map(origArg, inputMap->replacementValues); - appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); + replaceUsesOfBlockArgument(origArg, inputMap->replacementValues, + converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs); - mapping.map(origArg, std::move(replArgVals)); - appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); + replaceUsesOfBlockArgument(origArg, replArgs, converter); } appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); @@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } +void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( + BlockArgument from, ValueRange to, const TypeConverter *converter) { + appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); + mapping.map(from, to); +} + void ConversionPatternRewriterImpl::eraseBlock(Block *block) { assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); @@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - Value to) { + ValueRange to) { LLVM_DEBUG({ impl->logger.startLine() << "** Replace Argument : '" << from << "'"; if (Operation *parentOp = from.getOwner()->getParentOp()) { @@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, impl->logger.getOStream() << " (unlinked block)\n"; } }); - impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, - impl->currentTypeConverter); - impl->mapping.map(from, to); + impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter); } Value ConversionPatternRewriter::getRemappedValue(Value key) { diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 204c8c1456826..79518b04e7158 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -300,18 +300,35 @@ func.func @create_illegal_block() { // ----- // CHECK-LABEL: @undo_block_arg_replace +// expected-remark@+1{{applyPartialConversion failed}} +module { func.func @undo_block_arg_replace() { - // expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}} - "test.undo_block_arg_replace"() ({ - ^bb0(%arg0: i32): - // CHECK: ^bb0(%[[ARG:.*]]: i32): - // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32) + // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}} + "test.block_arg_replace"() ({ + ^bb0(%arg0: i32, %arg1: i16): + // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16): + // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32) "test.return"(%arg0) : (i32) -> () - }) : () -> () - // expected-remark@+1 {{op 'func.return' is not legalizable}} + }) {trigger_rollback} : () -> () return } +} + +// ----- + +// CHECK-LABEL: @replace_block_arg_1_to_n +func.func @replace_block_arg_1_to_n() { + // CHECK: "test.block_arg_replace" + "test.block_arg_replace"() ({ + ^bb0(%arg0: i32, %arg1: i16): + // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16): + // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32 + // CHECK-NEXT: "test.return"(%[[cast]]) : (i32) + "test.return"(%arg0) : (i32) -> () + }) : () -> () + "test.return"() : () -> () +} // ----- diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d073843484d81..588e529665dd1 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern { } }; -/// A simple pattern that tests the undo mechanism when replacing the uses of a -/// block argument. -struct TestUndoBlockArgReplace : public ConversionPattern { - TestUndoBlockArgReplace(MLIRContext *ctx) - : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} +/// A simple pattern that tests the "replaceUsesOfBlockArgument" API. +struct TestBlockArgReplace : public ConversionPattern { + TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1, + ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - auto illegalOp = - rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); + // Replace the first block argument with 2x the second block argument. + Value repl = op->getRegion(0).getArgument(1); rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), - illegalOp->getResult(0)); - rewriter.modifyOpInPlace(op, [] {}); + {repl, repl}); + rewriter.modifyOpInPlace(op, [&] { + // If the "trigger_rollback" attribute is set, keep the op illegal, so + // that a rollback is triggered. + if (!op->hasAttr("trigger_rollback")) + op->setAttr("is_legal", rewriter.getUnitAttr()); + }); return success(); } }; @@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver TestTypeConverter converter; mlir::RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); - patterns - .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, - TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, - TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, - TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, - TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, - TestNonRootReplacement, TestBoundedRecursiveRewrite, - TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, - TestCreateUnregisteredOp, TestUndoMoveOpBefore, - TestUndoPropertiesModification, TestEraseOp, - TestRepetitive1ToNConsumer>(&getContext()); + patterns.add< + TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, + TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, + TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32, + TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, + TestUpdateConsumerType, TestNonRootReplacement, + TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, + TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore, + TestUndoPropertiesModification, TestEraseOp, + TestRepetitive1ToNConsumer>(&getContext()); patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, - TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( - &getContext(), converter); + TestPassthroughInvalidOp, TestMultiple1ToNReplacement, + TestBlockArgReplace>(&getContext(), converter); patterns.add<TestConvertBlockArgs>(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver }); target.addDynamicallyLegalOp<func::CallOp>( [&](func::CallOp op) { return converter.isLegal(op); }); + target.addDynamicallyLegalOp( + OperationName("test.block_arg_replace", &getContext()), + [](Operation *op) { return op->hasAttr("is_legal"); }); // TestCreateUnregisteredOp creates `arith.constant` operation, // which was not added to target intentionally to test _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits