skatrak updated this revision to Diff 531437.
skatrak added a comment.
Update patch to integrate with related patch D149337
<https://reviews.llvm.org/D149337> and address reviewer's comments.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D147218/new/
https://reviews.llvm.org/D147218
Files:
flang/include/flang/Lower/OpenMP.h
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/OpenMP.cpp
flang/test/Lower/OpenMP/requires-notarget.f90
flang/test/Lower/OpenMP/requires.f90
Index: flang/test/Lower/OpenMP/requires.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires.f90
@@ -0,0 +1,13 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+! This test checks the lowering of requires into MLIR
+
+!CHECK: module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>
+program requires
+ !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+end program requires
+
+subroutine f
+ !$omp declare target
+end subroutine f
Index: flang/test/Lower/OpenMP/requires-notarget.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires-notarget.f90
@@ -0,0 +1,11 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+! This test checks that requires lowering into MLIR skips creating the
+! omp.requires attribute with target-related clauses if there are no device
+! functions in the compilation unit
+
+!CHECK: module attributes {
+!CHECK-NOT: omp.requires
+program requires
+ !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+end program requires
Index: flang/lib/Lower/OpenMP.cpp
===================================================================
--- flang/lib/Lower/OpenMP.cpp
+++ flang/lib/Lower/OpenMP.cpp
@@ -2594,16 +2594,14 @@
converter.bindSymbol(sym, symThreadprivateExv);
}
-void handleDeclareTarget(Fortran::lower::AbstractConverter &converter,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPDeclareTargetConstruct
- &declareTargetConstruct) {
- llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause,
- Fortran::semantics::Symbol>,
- 0>
- symbolAndClause;
- mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
-
+/// Extract the list of function and variable symbols affected by the given
+/// 'declare target' directive and return the intended device type for them.
+static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+ SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
+ Fortran::semantics::Symbol>> &symbolAndClause) {
+ // Gather the symbols and clauses
auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList,
mlir::omp::DeclareTargetCaptureClause clause) {
for (const Fortran::parser::OmpObject &ompObject : objList.v) {
@@ -2628,6 +2626,7 @@
Fortran::parser::OmpDeviceTypeClause::Type::Any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);
+
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
// Case: declare target(func, var1, var2)
@@ -2662,6 +2661,28 @@
}
}
+ switch (deviceType) {
+ case Fortran::parser::OmpDeviceTypeClause::Type::Any:
+ return mlir::omp::DeclareTargetDeviceType::any;
+ case Fortran::parser::OmpDeviceTypeClause::Type::Host:
+ return mlir::omp::DeclareTargetDeviceType::host;
+ case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
+ return mlir::omp::DeclareTargetDeviceType::nohost;
+ }
+}
+
+void genDeclareTarget(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclareTargetConstruct
+ &declareTargetConstruct) {
+ llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause,
+ Fortran::semantics::Symbol>,
+ 0>
+ symbolAndClause;
+ mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+ mlir::omp::DeclareTargetDeviceType deviceType =
+ getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause);
+
for (std::pair<mlir::omp::DeclareTargetCaptureClause,
Fortran::semantics::Symbol>
symClause : symbolAndClause) {
@@ -2688,35 +2709,44 @@
converter.getCurrentLocation(),
"Attempt to apply declare target on unsupported operation");
- mlir::omp::DeclareTargetDeviceType newDeviceType;
- switch (deviceType) {
- case Fortran::parser::OmpDeviceTypeClause::Type::Nohost:
- newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost;
- break;
- case Fortran::parser::OmpDeviceTypeClause::Type::Host:
- newDeviceType = mlir::omp::DeclareTargetDeviceType::host;
- break;
- case Fortran::parser::OmpDeviceTypeClause::Type::Any:
- newDeviceType = mlir::omp::DeclareTargetDeviceType::any;
- break;
- }
-
// The function or global already has a declare target applied to it,
// very likely through implicit capture (usage in another declare
// target function/subroutine). It should be marked as any if it has
// been assigned both host and nohost, else we skip, as there is no
// change
if (declareTargetOp.isDeclareTarget()) {
- if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType)
+ if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
declareTargetOp.setDeclareTarget(
mlir::omp::DeclareTargetDeviceType::any, std::get<0>(symClause));
continue;
}
- declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(symClause));
+ declareTargetOp.setDeclareTarget(deviceType, std::get<0>(symClause));
}
}
+void Fortran::lower::analyzeOpenMPDeclarativeConstruct(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl,
+ bool &ompDeviceCodeFound) {
+ std::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
+ mlir::omp::DeclareTargetDeviceType targetType =
+ Fortran::lower::getOpenMPDeclareTargetFunctionDevice(
+ converter, eval, ompReq)
+ .value_or(mlir::omp::DeclareTargetDeviceType::host);
+
+ ompDeviceCodeFound =
+ ompDeviceCodeFound ||
+ targetType != mlir::omp::DeclareTargetDeviceType::host;
+ },
+ [&](const auto &) {},
+ },
+ ompDecl.u);
+}
+
void Fortran::lower::genOpenMPDeclarativeConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
@@ -2739,11 +2769,14 @@
},
[&](const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
- handleDeclareTarget(converter, eval, declareTargetConstruct);
+ genDeclareTarget(converter, eval, declareTargetConstruct);
},
[&](const Fortran::parser::OpenMPRequiresConstruct
&requiresConstruct) {
- TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct");
+ // Requires directives are gathered and processed in semantics in
+ // order to support modules, and then combined in the lowering
+ // bridge before triggering codegen just once. Hence, there is no
+ // need for codegen for each individual occurrence here.
},
[&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
// The directive is lowered when instantiating the variable to
@@ -2965,3 +2998,84 @@
}
}
}
+
+std::optional<mlir::omp::DeclareTargetDeviceType>
+Fortran::lower::getOpenMPDeclareTargetFunctionDevice(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPDeclareTargetConstruct
+ &declareTargetConstruct) {
+ llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause,
+ Fortran::semantics::Symbol>,
+ 0>
+ symbolAndClause;
+ mlir::omp::DeclareTargetDeviceType deviceType =
+ getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause);
+
+ // Return the device type only if at least one of the targets for the
+ // directive is a function or subroutine
+ mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+ for (std::pair<mlir::omp::DeclareTargetCaptureClause,
+ Fortran::semantics::Symbol>
+ sym : symbolAndClause) {
+ mlir::Operation *op =
+ mod.lookupSymbol(converter.mangleName(std::get<1>(sym)));
+
+ if (mlir::isa<mlir::func::FuncOp>(op))
+ return deviceType;
+ }
+
+ return std::nullopt;
+}
+
+bool Fortran::lower::isOpenMPTargetConstruct(
+ const Fortran::parser::OpenMPConstruct &omp) {
+ if (const auto *blockDir =
+ std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
+ const auto &beginBlockDir{
+ std::get<Fortran::parser::OmpBeginBlockDirective>(blockDir->t)};
+ const auto &beginDir{
+ std::get<Fortran::parser::OmpBlockDirective>(beginBlockDir.t)};
+
+ switch (beginDir.v) {
+ case llvm::omp::Directive::OMPD_target:
+ case llvm::omp::Directive::OMPD_target_parallel:
+ case llvm::omp::Directive::OMPD_target_parallel_do:
+ case llvm::omp::Directive::OMPD_target_parallel_do_simd:
+ case llvm::omp::Directive::OMPD_target_simd:
+ case llvm::omp::Directive::OMPD_target_teams:
+ case llvm::omp::Directive::OMPD_target_teams_distribute:
+ case llvm::omp::Directive::OMPD_target_teams_distribute_simd:
+ return true;
+ default:
+ break;
+ }
+ }
+
+ return false;
+}
+
+omp::ClauseRequires Fortran::lower::extractOpenMPRequiresClauses(
+ const Fortran::parser::OmpClauseList &clauseList) {
+ using omp::ClauseRequires, Fortran::parser::OmpClause;
+ auto requiresFlags = ClauseRequires::none;
+
+ for (const OmpClause &clause : clauseList.v) {
+ if (std::get_if<OmpClause::DynamicAllocators>(&clause.u))
+ requiresFlags = requiresFlags | ClauseRequires::dynamic_allocators;
+ else if (std::get_if<OmpClause::ReverseOffload>(&clause.u))
+ requiresFlags = requiresFlags | ClauseRequires::reverse_offload;
+ else if (std::get_if<OmpClause::UnifiedAddress>(&clause.u))
+ requiresFlags = requiresFlags | ClauseRequires::unified_address;
+ else if (std::get_if<OmpClause::UnifiedSharedMemory>(&clause.u))
+ requiresFlags = requiresFlags | ClauseRequires::unified_shared_memory;
+ }
+
+ return requiresFlags;
+}
+
+void Fortran::lower::genOpenMPRequires(Operation *mod,
+ omp::ClauseRequires flags) {
+ if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod))
+ offloadMod.setRequires(flags);
+}
Index: flang/lib/Lower/Bridge.cpp
===================================================================
--- flang/lib/Lower/Bridge.cpp
+++ flang/lib/Lower/Bridge.cpp
@@ -50,6 +50,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Runtime/iostat.h"
#include "flang/Semantics/runtime-type-info.h"
+#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -62,6 +63,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
#include <optional>
#define DEBUG_TYPE "flang-lower-bridge"
@@ -288,20 +290,34 @@
// that they are available before lowering any function that may use
// them.
bool hasMainProgram = false;
+ Fortran::semantics::OmpRequiresFlags ompRequiresFlags =
+ Fortran::semantics::OmpRequiresFlags::None;
+ std::optional<Fortran::parser::OmpAtomicDefaultMemOrderClause::Type>
+ ompAtomicDefaultMemOrder;
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
std::visit(Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) {
if (f.isMainProgram())
hasMainProgram = true;
declareFunction(f);
+ ompProcessTopLevelSymbol(f.getScope().symbol(),
+ ompRequiresFlags,
+ ompAtomicDefaultMemOrder);
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::FunctionLikeUnit &f :
m.nestedFunctions)
declareFunction(f);
+ ompProcessTopLevelSymbol(m.getScope().symbol(),
+ ompRequiresFlags,
+ ompAtomicDefaultMemOrder);
+ },
+ [&](Fortran::lower::pft::BlockDataUnit &b) {
+ ompProcessTopLevelSymbol(b.symTab.symbol(),
+ ompRequiresFlags,
+ ompAtomicDefaultMemOrder);
},
- [&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
},
u);
@@ -344,6 +360,24 @@
fir::runtime::genEnvironmentDefaults(*builder, toLocation(),
bridge.getEnvironmentDefaults());
});
+
+ // Set the module attribute related to OpenMP requires directives
+ if (ompDeviceCodeFound) {
+ using MlirRequires = mlir::omp::ClauseRequires;
+ using SemaRequires = Fortran::semantics::OmpRequiresFlags;
+ MlirRequires flags = MlirRequires::none;
+
+ if (ompRequiresFlags & SemaRequires::ReverseOffload)
+ flags = flags | MlirRequires::reverse_offload;
+ if (ompRequiresFlags & SemaRequires::UnifiedAddress)
+ flags = flags | MlirRequires::unified_address;
+ if (ompRequiresFlags & SemaRequires::UnifiedSharedMemory)
+ flags = flags | MlirRequires::unified_shared_memory;
+ if (ompRequiresFlags & SemaRequires::DynamicAllocators)
+ flags = flags | MlirRequires::dynamic_allocators;
+
+ Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), flags);
+ }
}
/// Declare a function.
@@ -1191,6 +1225,47 @@
activeConstructStack.pop_back();
}
+ void ompProcessTopLevelSymbol(
+ const Fortran::semantics::Symbol *symbol,
+ Fortran::semantics::OmpRequiresFlags &ompRequiresFlags,
+ std::optional<Fortran::parser::OmpAtomicDefaultMemOrderClause::Type>
+ &ompAtomicDefaultMemOrder) {
+ if (!symbol)
+ return;
+
+ Fortran::common::visit(
+ [&](const auto &details) {
+ if constexpr (std::is_base_of_v<
+ Fortran::semantics::WithOmpDeclarative,
+ std::decay_t<decltype(details)>>) {
+ // Collect OpenMP 'requires' clauses.
+ if (details.has_ompRequires())
+ ompRequiresFlags |= *details.ompRequires();
+
+ // Make sure any atomic_default_mem_order OpenMP 'requires' clauses
+ // obtained for different top-level symbols match.
+ if (details.has_ompAtomicDefaultMemOrder()) {
+ Fortran::parser::OmpAtomicDefaultMemOrderClause::Type memOrder{
+ *details.ompAtomicDefaultMemOrder()};
+ if (ompAtomicDefaultMemOrder &&
+ memOrder != *ompAtomicDefaultMemOrder)
+ fir::emitFatalError(
+ getCurrentLocation(),
+ llvm::StringRef{
+ "incompatible OpenMP requires atomic_default_mem_order "
+ "clauses found: '"} +
+ Fortran::parser::OmpAtomicDefaultMemOrderClause::
+ EnumToString(memOrder) +
+ llvm::StringRef{"' and '"} +
+ Fortran::parser::OmpAtomicDefaultMemOrderClause::
+ EnumToString(*ompAtomicDefaultMemOrder));
+ ompAtomicDefaultMemOrder = memOrder;
+ }
+ }
+ },
+ symbol->details());
+ }
+
//===--------------------------------------------------------------------===//
// Termination of symbolically referenced execution units
//===--------------------------------------------------------------------===//
@@ -2201,10 +2276,16 @@
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
+
+ // Register if a target region was found
+ ompDeviceCodeFound =
+ ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
}
void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+ analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl,
+ ompDeviceCodeFound);
genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
@@ -4530,6 +4611,10 @@
/// A counter for uniquing names in `literalNamesMap`.
std::uint64_t uniqueLitId = 0;
+
+ /// Whether an OpenMP target region or declare target function/subroutine
+ /// intended for device offloading has been detected
+ bool ompDeviceCodeFound = false;
};
} // namespace
Index: flang/include/flang/Lower/OpenMP.h
===================================================================
--- flang/include/flang/Lower/OpenMP.h
+++ flang/include/flang/Lower/OpenMP.h
@@ -13,13 +13,9 @@
#ifndef FORTRAN_LOWER_OPENMP_H
#define FORTRAN_LOWER_OPENMP_H
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include <cinttypes>
-namespace mlir {
-class Value;
-class Operation;
-} // namespace mlir
-
namespace fir {
class FirOpBuilder;
class ConvertOp;
@@ -29,6 +25,7 @@
namespace parser {
struct OpenMPConstruct;
struct OpenMPDeclarativeConstruct;
+struct OpenMPDeclareTargetConstruct;
struct OmpEndLoopDirective;
struct OmpClauseList;
} // namespace parser
@@ -44,6 +41,9 @@
void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
const parser::OpenMPConstruct &);
+void analyzeOpenMPDeclarativeConstruct(
+ Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+ const parser::OpenMPDeclarativeConstruct &, bool &);
void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
const parser::OpenMPDeclarativeConstruct &);
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
@@ -56,6 +56,17 @@
void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
mlir::Value, fir::ConvertOp * = nullptr);
void removeStoreOp(mlir::Operation *, mlir::Value);
+
+std::optional<mlir::omp::DeclareTargetDeviceType>
+getOpenMPDeclareTargetFunctionDevice(
+ Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+ const Fortran::parser::OpenMPDeclareTargetConstruct &);
+bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
+
+mlir::omp::ClauseRequires
+extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &);
+void genOpenMPRequires(mlir::Operation *, mlir::omp::ClauseRequires);
+
} // namespace lower
} // namespace Fortran
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits