https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81623
>From 655dce519efb87f8d3babf3b7a5d6132bb82e2a6 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Wed, 21 Feb 2024 15:51:38 -0600 Subject: [PATCH] [flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProcessor Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work. --- flang/include/flang/Evaluate/tools.h | 23 ++ flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 218 ++++++++---------- flang/lib/Lower/OpenMP/ClauseProcessor.h | 29 ++- flang/lib/Lower/OpenMP/Clauses.cpp | 6 - flang/lib/Lower/OpenMP/Clauses.h | 6 + flang/lib/Lower/OpenMP/OpenMP.cpp | 182 +++++++-------- flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 155 ++++++------- flang/lib/Lower/OpenMP/ReductionProcessor.h | 23 +- flang/lib/Lower/OpenMP/Utils.cpp | 41 ++-- flang/lib/Lower/OpenMP/Utils.h | 10 +- 10 files changed, 348 insertions(+), 345 deletions(-) diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index d257da1a709642..e9999974944e88 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) { } } +struct ExtractSubstringHelper { + template <typename T> static std::optional<Substring> visit(T &&) { + return std::nullopt; + } + + static std::optional<Substring> visit(const Substring &e) { return e; } + + template <typename T> + static std::optional<Substring> visit(const Designator<T> &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } + + template <typename T> + static std::optional<Substring> visit(const Expr<T> &e) { + return std::visit([](auto &&s) { return visit(s); }, e.u); + } +}; + +template <typename A> +std::optional<Substring> ExtractSubstring(const A &x) { + return ExtractSubstringHelper::visit(x); +} + // If an expression is simply a whole symbol data designator, // extract and return that symbol, else null. template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 9987cd73fc7670..6e45a939333d62 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -87,7 +87,7 @@ getSimdModifier(const omp::clause::Schedule &clause) { static void genAllocateClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpAllocateClause &ompAllocateClause, + const omp::clause::Allocate &clause, llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, llvm::SmallVectorImpl<mlir::Value> &allocateOperands) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -95,21 +95,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, Fortran::lower::StatementContext stmtCtx; mlir::Value allocatorOperand; - const Fortran::parser::OmpObjectList &ompObjectList = - std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t); - const auto &allocateModifier = std::get< - std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>( - ompAllocateClause.t); + const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t); + const auto &modifier = + std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t); // If the allocate modifier is present, check if we only use the allocator // submodifier. ALIGN in this context is unimplemented const bool onlyAllocator = - allocateModifier && - std::holds_alternative< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); + modifier && + std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>( + modifier->u); - if (allocateModifier && !onlyAllocator) { + if (modifier && !onlyAllocator) { TODO(currentLocation, "OmpAllocateClause ALIGN modifier"); } @@ -117,20 +114,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter, // to list of allocators, otherwise, add default allocator to // list of allocators. if (onlyAllocator) { - const auto &allocatorValue = std::get< - Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>( - allocateModifier->u); - allocatorOperand = fir::getBase(converter.genExprValue( - *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx)); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + const auto &value = + std::get<omp::clause::Allocate::Modifier::Allocator>(modifier->u); + mlir::Value operand = + fir::getBase(converter.genExprValue(value.v, stmtCtx)); + allocatorOperands.append(objectList.size(), operand); } else { - allocatorOperand = firOpBuilder.createIntegerConstant( + mlir::Value operand = firOpBuilder.createIntegerConstant( currentLocation, firOpBuilder.getI32Type(), 1); - allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), - allocatorOperand); + allocatorOperands.append(objectList.size(), operand); } - genObjectList(ompObjectList, converter, allocateOperands); + genObjectList(objectList, converter, allocateOperands); } static mlir::omp::ClauseProcBindKindAttr @@ -157,20 +151,17 @@ genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder, static mlir::omp::ClauseTaskDependAttr genDependKindAttr(fir::FirOpBuilder &firOpBuilder, - const Fortran::parser::OmpClause::Depend *dependClause) { + const omp::clause::Depend &clause) { mlir::omp::ClauseTaskDepend pbKind; - switch ( - std::get<Fortran::parser::OmpDependenceType>( - std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u) - .t) - .v) { - case Fortran::parser::OmpDependenceType::Type::In: + const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u); + switch (std::get<omp::clause::Depend::Type>(inOut.t)) { + case omp::clause::Depend::Type::In: pbKind = mlir::omp::ClauseTaskDepend::taskdependin; break; - case Fortran::parser::OmpDependenceType::Type::Out: + case omp::clause::Depend::Type::Out: pbKind = mlir::omp::ClauseTaskDepend::taskdependout; break; - case Fortran::parser::OmpDependenceType::Type::Inout: + case omp::clause::Depend::Type::Inout: pbKind = mlir::omp::ClauseTaskDepend::taskdependinout; break; default: @@ -181,45 +172,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder, pbKind); } -static mlir::Value getIfClauseOperand( - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpClause::If *ifClause, - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Location clauseLocation) { +static mlir::Value +getIfClauseOperand(Fortran::lower::AbstractConverter &converter, + const omp::clause::If &clause, + omp::clause::If::DirectiveNameModifier directiveName, + mlir::Location clauseLocation) { // Only consider the clause if it's intended for the given directive. - auto &directive = std::get< - std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>( - ifClause->v.t); + auto &directive = + std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t); if (directive && directive.value() != directiveName) return nullptr; Fortran::lower::StatementContext stmtCtx; fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t); mlir::Value ifVal = fir::getBase( - converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); + converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx)); return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(), ifVal); } static void addUseDeviceClause(Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpObjectList &useDeviceClause, + const omp::ObjectList &objects, llvm::SmallVectorImpl<mlir::Value> &operands, llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols) { - genObjectList(useDeviceClause, converter, operands); + genObjectList(objects, converter, operands); for (mlir::Value &operand : operands) { checkMapType(operand.getLoc(), operand.getType()); useDeviceTypes.push_back(operand.getType()); useDeviceLocs.push_back(operand.getLoc()); } - for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - useDeviceSymbols.push_back(sym); - } + for (const omp::Object &object : objects) + useDeviceSymbols.push_back(object.id()); } //===----------------------------------------------------------------------===// @@ -527,10 +514,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const { bool ClauseProcessor::processAllocate( llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const { - return findRepeatableClause<ClauseTy::Allocate>( - [&](const ClauseTy::Allocate *allocateClause, + return findRepeatableClause<omp::clause::Allocate>( + [&](const omp::clause::Allocate &clause, const Fortran::parser::CharBlock &) { - genAllocateClause(converter, allocateClause->v, allocatorOperands, + genAllocateClause(converter, clause, allocatorOperands, allocateOperands); }); } @@ -547,12 +534,12 @@ bool ClauseProcessor::processCopyin() const { if (converter.isPresentShallowLookup(*sym)) converter.copyHostAssociateVar(*sym, copyAssignIP); }; - bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>( - [&](const ClauseTy::Copyin *copyinClause, + bool hasCopyin = findRepeatableClause<omp::clause::Copyin>( + [&](const omp::clause::Copyin &clause, const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const omp::Object &object : clause.v) { + Fortran::semantics::Symbol *sym = object.id(); + assert(sym && "Expecting symbol"); if (const auto *commonDetails = sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) { for (const auto &mem : commonDetails->objects()) @@ -716,13 +703,11 @@ bool ClauseProcessor::processCopyPrivate( copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp)); }; - bool hasCopyPrivate = findRepeatableClause<ClauseTy::Copyprivate>( - [&](const ClauseTy::Copyprivate *copyPrivateClause, + bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>( + [&](const clause::Copyprivate &clause, const Fortran::parser::CharBlock &) { - const Fortran::parser::OmpObjectList &ompObjectList = - copyPrivateClause->v; - for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); + for (const Object &object : clause.v) { + Fortran::semantics::Symbol *sym = object.id(); if (const auto *commonDetails = sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) { for (const auto &mem : commonDetails->objects()) @@ -741,38 +726,30 @@ bool ClauseProcessor::processDepend( llvm::SmallVectorImpl<mlir::Value> &dependOperands) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause<ClauseTy::Depend>( - [&](const ClauseTy::Depend *dependClause, + return findRepeatableClause<omp::clause::Depend>( + [&](const omp::clause::Depend &clause, const Fortran::parser::CharBlock &) { - const std::list<Fortran::parser::Designator> &depVal = - std::get<std::list<Fortran::parser::Designator>>( - std::get<Fortran::parser::OmpDependClause::InOut>( - dependClause->v.u) - .t); + assert(std::holds_alternative<omp::clause::Depend::InOut>(clause.u) && + "Only InOut is handled at the moment"); + const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u); + const auto &objects = std::get<omp::ObjectList>(inOut.t); + mlir::omp::ClauseTaskDependAttr dependTypeOperand = - genDependKindAttr(firOpBuilder, dependClause); - dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(), - dependTypeOperand); - for (const Fortran::parser::Designator &ompObject : depVal) { - Fortran::semantics::Symbol *sym = nullptr; - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::DataRef &designator) { - if (const Fortran::parser::Name *name = - std::get_if<Fortran::parser::Name>(&designator.u)) { - sym = name->symbol; - } else if (std::get_if<Fortran::common::Indirection< - Fortran::parser::ArrayElement>>( - &designator.u)) { - TODO(converter.getCurrentLocation(), - "array sections not supported for task depend"); - } - }, - [&](const Fortran::parser::Substring &designator) { - TODO(converter.getCurrentLocation(), - "substring not supported for task depend"); - }}, - (ompObject).u); + genDependKindAttr(firOpBuilder, clause); + dependTypeOperands.append(objects.size(), dependTypeOperand); + + for (const omp::Object &object : objects) { + assert(object.ref() && "Expecting designator"); + + if (Fortran::evaluate::ExtractSubstring(*object.ref())) { + TODO(converter.getCurrentLocation(), + "substring not supported for task depend"); + } else if (Fortran::evaluate::IsArrayElement(*object.ref())) { + TODO(converter.getCurrentLocation(), + "array sections not supported for task depend"); + } + + Fortran::semantics::Symbol *sym = object.id(); const mlir::Value variable = converter.getSymbolAddress(*sym); dependOperands.push_back(variable); } @@ -780,14 +757,14 @@ bool ClauseProcessor::processDepend( } bool ClauseProcessor::processIf( - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, + omp::clause::If::DirectiveNameModifier directiveName, mlir::Value &result) const { bool found = false; - findRepeatableClause<ClauseTy::If>( - [&](const ClauseTy::If *ifClause, + findRepeatableClause<omp::clause::If>( + [&](const omp::clause::If &clause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); - mlir::Value operand = getIfClauseOperand(converter, ifClause, + mlir::Value operand = getIfClauseOperand(converter, clause, directiveName, clauseLocation); // Assume that, at most, a single 'if' clause will be applicable to the // given directive. @@ -801,12 +778,11 @@ bool ClauseProcessor::processIf( bool ClauseProcessor::processLink( llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { - return findRepeatableClause<ClauseTy::Link>( - [&](const ClauseTy::Link *linkClause, - const Fortran::parser::CharBlock &) { + return findRepeatableClause<omp::clause::Link>( + [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) { // Case: declare target link(var1, var2)... gatherFuncAndVarSyms( - linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result); + clause.v, mlir::omp::DeclareTargetCaptureClause::link, result); }); } @@ -843,7 +819,7 @@ bool ClauseProcessor::processMap( llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - return findRepeatableClause<ClauseTy::Map>( + return findRepeatableClause2<ClauseTy::Map>( [&](const ClauseTy::Map *mapClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); @@ -935,43 +911,41 @@ bool ClauseProcessor::processReduction( llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols) const { - return findRepeatableClause<ClauseTy::Reduction>( - [&](const ClauseTy::Reduction *reductionClause, + return findRepeatableClause<omp::clause::Reduction>( + [&](const omp::clause::Reduction &clause, const Fortran::parser::CharBlock &) { ReductionProcessor rp; - rp.addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols, - reductionSymbols); + rp.addReductionDecl(currentLocation, converter, clause, reductionVars, + reductionDeclSymbols, reductionSymbols); }); } bool ClauseProcessor::processSectionsReduction( mlir::Location currentLocation) const { - return findRepeatableClause<ClauseTy::Reduction>( - [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) { + return findRepeatableClause<omp::clause::Reduction>( + [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) { TODO(currentLocation, "OMPC_Reduction"); }); } bool ClauseProcessor::processTo( llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { - return findRepeatableClause<ClauseTy::To>( - [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) { + return findRepeatableClause<omp::clause::To>( + [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) { // Case: declare target to(func, var1, var2)... - gatherFuncAndVarSyms(toClause->v, + gatherFuncAndVarSyms(clause.v, mlir::omp::DeclareTargetCaptureClause::to, result); }); } bool ClauseProcessor::processEnter( llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const { - return findRepeatableClause<ClauseTy::Enter>( - [&](const ClauseTy::Enter *enterClause, + return findRepeatableClause<omp::clause::Enter>( + [&](const omp::clause::Enter &clause, const Fortran::parser::CharBlock &) { // Case: declare target enter(func, var1, var2)... - gatherFuncAndVarSyms(enterClause->v, - mlir::omp::DeclareTargetCaptureClause::enter, - result); + gatherFuncAndVarSyms( + clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result); }); } @@ -981,11 +955,11 @@ bool ClauseProcessor::processUseDeviceAddr( llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols) const { - return findRepeatableClause<ClauseTy::UseDeviceAddr>( - [&](const ClauseTy::UseDeviceAddr *devAddrClause, + return findRepeatableClause<omp::clause::UseDeviceAddr>( + [&](const omp::clause::UseDeviceAddr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devAddrClause->v, operands, - useDeviceTypes, useDeviceLocs, useDeviceSymbols); + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, + useDeviceLocs, useDeviceSymbols); }); } @@ -995,10 +969,10 @@ bool ClauseProcessor::processUseDevicePtr( llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols) const { - return findRepeatableClause<ClauseTy::UseDevicePtr>( - [&](const ClauseTy::UseDevicePtr *devPtrClause, + return findRepeatableClause<omp::clause::UseDevicePtr>( + [&](const omp::clause::UseDevicePtr &clause, const Fortran::parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes, + addUseDeviceClause(converter, clause.v, operands, useDeviceTypes, useDeviceLocs, useDeviceSymbols); }); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index c87fc30c88bb93..3f6adcce8ae877 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -105,9 +105,8 @@ class ClauseProcessor { llvm::SmallVectorImpl<mlir::Value> &dependOperands) const; bool processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; - bool - processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName, - mlir::Value &result) const; + bool processIf(omp::clause::If::DirectiveNameModifier directiveName, + mlir::Value &result) const; bool processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; @@ -178,6 +177,10 @@ class ClauseProcessor { /// if at least one instance was found. template <typename T> bool findRepeatableClause( + std::function<void(const T &, const Fortran::parser::CharBlock &source)> + callbackFn) const; + template <typename T> + bool findRepeatableClause2( std::function<void(const T *, const Fortran::parser::CharBlock &source)> callbackFn) const; @@ -195,7 +198,7 @@ template <typename T> bool ClauseProcessor::processMotionClauses( Fortran::lower::StatementContext &stmtCtx, llvm::SmallVectorImpl<mlir::Value> &mapOperands) { - return findRepeatableClause<T>( + return findRepeatableClause2<T>( [&](const T *motionClause, const Fortran::parser::CharBlock &source) { mlir::Location clauseLocation = converter.genLocation(source); fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); @@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause( template <typename T> bool ClauseProcessor::findRepeatableClause( + std::function<void(const T &, const Fortran::parser::CharBlock &source)> + callbackFn) const { + bool found = false; + ClauseIterator nextIt, endIt = clauses.end(); + for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { + nextIt = findClause<T>(it, endIt); + + if (nextIt != endIt) { + callbackFn(std::get<T>(nextIt->u), nextIt->source); + found = true; + ++nextIt; + } + } + return found; +} + +template <typename T> +bool ClauseProcessor::findRepeatableClause2( std::function<void(const T *, const Fortran::parser::CharBlock &source)> callbackFn) const { bool found = false; diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index 0b90b705b9e406..a3aa3d4de3cdc9 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -205,12 +205,6 @@ namespace clause { #undef EMPTY_CLASS #undef WRAPPER_CLASS -using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>; -using ProcedureDesignator = - tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>; -using ReductionOperator = - tomp::clause::ReductionOperatorT<SymIdent, SymReference>; - DefinedOperator makeDefOp(const parser::DefinedOperator &inp, semantics::SemanticsContext &semaCtx) { return DefinedOperator{ diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h index a7e563f4b0f90b..c167e34637d500 100644 --- a/flang/lib/Lower/OpenMP/Clauses.h +++ b/flang/lib/Lower/OpenMP/Clauses.h @@ -106,6 +106,12 @@ getBaseObject(const Object &object, Fortran::semantics::SemanticsContext &semaCtx); namespace clause { +using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>; +using ProcedureDesignator = + tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>; +using ReductionOperator = + tomp::clause::ReductionOperatorT<SymIdent, SymReference>; + #ifdef EMPTY_CLASS #undef EMPTY_CLASS #endif diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 7953bf83cba0fe..7445c0f13526f7 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -572,8 +572,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand); cp.processNumThreads(stmtCtx, numThreadsClauseOperand); cp.processProcBind(procBindKindAttr); cp.processDefault(); @@ -676,8 +675,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter, dependOperands; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processFinal(stmtCtx, finalClauseOperand); @@ -738,7 +736,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData, + cp.processIf(clause::If::DirectiveNameModifier::TargetData, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs, @@ -770,19 +768,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Value> mapOperands, dependOperands; llvm::SmallVector<mlir::Attribute> dependTypeOperands; - Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName; + clause::If::DirectiveNameModifier directiveName; llvm::omp::Directive directive; if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData; + directiveName = clause::If::DirectiveNameModifier::TargetEnterData; directive = llvm::omp::Directive::OMPD_target_enter_data; } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData; + directiveName = clause::If::DirectiveNameModifier::TargetExitData; directive = llvm::omp::Directive::OMPD_target_exit_data; } else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) { - directiveName = - Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate; + directiveName = clause::If::DirectiveNameModifier::TargetUpdate; directive = llvm::omp::Directive::OMPD_target_update; } else { return nullptr; @@ -984,8 +979,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand); cp.processDevice(stmtCtx, deviceOperand); cp.processThreadLimit(stmtCtx, threadLimitOperand); cp.processDepend(dependTypeOperands, dependOperands); @@ -1102,8 +1096,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<mlir::Attribute> reductionDeclSymbols; ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand); cp.processAllocate(allocatorOperands, allocateOperands); cp.processDefault(); cp.processNumTeams(stmtCtx, numTeamsClauseOperand); @@ -1142,8 +1135,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( if (const auto *objectList{ Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) { + ObjectList objects{makeList(*objectList, semaCtx)}; // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to, + gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, symbolAndClause); } else if (const auto *clauseList{ Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>( @@ -1257,7 +1251,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter, if (const auto &ompObjectList = std::get<std::optional<Fortran::parser::OmpObjectList>>( flushConstruct.t)) - genObjectList(*ompObjectList, converter, operandRange); + genObjectList2(*ompObjectList, converter, operandRange); const auto &memOrderClause = std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>( flushConstruct.t); @@ -1419,8 +1413,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, loopVarTypeSize); cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand); cp.processReduction(loc, reductionVars, reductionDeclSymbols); - cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd, - ifClauseOperand); + cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand); cp.processSimdlen(simdlenClauseOperand); cp.processSafelen(safelenClauseOperand); cp.processTODO<Fortran::parser::OmpClause::Aligned, @@ -2223,106 +2216,99 @@ void Fortran::lower::genOpenMPReduction( const Fortran::parser::OmpClauseList &clauseList) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - for (const Fortran::parser::OmpClause &clause : clauseList.v) { + List<Clause> clauses{makeList(clauseList, semaCtx)}; + + for (const Clause &clause : clauses) { if (const auto &reductionClause = - std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { - const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( - reductionClause->v.t)}; - const auto &objectList{ - std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; + std::get_if<clause::Reduction>(&clause.u)) { + const auto &redOperator{ + std::get<clause::ReductionOperator>(reductionClause->t)}; + const auto &objects{std::get<ObjectList>(reductionClause->t)}; if (const auto *reductionOp = - std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { + std::get_if<clause::DefinedOperator>(&redOperator.u)) { const auto &intrinsicOp{ - std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( + std::get<clause::DefinedOperator::IntrinsicOperator>( reductionOp->u)}; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case clause::DefinedOperator::IntrinsicOperator::Add: + case clause::DefinedOperator::IntrinsicOperator::Multiply: + case clause::DefinedOperator::IntrinsicOperator::AND: + case clause::DefinedOperator::IntrinsicOperator::EQV: + case clause::DefinedOperator::IntrinsicOperator::OR: + case clause::DefinedOperator::IntrinsicOperator::NEQV: break; default: continue; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast<fir::ReferenceType>().getEleTy(); - if (!reductionType.isa<fir::LogicalType>()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast<fir::LoadOp>( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa<fir::LogicalType>()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const Object &object : objects) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) + reductionVal = declOp.getBase(); + mlir::Type reductionType = + reductionVal.getType().cast<fir::ReferenceType>().getEleTy(); + if (!reductionType.isa<fir::LogicalType>()) { + if (!reductionType.isIntOrIndexOrFloat()) + continue; + } + for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { + if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + if (reductionType.isa<fir::LogicalType>()) { + mlir::Operation *reductionOp = findReductionChain(loadVal); + fir::ConvertOp convertOp = + getConvertFromReductionOp(reductionOp, loadVal); + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal, &convertOp); + removeStoreOp(reductionOp, reductionVal); + } else if (mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal)) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } } } } else if (const auto *reductionIntrinsic = - std::get_if<Fortran::parser::ProcedureDesignator>( + std::get_if<clause::ProcedureDesignator>( &redOperator.u)) { if (!ReductionProcessor::supportedIntrinsicProcReduction( *reductionIntrinsic)) continue; ReductionProcessor::ReductionIdentifier redId = ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast<fir::LoadOp>( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + for (const Object &object : objects) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; + + if (redId == ReductionProcessor::ReductionIdentifier::MAX || + redId == ReductionProcessor::ReductionIdentifier::MIN) { + assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redId == ReductionProcessor::ReductionIdentifier::IOR || + redId == ReductionProcessor::ReductionIdentifier::IEOR || + redId == ReductionProcessor::ReductionIdentifier::IAND) { + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index a8b98f3f567249..bf755b27487d95 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -23,9 +23,9 @@ namespace lower { namespace omp { ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( - const Fortran::parser::ProcedureDesignator &pd) { + const omp::clause::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( - ReductionProcessor::getRealName(pd).ToString()) + getRealName(pd.v.id()).ToString()) .Case("max", ReductionIdentifier::MAX) .Case("min", ReductionIdentifier::MIN) .Case("iand", ReductionIdentifier::IAND) @@ -37,21 +37,21 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( } ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) { + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: return ReductionIdentifier::ADD; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract: + case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: return ReductionIdentifier::SUBTRACT; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: return ReductionIdentifier::MULTIPLY; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return ReductionIdentifier::AND; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return ReductionIdentifier::EQV; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return ReductionIdentifier::OR; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return ReductionIdentifier::NEQV; default: llvm_unreachable("unexpected intrinsic operator in reduction"); @@ -59,13 +59,11 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( } bool ReductionProcessor::supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - if (!name->symbol->GetUltimate().attrs().test( - Fortran::semantics::Attr::INTRINSIC)) + const omp::clause::ProcedureDesignator &pd) { + Fortran::semantics::Symbol *sym = pd.v.id(); + if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) return false; - auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString()) + auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) .Case("max", true) .Case("min", true) .Case("iand", true) @@ -84,24 +82,24 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name, } std::string ReductionProcessor::getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, mlir::Type ty) { std::string reductionName; switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case omp::clause::DefinedOperator::IntrinsicOperator::Add: reductionName = "add_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: reductionName = "multiply_reduction"; break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case omp::clause::DefinedOperator::IntrinsicOperator::AND: return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case omp::clause::DefinedOperator::IntrinsicOperator::EQV: return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case omp::clause::DefinedOperator::IntrinsicOperator::OR: return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: return "neqv_reduction"; default: reductionName = "other_reduction"; @@ -305,7 +303,7 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl( void ReductionProcessor::addReductionDecl( mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> @@ -313,12 +311,12 @@ void ReductionProcessor::addReductionDecl( fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::omp::ReductionDeclareOp decl; const auto &redOperator{ - std::get<Fortran::parser::OmpReductionOperator>(reduction.t)}; - const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)}; + std::get<omp::clause::ReductionOperator>(reduction.t)}; + const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; if (const auto &redDefinedOp = - std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { + std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { const auto &intrinsicOp{ - std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( + std::get<omp::clause::DefinedOperator::IntrinsicOperator>( redDefinedOp->u)}; ReductionIdentifier redId = getReductionType(intrinsicOp); switch (redId) { @@ -334,10 +332,41 @@ void ReductionProcessor::addReductionDecl( "Reduction of some intrinsic operators is not supported"); break; } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { + if (reductionSymbols) + reductionSymbols->push_back(symbol); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast<fir::ReferenceType>().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa<fir::LogicalType>()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, + redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + redId, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } else if (const auto *reductionIntrinsic = + std::get_if<omp::clause::ProcedureDesignator>( + &redOperator.u)) { + if (ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) { + ReductionProcessor::ReductionIdentifier redId = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const omp::Object &object : objectList) { + if (const Fortran::semantics::Symbol *symbol = object.id()) { if (reductionSymbols) reductionSymbols->push_back(symbol); mlir::Value symVal = converter.getSymbolAddress(*symbol); @@ -346,68 +375,28 @@ void ReductionProcessor::addReductionDecl( mlir::Type redType = symVal.getType().cast<fir::ReferenceType>().getEleTy(); reductionVars.push_back(symVal); - if (redType.isa<fir::LogicalType>()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId, - redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - redId, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } + assert(redType.isIntOrIndexOrFloat() && "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType), + redId, redType, currentLocation); reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( firOpBuilder.getContext(), decl.getSymName())); } } } - } else if (const auto *reductionIntrinsic = - std::get_if<Fortran::parser::ProcedureDesignator>( - &redOperator.u)) { - if (ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) { - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - if (reductionSymbols) - reductionSymbols->push_back(symbol); - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast<fir::ReferenceType>().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, - getReductionName(getRealName(*reductionIntrinsic).ToString(), - redType), - redId, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } } } const Fortran::semantics::SourceName -ReductionProcessor::getRealName(const Fortran::parser::Name *name) { - return name->symbol->GetUltimate().name(); +ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { + return symbol->GetUltimate().name(); } -const Fortran::semantics::SourceName ReductionProcessor::getRealName( - const Fortran::parser::ProcedureDesignator &pd) { - const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; - assert(name && "Invalid Reduction Intrinsic."); - return getRealName(name); +const Fortran::semantics::SourceName +ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { + return getRealName(pd.v.id()); } int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 00770fe81d1ef6..855e2aa4ad13cd 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -13,6 +13,7 @@ #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H +#include "Clauses.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Parser/parse-tree.h" #include "flang/Semantics/symbol.h" @@ -57,25 +58,25 @@ class ReductionProcessor { }; static ReductionIdentifier - getReductionType(const Fortran::parser::ProcedureDesignator &pd); + getReductionType(const omp::clause::ProcedureDesignator &pd); - static ReductionIdentifier getReductionType( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp); + static ReductionIdentifier + getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp); - static bool supportedIntrinsicProcReduction( - const Fortran::parser::ProcedureDesignator &pd); + static bool + supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd); static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::Name *name); + getRealName(const Fortran::semantics::Symbol *symbol); static const Fortran::semantics::SourceName - getRealName(const Fortran::parser::ProcedureDesignator &pd); + getRealName(const omp::clause::ProcedureDesignator &pd); static std::string getReductionName(llvm::StringRef name, mlir::Type ty); - static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty); + static std::string + getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty); /// This function returns the identity value of the operator \p /// reductionOpName. For example: @@ -112,7 +113,7 @@ class ReductionProcessor { static void addReductionDecl(mlir::Location currentLocation, Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, + const omp::clause::Reduction &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars, llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index 31b15257d18687..9a6a28ded7006d 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Utils.h" +#include "Clauses.h" #include <flang/Lower/AbstractConverter.h> #include <flang/Lower/ConvertType.h> @@ -28,9 +29,27 @@ namespace Fortran { namespace lower { namespace omp { -void genObjectList(const Fortran::parser::OmpObjectList &objectList, +void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl<mlir::Value> &operands) { + for (const Object &object : objects) { + const Fortran::semantics::Symbol *sym = object.id(); + assert(sym && "Expected Symbol"); + if (mlir::Value variable = converter.getSymbolAddress(*sym)) { + operands.push_back(variable); + } else { + if (const auto *details = + sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { + operands.push_back(converter.getSymbolAddress(details->symbol())); + converter.copySymbolBinding(details->symbol(), *sym); + } + } + } +} + +void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl<mlir::Value> &operands) { auto addOperands = [&](Fortran::lower::SymbolRef sym) { const mlir::Value variable = converter.getSymbolAddress(sym); if (variable) { @@ -50,24 +69,10 @@ void genObjectList(const Fortran::parser::OmpObjectList &objectList, } void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, + const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { - for (const Fortran::parser::OmpObject &ompObject : objList.v) { - Fortran::common::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::Designator &designator) { - if (const Fortran::parser::Name *name = - Fortran::semantics::getDesignatorNameIfDataRef( - designator)) { - symbolAndClause.emplace_back(clause, *name->symbol); - } - }, - [&](const Fortran::parser::Name &name) { - symbolAndClause.emplace_back(clause, *name.symbol); - }}, - ompObject.u); - } + for (const Object &object : objects) + symbolAndClause.emplace_back(clause, *object.id()); } Fortran::semantics::Symbol * diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index c346f891f0797e..4ab4bc9c137071 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -9,6 +9,7 @@ #ifndef FORTRAN_LOWER_OPENMPUTILS_H #define FORTRAN_LOWER_OPENMPUTILS_H +#include "Clauses.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Value.h" @@ -50,17 +51,20 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, bool isVal = false); void gatherFuncAndVarSyms( - const Fortran::parser::OmpObjectList &objList, - mlir::omp::DeclareTargetCaptureClause clause, + const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause); Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject); -void genObjectList(const Fortran::parser::OmpObjectList &objectList, +void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl<mlir::Value> &operands); +void genObjectList2(const Fortran::parser::OmpObjectList &objectList, + Fortran::lower::AbstractConverter &converter, + llvm::SmallVectorImpl<mlir::Value> &operands); + } // namespace omp } // namespace lower } // namespace Fortran _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits