lxfind updated this revision to Diff 337290.
lxfind added a comment.

some cleanups


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D100415/new/

https://reviews.llvm.org/D100415

Files:
  clang/lib/CodeGen/CGCoroutine.cpp
  llvm/include/llvm/IR/Intrinsics.td
  llvm/lib/Transforms/Coroutines/CoroEarly.cpp
  llvm/lib/Transforms/Coroutines/CoroInternal.h
  llvm/lib/Transforms/Coroutines/CoroSplit.cpp

Index: llvm/lib/Transforms/Coroutines/CoroSplit.cpp
===================================================================
--- llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -2049,6 +2049,74 @@
     Fns.push_back(PrepareFn);
 }
 
+static Function *getCoroInitFunction(Function &RampFunc) {
+  StringRef RampName = RampFunc.getName();
+  assert(RampName.endswith(".ramp") && "Ramp function must ends with .ramp");
+  StringRef InitName = RampName.substr(0, RampName.size() - 5);
+  return RampFunc.getParent()->getFunction(InitName);
+}
+
+static Function *inlineRampFunction(Function &F) {
+  CallInst *RampCall = cast<CallInst>(
+      &*llvm::find_if(instructions(F), [&](const Instruction &I) {
+        if (const CallInst *CI = dyn_cast<CallInst>(&I))
+          return CI->getCalledFunction()->getName().startswith(F.getName());
+        return false;
+      }));
+  InlineFunctionInfo IFI;
+  InlineFunction(*RampCall, IFI);
+
+  SmallVector<IntrinsicInst *, 2> CoroIds;
+  CoroBeginInst *CoroBegin = nullptr;
+  SmallVector<IntrinsicInst *, 8> CoroFrameGets;
+  for (Instruction &I : instructions(F)) {
+    auto *II = dyn_cast<IntrinsicInst>(&I);
+    if (!II)
+      continue;
+    switch (II->getIntrinsicID()) {
+    default:
+      break;
+    case Intrinsic::coro_id:
+      CoroIds.push_back(II);
+      break;
+    case Intrinsic::coro_begin:
+      CoroBegin = cast<CoroBeginInst>(II);
+      break;
+    case Intrinsic::coro_frame_get:
+      CoroFrameGets.push_back(II);
+      break;
+    }
+  }
+  assert(CoroIds.size() == 2 && "There must be two coro.id calls, from the "
+                                "init function and ramp function respectively");
+  CoroIdInst *RealId = cast<CoroIdInst>(CoroBegin->getId());
+  for (IntrinsicInst *I : CoroIds)
+    if (I != RealId)
+      I->replaceAllUsesWith(RealId);
+  DenseMap<uint32_t, Instruction *> FrameSlotMap;
+  for (IntrinsicInst *FrameGet : CoroFrameGets) {
+    bool IsPromise = cast<ConstantInt>(FrameGet->getOperand(2))->getZExtValue();
+    uint32_t SlotID =
+        cast<ConstantInt>(FrameGet->getOperand(3))->getZExtValue();
+    auto Itr = FrameSlotMap.find(SlotID);
+    Instruction *Ptr;
+    if (Itr == FrameSlotMap.end()) {
+      Ptr = cast<Instruction>(FrameGet->getOperand(1));
+      FrameSlotMap[SlotID] = Ptr;
+    } else {
+      Ptr = Itr->second;
+    }
+    FrameGet->replaceAllUsesWith(Ptr);
+    FrameGet->eraseFromParent();
+    if (IsPromise) {
+      RealId->setOperand(1, new BitCastInst(Ptr->stripPointerCasts(),
+                                            Ptr->getType(), "", RealId));
+    }
+  }
+
+  return RampCall->getCalledFunction();
+}
+
 PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
                                      CGSCCAnalysisManager &AM,
                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
@@ -2082,6 +2150,8 @@
     }
   }
 
+  SmallVector<Function *, 1> UnpreparedInitFuncs;
+  SmallVector<Function *, 1> InlinedRampFuncs;
   // Split all the coroutines.
   for (LazyCallGraph::Node *N : Coroutines) {
     Function &F = N->getFunction();
@@ -2089,12 +2159,24 @@
     StringRef Value = Attr.getValueAsString();
     LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName()
                       << "' state: " << Value << "\n");
-    if (Value == UNPREPARED_FOR_SPLIT) {
+    if (Value == DO_NOT_PROCESS)
+      continue;
+    if (Value == UNPREPARED_FOR_SPLIT_RAMP) {
       // Enqueue a second iteration of the CGSCC pipeline on this SCC.
       UR.CWorklist.insert(&C);
-      F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
+      // Once we allow the ramp function to be optimized, we will split
+      // the init function directly and ignore the ramp function.
+      F.addFnAttr(CORO_PRESPLIT_ATTR, DO_NOT_PROCESS);
+      UnpreparedInitFuncs.push_back(getCoroInitFunction(F));
       continue;
     }
+    if (Value == PREPARED_FOR_SPLIT_INIT) {
+      Function *RampFunc = inlineRampFunction(F);
+      InlinedRampFuncs.push_back(RampFunc);
+      RampFunc->removeDeadConstantUsers();
+      RampFunc->dropAllReferences();
+      updateCGAndAnalysisManagerForCGSCCPass(CG, C, *N, AM, UR, FAM);
+    }
     F.removeFnAttr(CORO_PRESPLIT_ATTR);
 
     SmallVector<Function *, 4> Clones;
@@ -2109,6 +2191,23 @@
       UR.RCWorklist.insert(CG.lookupRefSCC(CG.get(*Clones[0])));
     }
   }
