https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/113032
>From 7ac8aa2ee4ee5634d8dd6f2d64d0e10b800e2d70 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Sat, 19 Oct 2024 12:05:13 +0200 Subject: [PATCH] [mlir][Transforms] Merge 1:1 and 1:N type converters --- .../Dialect/SparseTensor/Transforms/Passes.h | 2 +- .../mlir/Transforms/DialectConversion.h | 56 ++++++++++++++----- .../mlir/Transforms/OneToNTypeConversion.h | 45 +-------------- .../ArmSME/Transforms/VectorLegalization.cpp | 2 +- .../Transforms/Utils/DialectConversion.cpp | 24 ++++++-- .../Transforms/Utils/OneToNTypeConversion.cpp | 44 +++++---------- .../TestOneToNTypeConversionPass.cpp | 18 ++++-- 7 files changed, 93 insertions(+), 98 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 6ccbc40bdd6034..2e9c297f20182a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass(); //===----------------------------------------------------------------------===// /// Type converter for iter_space and iterator. -struct SparseIterationTypeConverter : public OneToNTypeConverter { +struct SparseIterationTypeConverter : public TypeConverter { SparseIterationTypeConverter(); }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ff36160dd6162..37da03bbe386e9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -173,7 +173,9 @@ class TypeConverter { /// conversion has finished. /// /// Note: Target materializations may optionally accept an additional Type - /// parameter, which is the original type of the SSA value. + /// parameter, which is the original type of the SSA value. Furthermore, `T` + /// can be a TypeRange; in that case, the function must return a + /// SmallVector<Value>. /// This method registers a materialization that will be called when /// converting (potentially multiple) block arguments that were the result of @@ -210,6 +212,9 @@ class TypeConverter { /// will be invoked with: outputType = "t3", inputs = "v2", // originalType = "t1". Note that the original type "t1" cannot be recovered /// from just "t3" and "v2"; that's why the originalType parameter exists. + /// + /// Note: During a 1:N conversion, the result types can be a TypeRange. In + /// that case the materialization produces a SmallVector<Value>. template <typename FnT, typename T = typename llvm::function_traits< std::decay_t<FnT>>::template arg_t<1>> void addTargetMaterialization(FnT &&callback) { @@ -316,6 +321,11 @@ class TypeConverter { Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType = {}) const; + SmallVector<Value> materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultType, + ValueRange inputs, + Type originalType = {}) const; /// Convert an attribute present `attr` from within the type `type` using /// the registered conversion functions. If no applicable conversion has been @@ -340,9 +350,9 @@ class TypeConverter { /// The signature of the callback used to materialize a target conversion. /// - /// Arguments: builder, result type, inputs, location, original type - using TargetMaterializationCallbackFn = - std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>; + /// Arguments: builder, result types, inputs, location, original type + using TargetMaterializationCallbackFn = std::function<SmallVector<Value>( + OpBuilder &, TypeRange, ValueRange, Location, Type)>; /// The signature of the callback used to convert a type attribute. using TypeAttributeConversionCallbackFn = @@ -409,22 +419,40 @@ class TypeConverter { /// callback. /// /// With callback of form: - /// `Value(OpBuilder &, T, ValueRange, Location, Type)` + /// - Value(OpBuilder &, T, ValueRange, Location, Type) + /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type) template <typename T, typename FnT> std::enable_if_t< std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>, TargetMaterializationCallbackFn> wrapTargetMaterialization(FnT &&callback) const { return [callback = std::forward<FnT>(callback)]( - OpBuilder &builder, Type resultType, ValueRange inputs, - Location loc, Type originalType) -> Value { - if (T derivedType = dyn_cast<T>(resultType)) - return callback(builder, derivedType, inputs, loc, originalType); - return Value(); + OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc, Type originalType) -> SmallVector<Value> { + SmallVector<Value> result; + if constexpr (std::is_same<T, TypeRange>::value) { + // This is a 1:N target materialization. Return the produces values + // directly. + result = callback(builder, resultTypes, inputs, loc, originalType); + } else { + // This is a 1:1 target materialization. Invoke it only if the result + // type class of the callback matches the requested result type. + if (T derivedType = dyn_cast<T>(resultTypes.front())) { + // 1:1 materializations produce single values, but we store 1:N + // target materialization functions in the type converter. Wrap the + // result value in a SmallVector<Value>. + std::optional<Value> val = + callback(builder, derivedType, inputs, loc, originalType); + if (val.has_value() && *val) + result.push_back(*val); + } + } + return result; }; } /// With callback of form: - /// `Value(OpBuilder &, T, ValueRange, Location)` + /// - Value(OpBuilder &, T, ValueRange, Location) + /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location) template <typename T, typename FnT> std::enable_if_t< std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>, @@ -432,9 +460,9 @@ class TypeConverter { wrapTargetMaterialization(FnT &&callback) const { return wrapTargetMaterialization<T>( [callback = std::forward<FnT>(callback)]( - OpBuilder &builder, T resultType, ValueRange inputs, Location loc, - Type originalType) -> Value { - return callback(builder, resultType, inputs, loc); + OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc, + Type originalType) { + return callback(builder, resultTypes, inputs, loc); }); } diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h index c59a3a52f028f3..7b4dd65cbff7b2 100644 --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -33,49 +33,6 @@ namespace mlir { -/// Extends `TypeConverter` with 1:N target materializations. Such -/// materializations have to provide the "reverse" of 1:N type conversions, -/// i.e., they need to materialize N values with target types into one value -/// with a source type (which isn't possible in the base class currently). -class OneToNTypeConverter : public TypeConverter { -public: - /// Callback that expresses user-provided materialization logic from the given - /// value to N values of the given types. This is useful for expressing target - /// materializations for 1:N type conversions, which materialize one value in - /// a source type as N values in target types. - using OneToNMaterializationCallbackFn = - std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange, - Value, Location)>; - - /// Creates the mapping of the given range of original types to target types - /// of the conversion and stores that mapping in the given (signature) - /// conversion. This function simply calls - /// `TypeConverter::convertSignatureArgs` and exists here with a different - /// name to reflect the broader semantic. - LogicalResult computeTypeMapping(TypeRange types, - SignatureConversion &result) const { - return convertSignatureArgs(types, result); - } - - /// Applies one of the user-provided 1:N target materializations. If several - /// exists, they are tried out in the reverse order in which they have been - /// added until the first one succeeds. If none succeeds, the functions - /// returns `std::nullopt`. - std::optional<SmallVector<Value>> - materializeTargetConversion(OpBuilder &builder, Location loc, - TypeRange resultTypes, Value input) const; - - /// Adds a 1:N target materialization to the converter. Such materializations - /// build IR that converts N values with target types into 1 value of the - /// source type. - void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { - oneToNTargetMaterializations.emplace_back(std::move(callback)); - } - -private: - SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations; -}; - /// Stores a 1:N mapping of types and provides several useful accessors. This /// class extends `SignatureConversion`, which already supports 1:N type /// mappings but lacks some accessors into the mapping as well as access to the @@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern { /// not fail if some ops or types remain unconverted (i.e., the conversion is /// only "partial"). LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns); /// Add a pattern to the given pattern list to convert the signature of a diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 4968c4fc463d04..e908a536e6fb27 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -921,7 +921,7 @@ struct VectorLegalizationPass : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { void runOnOperation() override { auto *context = &getContext(); - OneToNTypeConverter converter; + TypeConverter converter; RewritePatternSet patterns(context); converter.addConversion([](Type type) { return type; }); converter.addConversion( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3cfcaa965f3546..bf969e74e8bfe0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType) const { + SmallVector<Value> result = materializeTargetConversion( + builder, loc, TypeRange(resultType), inputs, originalType); + if (result.empty()) + return nullptr; + assert(result.size() == 1 && "requested 1:1 materialization, but callback " + "produced 1:N materialization"); + return result.front(); +} + +SmallVector<Value> TypeConverter::materializeTargetConversion( + OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs, + Type originalType) const { for (const TargetMaterializationCallbackFn &fn : - llvm::reverse(targetMaterializations)) - if (Value result = fn(builder, resultType, inputs, loc, originalType)) - return result; - return nullptr; + llvm::reverse(targetMaterializations)) { + SmallVector<Value> result = + fn(builder, resultTypes, inputs, loc, originalType); + if (result.empty()) + continue; + return result; + } + return {}; } std::optional<TypeConverter::SignatureConversion> diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index 19e29d48623e04..c208716891ef1f 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -17,20 +17,6 @@ using namespace llvm; using namespace mlir; -std::optional<SmallVector<Value>> -OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, - Location loc, - TypeRange resultTypes, - Value input) const { - for (const OneToNMaterializationCallbackFn &fn : - llvm::reverse(oneToNTargetMaterializations)) { - if (std::optional<SmallVector<Value>> result = - fn(builder, resultTypes, input, loc)) - return *result; - } - return std::nullopt; -} - TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { TypeRange convertedTypes = getConvertedTypes(); if (auto mapping = getInputMapping(originalTypeNo)) @@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion( LogicalResult OneToNConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto *typeConverter = getTypeConverter<OneToNTypeConverter>(); + auto *typeConverter = getTypeConverter(); // Construct conversion mapping for results. Operation::result_type_range originalResultTypes = op->getResultTypes(); OneToNTypeMapping resultMapping(originalResultTypes); - if (failed(typeConverter->computeTypeMapping(originalResultTypes, - resultMapping))) + if (failed(typeConverter->convertSignatureArgs(originalResultTypes, + resultMapping))) return failure(); // Construct conversion mapping for operands. Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); OneToNTypeMapping operandMapping(originalOperandTypes); - if (failed(typeConverter->computeTypeMapping(originalOperandTypes, - operandMapping))) + if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, + operandMapping))) return failure(); // Cast operands to target types. @@ -318,7 +304,7 @@ namespace mlir { // inserted by this pass are annotated with a string attribute that also // documents which kind of the cast (source, argument, or target). LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns) { #ifndef NDEBUG // Remember existing unrealized casts. This data structure is only used in @@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, // Target materialization. assert(!areOperandTypesLegal && areResultsTypesLegal && operands.size() == 1 && "found unexpected target cast"); - std::optional<SmallVector<Value>> maybeResults = - typeConverter.materializeTargetConversion( - rewriter, castOp->getLoc(), resultTypes, operands.front()); - if (!maybeResults) { + materializedResults = typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (materializedResults.empty()) { emitError(castOp->getLoc()) << "failed to create target materialization"; return failure(); } - materializedResults = maybeResults.value(); } else { // Source and argument materializations. assert(areOperandTypesLegal && !areResultsTypesLegal && @@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern { const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const override { auto funcOp = cast<FunctionOpInterface>(op); - auto *typeConverter = getTypeConverter<OneToNTypeConverter>(); + auto *typeConverter = getTypeConverter(); // Construct mapping for function arguments. OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes()); - if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(), - argumentMapping))) + if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(), + argumentMapping))) return failure(); // Construct mapping for function results. OneToNTypeMapping funcResultMapping(funcOp.getResultTypes()); - if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(), - funcResultMapping))) + if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(), + funcResultMapping))) return failure(); // Nothing to do if the op doesn't have any non-identity conversions for its diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp index 5c03ac12d1e58c..b18dfd8bb22cb1 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter, /// /// This function has been copied (with small adaptions) from /// TestDecomposeCallGraphTypes.cpp. -static std::optional<SmallVector<Value>> -buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) { +static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder, + TypeRange resultTypes, + ValueRange inputs, + Location loc) { + if (inputs.size() != 1) + return {}; + Value input = inputs.front(); + TupleType inputType = dyn_cast<TupleType>(input.getType()); if (!inputType) return {}; @@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() { auto *context = &getContext(); // Assemble type converter. - OneToNTypeConverter typeConverter; + TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( @@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() { typeConverter.addArgumentMaterialization(buildMakeTupleOp); typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildGetTupleElementOps); + // Test the other target materialization variant that takes the original type + // as additional argument. This materialization function always fails. + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc, Type originalType) -> SmallVector<Value> { return {}; }); // Assemble patterns. RewritePatternSet patterns(context); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits