https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116532
>From 4e4a5c81a1c45c8d4fbadacd67fa5439231e912e Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Sat, 23 Nov 2024 08:22:13 +0100 Subject: [PATCH 1/2] [mlir][Func] Delete `DecomposeCallGraphTypes.cpp` --- .../Func/Transforms/DecomposeCallGraphTypes.h | 34 ----- .../Dialect/Func/Transforms/CMakeLists.txt | 1 - .../Transforms/DecomposeCallGraphTypes.cpp | 136 ------------------ .../Func/Transforms/FuncConversions.cpp | 8 +- .../Func/TestDecomposeCallGraphTypes.cpp | 7 +- 5 files changed, 8 insertions(+), 178 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h delete mode 100644 mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h deleted file mode 100644 index 1be406bf3adf92..00000000000000 --- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h +++ /dev/null @@ -1,34 +0,0 @@ -//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Conversion patterns for decomposing types along call graph edges. That is, -// decomposing types for calls, returns, and function args. -// -// TODO: Make this handle dialect-defined functions, calls, and returns. -// Currently, the generic interfaces aren't sophisticated enough for the -// types of mutations that we are doing here. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H -#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H - -#include "mlir/Transforms/DialectConversion.h" -#include <optional> - -namespace mlir { - -/// Populates the patterns needed to drive the conversion process for -/// decomposing call graph types with the given `TypeConverter`. -void populateDecomposeCallGraphTypesPatterns(MLIRContext *context, - const TypeConverter &typeConverter, - RewritePatternSet &patterns); - -} // namespace mlir - -#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt index f8fb1f436a95b1..6384d25ee70273 100644 --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(MLIRFuncTransforms - DecomposeCallGraphTypes.cpp DuplicateFunctionElimination.cpp FuncConversions.cpp OneToNFuncConversions.cpp diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp deleted file mode 100644 index 03be00328bda33..00000000000000 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ /dev/null @@ -1,136 +0,0 @@ -//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" - -using namespace mlir; -using namespace mlir::func; - -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesForFuncArgs -//===----------------------------------------------------------------------===// - -namespace { -/// Expand function arguments according to the provided TypeConverter. -struct DecomposeCallGraphTypesForFuncArgs - : public OpConversionPattern<func::FuncOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - auto functionType = op.getFunctionType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(functionType.getNumInputs()); - for (const auto &argType : llvm::enumerate(functionType.getInputs())) { - SmallVector<Type, 2> decomposedTypes; - if (failed(typeConverter->convertType(argType.value(), decomposedTypes))) - return failure(); - if (!decomposedTypes.empty()) - conversion.addInputs(argType.index(), decomposedTypes); - } - - // If the SignatureConversion doesn't apply, bail out. - if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(), - &conversion))) - return failure(); - - // Update the signature of the function. - SmallVector<Type, 2> newResultTypes; - if (failed(typeConverter->convertTypes(functionType.getResults(), - newResultTypes))) - return failure(); - rewriter.modifyOpInPlace(op, [&] { - op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesForReturnOp -//===----------------------------------------------------------------------===// - -namespace { -/// Expand return operands according to the provided TypeConverter. -struct DecomposeCallGraphTypesForReturnOp - : public OpConversionPattern<ReturnOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - SmallVector<Value, 2> newOperands; - for (ValueRange operand : adaptor.getOperands()) - llvm::append_range(newOperands, operand); - rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesForCallOp -//===----------------------------------------------------------------------===// - -namespace { -/// Expand call op operands and results according to the provided TypeConverter. -struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - - // Create the operands list of the new `CallOp`. - SmallVector<Value, 2> newOperands; - for (ValueRange operand : adaptor.getOperands()) - llvm::append_range(newOperands, operand); - - // Create the new result types for the new `CallOp` and track the number of - // replacement types for each original op result. - SmallVector<Type, 2> newResultTypes; - SmallVector<unsigned> expandedResultSizes; - for (Type resultType : op.getResultTypes()) { - unsigned oldSize = newResultTypes.size(); - if (failed(typeConverter->convertType(resultType, newResultTypes))) - return failure(); - expandedResultSizes.push_back(newResultTypes.size() - oldSize); - } - - CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(), - newResultTypes, newOperands); - - // Build a replacement value for each result to replace its uses. - SmallVector<ValueRange> replacedValues; - replacedValues.reserve(op.getNumResults()); - unsigned startIdx = 0; - for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - ValueRange repl = - newCallOp.getResults().slice(startIdx, expandedResultSizes[i]); - replacedValues.push_back(repl); - startIdx += expandedResultSizes[i]; - } - rewriter.replaceOpWithMultiple(op, replacedValues); - return success(); - } -}; -} // namespace - -void mlir::populateDecomposeCallGraphTypesPatterns( - MLIRContext *context, const TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns - .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs, - DecomposeCallGraphTypesForReturnOp>(typeConverter, context); -} diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index 9e7759bef6d8fd..a3638c8766a5c6 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -124,12 +124,10 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { using OpConversionPattern<ReturnOp>::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - // For a return, all operands go to the results of the parent, so - // rewrite them all. - rewriter.modifyOpInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.replaceOpWithNewOp<ReturnOp>(op, + flattenValues(adaptor.getOperands())); return success(); } }; diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index de511c58ae6ee0..09c5b4b2a0ad50 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -9,7 +9,7 @@ #include "TestDialect.h" #include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -142,7 +142,10 @@ struct TestDecomposeCallGraphTypes typeConverter.addArgumentMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildDecomposeTuple); - populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns); + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) return signalPassFailure(); >From fe68c3c6702ae0de6549a6db3b014e6fb4dc898a Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Sun, 17 Nov 2024 09:00:45 +0100 Subject: [PATCH 2/2] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization checks --- .../Conversion/LLVMCommon/TypeConverter.cpp | 32 ++++---- .../MemRefToLLVM/type-conversion.mlir | 57 ++++++++++++++ mlir/test/lib/Dialect/LLVM/CMakeLists.txt | 1 + mlir/test/lib/Dialect/LLVM/TestPatterns.cpp | 77 +++++++++++++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 5 files changed, 154 insertions(+), 15 deletions(-) create mode 100644 mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir create mode 100644 mlir/test/lib/Dialect/LLVM/TestPatterns.cpp diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index ce91424e7a577e..59b0f5c9b09bcd 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, type.isVarArg()); }); + // Helper function that checks if the given value range is a bare pointer. + auto isBarePointer = [](ValueRange values) { + return values.size() == 1 && + isa<LLVM::LLVMPointerType>(values.front().getType()); + }; + // Argument materializations convert from the new block argument types // (multiple SSA values that make up a memref descriptor) back to the // original block argument type. The dialect conversion framework will then @@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, addArgumentMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) { - if (inputs.size() == 1) { - // Bare pointers are not supported for unranked memrefs because a - // memref descriptor cannot be built just from a bare pointer. + // Note: Bare pointers are not supported for unranked memrefs because a + // memref descriptor cannot be built just from a bare pointer. + if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields()) return Value(); - } Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs); // An argument materialization must return a value of type @@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) { Value desc; - if (inputs.size() == 1) { - // This is a bare pointer. We allow bare pointers only for function entry - // blocks. - BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front()); - if (!barePtr) - return Value(); - Block *block = barePtr.getOwner(); - if (!block->isEntryBlock() || - !isa<FunctionOpInterface>(block->getParentOp())) - return Value(); + if (isBarePointer(inputs)) { desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType, inputs[0]); - } else { + } else if (TypeRange(inputs) == + getMemRefDescriptorFields(resultType, + /*unpackAggregates=*/true)) { desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); + } else { + // The inputs are neither a bare pointer nor an unpacked memref + // descriptor. This materialization function cannot be used. + return Value(); } // An argument materialization must return a value of type `resultType`, // so insert a cast from the memref descriptor type (!llvm.struct) to the diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir new file mode 100644 index 00000000000000..0288aa11313c72 --- /dev/null +++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file + +// Test the argument materializer for ranked MemRef types. + +// CHECK-LABEL: func @construct_ranked_memref_descriptor( +// CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-COUNT-7: llvm.insertvalue +// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<5x4xf32> +func.func @construct_ranked_memref_descriptor(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) { + %0 = "test.direct_replacement"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> (memref<5x4xf32>) + "test.legal_op"(%0) : (memref<5x4xf32>) -> () + return +} + +// ----- + +// The argument materializer for ranked MemRef types is called with incorrect +// input types. Make sure that the materializer is skipped and we do not +// generate invalid IR. + +// CHECK-LABEL: func @invalid_ranked_memref_descriptor( +// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<5x4xf32> +// CHECK: "test.legal_op"(%[[cast]]) +func.func @invalid_ranked_memref_descriptor(%arg0: i1) { + %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<5x4xf32>) + "test.legal_op"(%0) : (memref<5x4xf32>) -> () + return +} + +// ----- + +// Test the argument materializer for unranked MemRef types. + +// CHECK-LABEL: func @construct_unranked_memref_descriptor( +// CHECK: llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK-COUNT-2: llvm.insertvalue +// CHECK: builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(i64, ptr)> to memref<*xf32> +func.func @construct_unranked_memref_descriptor(%arg0: i64, %arg1: !llvm.ptr) { + %0 = "test.direct_replacement"(%arg0, %arg1) : (i64, !llvm.ptr) -> (memref<*xf32>) + "test.legal_op"(%0) : (memref<*xf32>) -> () + return +} + +// ----- + +// The argument materializer for unranked MemRef types is called with incorrect +// input types. Make sure that the materializer is skipped and we do not +// generate invalid IR. + +// CHECK-LABEL: func @invalid_unranked_memref_descriptor( +// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i1 to memref<*xf32> +// CHECK: "test.legal_op"(%[[cast]]) +func.func @invalid_unranked_memref_descriptor(%arg0: i1) { + %0 = "test.direct_replacement"(%arg0) : (i1) -> (memref<*xf32>) + "test.legal_op"(%0) : (memref<*xf32>) -> () + return +} diff --git a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt index 734757ce79da37..6a2f0ba2756d43 100644 --- a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt +++ b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLLVMTestPasses TestLowerToLLVM.cpp + TestPatterns.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp new file mode 100644 index 00000000000000..ab02866970b1d5 --- /dev/null +++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp @@ -0,0 +1,77 @@ +//===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Replace this op (which is expected to have 1 result) with the operands. +struct TestDirectReplacementOp : public ConversionPattern { + TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + if (op->getNumResults() != 1) + return failure(); + rewriter.replaceOpWithMultiple(op, {operands}); + return success(); + } +}; + +struct TestLLVMLegalizePatternsPass + : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass) + + StringRef getArgument() const final { return "test-llvm-legalize-patterns"; } + StringRef getDescription() const final { + return "Run LLVM dialect legalization patterns"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<LLVM::LLVMDialect>(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + LLVMTypeConverter converter(ctx); + mlir::RewritePatternSet patterns(ctx); + patterns.add<TestDirectReplacementOp>(ctx, converter); + + // Define the conversion target used for the test. + ConversionTarget target(*ctx); + target.addLegalOp(OperationName("test.legal_op", ctx)); + + // Handle a partial conversion. + DenseSet<Operation *> unlegalizedOps; + ConversionConfig config; + config.unlegalizedOps = &unlegalizedOps; + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns), config))) + getOperation()->emitError() << "applyPartialConversion failed"; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// PassRegistration +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace test { +void registerTestLLVMLegalizePatternsPass() { + PassRegistration<TestLLVMLegalizePatternsPass>(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 002c3900056dee..94bc67a1e96093 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -113,6 +113,7 @@ void registerTestLinalgRankReduceContractionOps(); void registerTestLinalgTransforms(); void registerTestLivenessAnalysisPass(); void registerTestLivenessPass(); +void registerTestLLVMLegalizePatternsPass(); void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); @@ -250,6 +251,7 @@ void registerTestPasses() { mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessAnalysisPass(); mlir::test::registerTestLivenessPass(); + mlir::test::registerTestLLVMLegalizePatternsPass(); mlir::test::registerTestLoopFusion(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits