llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Sergio Afonso (skatrak) <details> <summary>Changes</summary> This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the `OpenMP_HostEvalClause` `omp.target` set of operands and entry block arguments. When lowering clauses for a target construct, a more involved `processHostEvalClauses()` function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the `omp.target` operation under construction. This populates an instance of a global structure with the resulting MLIR values. The resulting list of host-evaluated values is used to initialize the `host_eval` operands when constructing the `omp.target` operation, and then replaced with the corresponding block arguments after creating that operation's region. Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and `collapse`) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to `omp.target` entry block arguments at that stage) are used instead of lowering these clauses again. --- Patch is 34.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116219.diff 4 Files Affected: - (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+429-29) - (added) flang/test/Lower/OpenMP/host-eval.f90 (+138) - (added) flang/test/Lower/OpenMP/target-spmd.f90 (+191) - (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h (+6) ``````````diff diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 91f99ba4b0ca55..a206af77a2f51f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -45,6 +45,19 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// +static void genOMPDispatch(lower::AbstractConverter &converter, + lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, + lower::pft::Evaluation &eval, mlir::Location loc, + const ConstructQueue &queue, + ConstructQueue::const_iterator item); + +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc); + namespace { /// Structure holding the information needed to create and bind entry block /// arguments associated to a single clause. @@ -63,6 +76,7 @@ struct EntryBlockArgsEntry { /// Structure holding the information needed to create and bind entry block /// arguments associated to all clauses that can define them. struct EntryBlockArgs { + llvm::ArrayRef<mlir::Value> hostEvalVars; EntryBlockArgsEntry inReduction; EntryBlockArgsEntry map; EntryBlockArgsEntry priv; @@ -85,18 +99,146 @@ struct EntryBlockArgs { auto getVars() const { return llvm::concat<const mlir::Value>( - inReduction.vars, map.vars, priv.vars, reduction.vars, + hostEvalVars, inReduction.vars, map.vars, priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars); } }; + +/// Structure holding information that is needed to pass host-evaluated +/// information to later lowering stages. +class HostEvalInfo { +public: + // Allow this function access to private members in order to initialize them. + friend void ::processHostEvalClauses(lower::AbstractConverter &, + semantics::SemanticsContext &, + lower::StatementContext &, + lower::pft::Evaluation &, + mlir::Location); + + /// Fill \c vars with values stored in \c ops. + /// + /// The order in which values are stored matches the one expected by \see + /// bindOperands(). + void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const { + vars.append(ops.loopLowerBounds); + vars.append(ops.loopUpperBounds); + vars.append(ops.loopSteps); + + if (ops.numTeamsLower) + vars.push_back(ops.numTeamsLower); + + if (ops.numTeamsUpper) + vars.push_back(ops.numTeamsUpper); + + if (ops.numThreads) + vars.push_back(ops.numThreads); + + if (ops.threadLimit) + vars.push_back(ops.threadLimit); + } + + /// Update \c ops, replacing all values with the corresponding block argument + /// in \c args. + /// + /// The order in which values are stored in \c args is the same as the one + /// used by \see collectValues(). + void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) { + assert(args.size() == + ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.threadLimit ? 1 : 0) && + "invalid block argument list"); + int argIndex = 0; + for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) + ops.loopLowerBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i) + ops.loopUpperBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopSteps.size(); ++i) + ops.loopSteps[i] = args[argIndex++]; + + if (ops.numTeamsLower) + ops.numTeamsLower = args[argIndex++]; + + if (ops.numTeamsUpper) + ops.numTeamsUpper = args[argIndex++]; + + if (ops.numThreads) + ops.numThreads = args[argIndex++]; + + if (ops.threadLimit) + ops.threadLimit = args[argIndex++]; + } + + /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated + /// values and Fortran symbols, respectively, if they have already been + /// initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::LoopNestOperands &clauseOps, + llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) { + if (iv.empty() || loopNestApplied) { + loopNestApplied = true; + return false; + } + + loopNestApplied = true; + clauseOps.loopLowerBounds = ops.loopLowerBounds; + clauseOps.loopUpperBounds = ops.loopUpperBounds; + clauseOps.loopSteps = ops.loopSteps; + ivOut.append(iv); + return true; + } + + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::ParallelOperands &clauseOps) { + if (!ops.numThreads || parallelApplied) { + parallelApplied = true; + return false; + } + + parallelApplied = true; + clauseOps.numThreads = ops.numThreads; + return true; + } + + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::TeamsOperands &clauseOps) { + if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit) + return false; + + clauseOps.numTeamsLower = ops.numTeamsLower; + clauseOps.numTeamsUpper = ops.numTeamsUpper; + clauseOps.threadLimit = ops.threadLimit; + return true; + } + +private: + mlir::omp::HostEvaluatedOperands ops; + llvm::SmallVector<const semantics::Symbol *> iv; + bool loopNestApplied = false, parallelApplied = false; +}; } // namespace -static void genOMPDispatch(lower::AbstractConverter &converter, - lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, - lower::pft::Evaluation &eval, mlir::Location loc, - const ConstructQueue &queue, - ConstructQueue::const_iterator item); +/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target +/// operations being created. +/// +/// The current implementation prevents nested 'target' regions from breaking +/// the handling of the outer region by keeping a stack of information +/// structures, but it will probably still require some further work to support +/// reverse offloading. +static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo; /// Bind symbols to their corresponding entry block arguments. /// @@ -219,6 +361,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter, }; // Process in clause name alphabetical order to match block arguments order. + // Do not bind host_eval variables because they cannot be used inside of the + // corresponding region, except for very specific cases handled separately. bindPrivateLike(args.inReduction.syms, args.inReduction.vars, op.getInReductionBlockArgs()); bindMapLike(args.map.syms, op.getMapBlockArgs()); @@ -256,6 +400,246 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars, }); } +/// Get the directive enumeration value corresponding to the given OpenMP +/// construct PFT node. +llvm::omp::Directive +extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { + return common::visit( + common::visitors{ + [](const parser::OpenMPAllocatorsConstruct &c) { + return llvm::omp::OMPD_allocators; + }, + [](const parser::OpenMPAtomicConstruct &c) { + return llvm::omp::OMPD_atomic; + }, + [](const parser::OpenMPBlockConstruct &c) { + return std::get<parser::OmpBlockDirective>( + std::get<parser::OmpBeginBlockDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPCriticalConstruct &c) { + return llvm::omp::OMPD_critical; + }, + [](const parser::OpenMPDeclarativeAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPExecutableAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPLoopConstruct &c) { + return std::get<parser::OmpLoopDirective>( + std::get<parser::OmpBeginLoopDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPSectionConstruct &c) { + return llvm::omp::OMPD_section; + }, + [](const parser::OpenMPSectionsConstruct &c) { + return std::get<parser::OmpSectionsDirective>( + std::get<parser::OmpBeginSectionsDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPStandaloneConstruct &c) { + return common::visit( + common::visitors{ + [](const parser::OpenMPSimpleStandaloneConstruct &c) { + return std::get<parser::OmpSimpleStandaloneDirective>(c.t) + .v; + }, + [](const parser::OpenMPFlushConstruct &c) { + return llvm::omp::OMPD_flush; + }, + [](const parser::OpenMPCancelConstruct &c) { + return llvm::omp::OMPD_cancel; + }, + [](const parser::OpenMPCancellationPointConstruct &c) { + return llvm::omp::OMPD_cancellation_point; + }, + [](const parser::OpenMPDepobjConstruct &c) { + return llvm::omp::OMPD_depobj; + }}, + c.u); + }}, + ompConstruct.u); +} + +/// Populate the global \see hostEvalInfo after processing clauses for the given +/// \p eval OpenMP target construct, or nested constructs, if these must be +/// evaluated outside of the target region per the spec. +/// +/// In particular, this will ensure that in 'target teams' and equivalent nested +/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated +/// in the host. Additionally, loop bounds, steps and the \c num_threads clause +/// will also be evaluated in the host if a target SPMD construct is detected +/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting). +/// +/// The result, stored as a global, is intended to be used to populate the \c +/// host_eval operands of the associated \c omp.target operation, and also to be +/// checked and used by later lowering steps to populate the corresponding +/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest +/// operations. +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc) { + // Obtain the list of clauses of the given OpenMP block or loop construct + // evaluation. Other evaluations passed to this lambda keep `clauses` + // unchanged. + auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval, + List<Clause> &clauses) { + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + if (!ompEval) + return; + + const parser::OmpClauseList *beginClauseList = nullptr; + const parser::OmpClauseList *endClauseList = nullptr; + common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &ompConstruct) { + const auto &beginDirective = + std::get<parser::OmpBeginBlockDirective>(ompConstruct.t); + beginClauseList = + &std::get<parser::OmpClauseList>(beginDirective.t); + endClauseList = &std::get<parser::OmpClauseList>( + std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t); + }, + [&](const parser::OpenMPLoopConstruct &ompConstruct) { + const auto &beginDirective = + std::get<parser::OmpBeginLoopDirective>(ompConstruct.t); + beginClauseList = + &std::get<parser::OmpClauseList>(beginDirective.t); + + if (auto &endDirective = + std::get<std::optional<parser::OmpEndLoopDirective>>( + ompConstruct.t)) + endClauseList = + &std::get<parser::OmpClauseList>(endDirective->t); + }, + [&](const auto &) {}}, + ompEval->u); + + assert(beginClauseList && "expected begin directive"); + clauses.append(makeClauses(*beginClauseList, semaCtx)); + + if (endClauseList) + clauses.append(makeClauses(*endClauseList, semaCtx)); + }; + + // Return the directive that is immediately nested inside of the given + // `parent` evaluation, if it is its only non-end-statement nested evaluation + // and it represents an OpenMP construct. + auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent) + -> std::optional<llvm::omp::Directive> { + if (!parent.hasNestedEvaluations()) + return std::nullopt; + + llvm::omp::Directive dir; + auto &nested = parent.getFirstNestedEvaluation(); + if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>()) + dir = extractOmpDirective(*ompEval); + else + return std::nullopt; + + for (auto &sibling : parent.getNestedEvaluations()) + if (&sibling != &nested && !sibling.isEndStmt()) + return std::nullopt; + + return dir; + }; + + // Process the given evaluation assuming it's part of a 'target' construct or + // captured by one, and store results in the global `hostEvalInfo`. + std::function<void(lower::pft::Evaluation &, const List<Clause> &)> + processEval; + processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) { + using namespace llvm::omp; + ClauseProcessor cp(converter, semaCtx, clauses); + + // Call `processEval` recursively with the immediately nested evaluation and + // its corresponding clauses if there is a single nested evaluation + // representing an OpenMP directive that passes the given test. + auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) { + std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval); + if (!nestedDir || !test(*nestedDir)) + return; + + lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation(); + List<lower::omp::Clause> nestedClauses; + extractClauses(nestedEval, nestedClauses); + processEval(nestedEval, nestedClauses); + }; + + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + if (!ompEval) + return; + + HostEvalInfo &hostInfo = hostEvalInfo.back(); + + switch (extractOmpDirective(*ompEval)) { + // Cases where 'teams' and target SPMD clauses might be present. + case OMPD_teams_distribute_parallel_do: + case OMPD_teams_distribute_parallel_do_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute_parallel_do: + case OMPD_target_teams_distribute_parallel_do_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_distribute_parallel_do: + case OMPD_distribute_parallel_do_simd: + cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processNumThreads(stmtCtx, hostInfo.ops); + break; + + // Cases where 'teams' clauses might be present, and target SPMD is + // possible by looking at nested evaluations. + case OMPD_teams: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return nestedDir == OMPD_distribute_parallel_do || + nestedDir == OMPD_distribute_parallel_do_simd; + }); + break; + + // Cases where only 'teams' host-evaluated clauses might be present. + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + break; + + // Standalone 'target' case. + case OMPD_target: { + processSingleNestedIf( + [](Directive nestedDir) { return topTeamsSet.test(nestedDir); }); + break; + } + default: + break; + } + }; + + assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); + + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + assert(ompEval && + llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + "expected TARGET construct evaluation"); + + // Use the whole list of clauses passed to the construct here, rather than the + // ones only applied to omp.target. + List<lower::omp::Clause> clauses; + extractClauses(eval, clauses); + processEval(eval, clauses); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -638,11 +1022,11 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, llvm::SmallVector<mlir::Type> types; llvm::SmallVector<mlir::Location> locs; - unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() + - args.priv.vars.size() + args.reduction.vars.size() + - args.taskReduction.vars.size() + - args.useDeviceAddr.vars.size() + - args.useDevicePtr.vars.size(); + unsigned numVars = + args.hostEvalVars.size() + args.inReduction.vars.size() + + args.map.vars.size() + args.priv.vars.size() + + args.reduction.vars.size() + args.taskReduction.vars.size() + + args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size(); types.reserve(numVars); locs.reserve(numVars); @@ -655,6 +1039,7 @@ static mlir::Block *genEntryBlock(lower::AbstractConverter &converter, // Populate block arguments in clause name alphabetical order to match // expected order by the BlockArgOpenMPOpInterface. + extractTypeLoc(args.hostEvalVars); extractTypeLoc(args.inReduction.vars); extractTypeLoc(args.map.vars); extractTypeLoc(args.priv.vars); @@ -991,12 +1376,15 @@ static void genBodyOfTargetOp( mlir::omp::TargetOp &targetOp, const EntryBlockArgs &args, const mlir::Location ¤tLocation, const ConstructQueue &queue, ConstructQueue::const_iterator item, DataSharingProcessor &dsp) { + assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); mlir::Region ®ion = targetOp.getRegion(); mlir::Block *entryBlock = genEntryBlock(converter, args, region); bindEntryBlockArgs(converter, targetOp, args); + hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1172,7 +1560,10 @@ genLoopNestClauses(lower::AbstractConverter &converter, mlir::Location loc, ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/116219 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits