https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83286
When a block signature is converted during dialect conversion a `BlockTypeConversionRewrite` object is stored in the stack of rewrites. Such an object represents multiple steps: - Splitting the old block, i.e., creating a new block and moving all operations over. - Rewriting block arguments. - Erasing the old block. We have dedicated `IRRewrite` objects that represent "creating a block", "moving an op" and "erasing a block". This commit reuses those rewrite objects, so that there is less work to do in `BlockTypeConversionRewrite::rollback` and `BlockTypeConversionRewrite::commit`. >From ab78e8c90ca3ecf60e1192c198d9f5025563dec2 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Wed, 28 Feb 2024 16:33:00 +0000 Subject: [PATCH] [mlir][Transforms][NFC] Simplify `BlockTypeConversionRewrite` --- .../Transforms/Utils/DialectConversion.cpp | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index b81495a95c80ed..cac990d498d7d3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -746,24 +746,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// block is returned containing the new arguments. Returns `block` if it did /// not require conversion. FailureOr<Block *> convertBlockSignature( - Block *block, const TypeConverter *converter, + ConversionPatternRewriter &rewriter, Block *block, + const TypeConverter *converter, TypeConverter::SignatureConversion *conversion = nullptr); /// Convert the types of non-entry block arguments within the given region. LogicalResult convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, + ConversionPatternRewriter &rewriter, Region *region, + const TypeConverter &converter, ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); /// Apply a signature conversion on the given region, using `converter` for /// materializations if not null. Block * - applySignatureConversion(Region *region, + applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter); /// Convert the types of block arguments within the given region. FailureOr<Block *> - convertRegionTypes(Region *region, const TypeConverter &converter, + convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, + const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); /// Apply the given signature conversion on the given block. The new block @@ -773,7 +776,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// translate between the origin argument types and those specified in the /// signature conversion. Block *applySignatureConversion( - Block *block, const TypeConverter *converter, + ConversionPatternRewriter &rewriter, Block *block, + const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion); //===--------------------------------------------------------------------===// @@ -940,24 +944,10 @@ void BlockTypeConversionRewrite::commit() { rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType())); } } - - assert(origBlock->empty() && "expected empty block"); - origBlock->dropAllDefinedValueUses(); - delete origBlock; - origBlock = nullptr; } void BlockTypeConversionRewrite::rollback() { - // Drop all uses of the new block arguments and replace uses of the new block. - for (int i = block->getNumArguments() - 1; i >= 0; --i) - block->getArgument(i).dropAllUses(); block->replaceAllUsesWith(origBlock); - - // Move the operations back the original block, move the original block back - // into its original location and the delete the new block. - origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); - eraseBlock(block); } LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( @@ -1173,10 +1163,11 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Type Conversion FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( - Block *block, const TypeConverter *converter, + ConversionPatternRewriter &rewriter, Block *block, + const TypeConverter *converter, TypeConverter::SignatureConversion *conversion) { if (conversion) - return applySignatureConversion(block, converter, *conversion); + return applySignatureConversion(rewriter, block, converter, *conversion); // If a converter wasn't provided, and the block wasn't already converted, // there is nothing we can do. @@ -1185,35 +1176,39 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( // Try to convert the signature for the block with the provided converter. if (auto conversion = converter->convertBlockSignature(block)) - return applySignatureConversion(block, converter, *conversion); + return applySignatureConversion(rewriter, block, converter, *conversion); return failure(); } Block *ConversionPatternRewriterImpl::applySignatureConversion( - Region *region, TypeConverter::SignatureConversion &conversion, + ConversionPatternRewriter &rewriter, Region *region, + TypeConverter::SignatureConversion &conversion, const TypeConverter *converter) { if (!region->empty()) - return *convertBlockSignature(®ion->front(), converter, &conversion); + return *convertBlockSignature(rewriter, ®ion->front(), converter, + &conversion); return nullptr; } FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( - Region *region, const TypeConverter &converter, + ConversionPatternRewriter &rewriter, Region *region, + const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { regionToConverter[region] = &converter; if (region->empty()) return nullptr; - if (failed(convertNonEntryRegionTypes(region, converter))) + if (failed(convertNonEntryRegionTypes(rewriter, region, converter))) return failure(); - FailureOr<Block *> newEntry = - convertBlockSignature(®ion->front(), &converter, entryConversion); + FailureOr<Block *> newEntry = convertBlockSignature( + rewriter, ®ion->front(), &converter, entryConversion); return newEntry; } LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, + ConversionPatternRewriter &rewriter, Region *region, + const TypeConverter &converter, ArrayRef<TypeConverter::SignatureConversion> blockConversions) { regionToConverter[region] = &converter; if (region->empty()) @@ -1234,16 +1229,18 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( : const_cast<TypeConverter::SignatureConversion *>( &blockConversions[blockIdx++]); - if (failed(convertBlockSignature(&block, &converter, blockConversion))) + if (failed(convertBlockSignature(rewriter, &block, &converter, + blockConversion))) return failure(); } return success(); } Block *ConversionPatternRewriterImpl::applySignatureConversion( - Block *block, const TypeConverter *converter, + ConversionPatternRewriter &rewriter, Block *block, + const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { - MLIRContext *ctx = eraseRewriter.getContext(); + MLIRContext *ctx = rewriter.getContext(); // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); @@ -1253,11 +1250,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // Split the block at the beginning to get a new block to use for the updated // signature. - Block *newBlock = block->splitBlock(block->begin()); + Block *newBlock = rewriter.splitBlock(block, block->begin()); block->replaceAllUsesWith(newBlock); - // Unlink the block, but do not erase it yet, so that the change can be rolled - // back. - block->getParent()->getBlocks().remove(block); // Map all new arguments to the location of the argument they originate from. SmallVector<Location> newLocs(convertedTypes.size(), @@ -1333,6 +1327,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo, converter); + + // Erase the old block. (It is just unlinked for now and will be erased during + // cleanup.) + rewriter.eraseBlock(block); + return newBlock; } @@ -1531,7 +1530,7 @@ Block *ConversionPatternRewriter::applySignatureConversion( assert(!impl->wasOpReplaced(region->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->applySignatureConversion(region, conversion, converter); + return impl->applySignatureConversion(*this, region, conversion, converter); } FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( @@ -1540,7 +1539,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( assert(!impl->wasOpReplaced(region->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->convertRegionTypes(region, converter, entryConversion); + return impl->convertRegionTypes(*this, region, converter, entryConversion); } LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( @@ -1549,7 +1548,8 @@ LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( assert(!impl->wasOpReplaced(region->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->convertNonEntryRegionTypes(region, converter, blockConversions); + return impl->convertNonEntryRegionTypes(*this, region, converter, + blockConversions); } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, @@ -2051,7 +2051,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the region of the block has a type converter, try to convert the block // directly. if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { - if (failed(impl.convertBlockSignature(block, converter))) { + if (failed(impl.convertBlockSignature(rewriter, block, converter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits