llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-openmp Author: Krzysztof Parzyszek (kparzysz) <details> <summary>Changes</summary> …essor Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work. --- Patch is 51.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81623.diff 2 Files Affected: - (modified) flang/include/flang/Evaluate/tools.h (+23) - (modified) flang/lib/Lower/OpenMP.cpp (+305-327) ``````````diff diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index d257da1a709642..e9999974944e88 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) { } } +struct ExtractSubstringHelper { + template <typename T> static std::optional<Substring> visit(T &&) { + return std::nullopt; + } + + static std::optional<Substring> visit(const Substring &e) { return e; } + + template <typename T> + static std::optional<Substring> visit(const Designator<T> &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } + + template <typename T> + static std::optional<Substring> visit(const Expr<T> &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } +}; + +template <typename A> +std::optional<Substring> ExtractSubstring(const A &x) { + return ExtractSubstringHelper::visit(x); +} + // If an expression is simply a whole symbol data designator, // extract and return that symbol, else null. template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) { diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index d7a93db15a4bb8..4b21ab934c9393 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -72,9 +72,9 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { return sym; } -static void genObjectList(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl<mlir::Value> &operands) { +static void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl<mlir::Value> &operands) { auto addOperands = [&](Fortran::lower::SymbolRef sym) { const mlir::Value variable = converter.getSymbolAddress(sym); if (variable) { @@ -93,27 +93,6 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList, } } -static void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, - llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.emplace_back(clause, *name->symbol); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.emplace_back(clause, *name.symbol); - }}, - ompObject.u); - } -} - static Fortran::lower::pft::Evaluation * getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -1257,6 +1236,32 @@ List<Clause> makeList(const parser::OmpClauseList &clauses, } } // namespace omp +static void genObjectList(const omp::ObjectList &objects, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl<mlir::Value> &operands) { + for (const omp::Object &object : objects) { + const Fortran::semantics::Symbol *sym = object.sym; + assert(sym && "Expected Symbol"); + if (mlir::Value variable = converter.getSymbolAddress(*sym)) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), *sym); + } + } + } +} + +static void gatherFuncAndVarSyms( + const omp::ObjectList &objects, + mlir::omp::DeclareTargetCaptureClause clause, + llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { + for (const omp::Object &object : objects) + symbolAndClause.emplace_back(clause, *object.sym); +} + //===----------------------------------------------------------------------===// // DataSharingProcessor //===----------------------------------------------------------------------===// @@ -1718,9 +1723,8 @@ class ClauseProcessor { llvm::SmallVectorImpl<mlir::Value> &dependOperands) const; bool processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; - bool - processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const; + bool processIf(omp::clause::If::DirectiveNameModifier directiveName, + mlir::Value &result) const; bool processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; @@ -1815,6 +1819,26 @@ class ClauseProcessor { /// if at least one instance was found. template <typename T> bool findRepeatableClause( + std::function<void(const T &, const Fortran::parser::CharBlock &source)> + callbackFn) const { + bool found = false; + ClauseIterator nextIt, endIt = clauses.end(); + for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { + nextIt = findClause<T>(it, endIt); + + if (nextIt != endIt) { + callbackFn(std::get<T>(nextIt->u), nextIt->source); + found = true; + ++nextIt; + } + } + return found; + } + + /// Call `callbackFn` for each occurrence of the given clause. Return `true` + /// if at least one instance was found. + template <typename T> + bool findRepeatableClause2( std::function<void(const T *, const Fortran::parser::CharBlock &source)> callbackFn) const { bool found = false; @@ -1880,9 +1904,9 @@ class ReductionProcessor { IEOR }; static ReductionIdentifier - getReductionType(const Fortran::parser::ProcedureDesignator &pd) { + getReductionType(const omp::clause::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( - getRealName(pd).ToString()) + getRealName(pd.v.sym).ToString()) .Case("max", ReductionIdentifier::MAX) .Case("min", ReductionIdentifier::MIN) .Case("iand", ReductionIdentifier::IAND) @@ -1894,35 +1918,33 @@ class ReductionProcessor { } static ReductionIdentifier getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: return ReductionIdentifier::ADD; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: return ReductionIdentifier::SUBTRACT; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: return ReductionIdentifier::MULTIPLY; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return ReductionIdentifier::AND; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return ReductionIdentifier::EQV; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return ReductionIdentifier::OR; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return ReductionIdentifier::NEQV; default: llvm_unreachable("unexpected intrinsic operator in reduction"); } } - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - if (!name->symbol->GetUltimate().attrs().test( - Fortran::semantics::Attr::INTRINSIC)) + static bool + supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) { + Fortran::semantics::Symbol *sym = pd.v.sym; + if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) return false; - auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString()) + auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) .Case("max", true) .Case("min", true) .Case("iand", true) @@ -1933,15 +1955,13 @@ class ReductionProcessor { } static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); + getRealName(const Fortran::semantics::Symbol *symbol) { + return symbol->GetUltimate().name(); } static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); + getRealName(const omp::clause::ProcedureDesignator &pd) { + return getRealName(pd.v.sym); } static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { @@ -1951,25 +1971,25 @@ class ReductionProcessor { .str(); } - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { + static std::string + getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { std::string reductionName; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return "neqv_reduction"; default: reductionName = "other_reduction"; @@ -2213,7 +2233,7 @@ class ReductionProcessor { static void addReductionDecl(mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> @@ -2221,13 +2241,12 @@ class ReductionProcessor { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ - std::get<Fortran::parser::OmpReductionOperator>(reduction.t)}; - const auto &objectList{ - std::get<Fortran::parser::OmpObjectList>(reduction.t)}; + std::get<omp::clause::ReductionOperator>(reduction.t)}; + const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; if (const auto &redDefinedOp = - std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { + std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { const auto &intrinsicOp{ - std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( + std::get<omp::clause::DefinedOperator::IntrinsicOperator>( redDefinedOp->u)}; ReductionIdentifier redId = getReductionType(intrinsicOp); switch (redId) { @@ -2243,10 +2262,41 @@ class ReductionProcessor { "Reduction of some intrinsic operators is not supported"); break; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast<fir::ReferenceType>().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa<fir::LogicalType>()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, + redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + redId, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } else if (const auto *reductionIntrinsic = + std::get_if<omp::clause::ProcedureDesignator>( + &redOperator.u)) { + if (ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) { + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.sym) { if (reductionSymbols) reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); @@ -2255,55 +2305,18 @@ class ReductionProcessor { mlir::Type redType = symVal.getType().cast<fir::ReferenceType>().getEleTy(); reductionVars.push_back(symVal); - if (redType.isa<fir::LogicalType>()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - redId, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - redId, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType), + redId, redType, currentLocation); reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } } } - } else if (const auto *reductionIntrinsic = - std::get_if<Fortran::parser::ProcedureDesignator>( - &redOperator.u)) { - if (ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) { - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast<fir::ReferenceType>().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType), - redId, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } } } }; @@ -2365,7 +2378,7 @@ getSimdModifier(const omp::clause::Schedule &clause) { static void genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, + const omp::clause::Allocate &clause, llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, llvm::SmallVectorImpl<mlir::Value> &allocateOperands) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -2373,21 +2386,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext stmtCtx; mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t); - const auto &allocateModifier = std::get< - std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>( - ompAllocateClause.t); + const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t); + const auto &modifier = + std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t); // If the allocate modifier is present, check if we only use the allocator // submodifier. ALIGN in this context is unimplemented const bool onlyAllocator = - allocateModifier && - std::holds_alternative< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); + modifier && + std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>( + modifier->u); - if (allocateModifier && !onlyAllocator) { + if (modifier && !onlyAllocator) { TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); } @@ -2395,20 +2405,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, // to list of allocators, otherwise, add default allocator to // list of allocators. if (onlyAllocator) { - const auto &allocatorValue = std::get< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - allocatorOperand = fir... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/81623 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits