================ @@ -1126,6 +1133,185 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag, Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin()); } +// Callback used to create OpenMP runtime calls to support +// omp parallel clause for the device. +// We need to use this callback to replace call to the OutlinedFn in OuterFn +// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51) +static void +targetParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, + Function *OuterFn, Value *Ident, Value *IfCondition, + Value *NumThreads, Instruction *PrivTID, + AllocaInst *PrivTIDAddr, Value *ThreadID, + const SmallVector<Instruction *, 4> &ToBeDeleted) { + // Add some known attributes. + Module &M = OMPIRBuilder->M; + IRBuilder<> &Builder = OMPIRBuilder->Builder; + OutlinedFn.addParamAttr(0, Attribute::NoAlias); + OutlinedFn.addParamAttr(1, Attribute::NoAlias); + OutlinedFn.addParamAttr(0, Attribute::NoUndef); + OutlinedFn.addParamAttr(1, Attribute::NoUndef); + OutlinedFn.addFnAttr(Attribute::NoUnwind); + + assert(OutlinedFn.arg_size() >= 2 && + "Expected at least tid and bounded tid as arguments"); + unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2; + + CallInst *CI = cast<CallInst>(OutlinedFn.user_back()); + assert(CI && "Expected call instruction to outlined function"); + CI->getParent()->setName("omp_parallel"); + // Replace direct call to the outlined function by the call to + // __kmpc_parallel_51 + Builder.SetInsertPoint(CI); + + // Build call __kmpc_parallel_51 + auto PtrTy = Type::getInt8PtrTy(M.getContext()); + Value *Void = ConstantPointerNull::get(PtrTy); + // Add alloca for kernel args. Put this instruction at the beginning + // of the function. + OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP(); + Builder.SetInsertPoint(&OuterFn->front(), + OuterFn->front().getFirstInsertionPt()); + AllocaInst *ArgsAlloca = + Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars)); + Value *Args = + Builder.CreatePointerCast(ArgsAlloca, Type::getInt8PtrTy(M.getContext())); ---------------- jdoerfert wrote:
``` - Type::getInt8PtrTy(M.getContext()) + PtrTy ``` Check your patches for things like this. https://github.com/llvm/llvm-project/pull/67000 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits