================ @@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold( }); return converged; } + +//===----------------------------------------------------------------------===// +// One-Shot Dialect Conversion Infrastructure +//===----------------------------------------------------------------------===// + +namespace { +/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter +/// immediately materializes all IR changes. It derives from +/// `ConversionPatternRewriter` so that the existing conversion patterns can +/// be used with the One-Shot Dialect Conversion. +class OneShotConversionPatternRewriter : public ConversionPatternRewriter { +public: + OneShotConversionPatternRewriter(MLIRContext *ctx) + : ConversionPatternRewriter(ctx) {} + + bool canRecoverFromRewriteFailure() const override { return false; } + + void replaceOp(Operation *op, ValueRange newValues) override; + + void replaceOp(Operation *op, Operation *newOp) override { + replaceOp(op, newOp->getResults()); + } + + void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); } + + void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); } + + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt) override { + PatternRewriter::inlineBlockBefore(source, dest, before, argValues); + } + using PatternRewriter::inlineBlockBefore; + + void startOpModification(Operation *op) override { + PatternRewriter::startOpModification(op); + } + + void finalizeOpModification(Operation *op) override { + PatternRewriter::finalizeOpModification(op); + } + + void cancelOpModification(Operation *op) override { + PatternRewriter::cancelOpModification(op); + } + + void setCurrentTypeConverter(const TypeConverter *converter) override { + typeConverter = converter; + } + + const TypeConverter *getCurrentTypeConverter() const override { + return typeConverter; + } + + LogicalResult getAdapterOperands(StringRef valueDiagTag, + std::optional<Location> inputLoc, + ValueRange values, + SmallVector<Value> &remapped) override; + +private: + /// Build an unrealized_conversion_cast op or look it up in the cache. + Value buildUnrealizedConversionCast(Location loc, Type type, Value value); + + /// The current type converter. + const TypeConverter *typeConverter; + + /// A cache for unrealized_conversion_casts. To ensure that identical casts + /// are not built multiple times. + DenseMap<std::pair<Value, Type>, Value> castCache; +}; + +void OneShotConversionPatternRewriter::replaceOp(Operation *op, + ValueRange newValues) { + assert(op->getNumResults() == newValues.size()); + for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) { + if (orig.getType() != repl.getType()) { + // Type mismatch: insert unrealized_conversion cast. + replaceAllUsesWith(orig, buildUnrealizedConversionCast( + op->getLoc(), orig.getType(), repl)); + } else { + // Same type: use replacement value directly. + replaceAllUsesWith(orig, repl); + } + } + eraseOp(op); +} + +Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast( + Location loc, Type type, Value value) { + auto it = castCache.find(std::make_pair(value, type)); + if (it != castCache.end()) + return it->second; + + // Insert cast at the beginning of the block (for block arguments) or right + // after the defining op. + OpBuilder::InsertionGuard g(*this); + Block *insertBlock = value.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast<OpResult>(value)) + insertPt = ++inputRes.getOwner()->getIterator(); + setInsertionPoint(insertBlock, insertPt); + auto castOp = create<UnrealizedConversionCastOp>(loc, type, value); + castCache[std::make_pair(value, type)] = castOp.getOutputs()[0]; + return castOp.getOutputs()[0]; +} + +class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver { ---------------- joker-eph wrote:
Relying on the greedy driver makes me concerned about the "one shot" aspect. I would really like to see a **simple** implementation, with the minimum amount of bookkeeping (ideally not even a worklist!), and relying on the GreedyPatternRewriteDriver does not really go in this direction right now. https://github.com/llvm/llvm-project/pull/93412 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits