================ @@ -114,24 +115,221 @@ static inline void clearModule(Module &M) { // TODO: simplify. eraseFromModule(*M.ifuncs().begin()); } +static inline SmallVector<std::reference_wrapper<Use>> +collectIndirectableUses(GlobalVariable *G) { + // We are interested only in use chains that end in an Instruction. + SmallVector<std::reference_wrapper<Use>> Uses; + + SmallVector<std::reference_wrapper<Use>> Tmp(G->use_begin(), G->use_end()); + while (!Tmp.empty()) { + Use &U = Tmp.back(); + Tmp.pop_back(); + if (isa<Instruction>(U.getUser())) + Uses.emplace_back(U); + else + transform(U.getUser()->uses(), std::back_inserter(Tmp), + [](auto &&U) { return std::ref(U); }); + } + + return Uses; +} + +static inline GlobalVariable *getGlobalForName(GlobalVariable *G) { + // Create an anonymous global which stores the variable's name, which will be + // used by the HIPSTDPAR runtime to look up the program-wide symbol. + LLVMContext &Ctx = G->getContext(); + auto *CDS = ConstantDataArray::getString(Ctx, G->getName()); + + GlobalVariable *N = G->getParent()->getOrInsertGlobal("", CDS->getType()); + N->setInitializer(CDS); + N->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage); + N->setConstant(true); + + return N; +} + +static inline GlobalVariable *getIndirectionGlobal(Module *M) { + // Create an anonymous global which stores a pointer to a pointer, which will + // be externally initialised by the HIPSTDPAR runtime with the address of the + // program-wide symbol. + Type *PtrTy = PointerType::get( + M->getContext(), M->getDataLayout().getDefaultGlobalsAddressSpace()); + GlobalVariable *NewG = M->getOrInsertGlobal("", PtrTy); + + NewG->setInitializer(PoisonValue::get(NewG->getValueType())); + NewG->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage); + NewG->setConstant(true); + NewG->setExternallyInitialized(true); + + return NewG; +} + +static inline Constant * +appendIndirectedGlobal(const GlobalVariable *IndirectionTable, + SmallVector<Constant *> &SymbolIndirections, + GlobalVariable *ToIndirect) { + Module *M = ToIndirect->getParent(); + + auto *InitTy = cast<StructType>(IndirectionTable->getValueType()); + auto *SymbolListTy = cast<StructType>(InitTy->getStructElementType(2)); + Type *NameTy = SymbolListTy->getElementType(0); + Type *IndirectTy = SymbolListTy->getElementType(1); + + Constant *NameG = getGlobalForName(ToIndirect); + Constant *IndirectG = getIndirectionGlobal(M); + Constant *Entry = ConstantStruct::get( + SymbolListTy, {ConstantExpr::getAddrSpaceCast(NameG, NameTy), + ConstantExpr::getAddrSpaceCast(IndirectG, IndirectTy)}); + SymbolIndirections.push_back(Entry); + + return IndirectG; +} + +static void fillIndirectionTable(GlobalVariable *IndirectionTable, + SmallVector<Constant *> Indirections) { + Module *M = IndirectionTable->getParent(); + size_t SymCnt = Indirections.size(); + + auto *InitTy = cast<StructType>(IndirectionTable->getValueType()); + Type *SymbolListTy = InitTy->getStructElementType(1); + auto *SymbolTy = cast<StructType>(InitTy->getStructElementType(2)); + + Constant *Count = ConstantInt::get(InitTy->getStructElementType(0), SymCnt); + M->removeGlobalVariable(IndirectionTable); + GlobalVariable *Symbols = + M->getOrInsertGlobal("", ArrayType::get(SymbolTy, SymCnt)); + Symbols->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage); + Symbols->setInitializer( + ConstantArray::get(ArrayType::get(SymbolTy, SymCnt), {Indirections})); + Symbols->setConstant(true); + + Constant *ASCSymbols = ConstantExpr::getAddrSpaceCast(Symbols, SymbolListTy); + Constant *Init = ConstantStruct::get( + InitTy, {Count, ASCSymbols, PoisonValue::get(SymbolTy)}); + M->insertGlobalVariable(IndirectionTable); + IndirectionTable->setInitializer(Init); +} + +static void replaceWithIndirectUse(const Use &U, const GlobalVariable *G, + Constant *IndirectedG) { + auto *I = cast<Instruction>(U.getUser()); + + IRBuilder<> Builder(I); + Value *Op = I->getOperand(U.getOperandNo()); + + // We walk back up the use chain, which could be an arbitrarily long sequence + // of constexpr AS casts, ptr-to-int and GEP instructions, until we reach the + // indirected global. + while (auto *CE = dyn_cast<ConstantExpr>(Op)) { + assert((CE->getOpcode() == Instruction::GetElementPtr || + CE->getOpcode() == Instruction::AddrSpaceCast || + CE->getOpcode() == Instruction::PtrToInt) && + "Only GEP, ASCAST or PTRTOINT constant uses supported!"); + + Instruction *NewI = Builder.Insert(CE->getAsInstruction()); + I->replaceUsesOfWith(Op, NewI); + I = NewI; + Op = I->getOperand(0); + Builder.SetInsertPoint(I); + } + + assert(Op == G && "Must reach indirected global!"); + + Builder.GetInsertPoint()->setOperand( ---------------- jmmartinez wrote:
Isn't this the same as `I->setOperand(0, ...)` ? Is it possible that there is no cast/gep/ptrtoint and only some other kind of user? Something where the operand is not at the 0 index, like a builtin memset/memcpy. https://github.com/llvm/llvm-project/pull/146813 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits