skatrak updated this revision to Diff 527854.
skatrak added a comment.
Herald added subscribers: cfe-commits, hiraditya.
Herald added a project: clang.

Update and remove OpenMP dialect dependency from the generic LLVM IR 
translation.

Followed @kiranchandramohan's suggestion in the comments for D147172 
<https://reviews.llvm.org/D147172>, since the
current approach prevents the use of OpenMP dialect-specific attributes for
initializing the `OpenMPIRBuilder` configuration. One of these is defined for
representing 'requires' clauses, which need to be stored in the
`OpenMPIRBuilderConfig`, so a different approach is necessary. The approach
implemented in this review is based on the `amendOperation` translation flow.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D147219/new/

https://reviews.llvm.org/D147219

Files:
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
  mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
  mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
  mlir/test/Target/LLVMIR/openmp-llvm.mlir

Index: mlir/test/Target/LLVMIR/openmp-llvm.mlir
===================================================================
--- mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2543,3 +2543,13 @@
 // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1
 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0
 module attributes {omp.flags = #omp.flags<assume_teams_oversubscription = true, assume_no_thread_state = true>} {}
+
+// -----
+
+// Check that OpenMP requires flags are registered by a global constructor.
+// CHECK: @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }]
+// CHECK-SAME: [{ i32, ptr, ptr } { i32 0, ptr @[[REG_FN:.*]], ptr null }]
+// CHECK: define {{.*}} @[[REG_FN]]({{.*}})
+// CHECK-NOT: }
+// CHECK:   call void @__tgt_register_requires(i64 10)
+module attributes {omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>} {}
Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
===================================================================
--- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -20,8 +20,6 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -1273,30 +1271,18 @@
 llvm::OpenMPIRBuilder *ModuleTranslation::getOpenMPBuilder() {
   if (!ompBuilder) {
     ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
+    ompBuilder->initialize();
 
-    bool isDevice = false;
-    llvm::StringRef hostIRFilePath = "";
-
-    if (Attribute deviceAttr = mlirModule->getAttr("omp.is_device"))
-      if (::llvm::isa<mlir::BoolAttr>(deviceAttr))
-        isDevice = ::llvm::dyn_cast<mlir::BoolAttr>(deviceAttr).getValue();
-
-    if (Attribute filepath = mlirModule->getAttr("omp.host_ir_filepath"))
-      if (::llvm::isa<mlir::StringAttr>(filepath))
-        hostIRFilePath =
-            ::llvm::dyn_cast<mlir::StringAttr>(filepath).getValue();
-
-    ompBuilder->initialize(hostIRFilePath);
-
-    // TODO: set the flags when available
-    llvm::OpenMPIRBuilderConfig config(
-        isDevice, /* IsTargetCodegen */ false,
+    // Flags represented as top-level OpenMP dialect attributes are set in
+    // OpenMPDialectLLVMIRTranslationInterface::amendOperation(). Here we set
+    // the default configuration.
+    ompBuilder->setConfig(llvm::OpenMPIRBuilderConfig(
+        /* IsEmbedded */ false, /* IsTargetCodegen */ false,
         /* OpenMPOffloadMandatory */ false,
         /* HasRequiresReverseOffload */ false,
         /* HasRequiresUnifiedAddress */ false,
         /* HasRequiresUnifiedSharedMemory */ false,
-        /* OpenMPOffloadMandatory */ false);
-    ompBuilder->setConfig(config);
+        /* OpenMPOffloadMandatory */ false));
   }
   return ompBuilder.get();
 }
@@ -1383,11 +1369,17 @@
     return nullptr;
   if (failed(translator.createTBAAMetadata()))
     return nullptr;
+
+  // Convert module itself before any functions and operations inside, so that
+  // the OpenMPIRBuilder is configured with the OpenMP dialect attributes
+  // attached to the module by the amendOperation() flow before then.
+  llvm::IRBuilder<> llvmBuilder(llvmContext);
+  if (failed(translator.convertOperation(*module, llvmBuilder)))
+    return nullptr;
   if (failed(translator.convertFunctions()))
     return nullptr;
 
   // Convert other top-level operations if possible.
-  llvm::IRBuilder<> llvmBuilder(llvmContext);
   for (Operation &o : getModuleBody(module).getOperations()) {
     if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp, LLVM::GlobalCtorsOp,
              LLVM::GlobalDtorsOp, LLVM::MetadataOp>(&o) &&
@@ -1397,10 +1389,6 @@
     }
   }
 
-  // Convert module itself.
-  if (failed(translator.convertOperation(*module, llvmBuilder)))
-    return nullptr;
-
   if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
     return nullptr;
 
Index: mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
===================================================================
--- mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -27,6 +28,8 @@
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/Support/FileSystem.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <optional>
 
 using namespace mlir;
 
@@ -1040,7 +1043,7 @@
 }
 
 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
-llvm::AtomicOrdering
+static llvm::AtomicOrdering
 convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
   if (!ao)
     return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
@@ -1675,6 +1678,27 @@
   return bodyGenStatus;
 }
 
+/// Converts the module-level set of OpenMP requires clauses into LLVM IR using
+/// OpenMPIRBuilder.
+static LogicalResult
+convertRequiresAttr(Operation &op, omp::ClauseRequiresAttr requiresAttr,
+                    LLVM::ModuleTranslation &moduleTranslation) {
+  auto *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
+  // No need to read requiresAttr here, because it has already been done in
+  // translateModuleToLLVMIR(). There, flags are stored in the
+  // OpenMPIRBuilderConfig object, available to the OpenMPIRBuilder.
+  auto *regFn =
+      ompBuilder->createRegisterRequires(ompBuilder->createPlatformSpecificName(
+          {"omp_offloading", "requires_reg"}));
+
+  // Add registration function as global constructor
+  if (regFn)
+    llvm::appendToGlobalCtors(ompBuilder->M, regFn, /* Priority = */ 0);
+
+  return success();
+}
+
 namespace {
 
 /// Implementation of the dialect interface that converts operations belonging
@@ -1690,6 +1714,8 @@
   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final;
 
+  /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
+  /// calls, or operation amendments
   LogicalResult
   amendOperation(Operation *op, NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final;
@@ -1697,28 +1723,72 @@
 
 } // namespace
 
-/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
-/// calls, or operation amendments
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
     Operation *op, NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
-
-  return llvm::TypeSwitch<Attribute, LogicalResult>(attribute.getValue())
-      .Case([&](mlir::omp::FlagsAttr rtlAttr) {
-        return convertFlagsAttr(op, rtlAttr, moduleTranslation);
-      })
-      .Case([&](mlir::omp::VersionAttr versionAttr) {
-        llvm::OpenMPIRBuilder *ompBuilder =
-            moduleTranslation.getOpenMPBuilder();
-        ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
-                                    versionAttr.getVersion());
-        return success();
-      })
-      .Default([&](Attribute attr) {
-        // fall through for omp attributes that do not require lowering and/or
-        // have no concrete definition and thus no type to define a case on
+  return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
+             attribute.getName())
+      .Case("omp.is_device",
+            [&](Attribute attr) {
+              if (auto deviceAttr = attr.dyn_cast<BoolAttr>()) {
+                llvm::OpenMPIRBuilderConfig &config =
+                    moduleTranslation.getOpenMPBuilder()->Config;
+                config.setIsEmbedded(deviceAttr.getValue());
+                return success();
+              }
+              return failure();
+            })
+      .Case("omp.host_ir_filepath",
+            [&](Attribute attr) {
+              if (auto filepathAttr = attr.dyn_cast<StringAttr>()) {
+                llvm::OpenMPIRBuilder *ompBuilder =
+                    moduleTranslation.getOpenMPBuilder();
+                ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
+                return success();
+              }
+              return failure();
+            })
+      .Case("omp.flags",
+            [&](Attribute attr) {
+              if (auto rtlAttr = attr.dyn_cast<omp::FlagsAttr>())
+                return convertFlagsAttr(op, rtlAttr, moduleTranslation);
+              return failure();
+            })
+      .Case("omp.version",
+            [&](Attribute attr) {
+              if (auto versionAttr = attr.dyn_cast<omp::VersionAttr>()) {
+                llvm::OpenMPIRBuilder *ompBuilder =
+                    moduleTranslation.getOpenMPBuilder();
+                ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
+                                            versionAttr.getVersion());
+                return success();
+              }
+              return failure();
+            })
+      .Case(
+          "omp.requires",
+          [&](Attribute attr) {
+            if (auto requiresAttr = attr.dyn_cast<omp::ClauseRequiresAttr>()) {
+              using Requires = omp::ClauseRequires;
+              Requires flags = requiresAttr.getValue();
+              llvm::OpenMPIRBuilderConfig &config =
+                  moduleTranslation.getOpenMPBuilder()->Config;
+              config.setHasRequiresReverseOffload(
+                  bitEnumContainsAll(flags, Requires::reverse_offload));
+              config.setHasRequiresUnifiedAddress(
+                  bitEnumContainsAll(flags, Requires::unified_address));
+              config.setHasRequiresUnifiedSharedMemory(
+                  bitEnumContainsAll(flags, Requires::unified_shared_memory));
+              config.setHasRequiresDynamicAllocators(
+                  bitEnumContainsAll(flags, Requires::dynamic_allocators));
+              return convertRequiresAttr(*op, requiresAttr, moduleTranslation);
+            }
+            return failure();
+          })
+      .Default([](Attribute) {
+        // Fall through for omp attributes that do not require lowering.
         return success();
-      });
+      })(attribute.getValue());
 
   return failure();
 }
Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -546,31 +546,7 @@
   return Fn;
 }
 
-void OpenMPIRBuilder::initialize(StringRef HostFilePath) {
-  initializeTypes(M);
-
-  if (HostFilePath.empty())
-    return;
-
-  auto Buf = MemoryBuffer::getFile(HostFilePath);
-  if (std::error_code Err = Buf.getError()) {
-    report_fatal_error(("error opening host file from host file path inside of "
-                        "OpenMPIRBuilder: " +
-                        Err.message())
-                           .c_str());
-  }
-
-  LLVMContext Ctx;
-  auto M = expectedToErrorOrAndEmitErrors(
-      Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
-  if (std::error_code Err = M.getError()) {
-    report_fatal_error(
-        ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
-            .c_str());
-  }
-
-  loadOffloadInfoMetadata(*M.get());
-}
+void OpenMPIRBuilder::initialize() { initializeTypes(M); }
 
 void OpenMPIRBuilder::finalize(Function *Fn) {
   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
@@ -5337,6 +5313,30 @@
   }
 }
 
+void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
+  if (HostFilePath.empty())
+    return;
+
+  auto Buf = MemoryBuffer::getFile(HostFilePath);
+  if (std::error_code Err = Buf.getError()) {
+    report_fatal_error(("error opening host file from host file path inside of "
+                        "OpenMPIRBuilder: " +
+                        Err.message())
+                           .c_str());
+  }
+
+  LLVMContext Ctx;
+  auto M = expectedToErrorOrAndEmitErrors(
+      Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
+  if (std::error_code Err = M.getError()) {
+    report_fatal_error(
+        ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
+            .c_str());
+  }
+
+  loadOffloadInfoMetadata(*M.get());
+}
+
 Function *OpenMPIRBuilder::createRegisterRequires(StringRef Name) {
   // Skip the creation of the registration function if this is device codegen
   if (Config.isEmbedded())
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -159,6 +159,7 @@
 
   void setIsEmbedded(bool Value) { IsEmbedded = Value; }
   void setIsTargetCodegen(bool Value) { IsTargetCodegen = Value; }
+  void setOpenMPOffloadMandatory(bool Value) { OpenMPOffloadMandatory = Value; }
   void setFirstSeparator(StringRef FS) { FirstSeparator = FS; }
   void setSeparator(StringRef S) { Separator = S; }
 
@@ -425,14 +426,9 @@
 
   /// Initialize the internal state, this will put structures types and
   /// potentially other helpers into the underlying module. Must be called
-  /// before any other method and only once! This internal state includes
-  /// Types used in the OpenMPIRBuilder generated from OMPKinds.def as well
-  /// as loading offload metadata for device from the OpenMP host IR file
-  /// passed in as the HostFilePath argument.
-  /// \param HostFilePath The path to the host IR file, used to load in
-  /// offload metadata for the device, allowing host and device to
-  /// maintain the same metadata mapping.
-  void initialize(StringRef HostFilePath = {});
+  /// before any other method and only once! This internal state includes types
+  /// used in the OpenMPIRBuilder generated from OMPKinds.def.
+  void initialize();
 
   void setConfig(OpenMPIRBuilderConfig C) { Config = C; }
 
@@ -2243,6 +2239,15 @@
   /// loaded from bitcode file, i.e, different from OpenMPIRBuilder::M module.
   void loadOffloadInfoMetadata(Module &M);
 
+  /// Loads all the offload entries information from the host IR
+  /// metadata read from the file passed in as the HostFilePath argument. This
+  /// function is only meant to be used with device code generation.
+  ///
+  /// \param HostFilePath The path to the host IR file,
+  /// used to load in offload metadata for the device, allowing host and device
+  /// to maintain the same metadata mapping.
+  void loadOffloadInfoMetadata(StringRef HostFilePath);
+
   /// Gets (if variable with the given name already exist) or creates
   /// internal global variable with the specified Name. The created variable has
   /// linkage CommonLinkage by default and is initialized by null value.
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -1040,9 +1040,10 @@
       /*IsTargetCodegen*/ false, CGM.getLangOpts().OpenMPOffloadMandatory,
       /*HasRequiresReverseOffload*/ false, /*HasRequiresUnifiedAddress*/ false,
       hasRequiresUnifiedSharedMemory(), /*HasRequiresDynamicAllocators*/ false);
-  OMPBuilder.initialize(CGM.getLangOpts().OpenMPIsDevice
-                            ? CGM.getLangOpts().OMPHostIRFile
-                            : StringRef{});
+  OMPBuilder.initialize();
+  OMPBuilder.loadOffloadInfoMetadata(CGM.getLangOpts().OpenMPIsDevice
+                                         ? CGM.getLangOpts().OMPHostIRFile
+                                         : StringRef{});
   OMPBuilder.setConfig(Config);
 }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to