+  for (Function *F : UnpreparedInitFuncs)
+    F->addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT_INIT);
+  for (Function *DeadF : InlinedRampFuncs) {
+    auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF));
+    FAM.clear(*DeadF, DeadF->getName());
+    AM.clear(DeadC, DeadC.getName());
+    auto &DeadRC = DeadC.getOuterRefSCC();
+    CG.removeDeadFunction(*DeadF);
+
+    // Mark the relevant parts of the call graph as invalid so we don't visit
+    // them.
+    UR.InvalidatedSCCs.insert(&DeadC);
+    UR.InvalidatedRefSCCs.insert(&DeadRC);
+
+    DeadF->getBasicBlockList().clear();
+    M.getFunctionList().remove(DeadF);
+  }
 
   if (!PrepareFns.empty()) {
     for (auto *PrepareFn : PrepareFns) {
@@ -2179,6 +2278,7 @@
     createDevirtTriggerFunc(CG, SCC);
 
     // Split all the coroutines.
+    // FIXME: adapt to the new split model
     for (Function *F : Coroutines) {
       Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
       StringRef Value = Attr.getValueAsString();
@@ -2190,7 +2290,7 @@
         F->removeFnAttr(CORO_PRESPLIT_ATTR);
         continue;
       }
-      if (Value == UNPREPARED_FOR_SPLIT) {
+      if (Value == UNPREPARED_FOR_SPLIT_RAMP) {
         prepareForSplit(*F, CG);
         continue;
       }
Index: llvm/lib/Transforms/Coroutines/CoroInternal.h
===================================================================
--- llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -37,9 +37,11 @@
 // Async lowering similarily triggers a restart of the pipeline after it has
 // split the coroutine.
 #define CORO_PRESPLIT_ATTR "coroutine.presplit"
-#define UNPREPARED_FOR_SPLIT "0"
+#define DO_NOT_PROCESS "0"
 #define PREPARED_FOR_SPLIT "1"
 #define ASYNC_RESTART_AFTER_SPLIT "2"
+#define UNPREPARED_FOR_SPLIT_RAMP "3"
+#define PREPARED_FOR_SPLIT_INIT "4"
 
 #define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger"
 
Index: llvm/lib/Transforms/Coroutines/CoroEarly.cpp
===================================================================
--- llvm/lib/Transforms/Coroutines/CoroEarly.cpp
+++ llvm/lib/Transforms/Coroutines/CoroEarly.cpp
@@ -8,10 +8,15 @@
 
 #include "llvm/Transforms/Coroutines/CoroEarly.h"
 #include "CoroInternal.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
 #include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/Local.h"
 
 using namespace llvm;
 
@@ -145,6 +150,119 @@
       CB->setCannotDuplicate();
 }
 
+static void splitRampFunction(Function &F) {
+  Module *M = F.getParent();
+  LLVMContext &C = M->getContext();
+  {
+    CoroBeginInst *CoroBegin = cast<CoroBeginInst>(
+        &*llvm::find_if(instructions(F),
+                        [](Instruction &I) { return isa<CoroBeginInst>(&I); }));
+    Instruction *InsertPoint = CoroBegin->getNextNode();
+
+    for (Instruction &I : make_early_inc_range(instructions(F))) {
+      auto *AI = dyn_cast<AllocaInst>(&I);
+      if (!AI)
+        continue;
+      auto *MD = AI->getMetadata("coroutine_frame_alloca");
+      if (!MD)
+        continue;
+
+      auto *IsPromise = cast<ConstantAsMetadata>(MD->getOperand(0))->getValue();
+      auto *SlotID = cast<ConstantAsMetadata>(MD->getOperand(1))->getValue();
+      auto *VoidPt =
+          new BitCastInst(AI, llvm::Type::getInt8PtrTy(C), "", InsertPoint);
+      auto *FrameGet = CallInst::Create(
+          Intrinsic::getDeclaration(M, Intrinsic::coro_frame_get),
+          {CoroBegin, VoidPt, IsPromise, SlotID}, "", InsertPoint);
+      auto *NewPtr = new BitCastInst(FrameGet, AI->getType(), "", InsertPoint);
+      AI->replaceUsesWithIf(NewPtr,
+                            [&](Use &U) { return U.getUser() != VoidPt; });
+      AI->setMetadata("coroutine_frame_alloca", nullptr);
+    }
+  }
+
+  Function *NewF;
+  {
+    // Create the split ramp function, and clone.
+    llvm::Type *NewFArgTypes[] = {llvm::Type::getInt8PtrTy(C)};
+    auto newFuncType =
+        FunctionType::get(F.getReturnType(), NewFArgTypes, false);
+    NewF = Function::Create(newFuncType,
+                            GlobalValue::LinkageTypes::ExternalLinkage,
+                            F.getName() + ".ramp");
+    NewF->addFnAttr(Attribute::NoInline);
+    M->getFunctionList().push_back(NewF);
+    ValueToValueMapTy VMap;
+    for (Argument &A : F.args())
+      VMap[&A] = UndefValue::get(A.getType());
+    SmallVector<ReturnInst *, 4> Returns;
+    CloneFunctionInto(NewF, &F, VMap, CloneFunctionChangeType::LocalChangesOnly,
+                      Returns);
+
+    for (Instruction &I : make_early_inc_range(instructions(*NewF))) {
+      auto *II = dyn_cast<IntrinsicInst>(&I);
+      if (!II)
+        continue;
+      switch (II->getIntrinsicID()) {
+      default:
+        continue;
+      case Intrinsic::coro_begin:
+        II->replaceAllUsesWith(NewF->getArg(0));
+        break;
+      case Intrinsic::coro_init:
+        II->replaceAllUsesWith(
+            llvm::ConstantInt::get(llvm::Type::getInt1Ty(C), 0));
+        break;
+      case Intrinsic::coro_alloc:
+        II->replaceAllUsesWith(
+            llvm::ConstantInt::get(llvm::Type::getInt1Ty(C), 0));
+        break;
+      }
+      II->eraseFromParent();
+    }
+    removeUnreachableBlocks(*NewF);
+    NewF->addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT_RAMP);
+  }
+
+  {
+    // Process the init function.
+    IntrinsicInst *CoroBegin = nullptr;
+    IntrinsicInst *CoroInitEnd = nullptr;
+    for (Instruction &I : make_early_inc_range(instructions(F))) {
+      auto *II = dyn_cast<IntrinsicInst>(&I);
+      if (!II)
+        continue;
+      switch (II->getIntrinsicID()) {
+      default:
+        break;
+      case Intrinsic::coro_begin:
+        CoroBegin = II;
+        break;
+      case Intrinsic::coro_init:
+        II->replaceAllUsesWith(
+            llvm::ConstantInt::get(llvm::Type::getInt1Ty(C), 1));
+        II->eraseFromParent();
+        break;
+      case Intrinsic::coro_init_end:
+        CoroInitEnd = II;
+        break;
+      }
+    }
+    assert(CoroInitEnd->getNextNode() ==
+               CoroInitEnd->getParent()->getTerminator() &&
+           "coro.init.end call should be at the end of the init block");
+    CoroInitEnd->getNextNode()->eraseFromParent();
+    CallInst *Ret = CallInst::Create(NewF, {CoroBegin}, "", CoroInitEnd);
+    if (F.getReturnType()->isVoidTy())
+      ReturnInst::Create(C, nullptr, CoroInitEnd);
+    else
+      ReturnInst::Create(C, Ret, CoroInitEnd);
+    CoroInitEnd->eraseFromParent();
+    removeUnreachableBlocks(F);
+    F.addFnAttr(CORO_PRESPLIT_ATTR, DO_NOT_PROCESS);
+  }
+}
+
 bool Lowerer::lowerEarlyIntrinsics(Function &F) {
   bool Changed = false;
   CoroIdInst *CoroId = nullptr;
@@ -179,7 +297,6 @@
         // with a coroutine attribute.
         if (auto *CII = cast<CoroIdInst>(&I)) {
           if (CII->getInfo().isPreSplit()) {
-            F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
             setCannotDuplicate(CII);
             CII->setCoroutineSelf();
             CoroId = cast<CoroIdInst>(&I);
@@ -210,9 +327,11 @@
   // Make sure that all CoroFree reference the coro.id intrinsic.
   // Token type is not exposed through coroutine C/C++ builtins to plain C, so
   // we allow specifying none and fixing it up here.
-  if (CoroId)
+  if (CoroId) {
     for (CoroFreeInst *CF : CoroFrees)
       CF->setArgOperand(0, CoroId);
+    splitRampFunction(F);
+  }
   return Changed;
 }
 
@@ -226,6 +345,10 @@
 }
 
 PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) {
+  if (F.getFnAttribute(CORO_PRESPLIT_ATTR).getValueAsString() ==
+      UNPREPARED_FOR_SPLIT_RAMP)
+    return PreservedAnalyses::all();
+
   Module &M = *F.getParent();
   if (!declaresCoroEarlyIntrinsics(M) || !Lowerer(M).lowerEarlyIntrinsics(F))
     return PreservedAnalyses::all();
Index: llvm/include/llvm/IR/Intrinsics.td
===================================================================
--- llvm/include/llvm/IR/Intrinsics.td
+++ llvm/include/llvm/IR/Intrinsics.td
@@ -1274,6 +1274,12 @@
                                      ReadOnly<ArgIndex<0>>,
                                      NoCapture<ArgIndex<0>>]>;
 
+def int_coro_frame_get : Intrinsic<[llvm_ptr_ty],
+                                   [llvm_ptr_ty, llvm_ptr_ty, llvm_i1_ty, llvm_i32_ty],
+                                   [IntrNoMem]>;
+def int_coro_init: Intrinsic<[llvm_i1_ty], [], []>;
+def int_coro_init_end: Intrinsic<[], [], []>;
+
 ///===-------------------------- Other Intrinsics --------------------------===//
 //
 def int_trap : Intrinsic<[], [], [IntrNoReturn, IntrCold]>,
Index: clang/lib/CodeGen/CGCoroutine.cpp
===================================================================
--- clang/lib/CodeGen/CGCoroutine.cpp
+++ clang/lib/CodeGen/CGCoroutine.cpp
@@ -547,7 +547,7 @@
 
   auto *EntryBB = Builder.GetInsertBlock();
   auto *AllocBB = createBasicBlock("coro.alloc");
-  auto *InitBB = createBasicBlock("coro.init");
+  auto *BeginBB = createBasicBlock("coro.begin");
   auto *FinalBB = createBasicBlock("coro.final");
   auto *RetBB = createBasicBlock("coro.ret");
 
@@ -564,7 +564,7 @@
   auto *CoroAlloc = Builder.CreateCall(
       CGM.getIntrinsic(llvm::Intrinsic::coro_alloc), {CoroId});
 
-  Builder.CreateCondBr(CoroAlloc, AllocBB, InitBB);
+  Builder.CreateCondBr(CoroAlloc, AllocBB, BeginBB);
 
   EmitBlock(AllocBB);
   auto *AllocateCall = EmitScalarExpr(S.getAllocate());
@@ -577,17 +577,17 @@
     // See if allocation was successful.
     auto *NullPtr = llvm::ConstantPointerNull::get(Int8PtrTy);
     auto *Cond = Builder.CreateICmpNE(AllocateCall, NullPtr);
-    Builder.CreateCondBr(Cond, InitBB, RetOnFailureBB);
+    Builder.CreateCondBr(Cond, BeginBB, RetOnFailureBB);
 
     // If not, return OnAllocFailure object.
     EmitBlock(RetOnFailureBB);
     EmitStmt(RetOnAllocFailure);
   }
   else {
-    Builder.CreateBr(InitBB);
+    Builder.CreateBr(BeginBB);
   }
 
-  EmitBlock(InitBB);
+  EmitBlock(BeginBB);
 
   // Pass the result of the allocation to coro.begin.
   auto *Phi = Builder.CreatePHI(VoidPtrTy, 2);
@@ -606,12 +606,36 @@
     CodeGenFunction::RunCleanupsScope ResumeScope(*this);
     EHStack.pushCleanup<CallCoroDelete>(NormalAndEHCleanup, S.getDeallocate());
 
+    // Wrap around the parameter copy with a coro.init() check.
+    // This will allows us to perform parameter copy in the init function, but
+    // not in the ramp function.
+    auto *InitBB = createBasicBlock("coro.init");
+    auto *InitReadyBB = createBasicBlock("coro.init.ready");
+    auto *CoroInit =
+        Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::coro_init));
+    Builder.CreateCondBr(CoroInit, InitBB, InitReadyBB);
+
+    EmitBlock(InitBB);
+    SmallVector<llvm::AllocaInst *, 4> FrameAllocas;
     // Create parameter copies. We do it before creating a promise, since an
     // evolution of coroutine TS may allow promise constructor to observe
     // parameter copies.
+    int ID = 0;
     for (auto *PM : S.getParamMoves()) {
       EmitStmt(PM);
       ParamReplacer.addCopy(cast<DeclStmt>(PM));
+      llvm::AllocaInst *Alloca = cast<llvm::AllocaInst>(
+          GetAddrOfLocalVar(cast<VarDecl>(cast<DeclStmt>(PM)->getSingleDecl()))
+              .getPointer());
+      Alloca->setMetadata(
+          "coroutine_frame_alloca",
+          llvm::MDNode::get(
+              getLLVMContext(),
+              {
+                  llvm::ConstantAsMetadata::get(
+                      Builder.getInt1(false)) /*IsPromise*/,
+                  llvm::ConstantAsMetadata::get(Builder.getInt32(ID++)),
+              }));
       // TODO: if(CoroParam(...)) need to surround ctor and dtor
       // for the copy, so that llvm can elide it if the copy is
       // not needed.
@@ -619,12 +643,23 @@
 
     EmitStmt(S.getPromiseDeclStmt());
 
+    Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::coro_init_end));
+    Builder.CreateBr(InitReadyBB);
+    EmitBlock(InitReadyBB);
+
     Address PromiseAddr = GetAddrOfLocalVar(S.getPromiseDecl());
-    auto *PromiseAddrVoidPtr =
-        new llvm::BitCastInst(PromiseAddr.getPointer(), VoidPtrTy, "", CoroId);
-    // Update CoroId to refer to the promise. We could not do it earlier because
-    // promise local variable was not emitted yet.
-    CoroId->setArgOperand(1, PromiseAddrVoidPtr);
+    llvm::AllocaInst *PromiseAlloca =
+        cast<llvm::AllocaInst>(PromiseAddr.getPointer());
+
+    PromiseAlloca->setMetadata(
+        "coroutine_frame_alloca",
+        llvm::MDNode::get(
+            getLLVMContext(),
+            {
+                llvm::ConstantAsMetadata::get(
+                    Builder.getInt1(true)) /*IsPromise*/,
+                llvm::ConstantAsMetadata::get(Builder.getInt32(ID++)),
+            }));
 
     // Now we have the promise, initialize the GRO
     GroManager.EmitGroInit();
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to