https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81240
>From c60c43bcd2296715ceca83a3f9666433883ec303 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Mon, 12 Feb 2024 09:05:50 +0000 Subject: [PATCH 1/2] [mlir][Transforms][WIP] RewriteAction BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC --- .../Transforms/Utils/DialectConversion.cpp | 504 +++++++++++------- 1 file changed, 306 insertions(+), 198 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e41231d7cbd390..edca84e5a73f04 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -154,13 +154,12 @@ namespace { struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, unsigned numReplacements, unsigned numArgReplacements, - unsigned numBlockActions, unsigned numIgnoredOperations, + unsigned numRewrites, unsigned numIgnoredOperations, unsigned numRootUpdates) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), numReplacements(numReplacements), - numArgReplacements(numArgReplacements), - numBlockActions(numBlockActions), + numArgReplacements(numArgReplacements), numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), numRootUpdates(numRootUpdates) {} @@ -176,8 +175,8 @@ struct RewriterState { /// The current number of argument replacements queued. unsigned numArgReplacements; - /// The current number of block actions performed. - unsigned numBlockActions; + /// The current number of rewrites performed. + unsigned numRewrites; /// The current number of ignored operations. unsigned numIgnoredOperations; @@ -235,86 +234,6 @@ struct OpReplacement { const TypeConverter *converter; }; -//===----------------------------------------------------------------------===// -// BlockAction - -/// The kind of the block action performed during the rewrite. Actions can be -/// undone if the conversion fails. -enum class BlockActionKind { - Create, - Erase, - Inline, - Move, - Split, - TypeConversion -}; - -/// Original position of the given block in its parent region. During undo -/// actions, the block needs to be placed before `insertBeforeBlock`. -struct BlockPosition { - Region *region; - Block *insertBeforeBlock; -}; - -/// Information needed to undo inlining actions. -/// - the source block -/// - the first inlined operation (could be null if the source block was empty) -/// - the last inlined operation (could be null if the source block was empty) -struct InlineInfo { - Block *sourceBlock; - Operation *firstInlinedInst; - Operation *lastInlinedInst; -}; - -/// The storage class for an undoable block action (one of BlockActionKind), -/// contains the information necessary to undo this action. -struct BlockAction { - static BlockAction getCreate(Block *block) { - return {BlockActionKind::Create, block, {}}; - } - static BlockAction getErase(Block *block, BlockPosition originalPosition) { - return {BlockActionKind::Erase, block, {originalPosition}}; - } - static BlockAction getInline(Block *block, Block *srcBlock, - Block::iterator before) { - BlockAction action{BlockActionKind::Inline, block, {}}; - action.inlineInfo = {srcBlock, - srcBlock->empty() ? nullptr : &srcBlock->front(), - srcBlock->empty() ? nullptr : &srcBlock->back()}; - return action; - } - static BlockAction getMove(Block *block, BlockPosition originalPosition) { - return {BlockActionKind::Move, block, {originalPosition}}; - } - static BlockAction getSplit(Block *block, Block *originalBlock) { - BlockAction action{BlockActionKind::Split, block, {}}; - action.originalBlock = originalBlock; - return action; - } - static BlockAction getTypeConversion(Block *block) { - return BlockAction{BlockActionKind::TypeConversion, block, {}}; - } - - // The action kind. - BlockActionKind kind; - - // A pointer to the block that was created by the action. - Block *block; - - union { - // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and - // contains a pointer to the region that originally contained the block as - // well as the position of the block in that region. - BlockPosition originalPosition; - // In use if kind == BlockActionKind::Split and contains a pointer to the - // block that was split into two parts. - Block *originalBlock; - // In use if kind == BlockActionKind::Inline, and contains the information - // needed to undo the inlining. - InlineInfo inlineInfo; - }; -}; - //===----------------------------------------------------------------------===// // UnresolvedMaterialization @@ -820,6 +739,251 @@ void ArgConverter::insertConversion(Block *newBlock, conversionInfo.insert({newBlock, std::move(info)}); } +//===----------------------------------------------------------------------===// +// IR rewrites +//===----------------------------------------------------------------------===// + +namespace { +/// An IR rewrite that can be committed (upon success) or rolled back (upon +/// failure). +/// +/// The dialect conversion keeps track of IR modifications (requested by the +/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites +/// are directly applied to the IR as the rewriter API is used, some are applied +/// partially, and some are delayed until the `IRRewrite` objects are committed. +class IRRewrite { +public: + /// The kind of the rewrite. Rewrites can be undone if the conversion fails. + enum class Kind { + CreateBlock, + EraseBlock, + InlineBlock, + MoveBlock, + SplitBlock, + BlockTypeConversion + }; + + virtual ~IRRewrite() = default; + + /// Roll back the rewrite. + virtual void rollback() = 0; + + /// Commit the rewrite. + virtual void commit() {} + + Kind getKind() const { return kind; } + + static bool classof(const IRRewrite *rewrite) { return true; } + +protected: + IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) + : kind(kind), rewriterImpl(rewriterImpl) {} + + const Kind kind; + ConversionPatternRewriterImpl &rewriterImpl; +}; + +/// A block rewrite. +class BlockRewrite : public IRRewrite { +public: + /// Return the block that this rewrite operates on. + Block *getBlock() const { return block; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::CreateBlock && + rewrite->getKind() <= Kind::BlockTypeConversion; + } + +protected: + BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : IRRewrite(kind, rewriterImpl), block(block) {} + + // The block that this rewrite operates on. + Block *block; +}; + +/// Creation of a block. Block creations are immediately reflected in the IR. +/// There is no extra work to commit the rewrite. During rollback, the newly +/// created block is erased. +class CreateBlockRewrite : public BlockRewrite { +public: + CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) + : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::CreateBlock; + } + + void rollback() override { + // Unlink all of the operations within this block, they will be deleted + // separately. + auto &blockOps = block->getOperations(); + while (!blockOps.empty()) + blockOps.remove(blockOps.begin()); + block->dropAllDefinedValueUses(); + block->erase(); + } +}; + +/// Erasure of a block. Block erasures are partially reflected in the IR. Erased +/// blocks are immediately unlinked, but only erased when the rewrite is +/// committed. This makes it easier to rollback a block erasure: the block is +/// simply inserted into its original location. +class EraseBlockRewrite : public BlockRewrite { +public: + EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Region *region, Block *insertBeforeBlock) + : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region), + insertBeforeBlock(insertBeforeBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::EraseBlock; + } + + ~EraseBlockRewrite() override { + assert(!block && "rewrite was neither rolled back nor committed"); + } + + void rollback() override { + // The block (owned by this rewrite) was not actually erased yet. It was + // just unlinked. Put it back into its original position. + assert(block && "expected block"); + auto &blockList = region->getBlocks(); + Region::iterator before = insertBeforeBlock + ? Region::iterator(insertBeforeBlock) + : blockList.end(); + blockList.insert(before, block); + block = nullptr; + } + + void commit() override { + // Erase the block. + assert(block && "expected block"); + delete block; + block = nullptr; + } + +private: + // The region in which this block was previously contained. + Region *region; + + // The original successor of this block before it was unlinked. "nullptr" if + // this block was the only block in the region. + Block *insertBeforeBlock; +}; + +/// Inlining of a block. This rewrite is immediately reflected in the IR. +/// Note: This rewrite represents only the inlining of the operations. The +/// erasure of the inlined block is a separate rewrite. +class InlineBlockRewrite : public BlockRewrite { +public: + InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Block *sourceBlock, Block::iterator before) + : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), + sourceBlock(sourceBlock), + firstInlinedInst(sourceBlock->empty() ? nullptr + : &sourceBlock->front()), + lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { + } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::InlineBlock; + } + + void rollback() override { + // Put the operations from the destination block (owned by the rewrite) + // back into the source block. + if (firstInlinedInst) { + assert(lastInlinedInst && "expected operation"); + sourceBlock->getOperations().splice(sourceBlock->begin(), + block->getOperations(), + Block::iterator(firstInlinedInst), + ++Block::iterator(lastInlinedInst)); + } + } + +private: + // The block that originally contained the operations. + Block *sourceBlock; + + // The first inlined operation. + Operation *firstInlinedInst; + + // The last inlined operation. + Operation *lastInlinedInst; +}; + +/// Moving of a block. This rewrite is immediately reflected in the IR. +class MoveBlockRewrite : public BlockRewrite { +public: + MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Region *region, Block *insertBeforeBlock) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), + insertBeforeBlock(insertBeforeBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::MoveBlock; + } + + void rollback() override { + // Move the block back to its original position. + Region::iterator before = + insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); + region->getBlocks().splice(before, block->getParent()->getBlocks(), block); + } + +private: + // The region in which this block was previously contained. + Region *region; + + // The original successor of this block before it was moved. "nullptr" if + // this block was the only block in the region. + Block *insertBeforeBlock; +}; + +/// Splitting of a block. This rewrite is immediately reflected in the IR. +class SplitBlockRewrite : public BlockRewrite { +public: + SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, + Block *originalBlock) + : BlockRewrite(Kind::SplitBlock, rewriterImpl, block), + originalBlock(originalBlock) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::SplitBlock; + } + + void rollback() override { + // Merge back the block that was split out. + originalBlock->getOperations().splice(originalBlock->end(), + block->getOperations()); + block->dropAllDefinedValueUses(); + block->erase(); + } + +private: + // The original block from which this block was split. + Block *originalBlock; +}; + +/// Block type conversion. This rewrite is partially reflected in the IR. +class BlockTypeConversionRewrite : public BlockRewrite { +public: + BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Block *block) + : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::BlockTypeConversion; + } + + // TODO: Block type conversions are currently committed in + // `ArgConverter::applyRewrites`. This should be done in the "commit" method. + void rollback() override; +}; +} // namespace + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -848,13 +1012,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Reset the state of the rewriter to a previously saved point. void resetState(RewriterState state); - /// Erase any blocks that were unlinked from their regions and stored in block - /// actions. - void eraseDanglingBlocks(); + /// Append a rewrite. Rewrites are committed upon success and rolled back upon + /// failure. + template <typename ActionTy, typename... Args> + void appendRewrite(Args &&...args) { + rewrites.push_back( + std::make_unique<ActionTy>(*this, std::forward<Args>(args)...)); + } - /// Undo the block actions (motions, splits) one by one in reverse order until - /// "numActionsToKeep" actions remains. - void undoBlockActions(unsigned numActionsToKeep = 0); + /// Undo the rewrites (motions, splits) one by one in reverse order until + /// "numRewritesToKeep" rewrites remains. + void undoRewrites(unsigned numRewritesToKeep = 0); /// Remap the given values to those with potentially different types. Returns /// success if the values could be remapped, failure otherwise. `valueDiagTag` @@ -954,7 +1122,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { SmallVector<BlockArgument, 4> argReplacements; /// Ordered list of block operations (creations, splits, motions). - SmallVector<BlockAction, 4> blockActions; + SmallVector<std::unique_ptr<IRRewrite>> rewrites; /// A set of operations that should no longer be considered for legalization, /// but were not directly replace/erased/etc. by a pattern. These are @@ -995,6 +1163,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } // namespace detail } // namespace mlir +void BlockTypeConversionRewrite::rollback() { + // Undo the type conversion. + rewriterImpl.argConverter.discardRewrites(block); +} + /// Detach any operations nested in the given operation from their parent /// blocks, and erase the given operation. This can be used when the nested /// operations are scheduled for erasure themselves, so deleting the regions of @@ -1020,7 +1193,7 @@ void ConversionPatternRewriterImpl::discardRewrites() { for (auto &state : rootUpdates) state.resetOperation(); - undoBlockActions(); + undoRewrites(); // Remove any newly created ops. for (UnresolvedMaterialization &materialization : unresolvedMaterializations) @@ -1083,8 +1256,9 @@ void ConversionPatternRewriterImpl::applyRewrites() { argConverter.applyRewrites(mapping); - // Now that the ops have been erased, also erase dangling blocks. - eraseDanglingBlocks(); + // Commit all rewrites. + for (auto &rewrite : rewrites) + rewrite->commit(); } //===----------------------------------------------------------------------===// @@ -1093,8 +1267,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), replacements.size(), argReplacements.size(), - blockActions.size(), ignoredOps.size(), - rootUpdates.size()); + rewrites.size(), ignoredOps.size(), rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1109,8 +1282,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { mapping.erase(replacedArg); argReplacements.resize(state.numArgReplacements); - // Undo any block actions. - undoBlockActions(state.numBlockActions); + // Undo any rewrites. + undoRewrites(state.numRewrites); // Reset any replaced operations and undo any saved mappings. for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) @@ -1149,76 +1322,11 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { operationsWithChangedResults.pop_back(); } -void ConversionPatternRewriterImpl::eraseDanglingBlocks() { - for (auto &action : blockActions) - if (action.kind == BlockActionKind::Erase) - delete action.block; -} - -void ConversionPatternRewriterImpl::undoBlockActions( - unsigned numActionsToKeep) { - for (auto &action : - llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { - switch (action.kind) { - // Delete the created block. - case BlockActionKind::Create: { - // Unlink all of the operations within this block, they will be deleted - // separately. - auto &blockOps = action.block->getOperations(); - while (!blockOps.empty()) - blockOps.remove(blockOps.begin()); - action.block->dropAllDefinedValueUses(); - action.block->erase(); - break; - } - // Put the block (owned by action) back into its original position. - case BlockActionKind::Erase: { - auto &blockList = action.originalPosition.region->getBlocks(); - Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock; - blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock) - : blockList.end()), - action.block); - break; - } - // Put the instructions from the destination block (owned by the action) - // back into the source block. - case BlockActionKind::Inline: { - Block *sourceBlock = action.inlineInfo.sourceBlock; - if (action.inlineInfo.firstInlinedInst) { - assert(action.inlineInfo.lastInlinedInst && "expected operation"); - sourceBlock->getOperations().splice( - sourceBlock->begin(), action.block->getOperations(), - Block::iterator(action.inlineInfo.firstInlinedInst), - ++Block::iterator(action.inlineInfo.lastInlinedInst)); - } - break; - } - // Move the block back to its original position. - case BlockActionKind::Move: { - Region *originalRegion = action.originalPosition.region; - Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock; - originalRegion->getBlocks().splice( - (insertBeforeBlock ? Region::iterator(insertBeforeBlock) - : originalRegion->end()), - action.block->getParent()->getBlocks(), action.block); - break; - } - // Merge back the block that was split out. - case BlockActionKind::Split: { - action.originalBlock->getOperations().splice( - action.originalBlock->end(), action.block->getOperations()); - action.block->dropAllDefinedValueUses(); - action.block->erase(); - break; - } - // Undo the type conversion. - case BlockActionKind::TypeConversion: { - argConverter.discardRewrites(action.block); - break; - } - } - } - blockActions.resize(numActionsToKeep); +void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { + for (auto &rewrite : + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) + rewrite->rollback(); + rewrites.resize(numRewritesToKeep); } LogicalResult ConversionPatternRewriterImpl::remapValues( @@ -1309,7 +1417,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( return failure(); if (Block *newBlock = *result) { if (newBlock != block) - blockActions.push_back(BlockAction::getTypeConversion(newBlock)); + appendRewrite<BlockTypeConversionRewrite>(newBlock); } return result; } @@ -1410,28 +1518,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { Region *region = block->getParent(); Block *origNextBlock = block->getNextNode(); - blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock})); + appendRewrite<EraseBlockRewrite>(block, region, origNextBlock); } void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { if (!previous) { // This is a newly created block. - blockActions.push_back(BlockAction::getCreate(block)); + appendRewrite<CreateBlockRewrite>(block); return; } Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock})); + appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); } void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, Block *continuation) { - blockActions.push_back(BlockAction::getSplit(continuation, block)); + appendRewrite<SplitBlockRewrite>(continuation, block); } void ConversionPatternRewriterImpl::notifyBlockBeingInlined( Block *block, Block *srcBlock, Block::iterator before) { - blockActions.push_back(BlockAction::getInline(block, srcBlock, before)); + appendRewrite<InlineBlockRewrite>(block, srcBlock, before); } void ConversionPatternRewriterImpl::notifyMatchFailure( @@ -1501,8 +1609,8 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { for (Operation &op : *block) eraseOp(&op); - // Unlink the block from its parent region. The block is kept in the block - // action and will be actually destroyed when rewrites are applied. This + // Unlink the block from its parent region. The block is kept in the rewrite + // object and will be actually destroyed when rewrites are applied. This // allows us to keep the operations in the block live and undo the removal by // re-inserting the block. block->getParent()->getBlocks().remove(block); @@ -1700,11 +1808,11 @@ class OperationLegalizer { RewriterState &curState); /// Legalizes the actions registered during the execution of a pattern. - LogicalResult legalizePatternBlockActions(Operation *op, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - RewriterState &state, - RewriterState &newState); + LogicalResult + legalizePatternBlockRewrites(Operation *op, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState); LogicalResult legalizePatternCreatedOperations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState); @@ -1986,8 +2094,8 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, // Legalize each of the actions registered during application. RewriterState newState = impl.getCurrentState(); - if (failed(legalizePatternBlockActions(op, rewriter, impl, curState, - newState)) || + if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, + newState)) || failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || failed(legalizePatternCreatedOperations(rewriter, impl, curState, newState))) { @@ -1998,7 +2106,7 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, return success(); } -LogicalResult OperationLegalizer::legalizePatternBlockActions( +LogicalResult OperationLegalizer::legalizePatternBlockRewrites( Operation *op, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { @@ -2006,22 +2114,22 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions( // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. - for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; - ++i) { - auto &action = impl.blockActions[i]; - if (action.kind == BlockActionKind::TypeConversion || - action.kind == BlockActionKind::Erase) + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get()); + if (!rewrite) + continue; + Block *block = rewrite->getBlock(); + if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite)) continue; // Only check blocks outside of the current operation. - Operation *parentOp = action.block->getParentOp(); - if (!parentOp || parentOp == op || action.block->getNumArguments() == 0) + Operation *parentOp = block->getParentOp(); + if (!parentOp || parentOp == op || block->getNumArguments() == 0) continue; // If the region of the block has a type converter, try to convert the block // directly. - if (auto *converter = - impl.argConverter.getConverter(action.block->getParent())) { - if (failed(impl.convertBlockSignature(action.block, converter))) { + if (auto *converter = impl.argConverter.getConverter(block->getParent())) { + if (failed(impl.convertBlockSignature(block, converter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); @@ -2042,9 +2150,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions( // If this operation should be considered for re-legalization, try it. if (operationsToIgnore.insert(parentOp).second && failed(legalize(parentOp, rewriter))) { - LLVM_DEBUG(logFailure( - impl.logger, "operation '{0}'({1}) became illegal after block action", - parentOp->getName(), parentOp)); + LLVM_DEBUG(logFailure(impl.logger, + "operation '{0}'({1}) became illegal after rewrite", + parentOp->getName(), parentOp)); return failure(); } } >From ebfaca6b688394233b0d6a22f77b8b7cccaf67a8 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Mon, 12 Feb 2024 09:08:21 +0000 Subject: [PATCH 2/2] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`. `RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.) BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC --- mlir/include/mlir/IR/PatternMatch.h | 6 +- .../mlir/Transforms/DialectConversion.h | 9 +-- .../Transforms/Utils/DialectConversion.cpp | 74 +++++++++++++++---- mlir/test/Transforms/test-legalizer.mlir | 14 ++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 20 ++++- 5 files changed, 95 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 78dcfe7f6fc3d2..b8aeea0d23475b 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder { /// Unlink this operation from its current block and insert it right before /// `iterator` in the specified block. - virtual void moveOpBefore(Operation *op, Block *block, - Block::iterator iterator); + void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); /// Unlink this operation from its current block and insert it right after /// `existingOp` which may be in the same or another block in the same @@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder { /// Unlink this operation from its current block and insert it right after /// `iterator` in the specified block. - virtual void moveOpAfter(Operation *op, Block *block, - Block::iterator iterator); + void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); /// Unlink this block and insert it right before `existingBlock`. void moveBlockBefore(Block *block, Block *anotherBlock); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f061d761ecefbb..b028d2b71b3762 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -721,8 +721,8 @@ class ConversionPatternRewriter final : public PatternRewriter { /// PatternRewriter hook for updating the given operation in-place. /// Note: These methods only track updates to the given operation itself, - /// and not nested regions. Updates to regions will still require notification - /// through other more specific hooks above. + /// and not nested regions. Updates to regions will still require + /// notification through other more specific hooks above. void startOpModification(Operation *op) override; /// PatternRewriter hook for updating the given operation in-place. @@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter { // Hide unsupported pattern rewriter API. using OpBuilder::setListener; - void moveOpBefore(Operation *op, Block *block, - Block::iterator iterator) override; - void moveOpAfter(Operation *op, Block *block, - Block::iterator iterator) override; - std::unique_ptr<detail::ConversionPatternRewriterImpl> impl; }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index edca84e5a73f04..85b67bb834de7c 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -760,7 +760,8 @@ class IRRewrite { InlineBlock, MoveBlock, SplitBlock, - BlockTypeConversion + BlockTypeConversion, + MoveOperation }; virtual ~IRRewrite() = default; @@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite { // `ArgConverter::applyRewrites`. This should be done in the "commit" method. void rollback() override; }; + +/// An operation rewrite. +class OperationRewrite : public IRRewrite { +public: + /// Return the operation that this rewrite operates on. + Operation *getOperation() const { return op; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() >= Kind::MoveOperation && + rewrite->getKind() <= Kind::MoveOperation; + } + +protected: + OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : IRRewrite(kind, rewriterImpl), op(op) {} + + // The operation that this rewrite operates on. + Operation *op; +}; + +/// Moving of an operation. This rewrite is immediately reflected in the IR. +class MoveOperationRewrite : public OperationRewrite { +public: + MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op, Block *block, Operation *insertBeforeOp) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), + insertBeforeOp(insertBeforeOp) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::MoveOperation; + } + + void rollback() override { + // Move the operation back to its original position. + Block::iterator before = + insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); + block->getOperations().splice(before, op->getBlock()->getOperations(), op); + } + +private: + // The block in which this operation was previously contained. + Block *block; + + // The original successor of this operation before it was moved. "nullptr" if + // this operation was the only operation in the region. + Operation *insertBeforeOp; +}; } // namespace //===----------------------------------------------------------------------===// @@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { - assert(!previous.isSet() && "expected newly created op"); LLVM_DEBUG({ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); - createdOps.push_back(op); + if (!previous.isSet()) { + // This is a newly created op. + createdOps.push_back(op); + return; + } + Operation *prevOp = previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint(); + appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); } void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, @@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { rootUpdates.erase(rootUpdates.begin() + updateIdx); } -void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block, - Block::iterator iterator) { - llvm_unreachable( - "moving single ops is not supported in a dialect conversion"); -} - -void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block, - Block::iterator iterator) { - llvm_unreachable( - "moving single ops is not supported in a dialect conversion"); -} - detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index d8cf6e4719cede..84fcc18ab7d370 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -320,3 +320,17 @@ module { return } } + +// ----- + +// CHECK-LABEL: func @test_move_op_before_rollback() +func.func @test_move_op_before_rollback() { + // CHECK: "test.one_region_op"() + // CHECK: "test.hoist_me"() + "test.one_region_op"() ({ + // expected-remark @below{{'test.hoist_me' is not legalizable}} + %0 = "test.hoist_me"() : () -> (i32) + "test.valid"(%0) : (i32) -> () + }) : () -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d7e5d6db50c1fb..1c02232b8adbb1 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern { } }; +/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. +/// This is to test the rollback logic. +struct TestUndoMoveOpBefore : public ConversionPattern { + TestUndoMoveOpBefore(MLIRContext *ctx) + : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.moveOpBefore(op, op->getParentOp()); + // Replace with an illegal op to ensure the conversion fails. + rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); + return success(); + } +}; + /// A rewrite pattern that tests the undo mechanism when erasing a block. struct TestUndoBlockErase : public ConversionPattern { TestUndoBlockErase(MLIRContext *ctx) @@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, - TestCreateUnregisteredOp>(&getContext()); + TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext()); patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver ConversionTarget target(getContext()); target.addLegalOp<ModuleOp>(); target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, - TerminatorOp>(); + TerminatorOp, OneRegionOp>(); target .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits