https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/117666
This commit fixes the implementation of `ConversionPatternRewriter::replaceUsesOfBlockArgument`. The old implementation was different from what the documentation says. ``` /// Replace all the uses of the block argument `from` with value `to`. void ConversionPatternRewriter::replaceUsesOfBlockArgument( BlockArgument from, Value to) { // ... impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } ``` The extra `mapping.lookupOrDefault` was incorrect: we may not replace `from`, but the value that `from` is mapped to (if it is mapped). This function is typically used after a block signature conversion to "fix up" some block arguments. During a 1:N conversion, an argument materialization is inserted. The old implementation could be used to replace the argument materialization by passing the old block argument as the `from` parameter. This was unintuitive, because it's not the block argument that is being replaced. Furthermore, replacing a block arguments of an erased block (scheduled for erasure to be precise) is incorrect from an API perspective because a block argument of an erased block should not have any uses anymore. The new implementation of `replaceUsesOfBlockArgument` now does what the documentation says: it replaces the `from` argument. No extra lookup magic anymore. When an argument materialization should be replaced, users can call `replaceOp` on the argument materialization. >From 18302a346179fda0b04416883a380decae3e4bfd Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Tue, 26 Nov 2024 05:38:50 +0100 Subject: [PATCH] [mlir][Transforms] Dialect conversion: Fix `replaceUsesOfBlockArgument` --- .../Transforms/Utils/DialectConversion.cpp | 2 +- mlir/test/Transforms/test-legalizer.mlir | 20 +++++++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 45 ++++++++++++++++++- 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 60b3656d98a38e..8b7ffd791d2591 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1641,7 +1641,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, }); impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, impl->currentTypeConverter); - impl->mapping.map(impl->mapping.lookupOrDefault(from), to); + impl->mapping.map(from, to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 624add08846a28..dfa619796700eb 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -472,3 +472,23 @@ func.func @circular_mapping() { %0 = "test.erase_op"() : () -> (i64) "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> () } + +// ----- + +// CHECK-LABEL: func @test_replace_uses_of_block_arg() { +// CHECK: "test.convert_block_and_replace_arg"() ({ +// CHECK: ^bb0(%[[arg0:.*]]: f64, %[[arg1:.*]]: f64): +// CHECK: %[[producer:.*]] = "test.type_producer"() : () -> f64 +// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]], %[[arg1]]) : (f64, f64) -> f32 +// CHECK: "test.some_user"(%[[cast]]) : (f32) -> () +// CHECK: }) {legal} : () -> () +// CHECK: "test.return"() : () -> () +// CHECK: } +func.func @test_replace_uses_of_block_arg() { + "test.convert_block_and_replace_arg"() ({ + ^bb0(%arg0: f32): + // expected-remark @below{{'test.some_user' is not legalizable}} + "test.some_user"(%arg0) : (f32) -> () + }) : () -> () + "test.return"() : () -> () +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index e931b394c86210..54699f402e2f1e 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -902,6 +902,44 @@ struct TestUndoBlockArgReplace : public ConversionPattern { } }; +struct TestConvertBlockAndReplaceArg : public ConversionPattern { + TestConvertBlockAndReplaceArg(MLIRContext *ctx, + const TypeConverter &converter) + : ConversionPattern(converter, "test.convert_block_and_replace_arg", + /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + // Expect single region with single block with single block argument. + if (op->getNumRegions() != 1) + return failure(); + if (op->getRegion(0).getBlocks().size() != 1) + return failure(); + Block *block = &op->getRegion(0).front(); + if (block->getArguments().size() != 1) + return failure(); + + // Convert the block argument into to F64 block arguments. + TypeConverter::SignatureConversion result(1); + result.addInputs(0, {rewriter.getF64Type(), rewriter.getF64Type()}); + Block *newBlock = + rewriter.applySignatureConversion(block, result, getTypeConverter()); + + // Replace the first block argument with a new op. + BlockArgument arg = newBlock->getArgument(0); + rewriter.setInsertionPointToStart(newBlock); + Value zero = rewriter.create<TestTypeProducerOp>(op->getLoc(), + rewriter.getF64Type()); + rewriter.replaceUsesOfBlockArgument(arg, zero); + + // Mark the op as legal. + rewriter.modifyOpInPlace( + op, [&]() { op->setAttr("legal", rewriter.getUnitAttr()); }); + return success(); + } +}; + /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. /// This is to test the rollback logic. struct TestUndoMoveOpBefore : public ConversionPattern { @@ -1265,7 +1303,8 @@ struct TestLegalizePatternDriver TestCreateUnregisteredOp, TestUndoMoveOpBefore, TestUndoPropertiesModification, TestEraseOp>(&getContext()); patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, - TestPassthroughInvalidOp>(&getContext(), converter); + TestPassthroughInvalidOp, TestConvertBlockAndReplaceArg>( + &getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -1317,6 +1356,10 @@ struct TestLegalizePatternDriver target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); + target.addDynamicallyLegalOp( + OperationName("test.convert_block_and_replace_arg", &getContext()), + [](Operation *op) { return op->hasAttr("legal"); }); + // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet<Operation *> unlegalizedOps; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits