llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> `ConversionPatternRewriter` objects should not be constructed outside of dialect conversions. Some IR modifications performed through a `ConversionPatternRewriter` are reflected in the IR in a delayed fashion (e.g., only when the dialect conversion is guaranteed to succeed). Using a `ConversionPatternRewriter` outside of the dialect conversion is incorrect API usage and can bring the IR in an inconsistent state. Migration guide: Use `IRRewriter` instead of `ConversionPatternRewriter`. --- Full diff: https://github.com/llvm/llvm-project/pull/82244.diff 2 Files Affected: - (modified) mlir/include/mlir/Transforms/DialectConversion.h (+9-1) - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+11-7) ``````````diff diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 2575be4cdea1ac..5c91a9498b35d4 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -27,6 +27,7 @@ class Block; class ConversionPatternRewriter; class MLIRContext; class Operation; +struct OperationConverter; class Type; class Value; @@ -657,7 +658,6 @@ struct ConversionPatternRewriterImpl; /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -764,6 +764,14 @@ class ConversionPatternRewriter final : public PatternRewriter { detail::ConversionPatternRewriterImpl &getImpl(); private: + // Allow OperationConverter to construct new rewriters. + friend struct OperationConverter; + + /// Conversion pattern rewriters must not be used outside of dialect + /// conversions. They apply some IR rewrites in a delayed fashion and could + /// bring the IR into an inconsistent state when used standalone. + explicit ConversionPatternRewriter(MLIRContext *ctx); + // Hide unsupported pattern rewriter API. using OpBuilder::setListener; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4ef26a739e4ea1..6cf178e149be7f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -594,9 +594,11 @@ class ReplaceOperationRewrite : public OperationRewrite { void cleanup() override; -private: - friend struct OperationConverter; + const TypeConverter *getConverter() const { return converter; } + + bool hasChangedResults() const { return changedResults; } +private: /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. const TypeConverter *converter; @@ -2354,7 +2356,9 @@ enum OpConversionMode { /// applied to the operations on success. Analysis, }; +} // namespace +namespace mlir { // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. @@ -2414,7 +2418,7 @@ struct OperationConverter { /// *not* to be legalizable to the target. DenseSet<Operation *> *trackedOps; }; -} // namespace +} // namespace mlir LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { @@ -2506,7 +2510,7 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get()); - if (!opReplacement || !opReplacement->changedResults) + if (!opReplacement || !opReplacement->hasChangedResults()) continue; Operation *op = opReplacement->getOperation(); for (OpResult result : op->getResults()) { @@ -2530,9 +2534,9 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) { // Legalize this result. rewriter.setInsertionPoint(op); - if (failed(legalizeChangedResultType(op, result, newValue, - opReplacement->converter, rewriter, - rewriterImpl, *inverseMapping))) + if (failed(legalizeChangedResultType( + op, result, newValue, opReplacement->getConverter(), rewriter, + rewriterImpl, *inverseMapping))) return failure(); } } `````````` </details> https://github.com/llvm/llvm-project/pull/82244 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits