https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051
>From 2fbe762b53bb6d6ffdce2b5ae3d6de30584ed93b Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Wed, 27 Nov 2024 11:33:01 +0000 Subject: [PATCH 1/3] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 26 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 125 +++++++-- .../Frontend/OpenMPIRBuilderTest.cpp | 256 +++++++++++++++++- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +- 4 files changed, 383 insertions(+), 34 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index f475e34497105fd..444bc280df9f89b 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2237,6 +2237,26 @@ class OpenMPIRBuilder { int32_t MinThreads = 1; }; + /// Container to pass LLVM IR runtime values or constants related to the + /// number of teams and threads with which the kernel must be launched, as + /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These + /// must be defined in the host prior to the call to the kernel launch OpenMP + /// RTL function. + struct TargetKernelRuntimeAttrs { + SmallVector<Value *, 3> MaxTeams = {nullptr}; + Value *MinTeams = nullptr; + SmallVector<Value *, 3> TargetThreadLimit = {nullptr}; + SmallVector<Value *, 3> TeamsThreadLimit = {nullptr}; + + /// 'parallel' construct 'num_threads' clause value, if present and it is a + /// target SPMD kernel. + Value *MaxThreads = nullptr; + + /// Total number of iterations of the target SPMD kernel or null if it is a + /// generic kernel. + Value *LoopTripCount = nullptr; + }; + /// Data structure that contains the needed information to construct the /// kernel args vector. struct TargetKernelArgs { @@ -2905,11 +2925,14 @@ class OpenMPIRBuilder { /// /// \param Loc where the target data construct was encountered. /// \param IsOffloadEntry whether it is an offload entry. + /// \param IsSPMD whether it is a target SPMD kernel. /// \param CodeGenIP The insertion point where the call to the outlined /// function should be emitted. /// \param EntryInfo The entry information about the function. /// \param DefaultAttrs Structure containing the default numbers of threads /// and teams to launch the kernel with. + /// \param RuntimeAttrs Structure containing the runtime numbers of threads + /// and teams to launch the kernel with. /// \param Inputs The input values to the region that will be passed. /// as arguments to the outlined function. /// \param BodyGenCB Callback that will generate the region code. @@ -2919,11 +2942,12 @@ class OpenMPIRBuilder { // dependency information as passed in the depend clause // \param HasNowait Whether the target construct has a `nowait` clause or not. InsertPointOrErrorTy createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, + const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 4c4d8f867fba511..cc299a9f46ce788 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6731,8 +6731,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() { return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit); } +static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List, + Module &M) { + if (List.empty()) + return; + + Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0); + + // Convert List to what ConstantArray needs. + SmallVector<Constant *, 8> UsedArray; + UsedArray.reserve(List.size()); + for (auto Item : List) + UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(&*Item), PtrTy)); + + ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size()); + auto *GV = + new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage, + llvm::ConstantArray::get(ArrTy, UsedArray), Name); + + GV->setSection("llvm.metadata"); +} + +static void +emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + StringRef FunctionName, OMPTgtExecModeFlags Mode, + std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) { + auto *Int8Ty = Type::getInt8Ty(Builder.getContext()); + auto *GVMode = new llvm::GlobalVariable( + OMPBuilder.M, Int8Ty, /*isConstant=*/true, + llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode), + Twine(FunctionName, "_exec_mode")); + GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility); + LLVMCompilerUsed.emplace_back(GVMode); +} + static Expected<Function *> createOutlinedFunction( - OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, StringRef FuncName, SmallVectorImpl<Value *> &Inputs, OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, @@ -6762,6 +6797,15 @@ static Expected<Function *> createOutlinedFunction( auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M); + if (OMPBuilder.Config.isTargetDevice()) { + std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed; + emitExecutionMode(OMPBuilder, Builder, FuncName, + IsSPMD ? OMP_TGT_EXEC_MODE_SPMD + : OMP_TGT_EXEC_MODE_GENERIC, + LLVMCompilerUsed); + emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M); + } + // Save insert point. IRBuilder<>::InsertPointGuard IPG(Builder); // If there's a DISubprogram associated with current function, then @@ -6802,7 +6846,7 @@ static Expected<Function *> createOutlinedFunction( // Insert target init call in the device compilation pass. if (OMPBuilder.Config.isTargetDevice()) Builder.restoreIP( - OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs)); + OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs)); BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock(); @@ -6998,7 +7042,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder, static Error emitTargetOutlinedFunction( OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry, - TargetRegionEntryInfo &EntryInfo, + bool IsSPMD, TargetRegionEntryInfo &EntryInfo, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, Function *&OutlinedFn, Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs, @@ -7007,7 +7051,7 @@ static Error emitTargetOutlinedFunction( OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = [&](StringRef EntryFnName) { - return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs, + return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs, EntryFnName, Inputs, CBFunc, ArgAccessorFuncCB); }; @@ -7307,6 +7351,7 @@ static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, + const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl<Value *> &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, @@ -7388,11 +7433,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, /*ForEndCall=*/false); SmallVector<Value *, 3> NumTeamsC; + for (auto [DefaultVal, RuntimeVal] : + zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams)) + NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal)); + + // Calculate number of threads: 0 if no clauses specified, otherwise it is the + // minimum between optional THREAD_LIMIT and NUM_THREADS clauses. + auto InitMaxThreadsClause = [&Builder](Value *Clause) { + if (Clause) + Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(), + /*isSigned=*/false); + return Clause; + }; + auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) { + if (Clause) + Result = Result + ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause), + Result, Clause) + : Clause; + }; + + // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so + // the NUM_THREADS clause is overriden by THREAD_LIMIT. SmallVector<Value *, 3> NumThreadsC; - for (auto V : DefaultAttrs.MaxTeams) - NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); - for (auto V : DefaultAttrs.MaxThreads) - NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V)); + Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1 + ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) + : nullptr; + + for (auto [TeamsVal, TargetVal] : llvm::zip_equal( + RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) { + Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); + Value *NumThreads = InitMaxThreadsClause(TargetVal); + + CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads); + CombineMaxThreadsClauses(MaxThreadsClause, NumThreads); + + NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0)); + } unsigned NumTargetItems = Info.NumberOfPtrs; // TODO: Use correct device ID @@ -7401,14 +7478,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize); Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize, llvm::omp::IdentFlag(0), 0); - // TODO: Use correct NumIterations - Value *NumIterations = Builder.getInt64(0); + + Value *TripCount = RuntimeAttrs.LoopTripCount + ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount, + Builder.getInt64Ty(), + /*isSigned=*/false) + : Builder.getInt64(0); + // TODO: Use correct DynCGGroupMem Value *DynCGGroupMem = Builder.getInt32(0); - KArgs = OpenMPIRBuilder::TargetKernelArgs( - NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC, - DynCGGroupMem, HasNoWait); + KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount, + NumTeamsC, NumThreadsC, + DynCGGroupMem, HasNoWait); // The presence of certain clauses on the target directive require the // explicit generation of the target task. @@ -7430,13 +7512,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, - InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, + const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD, + InsertPointTy AllocaIP, InsertPointTy CodeGenIP, + TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, + const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, SmallVector<DependData> Dependencies, bool HasNowait) { + assert((!RuntimeAttrs.LoopTripCount || IsSPMD) && + "trip count not expected if IsSPMD=false"); if (!updateToLocation(Loc)) return InsertPointTy(); @@ -7449,16 +7535,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // the target region itself is generated using the callbacks CBFunc // and ArgAccessorFuncCB if (Error Err = emitTargetOutlinedFunction( - *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) + *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs, + OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) return Err; // If we are not on the target device, then we need to generate code // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait); + emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, + OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, + HasNowait); return Builder.saveIP(); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index b0688d6215e42d3..a8c786b5886afe0 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6123,7 +6123,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { OMPBuilder.setConfig(Config); F->setName("func"); IRBuilder<> Builder(BB); - auto Int32Ty = Builder.getInt32Ty(); + auto *Int32Ty = Builder.getInt32Ty(); AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr"); AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr"); @@ -6183,11 +6183,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { - /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; - OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = - OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20); + RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30); + RuntimeAttrs.MaxThreads = Builder.getInt32(40); + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(), + Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); OMPBuilder.finalize(); @@ -6207,6 +6211,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { StringRef FunctionName = KernelLaunchFunc->getName(); EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel")); + // Check num_teams and num_threads in call arguments + EXPECT_TRUE(Call->arg_size() >= 4); + Value *NumTeamsArg = Call->getArgOperand(2); + EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg)); + EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue()); + Value *NumThreadsArg = Call->getArgOperand(3); + EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg)); + EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue()); + + // Check num_teams and num_threads kernel arguments (use number 5 starting + // from the end and counting the call to __tgt_target_kernel as the first use) + Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1); + EXPECT_TRUE(KernelArgs->getNumUses() >= 4); + Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3); + EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr)); + Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(NumTeamsStore)); + Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg)); + auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg); + EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements()); + EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0)); + EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1)); + EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2)); + Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2); + EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr)); + Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(NumThreadsStore)); + Value *NumThreadsStoreArg = + cast<StoreInst>(NumThreadsStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg)); + auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg); + EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements()); + EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0)); + EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1)); + EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2)); + // Check the fallback call BasicBlock *FallbackBlock = Branch->getSuccessor(0); Iter = FallbackBlock->rbegin(); @@ -6297,9 +6338,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, - CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); @@ -6378,6 +6421,197 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { auto *ExitBlock = EntryBlockBranch->getSuccessor(1); EXPECT_EQ(ExitBlock->getName(), "worker.exit"); EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI())); + + // Check global exec_mode. + GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used"); + EXPECT_NE(Used, nullptr); + Constant *UsedInit = Used->getInitializer(); + EXPECT_NE(UsedInit, nullptr); + EXPECT_TRUE(isa<ConstantArray>(UsedInit)); + auto *UsedInitData = cast<ConstantArray>(UsedInit); + EXPECT_EQ(1U, UsedInitData->getNumOperands()); + Constant *ExecMode = UsedInitData->getOperand(0); + EXPECT_TRUE(isa<GlobalVariable>(ExecMode)); + Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer(); + EXPECT_NE(ExecModeValue, nullptr); + EXPECT_TRUE(isa<ConstantInt>(ExecModeValue)); + EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC, + cast<ConstantInt>(ExecModeValue)->getZExtValue()); +} + +TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false, + /*OpenMPOffloadMandatory=*/false, + /*HasRequiresReverseOffload=*/false, + /*HasRequiresUnifiedAddress=*/false, + /*HasRequiresUnifiedSharedMemory=*/false, + /*HasRequiresDynamicAllocators=*/false); + OMPBuilder.setConfig(Config); + F->setName("func"); + IRBuilder<> Builder(BB); + + auto BodyGenCB = [&](InsertPointTy, + InsertPointTy CodeGenIP) -> InsertPointTy { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + auto SimpleArgAccessorCB = + [&](llvm::Argument &, llvm::Value *, llvm::Value *&, + llvm::OpenMPIRBuilder::InsertPointTy, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + llvm::SmallVector<llvm::Value *> Inputs; + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy) + -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; }; + + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); + OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); + OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(), + Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + assert(AfterIP && "unexpected error"); + Builder.restoreIP(*AfterIP); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check the kernel launch sequence + auto Iter = F->getEntryBlock().rbegin(); + EXPECT_TRUE(isa<BranchInst>(&*(Iter))); + BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter)); + EXPECT_TRUE(isa<CmpInst>(&*(++Iter))); + EXPECT_TRUE(isa<CallInst>(&*(++Iter))); + CallInst *Call = dyn_cast<CallInst>(&*(Iter)); + + // Check that the kernel launch function is called + Function *KernelLaunchFunc = Call->getCalledFunction(); + EXPECT_NE(KernelLaunchFunc, nullptr); + StringRef FunctionName = KernelLaunchFunc->getName(); + EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel")); + + // Check the trip count kernel argument (use number 5 starting from the end + // and counting the call to __tgt_target_kernel as the first use) + Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1); + EXPECT_TRUE(KernelArgs->getNumUses() >= 6); + Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5); + EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr)); + Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser(); + EXPECT_TRUE(isa<StoreInst>(TripCountStore)); + Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand(); + EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg)); + EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue()); + + // Check the fallback call + BasicBlock *FallbackBlock = Branch->getSuccessor(0); + Iter = FallbackBlock->rbegin(); + CallInst *FCall = dyn_cast<CallInst>(&*(++Iter)); + // 'F' has a dummy DISubprogram which causes OutlinedFunc to also + // have a DISubprogram. In this case, the call to OutlinedFunc needs + // to have a debug loc, otherwise verifier will complain. + FCall->setDebugLoc(DL); + EXPECT_NE(FCall, nullptr); + + // Check that the outlined function exists with the expected prefix + Function *OutlinedFunc = FCall->getCalledFunction(); + EXPECT_NE(OutlinedFunc, nullptr); + StringRef FunctionName2 = OutlinedFunc->getName(); + EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading")); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + +TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.setConfig( + OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false, + /*OpenMPOffloadMandatory=*/false, + /*HasRequiresReverseOffload=*/false, + /*HasRequiresUnifiedAddress=*/false, + /*HasRequiresUnifiedSharedMemory=*/false, + /*HasRequiresDynamicAllocators=*/false)); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + + Function *OutlinedFn = nullptr; + llvm::SmallVector<llvm::Value *> CapturedArgs; + + auto SimpleArgAccessorCB = + [&](llvm::Argument &, llvm::Value *, llvm::Value *&, + llvm::OpenMPIRBuilder::InsertPointTy, + llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; + + llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy) + -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; }; + + auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) + -> OpenMPIRBuilder::InsertPointTy { + Builder.restoreIP(CodeGenIP); + OutlinedFn = CodeGenIP.getBlock()->getParent(); + return Builder.saveIP(); + }; + + IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(), + F->getEntryBlock().getFirstInsertionPt()); + TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2, + /*Line=*/3, /*Count=*/0); + + OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { + /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; + OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( + Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB); + assert(AfterIP && "unexpected error"); + Builder.restoreIP(*AfterIP); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + // Check outlined function + EXPECT_FALSE(verifyModule(*M, &errs())); + EXPECT_NE(OutlinedFn, nullptr); + EXPECT_NE(F, OutlinedFn); + + EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage()); + // Account for the "implicit" first argument. + EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3"); + EXPECT_EQ(OutlinedFn->arg_size(), 1U); + + // Check global exec_mode. + GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used"); + EXPECT_NE(Used, nullptr); + Constant *UsedInit = Used->getInitializer(); + EXPECT_NE(UsedInit, nullptr); + EXPECT_TRUE(isa<ConstantArray>(UsedInit)); + auto *UsedInitData = cast<ConstantArray>(UsedInit); + EXPECT_EQ(1U, UsedInitData->getNumOperands()); + Constant *ExecMode = UsedInitData->getOperand(0); + EXPECT_TRUE(isa<GlobalVariable>(ExecMode)); + Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer(); + EXPECT_NE(ExecModeValue, nullptr); + EXPECT_TRUE(isa<ConstantInt>(ExecModeValue)); + EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD, + cast<ConstantInt>(ExecModeValue)->getZExtValue()); } TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { @@ -6448,9 +6682,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs, - CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); + Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP, + EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, + BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); Builder.restoreIP(*AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index cca2613ce102afa..f30ba2c29261625 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3951,9 +3951,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, allocaIP, codeGenIP); }; - // TODO: Populate default attributes based on the construct and clauses. + // TODO: Populate default and runtime attributes based on the construct and + // clauses. llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs; llvm::SmallVector<llvm::Value *, 4> kernelInput; for (size_t i = 0; i < mapVars.size(); ++i) { @@ -3973,9 +3975,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, - defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds, - targetOp.getNowait()); + ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(), + entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, + bodyCB, argAccessorCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); >From e9ea3a501bb97a4ddaa85b795874e343cae40f2b Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Wed, 27 Nov 2024 12:08:46 +0000 Subject: [PATCH 2/3] Address review comments --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 10 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 104 +++++++----------- .../Frontend/OpenMPIRBuilderTest.cpp | 46 ++++---- 3 files changed, 68 insertions(+), 92 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 444bc280df9f89b..3a640fbd7336951 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1387,9 +1387,6 @@ class OpenMPIRBuilder { /// Supporting functions for Reductions CodeGen. private: - /// Emit the llvm.used metadata. - void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List); - /// Get the id of the current thread on the GPU. Value *getGPUThreadID(); @@ -2011,6 +2008,13 @@ class OpenMPIRBuilder { /// Value. GlobalValue *createGlobalFlag(unsigned Value, StringRef Name); + /// Emit the llvm.used metadata. + void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List); + + /// Emit the kernel execution mode. + GlobalVariable *emitKernelExecutionMode(StringRef KernelName, + omp::OMPTgtExecModeFlags Mode); + /// Generate control flow and cleanup for cancellation. /// /// \param CancelFlag Flag indicating if the cancellation is performed. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index cc299a9f46ce788..dcf2515311eabcd 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) { return GV; } +void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) { + if (List.empty()) + return; + + // Convert List to what ConstantArray needs. + SmallVector<Constant *, 8> UsedArray; + UsedArray.resize(List.size()); + for (unsigned I = 0, E = List.size(); I != E; ++I) + UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(&*List[I]), Builder.getPtrTy()); + + if (UsedArray.empty()) + return; + ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size()); + + auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, + ConstantArray::get(ATy, UsedArray), Name); + + GV->setSection("llvm.metadata"); +} + +GlobalVariable * +OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName, + OMPTgtExecModeFlags Mode) { + auto *Int8Ty = Builder.getInt8Ty(); + auto *GVMode = new GlobalVariable( + M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage, + ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode")); + GVMode->setVisibility(GlobalVariable::ProtectedVisibility); + return GVMode; +} + Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr, uint32_t SrcLocStrSize, IdentFlag LocFlags, @@ -2246,28 +2278,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) { return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT); } -void OpenMPIRBuilder::emitUsed(StringRef Name, - std::vector<WeakTrackingVH> &List) { - if (List.empty()) - return; - - // Convert List to what ConstantArray needs. - SmallVector<Constant *, 8> UsedArray; - UsedArray.resize(List.size()); - for (unsigned I = 0, E = List.size(); I != E; ++I) - UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast( - cast<Constant>(&*List[I]), Builder.getPtrTy()); - - if (UsedArray.empty()) - return; - ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size()); - - auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, - ConstantArray::get(ATy, UsedArray), Name); - - GV->setSection("llvm.metadata"); -} - Value *OpenMPIRBuilder::getGPUThreadID() { return Builder.CreateCall( getOrCreateRuntimeFunction(M, @@ -6731,41 +6741,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() { return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit); } -static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List, - Module &M) { - if (List.empty()) - return; - - Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0); - - // Convert List to what ConstantArray needs. - SmallVector<Constant *, 8> UsedArray; - UsedArray.reserve(List.size()); - for (auto Item : List) - UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast( - cast<Constant>(&*Item), PtrTy)); - - ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size()); - auto *GV = - new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage, - llvm::ConstantArray::get(ArrTy, UsedArray), Name); - - GV->setSection("llvm.metadata"); -} - -static void -emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, - StringRef FunctionName, OMPTgtExecModeFlags Mode, - std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) { - auto *Int8Ty = Type::getInt8Ty(Builder.getContext()); - auto *GVMode = new llvm::GlobalVariable( - OMPBuilder.M, Int8Ty, /*isConstant=*/true, - llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode), - Twine(FunctionName, "_exec_mode")); - GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility); - LLVMCompilerUsed.emplace_back(GVMode); -} - static Expected<Function *> createOutlinedFunction( OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, @@ -6798,12 +6773,9 @@ static Expected<Function *> createOutlinedFunction( Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M); if (OMPBuilder.Config.isTargetDevice()) { - std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed; - emitExecutionMode(OMPBuilder, Builder, FuncName, - IsSPMD ? OMP_TGT_EXEC_MODE_SPMD - : OMP_TGT_EXEC_MODE_GENERIC, - LLVMCompilerUsed); - emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M); + Value *ExecMode = OMPBuilder.emitKernelExecutionMode( + FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC); + OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode}); } // Save insert point. @@ -7460,8 +7432,8 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads) : nullptr; - for (auto [TeamsVal, TargetVal] : llvm::zip_equal( - RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) { + for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit, + RuntimeAttrs.TargetThreadLimit)) { Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal); Value *NumThreads = InitMaxThreadsClause(TargetVal); @@ -7521,8 +7493,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, SmallVector<DependData> Dependencies, bool HasNowait) { - assert((!RuntimeAttrs.LoopTripCount || IsSPMD) && - "trip count not expected if IsSPMD=false"); if (!updateToLocation(Loc)) return InsertPointTy(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index a8c786b5886afe0..e4845256633b9c8 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6459,18 +6459,19 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { return Builder.saveIP(); }; - auto SimpleArgAccessorCB = - [&](llvm::Argument &, llvm::Value *, llvm::Value *&, - llvm::OpenMPIRBuilder::InsertPointTy, - llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { - Builder.restoreIP(CodeGenIP); - return Builder.saveIP(); - }; + auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&, + OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; - llvm::SmallVector<llvm::Value *> Inputs; - llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; - auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy) - -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; }; + SmallVector<Value *> Inputs; + OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = + [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & { + return CombinedInfos; + }; TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); @@ -6547,19 +6548,20 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); Function *OutlinedFn = nullptr; - llvm::SmallVector<llvm::Value *> CapturedArgs; + SmallVector<Value *> CapturedArgs; - auto SimpleArgAccessorCB = - [&](llvm::Argument &, llvm::Value *, llvm::Value *&, - llvm::OpenMPIRBuilder::InsertPointTy, - llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) { - Builder.restoreIP(CodeGenIP); - return Builder.saveIP(); - }; + auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&, + OpenMPIRBuilder::InsertPointTy, + OpenMPIRBuilder::InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + return Builder.saveIP(); + }; - llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos; - auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy) - -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; }; + OpenMPIRBuilder::MapInfosTy CombinedInfos; + auto GenMapInfoCB = + [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & { + return CombinedInfos; + }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, OpenMPIRBuilder::InsertPointTy CodeGenIP) >From b1e4eb5699afa9582a136c589adb9c256cc0bc66 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Wed, 4 Dec 2024 14:16:17 +0000 Subject: [PATCH 3/3] Fine-grained control of kernel execution mode --- clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp | 7 ++++- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 22 +++++++------- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 30 +++++++++---------- .../Frontend/OpenMPIRBuilderTest.cpp | 15 ++++++---- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 7 +++-- 5 files changed, 47 insertions(+), 34 deletions(-) diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp index 659783a813c83ef..515dbe379eb6e3d 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -20,6 +20,7 @@ #include "clang/AST/StmtVisitor.h" #include "clang/Basic/Cuda.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPGridValues.h" using namespace clang; @@ -748,7 +749,11 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D, computeMinAndMaxThreadsAndTeams(D, CGF, Attrs); CGBuilderTy &Bld = CGF.Builder; - Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs)); + Bld.restoreIP(OMPBuilder.createTargetInit( + Bld, + IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD + : llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, + Attrs)); if (!IsSPMD) emitGenericVarsProlog(CGF, EST.Loc); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 3a640fbd7336951..580b2b3e2341580 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2243,21 +2243,21 @@ class OpenMPIRBuilder { /// Container to pass LLVM IR runtime values or constants related to the /// number of teams and threads with which the kernel must be launched, as - /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These - /// must be defined in the host prior to the call to the kernel launch OpenMP - /// RTL function. + /// well as the trip count of the loop, if it is an SPMD or Generic-SPMD + /// kernel. These must be defined in the host prior to the call to the kernel + /// launch OpenMP RTL function. struct TargetKernelRuntimeAttrs { SmallVector<Value *, 3> MaxTeams = {nullptr}; Value *MinTeams = nullptr; SmallVector<Value *, 3> TargetThreadLimit = {nullptr}; SmallVector<Value *, 3> TeamsThreadLimit = {nullptr}; - /// 'parallel' construct 'num_threads' clause value, if present and it is a - /// target SPMD kernel. + /// 'parallel' construct 'num_threads' clause value, if present and it is an + /// SPMD kernel. Value *MaxThreads = nullptr; - /// Total number of iterations of the target SPMD kernel or null if it is a - /// generic kernel. + /// Total number of iterations of the SPMD or Generic-SPMD kernel or null if + /// it is a generic kernel. Value *LoopTripCount = nullptr; }; @@ -2763,11 +2763,12 @@ class OpenMPIRBuilder { /// Create a runtime call for kmpc_target_init /// /// \param Loc The insert and source location description. + /// \param ExecFlags Kernel execution mode flags. /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not. /// \param Attrs Structure containing the default numbers of threads and teams /// to launch the kernel with. InsertPointTy createTargetInit( - const LocationDescription &Loc, bool IsSPMD, + const LocationDescription &Loc, omp::OMPTgtExecModeFlags ExecFlags, const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs); /// Create a runtime call for kmpc_target_deinit @@ -2929,7 +2930,7 @@ class OpenMPIRBuilder { /// /// \param Loc where the target data construct was encountered. /// \param IsOffloadEntry whether it is an offload entry. - /// \param IsSPMD whether it is a target SPMD kernel. + /// \param ExecFlags kernel execution mode flags. /// \param CodeGenIP The insertion point where the call to the outlined /// function should be emitted. /// \param EntryInfo The entry information about the function. @@ -2946,7 +2947,8 @@ class OpenMPIRBuilder { // dependency information as passed in the depend clause // \param HasNowait Whether the target construct has a `nowait` clause or not. InsertPointOrErrorTy createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD, + const LocationDescription &Loc, bool IsOffloadEntry, + omp::OMPTgtExecModeFlags ExecFlags, OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index dcf2515311eabcd..28f85460624dd63 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6124,7 +6124,7 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate( } OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( - const LocationDescription &Loc, bool IsSPMD, + const LocationDescription &Loc, omp::OMPTgtExecModeFlags ExecFlags, const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) { assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() && "expected num_threads and num_teams to be specified"); @@ -6135,9 +6135,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( uint32_t SrcLocStrSize; Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); - Constant *IsSPMDVal = ConstantInt::getSigned( - Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC); - Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD); + Constant *IsSPMDVal = ConstantInt::getSigned(Int8, ExecFlags); + Constant *UseGenericStateMachineVal = + ConstantInt::getSigned(Int8, ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD); Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true); Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0); @@ -6742,7 +6742,8 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() { } static Expected<Function *> createOutlinedFunction( - OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD, + OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + omp::OMPTgtExecModeFlags ExecFlags, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, StringRef FuncName, SmallVectorImpl<Value *> &Inputs, OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc, @@ -6773,8 +6774,7 @@ static Expected<Function *> createOutlinedFunction( Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M); if (OMPBuilder.Config.isTargetDevice()) { - Value *ExecMode = OMPBuilder.emitKernelExecutionMode( - FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC); + Value *ExecMode = OMPBuilder.emitKernelExecutionMode(FuncName, ExecFlags); OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode}); } @@ -6818,7 +6818,7 @@ static Expected<Function *> createOutlinedFunction( // Insert target init call in the device compilation pass. if (OMPBuilder.Config.isTargetDevice()) Builder.restoreIP( - OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs)); + OMPBuilder.createTargetInit(Builder, ExecFlags, DefaultAttrs)); BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock(); @@ -7014,7 +7014,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder, static Error emitTargetOutlinedFunction( OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry, - bool IsSPMD, TargetRegionEntryInfo &EntryInfo, + omp::OMPTgtExecModeFlags ExecFlags, TargetRegionEntryInfo &EntryInfo, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, Function *&OutlinedFn, Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs, @@ -7023,8 +7023,8 @@ static Error emitTargetOutlinedFunction( OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction = [&](StringRef EntryFnName) { - return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs, - EntryFnName, Inputs, CBFunc, + return createOutlinedFunction(OMPBuilder, Builder, ExecFlags, + DefaultAttrs, EntryFnName, Inputs, CBFunc, ArgAccessorFuncCB); }; @@ -7484,9 +7484,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, } OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( - const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD, - InsertPointTy AllocaIP, InsertPointTy CodeGenIP, - TargetRegionEntryInfo &EntryInfo, + const LocationDescription &Loc, bool IsOffloadEntry, + omp::OMPTgtExecModeFlags ExecFlags, InsertPointTy AllocaIP, + InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, const TargetKernelRuntimeAttrs &RuntimeAttrs, SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB, @@ -7505,7 +7505,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // the target region itself is generated using the callbacks CBFunc // and ArgAccessorFuncCB if (Error Err = emitTargetOutlinedFunction( - *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs, + *this, Builder, IsOffloadEntry, ExecFlags, EntryInfo, DefaultAttrs, OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) return Err; diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index e4845256633b9c8..90a0a92888310cc 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -6189,7 +6189,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30); RuntimeAttrs.MaxThreads = Builder.getInt32(40); OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(), + OmpLoc, /*IsOffloadEntry=*/true, + omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); @@ -6340,7 +6341,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP, + Loc, /*IsOffloadEntry=*/true, + omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, EntryIP, EntryIP, EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); @@ -6480,7 +6482,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(), + OmpLoc, /*IsOffloadEntry=*/true, + omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, Builder.saveIP(), Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); @@ -6580,7 +6583,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP, + Loc, /*IsOffloadEntry=*/true, + omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, EntryIP, EntryIP, EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); @@ -6686,7 +6690,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget( - Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP, + Loc, /*IsOffloadEntry=*/true, + omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, EntryIP, EntryIP, EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); assert(AfterIP && "unexpected error"); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index f30ba2c29261625..acdbbcd5eafa21e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3975,9 +3975,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(), - entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, - bodyCB, argAccessorCB, dds, targetOp.getNowait()); + ompLoc, isOffloadEntry, llvm::omp::OMP_TGT_EXEC_MODE_GENERIC, + allocaIP, builder.saveIP(), entryInfo, defaultAttrs, runtimeAttrs, + kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds, + targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits