https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/68364
>From 2d3b34476df53f39d6cc6b7eee02b9d0d33e7a04 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Wed, 4 Oct 2023 15:55:55 -0500 Subject: [PATCH 1/5] [OpenMPIRBuilder] Add clauses to teams This patch adds `num_teams` (upperbound) and `thread_limit` clauses to `OpenMPIRBuilder`. --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 7 +- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 14 ++- .../Frontend/OpenMPIRBuilderTest.cpp | 115 ++++++++++++++++++ 3 files changed, 134 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 1699ed3aeab7661..8745b6df9e86330 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1893,8 +1893,13 @@ class OpenMPIRBuilder { /// /// \param Loc The location where the teams construct was encountered. /// \param BodyGenCB Callback that will generate the region code. + /// \param NumTeamsUpper Upper bound on the number of teams. + /// \param ThreadLimit on the number of threads that may participate in a + /// contention group created by each team. InsertPointTy createTeams(const LocationDescription &Loc, - BodyGenCallbackTy BodyGenCB); + BodyGenCallbackTy BodyGenCB, + Value *NumTeamsUpper = nullptr, + Value *ThreadLimit = nullptr); /// Generate conditional branch and relevant BasicBlocks through which private /// threads copy the 'copyin' variables from Master copy to threadprivate diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 9c70d384e55db2b..62bc7b3d40ca43a 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5733,7 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTeams(const LocationDescription &Loc, - BodyGenCallbackTy BodyGenCB) { + BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper, + Value *ThreadLimit) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -5771,6 +5772,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, BasicBlock *AllocaBB = splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); + // Push num_teams + if (NumTeamsUpper || ThreadLimit) { + NumTeamsUpper = + NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper; + ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit; + Value *ThreadNum = getOrCreateThreadID(Ident); + Builder.CreateCall( + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams), + {Ident, ThreadNum, NumTeamsUpper, ThreadLimit}); + } + OutlineInfo OI; OI.EntryBB = AllocaBB; OI.ExitBB = ExitBB; diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index fd524f6067ee0ea..88b7e4b397e46de 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4082,6 +4082,121 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) { [](Instruction &inst) { return isa<ICmpInst>(&inst); })); } +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> &Builder = OMPBuilder.Builder; + Builder.SetInsertPoint(BB); + + Function *FakeFunction = + Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::ExternalLinkage, "fakeFunction", M.get()); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + Builder.CreateCall(FakeFunction, {}); + }; + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + Builder.restoreIP( + OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin())); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + ASSERT_FALSE(verifyModule(*M)); + + Function *PushNumTeamsRTL = + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams); + EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U); + + CallInst *PushNumTeamsCallInst = + findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder); + ASSERT_NE(PushNumTeamsCallInst, nullptr); + + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0)); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin()); + + // Verifying that the next instruction to execute is kmpc_fork_teams + BranchInst *BrInst = + dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction()); + ASSERT_NE(BrInst, nullptr); + ASSERT_EQ(BrInst->getNumSuccessors(), 1U); + Instruction *NextInstruction = + BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime(); + CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction); + ASSERT_NE(ForkTeamsCI, nullptr); + EXPECT_EQ(ForkTeamsCI->getCalledFunction(), + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)); +} + +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> &Builder = OMPBuilder.Builder; + Builder.SetInsertPoint(BB); + + Function *FakeFunction = + Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::ExternalLinkage, "fakeFunction", M.get()); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + Builder.CreateCall(FakeFunction, {}); + }; + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin())); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + ASSERT_FALSE(verifyModule(*M)); + + // M->print(dbgs(), nullptr); +} + +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> &Builder = OMPBuilder.Builder; + Builder.SetInsertPoint(BB); + + BasicBlock *CodegenBB = splitBB(Builder, true); + Builder.SetInsertPoint(CodegenBB); + + Value *NumTeamsUpper = + Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper"); + Value *ThreadLimit = + Builder.CreateAdd(F->arg_begin(), Builder.getInt32(20), "threadLimit"); + + Function *FakeFunction = + Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::ExternalLinkage, "fakeFunction", M.get()); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + Builder.CreateCall(FakeFunction, {}); + }; + + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + Builder.restoreIP( + OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit)); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + ASSERT_FALSE(verifyModule(*M)); + + // M->print(dbgs(), nullptr); +} + /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. template <typename InstTy> >From 8393f14fb9a5b9f2cf2b8745cebe3d0b702c9541 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Thu, 5 Oct 2023 17:57:11 -0500 Subject: [PATCH 2/5] Add testcases --- .../Frontend/OpenMPIRBuilderTest.cpp | 47 +++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 88b7e4b397e46de..496c60ba38605ce 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4157,7 +4157,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) { ASSERT_FALSE(verifyModule(*M)); - // M->print(dbgs(), nullptr); + Function *PushNumTeamsRTL = + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams); + EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U); + + CallInst *PushNumTeamsCallInst = + findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder); + ASSERT_NE(PushNumTeamsCallInst, nullptr); + + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin()); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0)); + + // Verifying that the next instruction to execute is kmpc_fork_teams + BranchInst *BrInst = + dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction()); + ASSERT_NE(BrInst, nullptr); + ASSERT_EQ(BrInst->getNumSuccessors(), 1U); + Instruction *NextInstruction = + BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime(); + CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction); + ASSERT_NE(ForkTeamsCI, nullptr); + EXPECT_EQ(ForkTeamsCI->getCalledFunction(), + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)); } TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { @@ -4194,8 +4215,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { ASSERT_FALSE(verifyModule(*M)); - // M->print(dbgs(), nullptr); -} + Function *PushNumTeamsRTL = + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams); + EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U); + + CallInst *PushNumTeamsCallInst = + findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder); + ASSERT_NE(PushNumTeamsCallInst, nullptr); + + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit); + + // Verifying that the next instruction to execute is kmpc_fork_teams + BranchInst *BrInst = + dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction()); + ASSERT_NE(BrInst, nullptr); + ASSERT_EQ(BrInst->getNumSuccessors(), 1U); + Instruction *NextInstruction = + BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime(); + CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction); + ASSERT_NE(ForkTeamsCI, nullptr); + EXPECT_EQ(ForkTeamsCI->getCalledFunction(), + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));} /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. >From 9f368708a33a87dd9fea8944082c54ac37dd85c9 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Thu, 5 Oct 2023 17:58:02 -0500 Subject: [PATCH 3/5] Formatting --- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 496c60ba38605ce..fb87389023910c2 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4236,7 +4236,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction); ASSERT_NE(ForkTeamsCI, nullptr); EXPECT_EQ(ForkTeamsCI->getCalledFunction(), - OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams));} + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)); +} /// Returns the single instruction of InstTy type in BB that uses the value V. /// If there is more than one such instruction, returns null. >From 25870b3f64b9a07452e9207c30d15dc960c69f18 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Mon, 9 Oct 2023 09:10:18 -0500 Subject: [PATCH 4/5] Address comments --- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index bd38f5bece16df1..fef592718f79c95 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4091,9 +4091,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) { Builder.CreateCall(FakeFunction, {}); }; - OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - Builder.restoreIP( - OMPBuilder.createTeams(Builder, BodyGenCB, nullptr, F->arg_begin())); + // `F` has an argument - an integer, so we use that as the thread limit. + Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB, + /*NumTeamsUpper=*/nullptr, + /*ThreadLimit=*/F->arg_begin())); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -4141,8 +4142,10 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) { Builder.CreateCall(FakeFunction, {}); }; - OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, F->arg_begin())); + // `F` already has an integer argument, so we use that as upper bound to + // `num_teams` + Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, + /*NumTeamsUpper=*/F->arg_begin())); Builder.CreateRetVoid(); OMPBuilder.finalize(); @@ -4184,6 +4187,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { BasicBlock *CodegenBB = splitBB(Builder, true); Builder.SetInsertPoint(CodegenBB); + // Generate values for `num_teams` and `thread_limit` using the first argument + // of the testing function. Value *NumTeamsUpper = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper"); Value *ThreadLimit = >From c13e639039cd57d5ce9e97c1075dc21f662dbee8 Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay <shraiysh.vais...@amd.com> Date: Mon, 9 Oct 2023 11:45:56 -0500 Subject: [PATCH 5/5] Address comment about lower bound --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 4 + .../include/llvm/Frontend/OpenMP/OMPKinds.def | 1 + llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 26 ++++-- .../Frontend/OpenMPIRBuilderTest.cpp | 92 ++++++++++++++----- 4 files changed, 93 insertions(+), 30 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index ba679e2998eb413..9d2adf229b78654 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1917,11 +1917,15 @@ class OpenMPIRBuilder { /// /// \param Loc The location where the teams construct was encountered. /// \param BodyGenCB Callback that will generate the region code. + /// \param NumTeamsLower Lower bound on number of teams. If this is nullptr, + /// it is as if lower bound is specified as equal to upperbound. If + /// this is non-null, then upperbound must also be non-null. /// \param NumTeamsUpper Upper bound on the number of teams. /// \param ThreadLimit on the number of threads that may participate in a /// contention group created by each team. InsertPointTy createTeams(const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB, + Value *NumTeamsLower = nullptr, Value *NumTeamsUpper = nullptr, Value *ThreadLimit = nullptr); diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def index 176b883fe68f7ad..4823c4cc6b833ec 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -387,6 +387,7 @@ __OMP_RTL(__kmpc_cancellationpoint, false, Int32, IdentPtr, Int32, Int32) __OMP_RTL(__kmpc_fork_teams, true, Void, IdentPtr, Int32, ParallelTaskPtr) __OMP_RTL(__kmpc_push_num_teams, false, Void, IdentPtr, Int32, Int32, Int32) +__OMP_RTL(__kmpc_push_num_teams_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32) __OMP_RTL(__kmpc_set_thread_limit, false, Void, IdentPtr, Int32, Int32) __OMP_RTL(__kmpc_copyprivate, false, Void, IdentPtr, Int32, SizeTy, VoidPtr, diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 315967dc2b2a6f6..a658990f2d45355 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -5733,8 +5733,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare( OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTeams(const LocationDescription &Loc, - BodyGenCallbackTy BodyGenCB, Value *NumTeamsUpper, - Value *ThreadLimit) { + BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower, + Value *NumTeamsUpper, Value *ThreadLimit) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -5773,14 +5773,24 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, splitBB(Builder, /*CreateBranch=*/true, "teams.alloca"); // Push num_teams - if (NumTeamsUpper || ThreadLimit) { - NumTeamsUpper = - NumTeamsUpper == nullptr ? Builder.getInt32(0) : NumTeamsUpper; - ThreadLimit = ThreadLimit == nullptr ? Builder.getInt32(0) : ThreadLimit; + if (NumTeamsLower || NumTeamsUpper || ThreadLimit) { + assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) && + "if lowerbound is non-null, then upperbound must also be non-null " + "for bounds on num_teams"); + + if (NumTeamsUpper == nullptr) + NumTeamsUpper = Builder.getInt32(0); + + if (NumTeamsLower == nullptr) + NumTeamsLower = NumTeamsUpper; + + if (ThreadLimit == nullptr) + ThreadLimit = Builder.getInt32(0); + Value *ThreadNum = getOrCreateThreadID(Ident); Builder.CreateCall( - getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams), - {Ident, ThreadNum, NumTeamsUpper, ThreadLimit}); + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51), + {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit}); } // Generate the body of teams. InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin()); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index fef592718f79c95..37400a9be0d14a3 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -4093,6 +4093,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) { // `F` has an argument - an integer, so we use that as the thread limit. Builder.restoreIP(OMPBuilder.createTeams(/*=*/Builder, BodyGenCB, + /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/nullptr, /*ThreadLimit=*/F->arg_begin())); @@ -4101,16 +4102,13 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) { ASSERT_FALSE(verifyModule(*M)); - Function *PushNumTeamsRTL = - OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams); - EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U); - CallInst *PushNumTeamsCallInst = - findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder); + findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder); ASSERT_NE(PushNumTeamsCallInst, nullptr); EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), Builder.getInt32(0)); - EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin()); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0)); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), &*F->arg_begin()); // Verifying that the next instruction to execute is kmpc_fork_teams BranchInst *BrInst = @@ -4125,7 +4123,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) { OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)); } -TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) { +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); OMPBuilder.initialize(); @@ -4145,6 +4143,7 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) { // `F` already has an integer argument, so we use that as upper bound to // `num_teams` Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, + /*NumTeamsLower=*/nullptr, /*NumTeamsUpper=*/F->arg_begin())); Builder.CreateRetVoid(); @@ -4152,16 +4151,66 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeams) { ASSERT_FALSE(verifyModule(*M)); - Function *PushNumTeamsRTL = - OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams); - EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U); - CallInst *PushNumTeamsCallInst = - findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder); + findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder); ASSERT_NE(PushNumTeamsCallInst, nullptr); EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), &*F->arg_begin()); - EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), Builder.getInt32(0)); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), &*F->arg_begin()); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0)); + + // Verifying that the next instruction to execute is kmpc_fork_teams + BranchInst *BrInst = + dyn_cast<BranchInst>(PushNumTeamsCallInst->getNextNonDebugInstruction()); + ASSERT_NE(BrInst, nullptr); + ASSERT_EQ(BrInst->getNumSuccessors(), 1U); + Instruction *NextInstruction = + BrInst->getSuccessor(0)->getFirstNonPHIOrDbgOrLifetime(); + CallInst *ForkTeamsCI = dyn_cast_if_present<CallInst>(NextInstruction); + ASSERT_NE(ForkTeamsCI, nullptr); + EXPECT_EQ(ForkTeamsCI->getCalledFunction(), + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)); +} + +TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> &Builder = OMPBuilder.Builder; + Builder.SetInsertPoint(BB); + + Function *FakeFunction = + Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::ExternalLinkage, "fakeFunction", M.get()); + + Value *NumTeamsLower = + Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower"); + Value *NumTeamsUpper = + Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper"); + + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { + Builder.restoreIP(CodeGenIP); + Builder.CreateCall(FakeFunction, {}); + }; + + // `F` already has an integer argument, so we use that as upper bound to + // `num_teams` + Builder.restoreIP( + OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper)); + + Builder.CreateRetVoid(); + OMPBuilder.finalize(); + + ASSERT_FALSE(verifyModule(*M)); + + CallInst *PushNumTeamsCallInst = + findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder); + ASSERT_NE(PushNumTeamsCallInst, nullptr); + + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), Builder.getInt32(0)); // Verifying that the next instruction to execute is kmpc_fork_teams BranchInst *BrInst = @@ -4189,6 +4238,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { // Generate values for `num_teams` and `thread_limit` using the first argument // of the testing function. + Value *NumTeamsLower = + Builder.CreateAdd(F->arg_begin(), Builder.getInt32(5), "numTeamsLower"); Value *NumTeamsUpper = Builder.CreateAdd(F->arg_begin(), Builder.getInt32(10), "numTeamsUpper"); Value *ThreadLimit = @@ -4204,24 +4255,21 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) { }; OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); - Builder.restoreIP( - OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsUpper, ThreadLimit)); + Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB, NumTeamsLower, + NumTeamsUpper, ThreadLimit)); Builder.CreateRetVoid(); OMPBuilder.finalize(); ASSERT_FALSE(verifyModule(*M)); - Function *PushNumTeamsRTL = - OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams); - EXPECT_EQ(PushNumTeamsRTL->getNumUses(), 1U); - CallInst *PushNumTeamsCallInst = - findSingleCall(F, OMPRTL___kmpc_push_num_teams, OMPBuilder); + findSingleCall(F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder); ASSERT_NE(PushNumTeamsCallInst, nullptr); - EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsUpper); - EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), ThreadLimit); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(2), NumTeamsLower); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(3), NumTeamsUpper); + EXPECT_EQ(PushNumTeamsCallInst->getArgOperand(4), ThreadLimit); // Verifying that the next instruction to execute is kmpc_fork_teams BranchInst *BrInst = _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits