ggeorgakoudis created this revision.
Herald added subscribers: ormris, guansong, hiraditya, yaxunl.
ggeorgakoudis requested review of this revision.
Herald added a reviewer: jdoerfert.
Herald added a reviewer: jdoerfert.
Herald added a reviewer: sstefan1.
Herald added subscribers: llvm-commits, openmp-commits, cfe-commits, bbn, 
sstefan1.
Herald added a reviewer: baziotis.
Herald added projects: clang, OpenMP, LLVM.

Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D106746

Files:
  clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
  llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
  llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
  llvm/lib/Transforms/IPO/OpenMPOpt.cpp
  openmp/libomptarget/deviceRTLs/common/include/target.h
  openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
  openmp/libomptarget/deviceRTLs/common/src/parallel.cu
  openmp/libomptarget/deviceRTLs/common/support.h
  openmp/libomptarget/deviceRTLs/interface.h

Index: openmp/libomptarget/deviceRTLs/interface.h
===================================================================
--- openmp/libomptarget/deviceRTLs/interface.h
+++ openmp/libomptarget/deviceRTLs/interface.h
@@ -417,8 +417,9 @@
 
 // non standard
 EXTERN int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD,
-                                 bool UseGenericStateMachine,
-                           bool RequiresFullRuntime);
+                                  bool IsSPMDGuarded,
+                                  bool UseGenericStateMachine,
+                                  bool RequiresFullRuntime);
 EXTERN void __kmpc_target_deinit(ident_t *Ident, bool IsSPMD,
                            bool RequiresFullRuntime);
 EXTERN void __kmpc_kernel_prepare_parallel(void *WorkFn);
@@ -449,6 +450,8 @@
 // SPMD execution mode interrogation function.
 EXTERN int8_t __kmpc_is_spmd_exec_mode();
 
+EXTERN int8_t __kmpc_is_spmd_guarded_exec_mode();
+
 /// Return true if the hardware thread id \p Tid represents the OpenMP main
 /// thread in generic mode outside of a parallel region.
 EXTERN int8_t __kmpc_is_generic_main_thread(kmp_int32 Tid);
Index: openmp/libomptarget/deviceRTLs/common/support.h
===================================================================
--- openmp/libomptarget/deviceRTLs/common/support.h
+++ openmp/libomptarget/deviceRTLs/common/support.h
@@ -22,13 +22,14 @@
 enum ExecutionMode {
   Spmd = 0x00u,
   Generic = 0x01u,
-  ModeMask = 0x01u,
+  SpmdGuarded = 0x02u,
+  ModeMask = 0x03u,
 };
 
 enum RuntimeMode {
   RuntimeInitialized = 0x00u,
-  RuntimeUninitialized = 0x02u,
-  RuntimeMask = 0x02u,
+  RuntimeUninitialized = 0x04u,
+  RuntimeMask = 0x04u,
 };
 
 void setExecutionParameters(ExecutionMode EMode, RuntimeMode RMode);
Index: openmp/libomptarget/deviceRTLs/common/src/parallel.cu
===================================================================
--- openmp/libomptarget/deviceRTLs/common/src/parallel.cu
+++ openmp/libomptarget/deviceRTLs/common/src/parallel.cu
@@ -300,7 +300,36 @@
   }
 
   if (__kmpc_is_spmd_exec_mode()) {
+    // Store spmd_guarded status to check after the parallel region executes.
+    int is_spmd_guarded = __kmpc_is_spmd_guarded_exec_mode();
+    if (is_spmd_guarded) {
+      // No barrier is need on entry since this will be called only from non-guarded
+      // SPMD execution.
+
+      // Disable SPMD guarding for the parallel region. Runtime suport is not needed
+      // by construction of SPMD guarded regions, so simple assignment to Spmd is
+      // enough. Also, a preceding barrier is unnecessary since all threads must be
+      // in non-guarded context when reaching this point.
+      if (__kmpc_get_hardware_thread_id_in_block() == 0)
+        execution_param = Spmd;
+
+      // Barrier to ensure all threads are updated to Spmd.
+      __kmpc_barrier_simple_spmd(ident, 0);
+    }
+
     __kmp_invoke_microtask(global_tid, 0, fn, args, nargs);
+
+    if (is_spmd_guarded) {
+      // Re-enable SPMD guarding. Runtime support is not needed by construction.
+      // Barrier to ensure all threads have finished Spmd execution before
+      // re-enabling guarding.
+      __kmpc_barrier_simple_spmd(ident, 0);
+      if (__kmpc_get_hardware_thread_id_in_block() == 0)
+        execution_param = SpmdGuarded;
+
+      // Barrier to ensure all threads are updated to SpmdGuarded.
+      __kmpc_barrier_simple_spmd(ident, 0);
+    }
     return;
   }
 
Index: openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
===================================================================
--- openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
+++ openmp/libomptarget/deviceRTLs/common/src/omptarget.cu
@@ -82,11 +82,12 @@
   omptarget_nvptx_workFn = 0;
 }
 
-static void __kmpc_spmd_kernel_init(bool RequiresFullRuntime) {
+static void __kmpc_spmd_kernel_init(bool IsSPMDGuarded, bool RequiresFullRuntime) {
   PRINT0(LD_IO, "call to __kmpc_spmd_kernel_init\n");
 
-  setExecutionParameters(Spmd, RequiresFullRuntime ? RuntimeInitialized
-                         : RuntimeUninitialized);
+  setExecutionParameters(IsSPMDGuarded ? SpmdGuarded : Spmd,
+                         RequiresFullRuntime ? RuntimeInitialized
+                                             : RuntimeUninitialized);
   int threadId = __kmpc_get_hardware_thread_id_in_block();
   if (threadId == 0) {
     usedSlotIdx = __kmpc_impl_smid() % MAX_SM;
@@ -160,7 +161,12 @@
 
 // Return true if the current target region is executed in SPMD mode.
 EXTERN int8_t __kmpc_is_spmd_exec_mode() {
-  return (execution_param & ModeMask) == Spmd;
+  return ((execution_param & ModeMask) == Spmd ||
+          (execution_param & ModeMask) == SpmdGuarded);
+}
+
+EXTERN __attribute__((used,retain)) int8_t __kmpc_is_spmd_guarded_exec_mode() {
+   return ((execution_param & ModeMask) == SpmdGuarded);
 }
 
 EXTERN int8_t __kmpc_is_generic_main_thread(kmp_int32 Tid) {
@@ -202,12 +208,12 @@
 }
 
 EXTERN
-int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD,
+int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, bool IsSPMDGuarded,
                            bool UseGenericStateMachine,
                            bool RequiresFullRuntime) {
   int TId = __kmpc_get_hardware_thread_id_in_block();
   if (IsSPMD)
-    __kmpc_spmd_kernel_init(RequiresFullRuntime);
+    __kmpc_spmd_kernel_init(IsSPMDGuarded, RequiresFullRuntime);
   else
     __kmpc_generic_kernel_init();
 
Index: openmp/libomptarget/deviceRTLs/common/include/target.h
===================================================================
--- openmp/libomptarget/deviceRTLs/common/include/target.h
+++ openmp/libomptarget/deviceRTLs/common/include/target.h
@@ -72,7 +72,7 @@
 ///
 /// \param Ident               Source location identification, can be NULL.
 ///
-int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD,
+int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, bool IsSPMDGuarded,
                            bool UseGenericStateMachine,
                            bool RequiresFullRuntime);
 
Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp
===================================================================
--- llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -39,6 +39,7 @@
 #include "llvm/Transforms/IPO/Attributor.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
+#include "llvm/Transforms/Utils/Cloning.h"
 #include "llvm/Transforms/Utils/CodeExtractor.h"
 
 using namespace llvm;
@@ -503,7 +504,7 @@
   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
   /// false.
-  BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
+  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
 
   /// The __kmpc_target_init call in this kernel, if any. If we find more than
   /// one we abort as the kernel is malformed.
@@ -2756,6 +2757,12 @@
   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
       : AAKernelInfo(IRP, A) {}
 
+  SmallPtrSet<Instruction *, 4> GuardedInstructions;
+
+  SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
+    return GuardedInstructions;
+  }
+
   /// See AbstractAttribute::initialize(...).
   void initialize(Attributor &A) override {
     // This is a high-level transform that might change the constant arguments
@@ -2849,6 +2856,29 @@
       return Val;
     };
 
+    Attributor::SimplifictionCallbackTy IsSPMDGuardedModeSimplifyCB =
+        [&](const IRPosition &IRP, const AbstractAttribute *AA,
+            bool &UsedAssumedInformation) -> Optional<Value *> {
+      // IRP represents the "SPMDCompatibilityTracker" argument of an
+      // __kmpc_target_init or
+      // __kmpc_target_deinit call. We will answer this one with the internal
+      // state.
+      if (!SPMDCompatibilityTracker.isValidState())
+        return nullptr;
+      if (!SPMDCompatibilityTracker.isAtFixpoint()) {
+        if (AA)
+          A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
+        UsedAssumedInformation = true;
+      } else {
+        UsedAssumedInformation = false;
+      }
+
+      auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
+                                       (SPMDCompatibilityTracker.isAssumed() &&
+                                        !SPMDCompatibilityTracker.empty()));
+      return Val;
+    };
+
     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
         [&](const IRPosition &IRP, const AbstractAttribute *AA,
             bool &UsedAssumedInformation) -> Optional<Value *> {
@@ -2871,9 +2901,10 @@
     };
 
     constexpr const int InitIsSPMDArgNo = 1;
+    constexpr const int InitIsSPMDGuardedArgNo = 2;
+    constexpr const int InitUseStateMachineArgNo = 3;
+    constexpr const int InitRequiresFullRuntimeArgNo = 4;
     constexpr const int DeinitIsSPMDArgNo = 1;
-    constexpr const int InitUseStateMachineArgNo = 2;
-    constexpr const int InitRequiresFullRuntimeArgNo = 3;
     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
     A.registerSimplificationCallback(
         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
@@ -2881,6 +2912,9 @@
     A.registerSimplificationCallback(
         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
         IsSPMDModeSimplifyCB);
+    A.registerSimplificationCallback(
+        IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDGuardedArgNo),
+        IsSPMDGuardedModeSimplifyCB);
     A.registerSimplificationCallback(
         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
         IsSPMDModeSimplifyCB);
@@ -2952,6 +2986,225 @@
       return false;
     }
 
+    auto CreateGuardedRegion = [&](Instruction *RegionStartI,
+                                   Instruction *RegionEndI) {
+      LoopInfo *LI = nullptr;
+      DominatorTree *DT = nullptr;
+      MemorySSAUpdater *MSU = nullptr;
+      using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+
+      BasicBlock *ParentBB = RegionStartI->getParent();
+      Function *Fn = ParentBB->getParent();
+      Module &M = *Fn->getParent();
+
+      // Create all the blocks and logic.
+      // ParentBB:
+      //    IsSPMDGuarded = __kmpc_is_spmd_guarded_mode()
+      //    if (IsSPMDGuarded)
+      //        goto RegionCheckTidBB
+      // RegionNotguardedBB:
+      //    <execute instructions not guarded>
+      //    goto RegionExitBB
+      // RegionCheckTidBB:
+      //    Tid = __kmpc_hardware_thread_id()
+      //    if (Tid != 0)
+      //        goto RegionBarrierBB
+      // RegionStartBB:
+      //    <execute instructions guarded>
+      //    goto RegionEndBB
+      // RegionEndBB:
+      //    <store escaping values to shared mem>
+      //    goto RegionBarrierBB
+      //  RegionBarrierBB:
+      //    __kmpc_simple_barrier_spmd()
+      //    // second barrier is omitted if lacking escaping values.
+      //    <load escaping values from shared mem>
+      //    __kmpc_simple_barrier_spmd()
+      //    goto RegionExitBB
+      // RegionExitBB:
+      //    <execute rest of instructions>
+
+      BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
+                                           DT, LI, MSU, "region.guarded.end");
+      BasicBlock *RegionBarrierBB =
+          SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
+                     MSU, "region.barrier");
+      BasicBlock *RegionExitBB =
+          SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
+                     DT, LI, MSU, "region.exit");
+      BasicBlock *RegionStartBB =
+          SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
+
+      // Create a clone that contains an non-guarded version for parallel
+      // execution.
+      ValueToValueMapTy VMap;
+      BasicBlock *RegionNotguardedBB =
+          CloneBasicBlock(RegionStartBB, VMap, ".not");
+      RegionNotguardedBB->insertInto(Fn, RegionStartBB);
+      RegionNotguardedBB->getTerminator()->setSuccessor(0, RegionExitBB);
+
+      assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
+             "Expected a different CFG");
+
+      BasicBlock *RegionCheckTidBB = SplitBlock(
+          ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
+
+      // Register basic blocks with the Attributor.
+      A.registerManifestAddedBasicBlock(*RegionEndBB);
+      A.registerManifestAddedBasicBlock(*RegionBarrierBB);
+      A.registerManifestAddedBasicBlock(*RegionExitBB);
+      A.registerManifestAddedBasicBlock(*RegionStartBB);
+      A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
+      A.registerManifestAddedBasicBlock(*RegionNotguardedBB);
+
+      bool HasBroadcastValues = false;
+      // Find escaping outputs from the guarded region to outside users and
+      // broadcast their values to them.
+      for (Instruction &I : *RegionStartBB) {
+        SmallPtrSet<Instruction *, 4> OutsideUsers;
+        for (User *Usr : I.users()) {
+          Instruction &UsrI = *cast<Instruction>(Usr);
+          if (UsrI.getParent() != RegionStartBB) {
+            dbgs() << "For I " << I << " in BB " << I.getParent()->getName()
+                   << " found outside user UsrI " << UsrI << " in BB "
+                   << UsrI.getParent()->getName() << "\n";
+            OutsideUsers.insert(&UsrI);
+          }
+        }
+
+        if (OutsideUsers.empty())
+          continue;
+
+        HasBroadcastValues = true;
+
+        // Emit a global variable in shared memory to store the broadcasted
+        // value.
+        auto *SharedMem = new GlobalVariable(
+            M, I.getType(), /* IsConstant */ false,
+            GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
+            I.getName() + ".guarded.output.alloc", nullptr,
+            GlobalValue::NotThreadLocal,
+            static_cast<unsigned>(AddressSpace::Shared));
+
+        // Emit a store instruction to update the value.
+        new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
+
+        LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
+                                       I.getName() + ".guarded.output.load",
+                                       RegionBarrierBB->getTerminator());
+
+        PHINode *PN = PHINode::Create(I.getType(), 2, ".phi.guarded",
+                                      &*RegionExitBB->getFirstInsertionPt());
+        PN->addIncoming(VMap[&I], RegionNotguardedBB);
+        PN->addIncoming(LoadI, RegionBarrierBB);
+        // Emit a load instruction and replace uses of the output value.
+        for (Instruction *UsrI : OutsideUsers) {
+          dbgs() << "PN " << *PN << " in UsrI " << *UsrI << " in BB "
+                 << UsrI->getParent()->getName() << "\n";
+          assert(UsrI->getParent() == RegionExitBB &&
+                 "Expected escaping users in exit region");
+          UsrI->replaceUsesOfWith(&I, PN);
+        }
+      }
+
+      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+
+      // Add check for parallel level in ParentBB.
+      const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
+      ParentBB->getTerminator()->eraseFromParent();
+      OpenMPIRBuilder::LocationDescription Loc(
+          InsertPointTy(ParentBB, ParentBB->end()), DL);
+      OMPInfoCache.OMPBuilder.updateToLocation(Loc);
+      auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
+      Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
+      FunctionCallee HardwareTidFn =
+          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
+              M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
+      // There's got to be a better way to do that...
+      FunctionCallee IsSPMDGuardedFn =
+          M.getFunction("__kmpc_is_spmd_guarded_exec_mode.internalized");
+      // OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
+      //    M, OMPRTL___kmpc_is_spmd_guarded_exec_mode);
+      Value *Tid =
+          OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
+      Value *IsSPMDGuarded =
+          OMPInfoCache.OMPBuilder.Builder.CreateCall(IsSPMDGuardedFn, {});
+      Value *IsSPMDGuardedCheck =
+          OMPInfoCache.OMPBuilder.Builder.CreateIsNull(IsSPMDGuarded);
+      OMPInfoCache.OMPBuilder.Builder
+          .CreateCondBr(IsSPMDGuardedCheck, RegionNotguardedBB,
+                        RegionCheckTidBB)
+          ->setDebugLoc(DL);
+
+      // Add check for Tid in RegionCheckTidBB
+      RegionCheckTidBB->getTerminator()->eraseFromParent();
+      OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
+          InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
+      OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
+      Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
+      OMPInfoCache.OMPBuilder.Builder
+          .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
+          ->setDebugLoc(DL);
+
+      // First barrier for synchronization, ensures main thread has updated
+      // values.
+      FunctionCallee BarrierFn =
+          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
+              M, OMPRTL___kmpc_barrier_simple_spmd);
+      OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
+          RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
+      OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
+          ->setDebugLoc(DL);
+
+      // Second barrier ensures workers have read broadcast values.
+      if (HasBroadcastValues)
+        CallInst::Create(BarrierFn, {Ident, Tid}, "",
+                         RegionBarrierBB->getTerminator())
+            ->setDebugLoc(DL);
+    };
+
+    // SmallPtrSet<BasicBlock *, 4> GuardedBasicBlocks;
+    SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
+
+    for (Instruction *GuardedI : SPMDCompatibilityTracker) {
+      BasicBlock *BB = GuardedI->getParent();
+      auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
+          IRPosition::function(*GuardedI->getFunction()), nullptr,
+          DepClassTy::NONE);
+      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
+      auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
+      // Continue if instruction is already guarded.
+      if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
+        continue;
+
+      Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
+      for (Instruction &I : *BB) {
+        // If instruction I needs to be guarded update the guarded region
+        // bounds.
+        if (SPMDCompatibilityTracker.contains(&I)) {
+          CalleeAAFunction.getGuardedInstructions().insert(&I);
+          if (GuardedRegionStart)
+            GuardedRegionEnd = &I;
+          else
+            GuardedRegionStart = GuardedRegionEnd = &I;
+
+          continue;
+        }
+
+        // Instruction I does not need guarding, store
+        // any region found and reset bounds.
+        if (GuardedRegionStart) {
+          GuardedRegions.push_back(
+              std::make_pair(GuardedRegionStart, GuardedRegionEnd));
+          GuardedRegionStart = nullptr;
+          GuardedRegionEnd = nullptr;
+        }
+      }
+    }
+
+    for (auto &GR : GuardedRegions)
+      CreateGuardedRegion(GR.first, GR.second);
+
     // Adjust the global exec mode flag that tells the runtime what mode this
     // kernel is executed in.
     Function *Kernel = getAnchorScope();
@@ -2970,14 +3223,18 @@
 
     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
     const int InitIsSPMDArgNo = 1;
+    const int InitIsSPMDGuardedArgNo = 2;
+    const int InitUseStateMachineArgNo = 3;
+    const int InitRequiresFullRuntimeArgNo = 4;
     const int DeinitIsSPMDArgNo = 1;
-    const int InitUseStateMachineArgNo = 2;
-    const int InitRequiresFullRuntimeArgNo = 3;
     const int DeinitRequiresFullRuntimeArgNo = 2;
 
     auto &Ctx = getAnchorValue().getContext();
     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
                              *ConstantInt::getBool(Ctx, 1));
+    A.changeUseAfterManifest(
+        KernelInitCB->getArgOperandUse(InitIsSPMDGuardedArgNo),
+        *ConstantInt::getBool(Ctx, 1));
     A.changeUseAfterManifest(
         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
         *ConstantInt::getBool(Ctx, 0));
@@ -3005,7 +3262,7 @@
            "Custom state machine with invalid parallel region states?");
 
     const int InitIsSPMDArgNo = 1;
-    const int InitUseStateMachineArgNo = 2;
+    const int InitUseStateMachineArgNo = 3;
 
     // Check if the current configuration is non-SPMD and generic state machine.
     // If we already have SPMD mode or a custom state machine we do not need to
@@ -3283,8 +3540,21 @@
         if (llvm::all_of(Objects,
                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
           return true;
+        // Check for AAHeapToStack moved objects to avoid guarding.
+        auto *HS = A.lookupAAFor<AAHeapToStack>(
+            IRPosition::function(*I.getFunction()), this, DepClassTy::OPTIONAL);
+        if (HS)
+          if (llvm::all_of(Objects, [HS](const Value *Obj) {
+                auto *CB = dyn_cast<CallBase>(Obj);
+                if (!CB)
+                  return false;
+                return HS->isAssumedHeapToStack(*CB);
+              })) {
+            return true;
+          }
       }
-      // For now we give up on everything but stores.
+
+      // Insert instruction that needs guarding.
       SPMDCompatibilityTracker.insert(&I);
       return true;
     };
@@ -3470,7 +3740,8 @@
       // We do not look into tasks right now, just give up.
       SPMDCompatibilityTracker.insert(&CB);
       ReachedUnknownParallelRegions.insert(&CB);
-      break;
+      indicatePessimisticFixpoint();
+      return;
     case OMPRTL___kmpc_alloc_shared:
     case OMPRTL___kmpc_free_shared:
       // Return without setting a fixpoint, to be resolved in updateImpl.
@@ -3479,7 +3750,8 @@
       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
       // generally.
       SPMDCompatibilityTracker.insert(&CB);
-      break;
+      indicatePessimisticFixpoint();
+      return;
     }
     // All other OpenMP runtime calls will not reach parallel regions so they
     // can be safely ignored for now. Since it is a known OpenMP runtime call we
Index: llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2193,13 +2193,17 @@
 }
 
 OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) {
+OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
+                                  bool IsSPMDGuarded,
+                                  bool RequiresFullRuntime) {
   if (!updateToLocation(Loc))
     return Loc.IP;
 
   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
   Value *Ident = getOrCreateIdent(SrcLocStr);
   ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
+  ConstantInt *IsSPMDGuardedVal =
+      ConstantInt::getBool(Int32->getContext(), IsSPMDGuarded);
   ConstantInt *UseGenericStateMachine =
       ConstantInt::getBool(Int32->getContext(), !IsSPMD);
   ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
@@ -2208,7 +2212,8 @@
       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
 
   CallInst *ThreadKind =
-      Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
+      Builder.CreateCall(Fn, {Ident, IsSPMDVal, IsSPMDGuardedVal,
+                              UseGenericStateMachine, RequiresFullRuntimeVal});
 
   Value *ExecUserCode = Builder.CreateICmpEQ(
       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code");
Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -412,7 +412,7 @@
           /* Int */ Int32, /* kmp_task_t */ VoidPtr)
 
 /// OpenMP Device runtime functions
-__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1)
+__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1, Int1)
 __OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1)
 __OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
 __OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
@@ -438,6 +438,7 @@
 __OMP_RTL(__kmpc_get_shared_variables, false, Void, VoidPtrPtrPtr)
 __OMP_RTL(__kmpc_parallel_level, false, Int8, )
 __OMP_RTL(__kmpc_is_spmd_exec_mode, false, Int8, )
+__OMP_RTL(__kmpc_is_spmd_guarded_exec_mode, false, Int8, )
 __OMP_RTL(__kmpc_barrier_simple_spmd, false, Void, IdentPtr, Int32)
 
 __OMP_RTL(__kmpc_warp_active_thread_mask, false, LanemaskTy,)
Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -791,7 +791,8 @@
   /// \param Loc The insert and source location description.
   /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
   /// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
-  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
+  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
+                                 bool IsSPMDGuarded, bool RequiresFullRuntime);
 
   /// Create a runtime call for kmpc_target_deinit
   ///
Index: clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -1049,7 +1049,8 @@
 void CGOpenMPRuntimeGPU::emitKernelInit(CodeGenFunction &CGF,
                                         EntryFunctionState &EST, bool IsSPMD) {
   CGBuilderTy &Bld = CGF.Builder;
-  Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, requiresFullRuntime()));
+  Bld.restoreIP(OMPBuilder.createTargetInit(
+      Bld, IsSPMD, /* IsSPMDGuarded */ false, requiresFullRuntime()));
   IsInTargetMasterThreadRegion = IsSPMD;
   if (!IsSPMD)
     emitGenericVarsProlog(CGF, EST.Loc);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D106746: [Ope... Giorgis Georgakoudis via Phabricator via cfe-commits

Reply via email to