https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83423
The dialect conversion uses a `SingleEraseRewriter` to ensure that an op/block is not erased twice. This can happen during the "commit" phase when an unresolved materialization is inserted into a block and the enclosing op is erased by the user. In that case, the unresolved materialization should not be erased a second time later in the "commit" phase. This problem cannot happen during "rollback", so ops/block can be erased directly without using the rewriter. With this change, the `SingleEraseRewriter` is used only during "commit"/"cleanup". At that point, the dialect conversion is guaranteed to succeed and no rollback can happen. Therefore, it is not necessary to store the number of erased IR objects (because we will never "reset" the rewriter to previous a previous state). >From 6b6c4b1a7dd4943bfe2d97245e8369b9ba63aa20 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Thu, 29 Feb 2024 12:48:28 +0000 Subject: [PATCH] [mlir][Transforms][NFC] Do not use SingleEraseRewriter during rollback --- .../Transforms/Utils/DialectConversion.cpp | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index cac990d498d7d3..9f6468402686bd 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -153,9 +153,9 @@ namespace { /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, - unsigned numErased, unsigned numReplacedOps) + unsigned numReplacedOps) : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), - numErased(numErased), numReplacedOps(numReplacedOps) {} + numReplacedOps(numReplacedOps) {} /// The current number of rewrites performed. unsigned numRewrites; @@ -163,9 +163,6 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; - /// The current number of erased operations/blocks. - unsigned numErased; - /// The current number of replaced ops that are scheduled for erasure. unsigned numReplacedOps; }; @@ -273,8 +270,9 @@ class CreateBlockRewrite : public BlockRewrite { auto &blockOps = block->getOperations(); while (!blockOps.empty()) blockOps.remove(blockOps.begin()); + block->dropAllUses(); if (block->getParent()) - eraseBlock(block); + block->erase(); else delete block; } @@ -858,7 +856,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void notifyBlockErased(Block *block) override { erased.insert(block); } /// Pointers to all erased operations and blocks. - SetVector<void *> erased; + DenseSet<void *> erased; }; //===--------------------------------------------------------------------===// @@ -1044,7 +1042,7 @@ void CreateOperationRewrite::rollback() { region.getBlocks().remove(region.getBlocks().begin()); } op->dropAllUses(); - eraseOp(op); + op->erase(); } void UnresolvedMaterializationRewrite::rollback() { @@ -1052,7 +1050,7 @@ void UnresolvedMaterializationRewrite::rollback() { for (Value input : op->getOperands()) rewriterImpl.mapping.erase(input); } - eraseOp(op); + op->erase(); } void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); } @@ -1069,8 +1067,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(rewrites.size(), ignoredOps.size(), - eraseRewriter.erased.size(), replacedOps.size()); + return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1081,9 +1078,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); - while (eraseRewriter.erased.size() != state.numErased) - eraseRewriter.erased.pop_back(); - while (replacedOps.size() != state.numReplacedOps) replacedOps.pop_back(); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits