================
@@ -114,24 +115,221 @@ static inline void clearModule(Module &M) { // TODO: 
simplify.
     eraseFromModule(*M.ifuncs().begin());
 }
 
+static inline SmallVector<std::reference_wrapper<Use>>
+collectIndirectableUses(GlobalVariable *G) {
+  // We are interested only in use chains that end in an Instruction.
+  SmallVector<std::reference_wrapper<Use>> Uses;
+
+  SmallVector<std::reference_wrapper<Use>> Tmp(G->use_begin(), G->use_end());
+  while (!Tmp.empty()) {
+    Use &U = Tmp.back();
+    Tmp.pop_back();
+    if (isa<Instruction>(U.getUser()))
+      Uses.emplace_back(U);
+    else
+      transform(U.getUser()->uses(), std::back_inserter(Tmp),
+                [](auto &&U) { return std::ref(U); });
+  }
+
+  return Uses;
+}
+
+static inline GlobalVariable *getGlobalForName(GlobalVariable *G) {
+  // Create an anonymous global which stores the variable's name, which will be
+  // used by the HIPSTDPAR runtime to look up the program-wide symbol.
+  LLVMContext &Ctx = G->getContext();
+  auto *CDS = ConstantDataArray::getString(Ctx, G->getName());
+
+  GlobalVariable *N = G->getParent()->getOrInsertGlobal("", CDS->getType());
+  N->setInitializer(CDS);
+  N->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage);
+  N->setConstant(true);
+
+  return N;
+}
+
+static inline GlobalVariable *getIndirectionGlobal(Module *M) {
+  // Create an anonymous global which stores a pointer to a pointer, which will
+  // be externally initialised by the HIPSTDPAR runtime with the address of the
+  // program-wide symbol.
+  Type *PtrTy = PointerType::get(
+      M->getContext(), M->getDataLayout().getDefaultGlobalsAddressSpace());
+  GlobalVariable *NewG = M->getOrInsertGlobal("", PtrTy);
+
+  NewG->setInitializer(PoisonValue::get(NewG->getValueType()));
+  NewG->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage);
+  NewG->setConstant(true);
+  NewG->setExternallyInitialized(true);
+
+  return NewG;
+}
+
+static inline Constant *
+appendIndirectedGlobal(const GlobalVariable *IndirectionTable,
+                       SmallVector<Constant *> &SymbolIndirections,
+                       GlobalVariable *ToIndirect) {
+  Module *M = ToIndirect->getParent();
+
+  auto *InitTy = cast<StructType>(IndirectionTable->getValueType());
+  auto *SymbolListTy = cast<StructType>(InitTy->getStructElementType(2));
+  Type *NameTy = SymbolListTy->getElementType(0);
+  Type *IndirectTy = SymbolListTy->getElementType(1);
+
+  Constant *NameG = getGlobalForName(ToIndirect);
+  Constant *IndirectG = getIndirectionGlobal(M);
+  Constant *Entry = ConstantStruct::get(
+      SymbolListTy, {ConstantExpr::getAddrSpaceCast(NameG, NameTy),
+                     ConstantExpr::getAddrSpaceCast(IndirectG, IndirectTy)});
+  SymbolIndirections.push_back(Entry);
+
+  return IndirectG;
+}
+
+static void fillIndirectionTable(GlobalVariable *IndirectionTable,
+                                 SmallVector<Constant *> Indirections) {
+  Module *M = IndirectionTable->getParent();
+  size_t SymCnt = Indirections.size();
+
+  auto *InitTy = cast<StructType>(IndirectionTable->getValueType());
+  Type *SymbolListTy = InitTy->getStructElementType(1);
+  auto *SymbolTy = cast<StructType>(InitTy->getStructElementType(2));
+
+  Constant *Count = ConstantInt::get(InitTy->getStructElementType(0), SymCnt);
+  M->removeGlobalVariable(IndirectionTable);
+  GlobalVariable *Symbols =
+      M->getOrInsertGlobal("", ArrayType::get(SymbolTy, SymCnt));
+  Symbols->setLinkage(GlobalValue::LinkageTypes::PrivateLinkage);
+  Symbols->setInitializer(
+      ConstantArray::get(ArrayType::get(SymbolTy, SymCnt), {Indirections}));
+  Symbols->setConstant(true);
+
+  Constant *ASCSymbols = ConstantExpr::getAddrSpaceCast(Symbols, SymbolListTy);
+  Constant *Init = ConstantStruct::get(
+      InitTy, {Count, ASCSymbols, PoisonValue::get(SymbolTy)});
+  M->insertGlobalVariable(IndirectionTable);
+  IndirectionTable->setInitializer(Init);
+}
+
+static void replaceWithIndirectUse(const Use &U, const GlobalVariable *G,
+                                   Constant *IndirectedG) {
+  auto *I = cast<Instruction>(U.getUser());
+
+  IRBuilder<> Builder(I);
+  Value *Op = I->getOperand(U.getOperandNo());
+
+  // We walk back up the use chain, which could be an arbitrarily long sequence
+  // of constexpr AS casts, ptr-to-int and GEP instructions, until we reach the
+  // indirected global.
+  while (auto *CE = dyn_cast<ConstantExpr>(Op)) {
+    assert((CE->getOpcode() == Instruction::GetElementPtr ||
+            CE->getOpcode() == Instruction::AddrSpaceCast ||
+            CE->getOpcode() == Instruction::PtrToInt) &&
+           "Only GEP, ASCAST or PTRTOINT constant uses supported!");
+
+    Instruction *NewI = Builder.Insert(CE->getAsInstruction());
+    I->replaceUsesOfWith(Op, NewI);
+    I = NewI;
+    Op = I->getOperand(0);
+    Builder.SetInsertPoint(I);
+  }
+
+  assert(Op == G && "Must reach indirected global!");
+
+  Builder.GetInsertPoint()->setOperand(
+      0, Builder.CreateLoad(G->getType(), IndirectedG));
+}
+
+static inline bool isValidIndirectionTable(GlobalVariable *IndirectionTable) {
+  std::string W;
+  raw_string_ostream OS(W);
+
+  Type *Ty = IndirectionTable->getValueType();
+  bool Valid = false;
+
+  if (!isa<StructType>(Ty)) {
+    OS << "The Indirection Table must be a struct type; ";
+    Ty->print(OS);
+    OS << " is incorrect.\n";
+  } else if (cast<StructType>(Ty)->getNumElements() != 3u) {
+    OS << "The Indirection Table must have 3 elements; "
+       << cast<StructType>(Ty)->getNumElements() << " is incorrect.\n";
+  } else if (!isa<IntegerType>(cast<StructType>(Ty)->getStructElementType(0))) 
{
+    OS << "The first element in the Indirection Table must be an integer; ";
+    cast<StructType>(Ty)->getStructElementType(0)->print(OS);
+    OS << " is incorrect.\n";
+  } else if (!isa<PointerType>(cast<StructType>(Ty)->getStructElementType(1))) 
{
+    OS << "The second element in the Indirection Table must be a pointer; ";
+    cast<StructType>(Ty)->getStructElementType(1)->print(OS);
+    OS << " is incorrect.\n";
+  } else if (!isa<StructType>(cast<StructType>(Ty)->getStructElementType(2))) {
+    OS << "The third element in the Indirection Table must be a struct type; ";
+    cast<StructType>(Ty)->getStructElementType(2)->print(OS);
+    OS << " is incorrect.\n";
+  } else {
+    Valid = true;
+  }
+
+  if (!Valid)
+    IndirectionTable->getContext().diagnose(DiagnosticInfoGeneric(W, 
DS_Error));
+
+  return Valid;
+}
+
+static void indirectGlobals(GlobalVariable *IndirectionTable,
+                            SmallVector<GlobalVariable *> ToIndirect) {
+  // We replace globals with an indirected access via a pointer that will get
+  // set by the HIPSTDPAR runtime, using their accessible, program-wide unique
+  // address as set by the host linker-loader.
+  SmallVector<Constant *> SymbolIndirections;
+  for (auto &&G : ToIndirect) {
+    SmallVector<std::reference_wrapper<Use>> Uses = collectIndirectableUses(G);
+
+    if (Uses.empty())
+      continue;
+
+    Constant *IndirectedGlobal =
+        appendIndirectedGlobal(IndirectionTable, SymbolIndirections, G);
+
+    for_each(Uses,
+             [=](auto &&U) { replaceWithIndirectUse(U, G, IndirectedGlobal); 
});
+
+    eraseFromModule(*G);
+  }
+
+  if (SymbolIndirections.empty())
+    return;
+
+  fillIndirectionTable(IndirectionTable, std::move(SymbolIndirections));
+}
+
 static inline void maybeHandleGlobals(Module &M) {
   unsigned GlobAS = M.getDataLayout().getDefaultGlobalsAddressSpace();
-  for (auto &&G : M.globals()) { // TODO: should we handle these in the FE?
+
+  SmallVector<GlobalVariable *> ToIndirect;
+  for (auto &&G : M.globals()) {
     if (!checkIfSupported(G))
       return clearModule(M);
-
-    if (G.isThreadLocal())
-      continue;
-    if (G.isConstant())
-      continue;
     if (G.getAddressSpace() != GlobAS)
       continue;
-    if (G.getLinkage() != GlobalVariable::ExternalLinkage)
+    if (G.isConstant() && G.hasInitializer() && G.hasAtLeastLocalUnnamedAddr())
       continue;
 
-    G.setLinkage(GlobalVariable::ExternalWeakLinkage);
-    G.setInitializer(nullptr);
-    G.setExternallyInitialized(true);
+    ToIndirect.push_back(&G);
+  }
+
+  if (ToIndirect.empty())
+    return;
+
+  if (auto *IT = M.getNamedGlobal("__hipstdpar_symbol_indirection_table")) {
+    if (!isValidIndirectionTable(IT))
+      return clearModule(M);
----------------
jmmartinez wrote:

Maybe as a NFC for another PR, but is `clearModule` the same as 
`Module::dropAllReferences` ?

https://github.com/llvm/llvm-project/pull/146813
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to