llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> This commit updates the internal `ConversionValueMapping` data structure in the dialect conversion driver to support 1:N replacements. This is the last major commit for adding 1:N support to the dialect conversion driver. Since #<!-- -->116470, the infrastructure already supports 1:N replacements. But the `ConversionValueMapping` still stored 1:1 value mappings. To that end, the driver inserted temporary argument materializations (converting N SSA values into 1 value). This is no longer the case. Argument materializations are now entirely gone. (They will be deleted from the type converter after some time, when we delete the old 1:N dialect conversion driver.) Note for LLVM integration: Replace all occurrences of `addArgumentMaterialization` (except for 1:N dialect conversion passes) with `addSourceMaterialization`. Depends on #<!-- -->117513. --- Patch is 46.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116524.diff 14 Files Affected: - (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (-1) - (modified) mlir/docs/DialectConversion.md (+5-30) - (modified) mlir/include/mlir/Transforms/DialectConversion.h (+7-11) - (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+6-8) - (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (-1) - (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (-1) - (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (-1) - (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (-3) - (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (-1) - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+206-226) - (modified) mlir/test/Transforms/test-legalizer.mlir (+2-5) - (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+1-1) - (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (-1) - (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (-1) ``````````diff diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp index 1bb91d252529f0..104ae7408b80c1 100644 --- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp +++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp @@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter { addConversion([&](TypeDescType ty) { return TypeDescType::get(convertType(ty.getOfTy())); }); - addArgumentMaterialization(materializeProcedure); addSourceMaterialization(materializeProcedure); addTargetMaterialization(materializeProcedure); } diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 3168f5e13c7515..abacd5a82c61eb 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure type safety during the conversion process. There are several types of materializations depending on the situation. -* Argument Materialization - - - An argument materialization is used when converting the type of a block - argument during a [signature conversion](#region-signature-conversion). - The new block argument types are specified in a `SignatureConversion` - object. An original block argument can be converted into multiple - block arguments, which is not supported everywhere in the dialect - conversion. (E.g., adaptors support only a single replacement value for - each original value.) Therefore, an argument materialization is used to - convert potentially multiple new block arguments back into a single SSA - value. An argument materialization is also used when replacing an op - result with multiple values. - * Source Materialization - A source materialization is used when a value was replaced with a value @@ -343,17 +330,6 @@ class TypeConverter { /// Materialization functions must be provided when a type conversion may /// persist after the conversion has finished. - /// This method registers a materialization that will be called when - /// converting (potentially multiple) block arguments that were the result of - /// a signature conversion of a single block argument, to a single SSA value - /// with the old argument type. - template <typename FnT, - typename T = typename llvm::function_traits<FnT>::template arg_t<1>> - void addArgumentMaterialization(FnT &&callback) { - argumentMaterializations.emplace_back( - wrapMaterialization<T>(std::forward<FnT>(callback))); - } - /// This method registers a materialization that will be called when /// converting a replacement value back to its original source type. /// This is used when some uses of the original value persist beyond the main @@ -406,12 +382,11 @@ done explicitly via a conversion pattern. To convert the types of block arguments within a Region, a custom hook on the `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook uses a provided type converter to apply type conversions to all blocks of a -given region. As noted above, the conversions performed by this method use the -argument materialization hook on the `TypeConverter`. This hook also takes an -optional `TypeConverter::SignatureConversion` parameter that applies a custom -conversion to the entry block of the region. The types of the entry block -arguments are often tied semantically to the operation, e.g., -`func::FuncOp`, `AffineForOp`, etc. +given region. This hook also takes an optional +`TypeConverter::SignatureConversion` parameter that applies a custom conversion +to the entry block of the region. The types of the entry block arguments are +often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`, +etc. To convert the signature of just one given block, the `applySignatureConversion` hook can be used. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 28150e886913e3..9a6975dcf8dfae 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -181,6 +181,10 @@ class TypeConverter { /// converting (potentially multiple) block arguments that were the result of /// a signature conversion of a single block argument, to a single SSA value /// with the old block argument type. + /// + /// Note: Argument materializations are used only with the 1:N dialect + /// conversion driver. The 1:N dialect conversion driver will be removed soon + /// and so will be argument materializations. template <typename FnT, typename T = typename llvm::function_traits< std::decay_t<FnT>>::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { @@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter { void replaceOp(Operation *op, Operation *newOp) override; /// Replace the given operation with the new value ranges. The number of op - /// results and value ranges must match. If an original SSA value is replaced - /// by multiple SSA values (i.e., a value range has more than 1 element), the - /// conversion driver will insert an argument materialization to convert the - /// N SSA values back into 1 SSA value of the original type. The given - /// operation is erased. - /// - /// Note: The argument materialization is a workaround until we have full 1:N - /// support in the dialect conversion. (It is going to disappear from both - /// `replaceOpWithMultiple` and `applySignatureConversion`.) + /// results and value ranges must match. The given operation is erased. void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues); /// PatternRewriter hook for erasing a dead operation. The uses of this @@ -1285,8 +1281,8 @@ struct ConversionConfig { // represented at the moment. RewriterBase::Listener *listener = nullptr; - /// If set to "true", the dialect conversion attempts to build source/target/ - /// argument materializations through the type converter API in lieu of + /// If set to "true", the dialect conversion attempts to build source/target + /// materializations through the type converter API in lieu of /// "builtin.unrealized_conversion_cast ops". The conversion process fails if /// at least one materialization could not be built. /// diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index e2ab0ed6f66cc5..d27b557736c924 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -189,9 +189,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, auto unrakedMemRefMaterialization = [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) { - // An argument materialization must return a value of type - // `resultType`, so insert a cast from the memref descriptor type - // (!llvm.struct) to the original memref type. + // A source materialization must return a value of type `resultType`, so + // insert a cast from the memref descriptor type (!llvm.struct) to the + // original memref type. Value packed = packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this); if (!packed) @@ -223,7 +223,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, auto rankedMemRefMaterialization = [&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) { - // An argument materialization must return a value of type `resultType`, + // A source materialization must return a value of type `resultType`, // so insert a cast from the memref descriptor type (!llvm.struct) to the // original memref type. Value packed = @@ -234,11 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, .getResult(0); }; - // Argument materializations convert from the new block argument types - // (multiple SSA values that make up a memref descriptor) back to the + // Source materializations convert from the new block argument types + // (e.g., multiple SSA values that make up a memref descriptor) back to the // original block argument type. - addArgumentMaterialization(unrakedMemRefMaterialization); - addArgumentMaterialization(rankedMemRefMaterialization); addSourceMaterialization(unrakedMemRefMaterialization); addSourceMaterialization(rankedMemRefMaterialization); diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp index 0b3a494794f3f5..72c8fd0f324850 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp @@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) { converter.addSourceMaterialization(materializeAsUnrealizedCast); converter.addTargetMaterialization(materializeAsUnrealizedCast); - converter.addArgumentMaterialization(materializeAsUnrealizedCast); } /// Get an unsigned integer or size data type corresponding to \p ty. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index af38485291182f..61bc5022893741 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter { }); addSourceMaterialization(sourceMaterializationCallback); - addArgumentMaterialization(sourceMaterializationCallback); } }; diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp index 61912722662830..71b88d1be1b05b 100644 --- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter { addConversion(convertQuantizedType); addConversion(convertTensorType); - addArgumentMaterialization(materializeConversion); addSourceMaterialization(materializeConversion); addTargetMaterialization(materializeConversion); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index 834e3634cc130d..8bbb2cac5efdf3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { // Required by scf.for 1:N type conversion. addSourceMaterialization(materializeTuple); - - // Required as a workaround until we have full 1:N support. - addArgumentMaterialization(materializeTuple); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 757631944f224f..68535ae5a7a5c6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return builder.create<vector::ShapeCastOp>(loc, type, inputs.front()); }; - typeConverter.addArgumentMaterialization(materializeCast); typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 51686646a0a2fc..ea169a1df42b6a 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -53,6 +54,55 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { }); } +/// Given two insertion points in the same block, choose the later one. +static OpBuilder::InsertPoint +chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a, + OpBuilder::InsertPoint b) { + assert(a.getBlock() == b.getBlock() && "expected same block"); + Block *block = a.getBlock(); + if (a.getPoint() == block->begin()) + return b; + if (b.getPoint() == block->begin()) + return a; + if (a.getPoint()->isBeforeInBlock(&*b.getPoint())) + return b; + return a; +} + +/// Helper function that chooses the insertion point among the two given ones +/// that is later. +// TODO: Extend DominanceInfo API to work with block iterators. +static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a, + OpBuilder::InsertPoint b) { + // Case 1: Same block. + if (a.getBlock() == b.getBlock()) + return chooseLaterInsertPointInBlock(a, b); + + // Case 2: Different block, but same region. + if (a.getBlock()->getParent() == b.getBlock()->getParent()) { + DominanceInfo domInfo; + if (domInfo.properlyDominates(a.getBlock(), b.getBlock())) + return b; + if (domInfo.properlyDominates(b.getBlock(), a.getBlock())) + return a; + // Neither of the two blocks dominante each other. + llvm_unreachable("unable to find valid insertion point"); + } + + // Case 3: b's region contains a: choose a. + if (Operation *aParent = b.getBlock()->getParent()->findAncestorOpInRegion( + *a.getPoint()->getParentOp())) + return a; + + // Case 4: a's region contains b: choose b. + if (Operation *bParent = a.getBlock()->getParent()->findAncestorOpInRegion( + *b.getPoint()->getParentOp())) + return b; + + // Neither of the two operations contain each other. + llvm_unreachable("unable to find valid insertion point"); +} + /// Helper function that computes an insertion point where the given value is /// defined and can be used without a dominance violation. static OpBuilder::InsertPoint computeInsertPoint(Value value) { @@ -63,11 +113,36 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { return OpBuilder::InsertPoint(insertBlock, insertPt); } +/// Helper function that computes an insertion point where the given values are +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { + assert(!vals.empty() && "expected at least one value"); + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) + pt = chooseLaterInsertPoint(pt, computeInsertPoint(v)); + return pt; +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// +/// A vector of SSA values, optimized for the most common case of a single +/// value. +using ValueVector = SmallVector<Value, 1>; + namespace { + +/// Helper class to make it possible to use `ValueVector` as a key in DenseMap. +struct ValueVectorMapInfo { + static ValueVector getEmptyKey() { return ValueVector{}; } + static ValueVector getTombstoneKey() { return ValueVector{}; } + static ::llvm::hash_code getHashValue(ValueVector val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; } +}; + /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { @@ -75,68 +150,103 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped value with the desired type in the + /// Lookup the most recently mapped values with the desired types in the /// mapping. /// /// Special cases: - /// - If the desired type is "null", simply return the most recently mapped - /// value. - /// - If there is no mapping to the desired type, also return the most - /// recently mapped value. - /// - If there is no mapping for the given value at all, return the given - /// value. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; - - /// Lookup a mapped value within the map, or return null if a mapping does not - /// exist. If a mapping exists, this follows the same behavior of - /// `lookupOrDefault`. - Value lookupOrNull(Value from, Type desiredType = nullptr) const; + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// values. + ValueVector lookupOrDefault(ValueVector from, + TypeRange desiredTypes = {}) const; + + /// Lookup the given values within the map, or return an empty vector if the + /// values are not mapped. If they are mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(const ValueVector &from, + TypeRange desiredTypes = {}) const; /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { + void map(const ValueVector &oldVal, const ValueVector &newVal) { LLVM_DEBUG({ - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - assert(it != oldVal && "inserting cyclic mapping"); + ValueVector next = newVal; + while (true) { + assert(next != oldVal && "inserting cyclic mapping"); + auto it = mapping.find(next); + if (it == mapping.end()) + break; + next = it->second; + } }); - mapping.map(oldVal, newVal); - mappedTo.insert(newVal); + mapping[oldVal] = newVal; + for (Value v : newVal) + mappedTo.insert(v); } - /// Drop the last mapping for the given value. - void erase(Value value) { mapping.erase(value); } + /// Drop the last mapping for the given values. + void erase(ValueVector value) { mapping.erase(value); } private: /// Current value mappings. - IRMapping mapping; + DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping; /// All SSA values that are mapped to. May contain false positives. DenseSet<Value> mappedTo; }; } // namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // Try to find the deepest value that has the desired type. If there is no - // such value, simply return the deepest value. - Value desiredValue; +ValueVector +ConversionValueMapping::lookupOrDefault(ValueVector from, + TypeRange desiredTypes) const { + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; do { - if (!desiredType || from.getType() == desiredType) + // Store the current value if the types match. + if (desiredTypes.empty() || TypeRange(from) == desiredTypes) desiredValue = from; - Value mappedValue = mapping.lookupOrNull(from); - if (!mappedValue) + // If possible, Replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : from) { + auto it = mapping.find({v}); + if (it != mapping.end()) { + llvm::append_range(next, it->second); + } else { + next.push_back(v); + } + } + if (next != from) { + // If at least one value was replaced, continue the lookup from there. + from = next; + ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/116524 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits