https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/168703
>From 9186d8ed4c61f5277c08b57b34cbb120dfcacf28 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Thu, 13 Nov 2025 13:37:15 +0800 Subject: [PATCH 1/4] [FIR][Lowering] Add FIRToMLIR pass. --- .../include/flang/Optimizer/Support/InitFIR.h | 8 ++++- .../flang/Optimizer/Transforms/Passes.td | 11 +++++++ flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 + .../Optimizer/Transforms/ConvertFIRToMLIR.cpp | 30 +++++++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h index 67e9287ddad4f..b90badf8ede0f 100644 --- a/flang/include/flang/Optimizer/Support/InitFIR.h +++ b/flang/include/flang/Optimizer/Support/InitFIR.h @@ -32,6 +32,8 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenACC/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -54,7 +56,8 @@ namespace fir::support { mlir::NVVM::NVVMDialect, mlir::gpu::GPUDialect, \ mlir::index::IndexDialect, mif::MIFDialect -#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect +#define FLANG_CODEGEN_DIALECT_LIST \ + FIRCodeGenDialect, mlir::memref::MemRefDialect, mlir::LLVM::LLVMDialect // The definitive list of dialects used by flang. #define FLANG_DIALECT_LIST \ @@ -129,6 +132,9 @@ inline void registerMLIRPassesForFortranTools() { mlir::affine::registerAffineLoopTilingPass(); mlir::affine::registerAffineDataCopyGenerationPass(); + mlir::registerMem2RegPass(); + mlir::memref::registerMemRefPasses(); + mlir::registerLowerAffinePass(); } diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index bb2509b1747d5..0bf1537b2215c 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -87,6 +87,17 @@ def FIRToSCFPass : Pass<"fir-to-scf"> { ]; } +def ConvertFIRToMLIRPass : Pass<"fir-to-mlir", "mlir::ModuleOp"> { + let summary = "Convert the FIR dialect module to MLIR standard dialects."; + let description = [{ + Convert the FIR dialect module to MLIR standard dialects. + }]; + let dependentDialects = [ + "fir::FIROpsDialect", "fir::FIRCodeGenDialect", "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", "mlir::affine::AffineDialect" + ]; +} + def AnnotateConstantOperands : Pass<"annotate-constant"> { let summary = "Annotate constant operands to all FIR operations"; let description = [{ diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 0388439f89a54..a6423b3dea5a9 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -36,6 +36,7 @@ add_flang_library(FIRTransforms SimplifyFIROperations.cpp OptimizeArrayRepacking.cpp ConvertComplexPow.cpp + ConvertFIRToMLIR.cpp MIFOpConversion.cpp DEPENDS diff --git a/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp new file mode 100644 index 0000000000000..a24d011da50c9 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp @@ -0,0 +1,30 @@ +//===-- ConvertFIRToMLIR.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +namespace fir { +#define GEN_PASS_DEF_CONVERTFIRTOMLIRPASS +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +namespace { +class ConvertFIRToMLIRPass + : public fir::impl::ConvertFIRToMLIRPassBase<ConvertFIRToMLIRPass> { +public: + void runOnOperation() override; +}; +} // namespace + +void ConvertFIRToMLIRPass::runOnOperation() { + // TODO: +} >From 1de4ae5f9d688710da4c6bb71e5c8271186de1a8 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Wed, 19 Nov 2025 16:30:04 +0800 Subject: [PATCH 2/4] [FIR][Lowering] Add a flag to select lowering through MLIR. --- clang/include/clang/Options/Options.td | 5 +++++ clang/lib/Driver/ToolChains/Flang.cpp | 1 + flang/include/flang/Lower/LoweringOptions.def | 3 +++ .../flang/Optimizer/Passes/Pipelines.h | 2 ++ flang/include/flang/Tools/CrossToolHelpers.h | 1 + flang/lib/Frontend/CompilerInvocation.cpp | 4 ++++ flang/lib/Frontend/FrontendActions.cpp | 1 + flang/lib/Optimizer/Passes/Pipelines.cpp | 20 ++++++++++++++----- flang/test/Driver/frontend-forwarding.f90 | 2 ++ 9 files changed, 34 insertions(+), 5 deletions(-) diff --git a/clang/include/clang/Options/Options.td b/clang/include/clang/Options/Options.td index cda11fdc94230..cd8409de8c5a9 100644 --- a/clang/include/clang/Options/Options.td +++ b/clang/include/clang/Options/Options.td @@ -7217,6 +7217,11 @@ def flang_deprecated_no_hlfir : Flag<["-"], "flang-deprecated-no-hlfir">, Flags<[HelpHidden]>, Visibility<[FlangOption, FC1Option]>, HelpText<"Do not use HLFIR lowering (deprecated)">; +def flang_experimental_lower_through_mlir + : Flag<["-"], "flang-experimental-lower-through-mlir">, + Flags<[HelpHidden]>, Visibility<[FlangOption, FC1Option]>, + HelpText<"Lower form FIR through MLIR to LLVM (experimental)">; + //===----------------------------------------------------------------------===// // FLangOption + CoreOption + NoXarchOption //===----------------------------------------------------------------------===// diff --git a/clang/lib/Driver/ToolChains/Flang.cpp b/clang/lib/Driver/ToolChains/Flang.cpp index 270904de544d6..e294ac59af73d 100644 --- a/clang/lib/Driver/ToolChains/Flang.cpp +++ b/clang/lib/Driver/ToolChains/Flang.cpp @@ -222,6 +222,7 @@ void Flang::addCodegenOptions(const ArgList &Args, {options::OPT_fdo_concurrent_to_openmp_EQ, options::OPT_flang_experimental_hlfir, options::OPT_flang_deprecated_no_hlfir, + options::OPT_flang_experimental_lower_through_mlir, options::OPT_fno_ppc_native_vec_elem_order, options::OPT_fppc_native_vec_elem_order, options::OPT_finit_global_zero, options::OPT_fno_init_global_zero, options::OPT_frepack_arrays, diff --git a/flang/include/flang/Lower/LoweringOptions.def b/flang/include/flang/Lower/LoweringOptions.def index 39f197d8d35c8..01fc96b78df50 100644 --- a/flang/include/flang/Lower/LoweringOptions.def +++ b/flang/include/flang/Lower/LoweringOptions.def @@ -38,6 +38,9 @@ ENUM_LOWERINGOPT(Underscoring, unsigned, 1, 1) /// (i.e. wraps around as two's complement). Off by default. ENUM_LOWERINGOPT(IntegerWrapAround, unsigned, 1, 0) +/// If true, lower form FIR through MLIR to LLVM +ENUM_LOWERINGOPT(LowerThroughMLIR, unsigned, 1, 0) + /// If true (default), follow Fortran 2003 rules for (re)allocating /// the allocatable on the left side of the intrinsic assignment, /// if LHS and RHS have mismatching shapes/types. diff --git a/flang/include/flang/Optimizer/Passes/Pipelines.h b/flang/include/flang/Optimizer/Passes/Pipelines.h index 70b9341347244..fc1ebaf7d24f7 100644 --- a/flang/include/flang/Optimizer/Passes/Pipelines.h +++ b/flang/include/flang/Optimizer/Passes/Pipelines.h @@ -18,10 +18,12 @@ #include "flang/Optimizer/Passes/CommandLineOpts.h" #include "flang/Optimizer/Transforms/Passes.h" #include "flang/Tools/CrossToolHelpers.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/OpenMP/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h index e964882ef6dac..0dcb99e1eb5b1 100644 --- a/flang/include/flang/Tools/CrossToolHelpers.h +++ b/flang/include/flang/Tools/CrossToolHelpers.h @@ -137,6 +137,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks { bool EnableOpenMP = false; ///< Enable OpenMP lowering. bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode. bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion. + bool LowerThroughMLIR = false; ///< Lower form FIR through MLIR to LLVM std::string InstrumentFunctionEntry = ""; ///< Name of the instrument-function that is called on each ///< function-entry diff --git a/flang/lib/Frontend/CompilerInvocation.cpp b/flang/lib/Frontend/CompilerInvocation.cpp index 893121fe01f27..8c3fde0a27153 100644 --- a/flang/lib/Frontend/CompilerInvocation.cpp +++ b/flang/lib/Frontend/CompilerInvocation.cpp @@ -1580,6 +1580,10 @@ bool CompilerInvocation::createFromArgs( invoc.loweringOpts.setLowerToHighLevelFIR(false); } + // -flang-experimental-lower-through-mlir + invoc.loweringOpts.setLowerThroughMLIR( + args.hasArg(clang::options::OPT_flang_experimental_lower_through_mlir)); + // -fno-ppc-native-vector-element-order if (args.hasArg(clang::options::OPT_fno_ppc_native_vec_elem_order)) { invoc.loweringOpts.setNoPPCNativeVecElemOrder(true); diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 159d08a2797b3..0cb241f209522 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -769,6 +769,7 @@ void CodeGenAction::generateLLVMIR() { config.NSWOnLoopVarInc = false; config.ComplexRange = opts.getComplexRange(); + config.LowerThroughMLIR = invoc.getLoweringOpts().getLowerThroughMLIR(); // Create the pass pipeline fir::createMLIRToLLVMPassPipeline(pm, config, getCurrentFile()); diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 103e736accca0..6aa81a1b44c6b 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -109,6 +109,18 @@ void addDebugInfoPass(mlir::PassManager &pm, void addFIRToLLVMPass(mlir::PassManager &pm, const MLIRToLLVMPassPipelineConfig &config) { + if (disableFirToLlvmIr) + return; + + if (config.LowerThroughMLIR) { + pm.addPass(createConvertFIRToMLIRPass()); + pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass()); + pm.addPass(mlir::createMem2Reg()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + } + fir::FIRToLLVMPassOptions options; options.ignoreMissingTypeDescriptors = ignoreMissingTypeDescriptors; options.skipExternalRttiDefinition = skipExternalRttiDefinition; @@ -117,13 +129,11 @@ void addFIRToLLVMPass(mlir::PassManager &pm, options.typeDescriptorsRenamedForAssembly = !disableCompilerGeneratedNamesConversion; options.ComplexRange = config.ComplexRange; - addPassConditionally(pm, disableFirToLlvmIr, - [&]() { return fir::createFIRToLLVMPass(options); }); + pm.addPass(fir::createFIRToLLVMPass(options)); + // The dialect conversion framework may leave dead unrealized_conversion_cast // ops behind, so run reconcile-unrealized-casts to clean them up. - addPassConditionally(pm, disableFirToLlvmIr, [&]() { - return mlir::createReconcileUnrealizedCastsPass(); - }); + pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } void addLLVMDialectToLLVMPass(mlir::PassManager &pm, diff --git a/flang/test/Driver/frontend-forwarding.f90 b/flang/test/Driver/frontend-forwarding.f90 index 952937168c95d..ab9e5e8b4d088 100644 --- a/flang/test/Driver/frontend-forwarding.f90 +++ b/flang/test/Driver/frontend-forwarding.f90 @@ -20,6 +20,7 @@ ! RUN: -fversion-loops-for-stride \ ! RUN: -flang-experimental-hlfir \ ! RUN: -flang-deprecated-no-hlfir \ +! RUN: -flang-experimental-lower-through-mlir \ ! RUN: -fno-ppc-native-vector-element-order \ ! RUN: -fppc-native-vector-element-order \ ! RUN: -mllvm -print-before-all \ @@ -51,6 +52,7 @@ ! CHECK: "-fversion-loops-for-stride" ! CHECK: "-flang-experimental-hlfir" ! CHECK: "-flang-deprecated-no-hlfir" +! CHECK: "-flang-experimental-lower-through-mlir" ! CHECK: "-fno-ppc-native-vector-element-order" ! CHECK: "-fppc-native-vector-element-order" ! CHECK: "-Rpass" >From c86bcbffee9f812ce2d27c18900ca2ac708e3c69 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Wed, 19 Nov 2025 18:27:39 +0800 Subject: [PATCH 3/4] [FIR][Lowering] Add lowering for `fir.convert` between `fir.ref` and `memref` type. --- flang/lib/Optimizer/CodeGen/CodeGen.cpp | 25 ++++++++++++++- flang/test/Fir/convert-to-llvm.fir | 41 +++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index f96d45d3f6b66..7959c846a2418 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1008,6 +1008,29 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> { rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0); return mlir::success(); } + // Pointer to MemRef conversion. + if (mlir::isa<mlir::MemRefType>(toFirTy)) { + auto dstMemRef = mlir::MemRefDescriptor::poison(rewriter, loc, toTy); + dstMemRef.setAlignedPtr(rewriter, loc, op0); + dstMemRef.setOffset( + rewriter, loc, + createIndexAttrConstant(rewriter, loc, getIndexType(), 0)); + rewriter.replaceOp(convert, {dstMemRef}); + return mlir::success(); + } + } else if (mlir::isa<mlir::MemRefType>(fromFirTy) && + mlir::isa<mlir::LLVM::LLVMPointerType>(toTy)) { + // MemRef to pointer conversion. + auto srcMemRef = mlir::MemRefDescriptor(op0); + mlir::Type elementType = typeConverter->convertType( + mlir::cast<mlir::MemRefType>(fromFirTy).getElementType()); + mlir::Value srcBasePtr = srcMemRef.alignedPtr(rewriter, loc); + mlir::Value srcOffset = srcMemRef.offset(rewriter, loc); + mlir::Value srcPtr = + mlir::LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(), + elementType, srcBasePtr, srcOffset); + rewriter.replaceOp(convert, srcPtr); + return mlir::success(); } return emitError(loc) << "cannot convert " << fromTy << " to " << toTy; } @@ -4326,7 +4349,7 @@ class FIRToLLVMLowering target.addLegalDialect<mlir::gpu::GPUDialect>(); // required NOPs for applying a full conversion - target.addLegalOp<mlir::ModuleOp>(); + target.addLegalOp<mlir::ModuleOp, mlir::UnrealizedConversionCastOp>(); // If we're on Windows, we might need to rename some libm calls. bool isMSVC = fir::getTargetTriple(mod).isOSMSVCRT(); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 864368740be02..41c7ea992c29c 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -882,6 +882,47 @@ func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.stru // ----- +// Test `fir.convert` operation conversion between `memref` and `fir.ref`. + +func.func @convert_to_memref(%arg0 : !fir.ref<i32>) -> memref<i32, strided<[], offset; ?>>{ + %0 = fir.convert %arg0 : (!fir.ref<i32>) -> memref<i32, strided<[], offset; ?>> + return %0 : memref<i32, strided<[], offset; ?>> +} + +// CHECK-LABEL: llvm.func @convert_to_memref( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64)> { +// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[MLIR_1]], %[[INSERTVALUE_0]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: llvm.return %[[INSERTVALUE_1]] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between `memref` and `fir.ref`. + +func.func @convert_from_memref(%arg0 : memref<i32, strided<[], offset; ?>>) -> !fir.ref<i32> { + %0 = fir.convert %arg0 : (memref<i32, strided<[], offset; ?>>) -> !fir.ref<i32> + return %0 : !fir.ref<i32> +} + +// CHECK-LABEL: llvm.func @convert_from_memref( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr, +// CHECK-SAME: %[[ARG2:.*]]: i64) -> !llvm.ptr { +// CHECK: %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[ARG0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[ARG1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[ARG2]], %[[INSERTVALUE_1]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[INSERTVALUE_2]][2] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[EXTRACTVALUE_0]]{{\[}}%[[EXTRACTVALUE_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK: llvm.return %[[GETELEMENTPTR_0]] : !llvm.ptr +// CHECK: } + +// ----- + // Test `fir.store` --> `llvm.store` conversion func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) { >From 59b2dad1b14cafa5adc2fe70e64ea28e8c1d7433 Mon Sep 17 00:00:00 2001 From: yanming <[email protected]> Date: Wed, 19 Nov 2025 18:56:14 +0800 Subject: [PATCH 4/4] [FIR][Lowering] Add fir to mlir core dialect patterns. --- .../Optimizer/Transforms/ConvertFIRToMLIR.cpp | 198 +++++++++++++++++- flang/test/Fir/convert-to-mlir.fir | 135 ++++++++++++ 2 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 flang/test/Fir/convert-to-mlir.fir diff --git a/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp index a24d011da50c9..34535fec003ce 100644 --- a/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertFIRToMLIR.cpp @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// +#include "flang/Optimizer/Dialect/FIRCG/CGOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/DialectConversion.h" namespace fir { #define GEN_PASS_DEF_CONVERTFIRTOMLIRPASS @@ -23,8 +25,202 @@ class ConvertFIRToMLIRPass public: void runOnOperation() override; }; + +class FIRLoadOpLowering : public mlir::OpConversionPattern<fir::LoadOp> { +public: + using mlir::OpConversionPattern<fir::LoadOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::LoadOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::MemRefType::isValidElementType(op.getType())) + return mlir::failure(); + + rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getMemref(), + mlir::ValueRange{}); + return mlir::success(); + } +}; + +class FIRStoreOpLowering : public mlir::OpConversionPattern<fir::StoreOp> { +public: + using mlir::OpConversionPattern<fir::StoreOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::StoreOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::MemRefType::isValidElementType(op.getValue().getType())) + return mlir::failure(); + + rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>( + op, adaptor.getValue(), adaptor.getMemref(), mlir::ValueRange{}); + return mlir::success(); + } +}; + +class FIRConvertOpLowering : public mlir::OpConversionPattern<fir::ConvertOp> { +public: + using mlir::OpConversionPattern<fir::ConvertOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::ConvertOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto srcVal = adaptor.getValue(); + auto srcType = srcVal.getType(); + auto dstType = getTypeConverter()->convertType(op.getType()); + + if (srcType == dstType) { + rewriter.replaceOp(op, mlir::ValueRange{srcVal}); + } else if (srcType.isIntOrIndex() && dstType.isIntOrIndex()) { + if (srcType.isIndex() || dstType.isIndex()) { + rewriter.replaceOpWithNewOp<mlir::arith::IndexCastOp>(op, dstType, + srcVal); + } else if (srcType.getIntOrFloatBitWidth() < + dstType.getIntOrFloatBitWidth()) { + rewriter.replaceOpWithNewOp<mlir::arith::ExtSIOp>(op, dstType, srcVal); + } else { + rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, dstType, srcVal); + } + } else if (srcType.isFloat() && dstType.isFloat()) { + if (srcType.getIntOrFloatBitWidth() < dstType.getIntOrFloatBitWidth()) { + rewriter.replaceOpWithNewOp<mlir::arith::ExtFOp>(op, dstType, srcVal); + } else { + rewriter.replaceOpWithNewOp<mlir::arith::TruncFOp>(op, dstType, srcVal); + } + } else { + return mlir::failure(); + } + + return mlir::success(); + } +}; + +class FIRAllocOpLowering : public mlir::OpConversionPattern<fir::AllocaOp> { +public: + using mlir::OpConversionPattern<fir::AllocaOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::AllocaOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::MemRefType::isValidElementType(op.getAllocatedType()) || + op.hasLenParams()) + return mlir::failure(); + + auto dstType = getTypeConverter()->convertType(op.getType()); + auto allocaOp = mlir::memref::AllocaOp::create( + rewriter, op.getLoc(), + mlir::MemRefType::get({}, op.getAllocatedType())); + allocaOp->setAttrs(op->getAttrs()); + rewriter.replaceOpWithNewOp<mlir::memref::CastOp>(op, dstType, allocaOp); + return mlir::success(); + } +}; + +class FIRXArrayCoorOpLowering + : public mlir::OpConversionPattern<fir::cg::XArrayCoorOp> { +public: + using mlir::OpConversionPattern<fir::cg::XArrayCoorOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(fir::cg::XArrayCoorOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + if (!mlir::isa<fir::ReferenceType>(op.getMemref().getType())) + return mlir::failure(); + + mlir::Location loc = op.getLoc(); + auto metadata = mlir::memref::ExtractStridedMetadataOp::create( + rewriter, loc, adaptor.getMemref()); + auto base = metadata.getBaseBuffer(); + auto offset = metadata.getOffset(); + mlir::ValueRange shape = adaptor.getShape(); + unsigned rank = op.getRank(); + + assert(rank > 0 && "expected rank to be greater than zero"); + + auto sizes = llvm::to_vector_of<mlir::OpFoldResult>(llvm::reverse(shape)); + mlir::SmallVector<mlir::OpFoldResult> strides(rank); + + strides[rank - 1] = rewriter.getIndexAttr(1); + mlir::Value stride = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); + for (unsigned i = 1; i < rank; ++i) { + stride = mlir::arith::MulIOp::create(rewriter, loc, stride, shape[i - 1]); + strides[rank - 1 - i] = stride; + } + + mlir::Value memref = mlir::memref::ReinterpretCastOp::create( + rewriter, loc, base, offset, sizes, strides); + + mlir::SmallVector<mlir::OpFoldResult> oneAttrs(rank, + rewriter.getIndexAttr(1)); + auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); + auto offsets = llvm::map_to_vector( + llvm::reverse(adaptor.getIndices()), + [&](mlir::Value idx) -> mlir::OpFoldResult { + if (idx.getType().isInteger()) + idx = mlir::arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), idx); + + assert(idx.getType().isIndex() && "expected index type"); + idx = mlir::arith::SubIOp::create(rewriter, loc, idx, one); + return idx; + }); + + auto subview = mlir::memref::SubViewOp::create( + rewriter, loc, + mlir::cast<mlir::MemRefType>( + getTypeConverter()->convertType(op.getType())), + memref, offsets, oneAttrs, oneAttrs); + + rewriter.replaceOp(op, mlir::ValueRange{subview}); + return mlir::success(); + } +}; + } // namespace +static mlir::TypeConverter prepareTypeConverter() { + mlir::TypeConverter converter; + converter.addConversion([](mlir::Type ty) { return ty; }); + converter.addConversion([&](fir::ReferenceType ty) { + auto eleTy = ty.getElementType(); + if (auto sequenceTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) + eleTy = sequenceTy.getElementType(); + + auto layout = mlir::StridedLayoutAttr::get(ty.getContext(), + mlir::ShapedType::kDynamic, {}); + return mlir::MemRefType::get({}, eleTy, layout); + }); + + // Use fir.convert as the bridge so that we don't need to pull in patterns for + // other dialects. + auto materializeProcedure = [](mlir::OpBuilder &builder, mlir::Type type, + mlir::ValueRange inputs, + mlir::Location loc) -> mlir::Value { + auto convertOp = fir::ConvertOp::create(builder, loc, type, inputs); + return convertOp; + }; + + converter.addSourceMaterialization(materializeProcedure); + converter.addTargetMaterialization(materializeProcedure); + return converter; +} + void ConvertFIRToMLIRPass::runOnOperation() { - // TODO: + mlir::MLIRContext *ctx = &getContext(); + mlir::ModuleOp theModule = getOperation(); + mlir::TypeConverter converter = prepareTypeConverter(); + mlir::RewritePatternSet patterns(ctx); + + patterns.add<FIRAllocOpLowering, FIRLoadOpLowering, FIRStoreOpLowering, + FIRConvertOpLowering, FIRXArrayCoorOpLowering>(converter, ctx); + + mlir::ConversionTarget target(getContext()); + + target.addLegalDialect<mlir::arith::ArithDialect, mlir::affine::AffineDialect, + mlir::memref::MemRefDialect, mlir::scf::SCFDialect>(); + + if (mlir::failed(mlir::applyPartialConversion(theModule, target, + std::move(patterns)))) { + signalPassFailure(); + } } diff --git a/flang/test/Fir/convert-to-mlir.fir b/flang/test/Fir/convert-to-mlir.fir new file mode 100644 index 0000000000000..3265349969e83 --- /dev/null +++ b/flang/test/Fir/convert-to-mlir.fir @@ -0,0 +1,135 @@ +// RUN: fir-opt --split-input-file --fir-to-mlir %s | FileCheck %s + +//=================================================== +// SUMMARY: Tests for FIR --> MLIR core dialects conversion +//=================================================== + +// Test `fir.load` --> `memref.load` conversion + +func.func @test_load_f32(%addr : !fir.ref<f32>) -> f32 { + %0 = fir.load %addr : !fir.ref<f32> + return %0 : f32 +} + +// CHECK-LABEL: func.func @test_load_f32( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<f32>) -> f32 { +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[ARG0]] : (!fir.ref<f32>) -> memref<f32, strided<[], offset: ?>> +// CHECK: %[[LOAD_0:.*]] = memref.load %[[CONVERT_0]][] : memref<f32, strided<[], offset: ?>> +// CHECK: return %[[LOAD_0]] : f32 +// CHECK: } + +// ----- + +// Test `fir.store` --> `memref.store` conversion + +func.func @test_store_f32(%val : f32, %addr : !fir.ref<f32>) { + fir.store %val to %addr : !fir.ref<f32> + return +} + +// CHECK-LABEL: func.func @test_store_f32( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<f32>) { +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[ARG1]] : (!fir.ref<f32>) -> memref<f32, strided<[], offset: ?>> +// CHECK: memref.store %[[ARG0]], %[[CONVERT_0]][] : memref<f32, strided<[], offset: ?>> +// CHECK: return +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between Interger and Index type. + +func.func @convert_between_int_and_index(%arg0 : i32) -> i64 { + %0 = fir.convert %arg0 : (i32) -> index + %1 = fir.convert %0 : (index) -> i64 + return %1 : i64 +} + +// CHECK-LABEL: func.func @convert_between_int_and_index( +// CHECK-SAME: %[[ARG0:.*]]: i32) -> i64 { +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG0]] : i32 to index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[INDEX_CAST_0]] : index to i64 +// CHECK: return %[[INDEX_CAST_1]] : i64 +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between Interger type. + +func.func @convert_between_int(%arg0 : i32) -> i16 { + %0 = fir.convert %arg0 : (i32) -> i64 + %1 = fir.convert %0 : (i64) -> i16 + return %1 : i16 +} + +// CHECK-LABEL: func.func @convert_between_int( +// CHECK-SAME: %[[ARG0:.*]]: i32) -> i16 { +// CHECK: %[[EXTSI_0:.*]] = arith.extsi %[[ARG0]] : i32 to i64 +// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[EXTSI_0]] : i64 to i16 +// CHECK: return %[[TRUNCI_0]] : i16 +// CHECK: } + +// ----- + +// Test `fir.convert` operation conversion between Float type. + +func.func @convert_between_fp(%arg0 : f32) -> f16 { + %0 = fir.convert %arg0 : (f32) -> f64 + %1 = fir.convert %0 : (f64) -> f16 + return %1 : f16 +} + +// CHECK-LABEL: func.func @convert_between_fp( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> f16 { +// CHECK: %[[EXTF_0:.*]] = arith.extf %[[ARG0]] : f32 to f64 +// CHECK: %[[TRUNCF_0:.*]] = arith.truncf %[[EXTF_0]] : f64 to f16 +// CHECK: return %[[TRUNCF_0]] : f16 +// CHECK: } + +// ----- + +// Test `fir.alloca` --> `memref.alloca` conversion + +func.func @test_alloca_f32() -> !fir.ref<f32> { + %1 = fir.alloca f32 + return %1 : !fir.ref<f32> +} + +// CHECK-LABEL: func.func @test_alloca_f32() -> !fir.ref<f32> { +// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() {in_type = f32} : memref<f32> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOCA_0]] : memref<f32> to memref<f32, strided<[], offset: ?>> +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[CAST_0]] : (memref<f32, strided<[], offset: ?>>) -> !fir.ref<f32> +// CHECK: return %[[CONVERT_0]] : !fir.ref<f32> +// CHECK: } + +// ----- + +// Test `fircg.ext_array_coor` conversion. + +func.func @test_ext_array_coor(%arg0: !fir.ref<!fir.array<100x200xf32>>, %i : i64, %j : i64) -> !fir.ref<f32> { + %c200 = arith.constant 200 : index + %c100 = arith.constant 100 : index + %0 = fircg.ext_array_coor %arg0(%c100, %c200)<%i, %j> : (!fir.ref<!fir.array<100x200xf32>>, index, index, i64, i64) -> !fir.ref<f32> + return %0 : !fir.ref<f32> +} + +// CHECK-LABEL: func.func @test_ext_array_coor( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100x200xf32>>, +// CHECK-SAME: %[[ARG1:.*]]: i64, +// CHECK-SAME: %[[ARG2:.*]]: i64) -> !fir.ref<f32> { +// CHECK: %[[CONVERT_0:.*]] = fir.convert %[[ARG0]] : (!fir.ref<!fir.array<100x200xf32>>) -> memref<f32, strided<[], offset: ?>> +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 200 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_0:.*]], %[[EXTRACT_STRIDED_METADATA_0:.*]] = memref.extract_strided_metadata %[[CONVERT_0]] : memref<f32, strided<[], offset: ?>> -> memref<f32>, index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index +// CHECK: %[[MULI_0:.*]] = arith.muli %[[CONSTANT_2]], %[[CONSTANT_1]] : index +// CHECK: %[[REINTERPRET_CAST_0:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[EXTRACT_STRIDED_METADATA_0]]], sizes: {{\[}}%[[CONSTANT_0]], %[[CONSTANT_1]]], strides: {{\[}}%[[MULI_0]], 1] : memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>> +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 1 : index +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[ARG2]] : i64 to index +// CHECK: %[[SUBI_0:.*]] = arith.subi %[[INDEX_CAST_0]], %[[CONSTANT_3]] : index +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ARG1]] : i64 to index +// CHECK: %[[SUBI_1:.*]] = arith.subi %[[INDEX_CAST_1]], %[[CONSTANT_3]] : index +// CHECK: %[[SUBVIEW_0:.*]] = memref.subview %[[REINTERPRET_CAST_0]]{{\[}}%[[SUBI_0]], %[[SUBI_1]]] [1, 1] [1, 1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<f32, strided<[], offset: ?>> +// CHECK: %[[CONVERT_1:.*]] = fir.convert %[[SUBVIEW_0]] : (memref<f32, strided<[], offset: ?>>) -> !fir.ref<f32> +// CHECK: return %[[CONVERT_1]] : !fir.ref<f32> +// CHECK: } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
