https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/144898
>From 392514e4d56491575ec47a1eb5607fd52f5b1ff9 Mon Sep 17 00:00:00 2001 From: Tom Eccles <tom.ecc...@arm.com> Date: Wed, 18 Jun 2025 21:01:13 +0000 Subject: [PATCH 1/2] [flang][OpenMP][NFC] remove globals with mlir::StateStack Idea suggested by @skatrak --- flang/include/flang/Lower/AbstractConverter.h | 3 + flang/lib/Lower/Bridge.cpp | 6 ++ flang/lib/Lower/OpenMP/OpenMP.cpp | 102 ++++++++++++------ mlir/include/mlir/Support/StateStack.h | 11 ++ 4 files changed, 91 insertions(+), 31 deletions(-) diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h index 8ae68e143cd2f..de3e833f60699 100644 --- a/flang/include/flang/Lower/AbstractConverter.h +++ b/flang/include/flang/Lower/AbstractConverter.h @@ -26,6 +26,7 @@ namespace mlir { class SymbolTable; +class StateStack; } namespace fir { @@ -361,6 +362,8 @@ class AbstractConverter { /// functions in order to be in sync). virtual mlir::SymbolTable *getMLIRSymbolTable() = 0; + virtual mlir::StateStack &getStateStack() = 0; + private: /// Options controlling lowering behavior. const Fortran::lower::LoweringOptions &loweringOptions; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 64b16b3abe991..8506b9a984e58 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -69,6 +69,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" +#include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" @@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; } + mlir::StateStack &getStateStack() override { return stateStack; } + /// Add the symbol to the local map and return `true`. If the symbol is /// already in the map and \p forced is `false`, the map is not updated. /// Instead the value `false` is returned. @@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// attribute since mlirSymbolTable must pro-actively be maintained when /// new Symbol operations are created. mlir::SymbolTable mlirSymbolTable; + + /// Used to store context while recursing into regions during lowering. + mlir::StateStack stateStack; }; } // namespace diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 7ad8869597274..bff3321af2814 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -38,6 +38,7 @@ #include "flang/Support/OpenMP-utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Support/StateStack.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -200,9 +201,41 @@ class HostEvalInfo { /// the handling of the outer region by keeping a stack of information /// structures, but it will probably still require some further work to support /// reverse offloading. -static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo; -static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0> - sectionsStack; +class HostEvalInfoStackFrame + : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame) + + HostEvalInfo info; +}; + +static HostEvalInfo * +getHostEvalInfoStackTop(lower::AbstractConverter &converter) { + HostEvalInfoStackFrame *frame = + converter.getStateStack().getStackTop<HostEvalInfoStackFrame>(); + return frame ? &frame->info : nullptr; +} + +/// Stack frame for storing the OpenMPSectionsConstruct currently being +/// processed so that it can be refered to when lowering the construct. +class SectionsConstructStackFrame + : public mlir::StateStackFrameBase<SectionsConstructStackFrame> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame) + + explicit SectionsConstructStackFrame( + const parser::OpenMPSectionsConstruct §ionsConstruct) + : sectionsConstruct{sectionsConstruct} {} + + const parser::OpenMPSectionsConstruct §ionsConstruct; +}; + +static const parser::OpenMPSectionsConstruct * +getSectionsConstructStackTop(lower::AbstractConverter &converter) { + SectionsConstructStackFrame *frame = + converter.getStateStack().getStackTop<SectionsConstructStackFrame>(); + return frame ? &frame->sectionsConstruct : nullptr; +} /// Bind symbols to their corresponding entry block arguments. /// @@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, if (!ompEval) return; - HostEvalInfo &hostInfo = hostEvalInfo.back(); + HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter); + assert(hostInfo && "expected HOST_EVAL info structure"); switch (extractOmpDirective(*ompEval)) { case OMPD_teams_distribute_parallel_do: case OMPD_teams_distribute_parallel_do_simd: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_distribute_parallel_do: case OMPD_target_teams_distribute_parallel_do_simd: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_distribute_parallel_do: case OMPD_distribute_parallel_do_simd: - cp.processNumThreads(stmtCtx, hostInfo.ops); + cp.processNumThreads(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_distribute: case OMPD_distribute_simd: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; case OMPD_teams: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); processSingleNestedIf([](Directive nestedDir) { return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir); }); @@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, case OMPD_teams_distribute: case OMPD_teams_distribute_simd: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_distribute: case OMPD_target_teams_distribute_simd: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); + cp.processNumTeams(stmtCtx, hostInfo->ops); break; case OMPD_teams_loop: - cp.processThreadLimit(stmtCtx, hostInfo.ops); + cp.processThreadLimit(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_target_teams_loop: - cp.processNumTeams(stmtCtx, hostInfo.ops); + cp.processNumTeams(stmtCtx, hostInfo->ops); [[fallthrough]]; case OMPD_loop: - cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; // Standalone 'target' case. @@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, } }; - assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); - const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); assert(ompEval && llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && @@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp( mlir::Region ®ion = targetOp.getRegion(); mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region); bindEntryBlockArgs(converter, targetOp, args); - if (!hostEvalInfo.empty()) - hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); + if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) + hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter, llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv)) + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv)) cp.processCollapse(loc, eval, clauseOps, iv); clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); @@ -1753,7 +1786,8 @@ static void genParallelClauses( cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) cp.processNumThreads(stmtCtx, clauseOps); cp.processProcBind(clauseOps); @@ -1818,16 +1852,17 @@ static void genTargetClauses( llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); ClauseProcessor cp(converter, semaCtx, clauses); cp.processBare(clauseOps); cp.processDefaultMap(stmtCtx, defaultMaps); cp.processDepend(symTable, stmtCtx, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms); - if (!hostEvalInfo.empty()) { + if (hostEvalInfo) { // Only process host_eval if compiling for the host device. processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); - hostEvalInfo.back().collectValues(clauseOps.hostEvalVars); + hostEvalInfo->collectValues(clauseOps.hostEvalVars); } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); @@ -1963,7 +1998,8 @@ static void genTeamsClauses( cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) { + HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); + if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) { cp.processNumTeams(stmtCtx, clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); } @@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, lower::pft::Evaluation &eval, mlir::Location loc, const ConstructQueue &queue, ConstructQueue::const_iterator item) { - assert(!sectionsStack.empty()); + const parser::OpenMPSectionsConstruct *sectionsConstruct = + getSectionsConstructStackTop(converter); + assert(sectionsConstruct); + const auto §ionBlocks = - std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t); - sectionsStack.pop_back(); + std::get<parser::OmpSectionBlocks>(sectionsConstruct->t); + converter.getStateStack().stackPop(); mlir::omp::SectionsOperands clauseOps; llvm::SmallVector<const semantics::Symbol *> reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, @@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // Introduce a new host_eval information structure for this target region. if (!isTargetDevice) - hostEvalInfo.emplace_back(); + converter.getStateStack().stackPush<HostEvalInfoStackFrame>(); mlir::omp::TargetOperands clauseOps; DefaultMapsTy defaultMaps; @@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, // Remove the host_eval information structure created for this target region. if (!isTargetDevice) - hostEvalInfo.pop_back(); + converter.getStateStack().stackPop(); return targetOp; } @@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx, eval, source, directive, clauses)}; - sectionsStack.push_back(§ionsConstruct); + converter.getStateStack().stackPush<SectionsConstructStackFrame>( + sectionsConstruct); genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue, queue.begin()); } diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h index ac70d05a3020a..44972fafe7fed 100644 --- a/mlir/include/mlir/Support/StateStack.h +++ b/mlir/include/mlir/Support/StateStack.h @@ -84,6 +84,17 @@ class StateStack { return WalkResult::advance(); } + /// Get the top instance of frame type `T` or nullptr if none are found + template <typename T> + T *getStackTop() { + T *top = nullptr; + stackWalk<T>([&](T &frame) -> mlir::WalkResult { + top = &frame; + return mlir::WalkResult::interrupt(); + }); + return top; + } + private: SmallVector<std::unique_ptr<StateStackFrame>> stack; }; >From ec40d1aba6ab9af2830881b48e33ebfb1badb1a9 Mon Sep 17 00:00:00 2001 From: Tom Eccles <tom.ecc...@arm.com> Date: Fri, 20 Jun 2025 11:25:58 +0000 Subject: [PATCH 2/2] Review comments --- flang/lib/Lower/OpenMP/OpenMP.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index bff3321af2814..14e279cb2e759 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -217,7 +217,7 @@ getHostEvalInfoStackTop(lower::AbstractConverter &converter) { } /// Stack frame for storing the OpenMPSectionsConstruct currently being -/// processed so that it can be refered to when lowering the construct. +/// processed so that it can be referred to when lowering the construct. class SectionsConstructStackFrame : public mlir::StateStackFrameBase<SectionsConstructStackFrame> { public: @@ -1852,14 +1852,13 @@ static void genTargetClauses( llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) { - HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter); ClauseProcessor cp(converter, semaCtx, clauses); cp.processBare(clauseOps); cp.processDefaultMap(stmtCtx, defaultMaps); cp.processDepend(symTable, stmtCtx, clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms); - if (hostEvalInfo) { + if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) { // Only process host_eval if compiling for the host device. processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); hostEvalInfo->collectValues(clauseOps.hostEvalVars); @@ -2251,9 +2250,6 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable, converter.getCurrentLocation(), clauseOps); } -/// This breaks the normal prototype of the gen*Op functions: adding the -/// sectionBlocks argument so that the enclosed section constructs can be -/// lowered here with correct reduction symbol remapping. static mlir::omp::SectionsOp genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, @@ -2262,11 +2258,10 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { const parser::OpenMPSectionsConstruct *sectionsConstruct = getSectionsConstructStackTop(converter); - assert(sectionsConstruct); + assert(sectionsConstruct && "Missing additional parsing information"); const auto §ionBlocks = std::get<parser::OmpSectionBlocks>(sectionsConstruct->t); - converter.getStateStack().stackPop(); mlir::omp::SectionsOperands clauseOps; llvm::SmallVector<const semantics::Symbol *> reductionSyms; genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps, @@ -4274,8 +4269,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx, eval, source, directive, clauses)}; - converter.getStateStack().stackPush<SectionsConstructStackFrame>( - sectionsConstruct); + mlir::SaveStateStack<SectionsConstructStackFrame> saveStateStack{ + converter.getStateStack(), sectionsConstruct}; genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue, queue.begin()); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits