llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> This commit generalizes `replaceUsesOfBlockArgument` to `replaceAllUsesWith`. In rollback mode, the same restrictions keep applying: a value cannot be replaced multiple times and a call to `replaceAllUsesWith` will replace all current and future uses of the `from` value. `replaceAllUsesWith` is now fully supported and its behavior is consistent with the remaining dialect conversion API. Before this commit, `replaceAllUsesWith` was immediately reflected in the IR when running in rollback mode. After this commit, `replaceAllUsesWith` changes are materialized in a delayed fashion, at the end of the dialect conversion. This is consistent with the `replaceUsesOfBlockArgument` and `replaceOp` APIs. `replaceAllUsesExcept` etc. are still not supported and will be deactivated on the `ConversionPatternRewriter` (when running in rollback mode) in a follow-up commit. Note for LLVM integration: Replace `replaceUsesOfBlockArgument` with `replaceAllUsesWith`. If you are seeing failures, you may have patterns that use `replaceAllUsesWith` incorrectly (e.g., being called multiple times on the same value) or bypass the rewriter API entirely. E.g., such failures were mitigated in Flang by switching to the walk-patterns driver (#<!-- -->156171). You can temporarily reactivate the old behavior by calling `RewriterBase::replaceAllUsesWith`. However, note that that behavior is faulty in a dialect conversion. E.g., the base `RewriterBase::replaceAllUsesWith` implementation does not see uses of the `from` value that have not materialized yet and will, therefore, not replace them. --- Patch is 22.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155244.diff 7 Files Affected: - (modified) mlir/include/mlir/IR/PatternMatch.h (+1-1) - (modified) mlir/include/mlir/Transforms/DialectConversion.h (+19-8) - (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1) - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+98-60) - (modified) mlir/test/Transforms/test-legalizer-rollback.mlir (+5-3) - (modified) mlir/test/Transforms/test-legalizer.mlir (+22-4) - (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+12-12) ``````````diff diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 57e73c1d8c7c1..7b0b9cef9c5bd 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder { /// Find uses of `from` and replace them with `to`. Also notify the listener /// about every in-place op modification (for every use that was replaced). - void replaceAllUsesWith(Value from, Value to) { + virtual void replaceAllUsesWith(Value from, Value to) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 14dfbf18836c6..fe48e45a9b98c 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -854,15 +854,26 @@ class ConversionPatternRewriter final : public PatternRewriter { Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); - /// Replace all the uses of the block argument `from` with `to`. This - /// function supports both 1:1 and 1:N replacements. + /// Replace all the uses of `from` with `to`. The type of `from` and `to` is + /// allowed to differ. The conversion driver will try to reconcile all type + /// mismatches that still exist at the end of the conversion with + /// materializations. This function supports both 1:1 and 1:N replacements. /// - /// Note: If `allowPatternRollback` is set to "true", this function replaces - /// all current and future uses of the block argument. This same block - /// block argument must not be replaced multiple times. Uses are not replaced - /// immediately but in a delayed fashion. Patterns may still see the original - /// uses when inspecting IR. - void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to); + /// Note: If `allowPatternRollback` is set to "true", this function behaves + /// slightly different: + /// + /// 1. All current and future uses of `from` are replaced. The same value must + /// not be replaced multiple times. That's an API violation. + /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns + /// may still see the original uses when inspecting IR. + /// 3. Uses within the same block that appear before the defining operation + /// of the replacement value are not replaced. This allows users to + /// perform certain replaceAllUsesExcept-style replacements, even though + /// such API is not directly supported. + void replaceAllUsesWith(Value from, ValueRange to); + void replaceAllUsesWith(Value from, Value to) override { + replaceAllUsesWith(from, ValueRange{to}); + } /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 42c76ed475b4c..93fe2edad5274 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -284,7 +284,7 @@ static void restoreByValRefArgumentType( cast<TypeAttr>(byValRefAttr->getValue()).getValue()); Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); - rewriter.replaceUsesOfBlockArgument(arg, valueArg); + rewriter.replaceAllUsesWith(arg, valueArg); } } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 5ba109d96cf13..d72429298754f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -277,13 +277,14 @@ class IRRewrite { InlineBlock, MoveBlock, BlockTypeConversion, - ReplaceBlockArg, // Operation rewrites MoveOperation, ModifyOperation, ReplaceOperation, CreateOperation, - UnresolvedMaterialization + UnresolvedMaterialization, + // Value rewrites + ReplaceValue }; virtual ~IRRewrite() = default; @@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::CreateBlock && - rewrite->getKind() <= Kind::ReplaceBlockArg; + rewrite->getKind() <= Kind::BlockTypeConversion; } protected: @@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite { Block *block; }; +/// A value rewrite. +class ValueRewrite : public IRRewrite { +public: + /// Return the value that this rewrite operates on. + Value getValue() const { return value; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceValue; + } + +protected: + ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Value value) + : IRRewrite(kind, rewriterImpl), value(value) {} + + // The value that this rewrite operates on. + Value value; +}; + /// Creation of a block. Block creations are immediately reflected in the IR. /// There is no extra work to commit the rewrite. During rollback, the newly /// created block is erased. @@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite { Block *newBlock; }; -/// Replacing a block argument. This rewrite is not immediately reflected in the +/// Replacing a value. This rewrite is not immediately reflected in the /// IR. An internal IR mapping is updated, but the actual replacement is delayed /// until the rewrite is committed. -class ReplaceBlockArgRewrite : public BlockRewrite { +class ReplaceValueRewrite : public ValueRewrite { public: - ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, BlockArgument arg, - const TypeConverter *converter) - : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), + ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value, + const TypeConverter *converter) + : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value), converter(converter) {} static bool classof(const IRRewrite *rewrite) { - return rewrite->getKind() == Kind::ReplaceBlockArg; + return rewrite->getKind() == Kind::ReplaceValue; } void commit(RewriterBase &rewriter) override; @@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite { void rollback() override; private: - BlockArgument arg; - - /// The current type converter when the block argument was replaced. + /// The current type converter when the value was replaced. const TypeConverter *converter; }; @@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// uses. void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); - /// Replace the given block argument with the given values. The specified + /// Replace the uses of the given value with the given values. The specified /// converter is used to build materializations (if necessary). - void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, - const TypeConverter *converter); + void replaceAllUsesWith(Value from, ValueRange to, + const TypeConverter *converter); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { IRRewriter notifyingRewriter; #ifndef NDEBUG - /// A set of replaced block arguments. This set is for debugging purposes - /// only and it is maintained only if `allowPatternRollback` is set to - /// "true". - DenseSet<BlockArgument> replacedArgs; + /// A set of replaced values. This set is for debugging purposes only and it + /// is maintained only if `allowPatternRollback` is set to "true". + DenseSet<Value> replacedValues; /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, - Value repl) { +/// Replace all uses of `from` with `repl`. +static void performReplaceValue(RewriterBase &rewriter, Value from, + Value repl) { if (isa<BlockArgument>(repl)) { - rewriter.replaceAllUsesWith(arg, repl); + // `repl` is a block argument. Directly replace all uses. + rewriter.replaceAllUsesWith(from, repl); return; } - // If the replacement value is an operation, we check to make sure that we - // don't replace uses that are within the parent operation of the - // replacement value. - Operation *replOp = cast<OpResult>(repl).getOwner(); + // If the replacement value is an operation, only replace those uses that: + // - are in a different block than the replacement operation, or + // - are in the same block but after the replacement operation. + // + // Example: + // ^bb0(%arg0: i32): + // %0 = "consumer"(%arg0) : (i32) -> (i32) + // "another_consumer"(%arg0) : (i32) -> () + // + // In the above example, replaceAllUsesWith(%arg0, %0) will replace the + // use in "another_consumer" but not the use in "consumer". When using the + // normal RewriterBase API, this would typically be done with + // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not + // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism + // it cannot be supported efficiently with `allowPatternRollback` set to + // "true". Therefore, the conversion driver is trying to be smart and replaces + // only those uses that do not lead to a dominance violation. E.g., the + // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this + // behavior. + // + // TODO: As we move more and more towards `allowPatternRollback` set to + // "false", we should remove this special handling, in order to align the + // `ConversionPatternRewriter` API with the normal `RewriterBase` API. + Operation *replOp = repl.getDefiningOp(); Block *replBlock = replOp->getBlock(); - rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { + rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); }); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); +void ReplaceValueRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter); if (!repl) return; - performReplaceBlockArg(rewriter, arg, repl); + performReplaceValue(rewriter, value, repl); } -void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } +void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = @@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*isPureTypeConversion=*/false) .front(); - replaceUsesOfBlockArgument(origArg, mat, converter); + replaceAllUsesWith(origArg, mat, converter); continue; } @@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceUsesOfBlockArgument(origArg, inputMap->replacementValues, - converter); + replaceAllUsesWith(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceUsesOfBlockArgument(origArg, replArgs, converter); + replaceAllUsesWith(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( - BlockArgument from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceAllUsesWith( + Value from, ValueRange to, const TypeConverter *converter) { if (!config.allowPatternRollback) { SmallVector<Value> toConv = llvm::to_vector(to); SmallVector<Value> repls = @@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( if (!repl) return; - performReplaceBlockArg(r, from, repl); + performReplaceValue(r, from, repl); return; } #ifndef NDEBUG - // Make sure that a block argument is not replaced multiple times. In - // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current - // uses of the given block argument, but also all future uses that may be - // introduced by future pattern applications. Therefore, it does not make - // sense to call `replaceUsesOfBlockArgument` multiple times with the same - // block argument. Doing so would overwrite the mapping and mess with the - // internal state of the dialect conversion driver. - assert(!replacedArgs.contains(from) && - "attempting to replace a block argument that was already replaced"); - replacedArgs.insert(from); + // Make sure that a value is not replaced multiple times. In rollback mode, + // `replaceAllUsesWith` replaces not only all current uses of the given value, + // but also all future uses that may be introduced by future pattern + // applications. Therefore, it does not make sense to call + // `replaceAllUsesWith` multiple times with the same value. Doing so would + // overwrite the mapping and mess with the internal state of the dialect + // conversion driver. + assert(!replacedValues.contains(from) && + "attempting to replace a value that was already replaced"); + replacedValues.insert(from); #endif // NDEBUG - appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); mapping.map(from, to); + appendRewrite<ReplaceValueRewrite>(from, converter); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { @@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( return impl->convertRegionTypes(region, converter, entryConversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - ValueRange to) { +void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Argument : '" << from << "'"; - if (Operation *parentOp = from.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; + impl->logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast<BlockArgument>(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + impl->logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")\n"; + } else { + impl->logger.getOStream() << " (unlinked block)\n"; + } } }); - impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter); + impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); } Value ConversionPatternRewriter::getRemappedValue(Value key) { @@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // Replace all uses of block arguments. for (auto it : llvm::zip(source->getArguments(), argValues)) - replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); if (fastPath) { // Move all ops at once. diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir index 460911fd88ad1..71e11782e14b0 100644 --- a/mlir/test/Transforms/test-legalizer-rollback.mlir +++ b/mlir/test/Transforms/test-legalizer-rollback.mlir @@ -49,14 +49,16 @@ func.func @create_illegal_block() { // expected-remark@+1{{applyPartialConversion failed}} module { func.func @undo_block_arg_replace() { - // expected-error@+1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}} - "test.block_arg_replace"() ({ + "test.legal_op"() ({ ^bb0(%arg0: i32, %arg1: i16): // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16): + // CHECK-NEXT: "test.value_replace"(%[[ARG0]], %[[ARG1]]) {trigger_rollback} // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32) + // expected-error@+1{{failed to legalize operation 'test.value_replace' that was explicitly marked illegal}} + "test.value_replace"(%arg0, %arg1) {trigger_rollback} : (i32, i16) -> () "test.return"(%arg0) : (i32) -> () - }) {trigger_rollback} : () -> () + }) : () -> () return } } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 3fa42ff6b2757..94c5bb4e93b06 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -269,12 +269,14 @@ builtin.module { // CHECK-LABEL: @replace_block_arg_1_to_n func.func @replace_block_arg_1_to_n() { - // CHECK: "test.block_arg_replace" - "test.block_arg_replace"() ({ + // CHECK: "test.legal_op" + "test.legal_op"() ({ ^bb0(%arg0: i32, %arg1: i16): - // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16): - // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32 + // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16): + // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32 + // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[ARG1]]) {is_legal} : (i32, i16) -> () // CHECK-NEXT: "test.return"(%[[cast]]) : (i32) + "test.value_replace"(%arg0, %arg1) : (i32, i16) -> () "test.return"(%arg0) : (i32) -> () }) : () -> () "test.return"() : () -> () @@ -282,6 +284,22 @@ func.func @replace_block_arg_1_to_n() { // ----- +// CHECK-LABEL: @replace_op_result_1_to_n +func.func @replace_op_result_1_to_n() { + // CHECK: %[[orig:.*]] = "test.legal_op"() : () -> i32 + // CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i16 + %0 = "test.legal_op"() : () -> i32 + %1 = "test.legal_op"() : () -> i16 + + // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[repl]], %[[repl]]) : (i16, i16) -> i32 + // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[repl]]) {is_legal} : (i32, i16) -> () + // CHECK-NEXT: "test.return"(%[[cast]]) : (i32) + "test.value_replace"(%0, %1) : (i32, i16) -> () + "test.return"(%0) : (i32) -> () +} + +// ----- + // Check that a conversion pattern on `test.blackhole` can mark the producer // for deletion. // CHECK-LABEL: @blackhole diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 95f381ec471d6..93b007c792ad9 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -952,19 +952,19 @@ struct TestCreateIllegalBlock : public RewritePattern { } }; -/// A simple pattern that tests the "replaceUsesOfBlockArgument" API. -struct TestBlockArgReplace : public ConversionPattern { - TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter) - : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1, - ctx) {} +/// A simple pattern that tests the "replaceAllUsesWith"... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/155244 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
