================ @@ -114,24 +115,221 @@ static inline void clearModule(Module &M) { // TODO: simplify. eraseFromModule(*M.ifuncs().begin()); } +static inline SmallVector<std::reference_wrapper<Use>> +collectIndirectableUses(GlobalVariable *G) { + // We are interested only in use chains that end in an Instruction. + SmallVector<std::reference_wrapper<Use>> Uses; + + SmallVector<std::reference_wrapper<Use>> Tmp(G->use_begin(), G->use_end()); + while (!Tmp.empty()) { + Use &U = Tmp.back(); + Tmp.pop_back(); + if (isa<Instruction>(U.getUser())) + Uses.emplace_back(U); + else + transform(U.getUser()->uses(), std::back_inserter(Tmp), + [](auto &&U) { return std::ref(U); }); + } + + return Uses; +} + +static inline GlobalVariable *getGlobalForName(GlobalVariable *G) { + // Create an anonymous global which stores the variable's name, which will be + // used by the HIPSTDPAR runtime to look up the program-wide symbol. + LLVMContext &Ctx = G->getContext(); + auto *CDS = ConstantDataArray::getString(Ctx, G->getName()); + + GlobalVariable *N = G->getParent()->getOrInsertGlobal("", CDS->getType()); + N->setInitializer(CDS); + N->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage); + N->setConstant(true); + + return N; +} + +static inline GlobalVariable *getIndirectionGlobal(Module *M) { + // Create an anonymous global which stores a pointer to a pointer, which will + // be externally initialised by the HIPSTDPAR runtime with the address of the + // program-wide symbol. + Type *PtrTy = PointerType::get( + M->getContext(), M->getDataLayout().getDefaultGlobalsAddressSpace()); + GlobalVariable *NewG = M->getOrInsertGlobal("", PtrTy); + + NewG->setInitializer(PoisonValue::get(NewG->getValueType())); + NewG->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage); + NewG->setConstant(true); + NewG->setExternallyInitialized(true); + + return NewG; +} + +static inline Constant * +appendIndirectedGlobal(const GlobalVariable *IndirectionTable, + SmallVector<Constant *> &SymbolIndirections, + GlobalVariable *ToIndirect) { + Module *M = ToIndirect->getParent(); + + auto *InitTy = cast<StructType>(IndirectionTable->getValueType()); + auto *SymbolListTy = cast<StructType>(InitTy->getStructElementType(2)); + Type *NameTy = SymbolListTy->getElementType(0); + Type *IndirectTy = SymbolListTy->getElementType(1); + + Constant *NameG = getGlobalForName(ToIndirect); + Constant *IndirectG = getIndirectionGlobal(M); + Constant *Entry = ConstantStruct::get( + SymbolListTy, {ConstantExpr::getAddrSpaceCast(NameG, NameTy), + ConstantExpr::getAddrSpaceCast(IndirectG, IndirectTy)}); + SymbolIndirections.push_back(Entry); + + return IndirectG; +} + +static void fillIndirectionTable(GlobalVariable *IndirectionTable, + SmallVector<Constant *> Indirections) { + Module *M = IndirectionTable->getParent(); + size_t SymCnt = Indirections.size(); + + auto *InitTy = cast<StructType>(IndirectionTable->getValueType()); + Type *SymbolListTy = InitTy->getStructElementType(1); + auto *SymbolTy = cast<StructType>(InitTy->getStructElementType(2)); + + Constant *Count = ConstantInt::get(InitTy->getStructElementType(0), SymCnt); + M->removeGlobalVariable(IndirectionTable); + GlobalVariable *Symbols = + M->getOrInsertGlobal("", ArrayType::get(SymbolTy, SymCnt)); + Symbols->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage); + Symbols->setInitializer( + ConstantArray::get(ArrayType::get(SymbolTy, SymCnt), {Indirections})); + Symbols->setConstant(true); + + Constant *ASCSymbols = ConstantExpr::getAddrSpaceCast(Symbols, SymbolListTy); + Constant *Init = ConstantStruct::get( + InitTy, {Count, ASCSymbols, PoisonValue::get(SymbolTy)}); + M->insertGlobalVariable(IndirectionTable); + IndirectionTable->setInitializer(Init); +} + +static void replaceWithIndirectUse(const Use &U, const GlobalVariable *G, + Constant *IndirectedG) { + auto *I = cast<Instruction>(U.getUser()); + + IRBuilder<> Builder(I); + Value *Op = I->getOperand(U.getOperandNo()); + + // We walk back up the use chain, which could be an arbitrarily long sequence + // of constexpr AS casts, ptr-to-int and GEP instructions, until we reach the + // indirected global. + while (auto *CE = dyn_cast<ConstantExpr>(Op)) { + assert((CE->getOpcode() == Instruction::GetElementPtr || + CE->getOpcode() == Instruction::AddrSpaceCast || + CE->getOpcode() == Instruction::PtrToInt) && + "Only GEP, ASCAST or PTRTOINT constant uses supported!"); + + Instruction *NewI = Builder.Insert(CE->getAsInstruction()); + I->replaceUsesOfWith(Op, NewI); + I = NewI; + Op = I->getOperand(0); + Builder.SetInsertPoint(I); + } + + assert(Op == G && "Must reach indirected global!"); + + Builder.GetInsertPoint()->setOperand( + 0, Builder.CreateLoad(G->getType(), IndirectedG)); +} + +static inline bool isValidIndirectionTable(GlobalVariable *IndirectionTable) { + std::string W; + raw_string_ostream OS(W); + + Type *Ty = IndirectionTable->getValueType(); + bool Valid = false; + + if (!isa<StructType>(Ty)) { + OS << "The Indirection Table must be a struct type; "; + Ty->print(OS); + OS << " is incorrect.\n"; + } else if (cast<StructType>(Ty)->getNumElements() != 3u) { + OS << "The Indirection Table must have 3 elements; " + << cast<StructType>(Ty)->getNumElements() << " is incorrect.\n"; + } else if (!isa<IntegerType>(cast<StructType>(Ty)->getStructElementType(0))) { + OS << "The first element in the Indirection Table must be an integer; "; + cast<StructType>(Ty)->getStructElementType(0)->print(OS); + OS << " is incorrect.\n"; + } else if (!isa<PointerType>(cast<StructType>(Ty)->getStructElementType(1))) { + OS << "The second element in the Indirection Table must be a pointer; "; + cast<StructType>(Ty)->getStructElementType(1)->print(OS); + OS << " is incorrect.\n"; + } else if (!isa<StructType>(cast<StructType>(Ty)->getStructElementType(2))) { + OS << "The third element in the Indirection Table must be a struct type; "; + cast<StructType>(Ty)->getStructElementType(2)->print(OS); + OS << " is incorrect.\n"; + } else { + Valid = true; + } + + if (!Valid) + IndirectionTable->getContext().diagnose(DiagnosticInfoGeneric(W, DS_Error)); + + return Valid; +} + +static void indirectGlobals(GlobalVariable *IndirectionTable, + SmallVector<GlobalVariable *> ToIndirect) { + // We replace globals with an indirected access via a pointer that will get + // set by the HIPSTDPAR runtime, using their accessible, program-wide unique + // address as set by the host linker-loader. + SmallVector<Constant *> SymbolIndirections; + for (auto &&G : ToIndirect) { + SmallVector<std::reference_wrapper<Use>> Uses = collectIndirectableUses(G); + + if (Uses.empty()) + continue; + + Constant *IndirectedGlobal = + appendIndirectedGlobal(IndirectionTable, SymbolIndirections, G); + + for_each(Uses, + [=](auto &&U) { replaceWithIndirectUse(U, G, IndirectedGlobal); }); + + eraseFromModule(*G); + } + + if (SymbolIndirections.empty()) + return; + + fillIndirectionTable(IndirectionTable, std::move(SymbolIndirections)); +} + static inline void maybeHandleGlobals(Module &M) { unsigned GlobAS = M.getDataLayout().getDefaultGlobalsAddressSpace(); - for (auto &&G : M.globals()) { // TODO: should we handle these in the FE? + + SmallVector<GlobalVariable *> ToIndirect; + for (auto &&G : M.globals()) { if (!checkIfSupported(G)) return clearModule(M); - - if (G.isThreadLocal()) - continue; - if (G.isConstant()) - continue; if (G.getAddressSpace() != GlobAS) continue; - if (G.getLinkage() != GlobalVariable::ExternalLinkage) + if (G.isConstant() && G.hasInitializer() && G.hasAtLeastLocalUnnamedAddr()) continue; - G.setLinkage(GlobalVariable::ExternalWeakLinkage); - G.setInitializer(nullptr); - G.setExternallyInitialized(true); + ToIndirect.push_back(&G); + } + + if (ToIndirect.empty()) + return; + + if (auto *IT = M.getNamedGlobal("__hipstdpar_symbol_indirection_table")) { + if (!isValidIndirectionTable(IT)) + return clearModule(M); ---------------- jmmartinez wrote:
Maybe as a NFC for another PR, but is `clearModule` the same as `Module::dropAllReferences` ? https://github.com/llvm/llvm-project/pull/146813 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits