efriedma updated this revision to Diff 426814.
efriedma edited the summary of this revision.
efriedma added a comment.

Switch to use llvm.seh_localunwind intrinsic.

I'd appreciate any feedback at this point.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D124642/new/

https://reviews.llvm.org/D124642

Files:
  clang/lib/CodeGen/CGException.cpp
  clang/lib/CodeGen/CGStmt.cpp
  clang/lib/CodeGen/CodeGenFunction.h
  llvm/include/llvm/IR/Intrinsics.td
  llvm/lib/CodeGen/AsmPrinter/WinException.cpp
  llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
  llvm/lib/CodeGen/WinEHPrepare.cpp
  llvm/lib/IR/Verifier.cpp

Index: llvm/lib/IR/Verifier.cpp
===================================================================
--- llvm/lib/IR/Verifier.cpp
+++ llvm/lib/IR/Verifier.cpp
@@ -4612,6 +4612,7 @@
                 F->getIntrinsicID() == Intrinsic::experimental_patchpoint_i64 ||
                 F->getIntrinsicID() == Intrinsic::experimental_gc_statepoint ||
                 F->getIntrinsicID() == Intrinsic::wasm_rethrow ||
+                F->getIntrinsicID() == Intrinsic::seh_localunwind ||
                 IsAttachedCallOperand(F, CBI, i),
             "Cannot invoke an intrinsic other than donothing, patchpoint, "
             "statepoint, coro_resume, coro_destroy or clang.arc.attachedcall",
Index: llvm/lib/CodeGen/WinEHPrepare.cpp
===================================================================
--- llvm/lib/CodeGen/WinEHPrepare.cpp
+++ llvm/lib/CodeGen/WinEHPrepare.cpp
@@ -380,6 +380,20 @@
     const Function *Filter = dyn_cast<Function>(FilterOrNull);
     assert((Filter || FilterOrNull->isNullValue()) &&
            "unexpected filter value");
+    // Filters named __IsLocalUnwind are treated specially: we want to catch
+    // unwinds from _local_unwind, but not catchrets in the same funclet.
+    // (They both need to point at the same catchswitch to pass the verifier
+    // checks for nesting.) To make this work, we mess with the state
+    // numbering: the "parent" of any cleanupret pointing to this catchpad is
+    // actually this catchpad's parent.
+    //
+    // Note that _local_unwind looks for unwind table entries for the
+    // catchpad; if there aren't any, it assumes the catchpad doesn't have a
+    // parent.
+    bool IsLocalUnwind =
+        Filter && Filter->getName().startswith("__IsLocalUnwind");
+    if (IsLocalUnwind)
+      Filter = nullptr;
     int TryState = addSEHExcept(FuncInfo, ParentState, Filter, CatchPadBB);
 
     // Everything in the __try block uses TryState as its parent state.
@@ -390,7 +404,7 @@
       if ((PredBlock = getEHPadFromPredecessor(PredBlock,
                                                CatchSwitch->getParentPad())))
         calculateSEHStateNumbers(FuncInfo, PredBlock->getFirstNonPHI(),
-                                 TryState);
+                                 IsLocalUnwind ? ParentState : TryState);
 
     // Everything in the __except block unwinds to ParentState, just like code
     // outside the __try.
Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2937,6 +2937,41 @@
       DAG.setRoot(DAG.getNode(ISD::INTRINSIC_VOID, getCurSDLoc(), VTs, Ops));
       break;
     }
+    case Intrinsic::seh_localunwind: {
+      if (!isa<CatchSwitchInst>(EHPadBB->getTerminator())) {
+        report_fatal_error("localunwind doesn't point to catchswitch");
+      }
+      auto *CatchSwitch = cast<CatchSwitchInst>(EHPadBB->getTerminator());
+      if (CatchSwitch->getNumHandlers() == 0) {
+        report_fatal_error("catchswitch with no handler");
+      }
+
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+      TargetLowering::ArgListEntry SP, DestBB;
+      Type *PtrTy = PointerType::getInt8PtrTy(*DAG.getContext());
+      EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout());
+      SP.Node = DAG.getNode(ISD::FRAMEADDR, getCurSDLoc(), PtrVT,
+                            DAG.getIntPtrConstant(0, getCurSDLoc()));
+      SP.Ty = PtrTy;
+      FuncInfo.MBBMap[*CatchSwitch->handler_begin()]->setHasAddressTaken();
+      DestBB.Node =
+          DAG.getBlockAddress(BlockAddress::get(const_cast<BasicBlock *>(
+                                  *CatchSwitch->handler_begin())),
+                              PtrVT);
+      DestBB.Ty = PtrTy;
+      TargetLowering::ArgListTy Args{SP, DestBB};
+
+      SDValue Callee = DAG.getExternalSymbol("_local_unwind", PtrVT);
+      TargetLowering::CallLoweringInfo CLI(DAG);
+      CLI.setDebugLoc(getCurSDLoc())
+          .setChain(getRoot())
+          .setCallee(CallingConv::C, Type::getVoidTy(*DAG.getContext()), Callee,
+                     std::move(Args))
+          .setNoReturn();
+      CLI.CB = &I;
+      lowerInvokable(CLI, EHPadBB);
+      break;
+    }
     }
   } else if (I.countOperandBundlesOfType(LLVMContext::OB_deopt)) {
     // Currently we do not lower any intrinsic calls with deopt operand bundles.
Index: llvm/lib/CodeGen/AsmPrinter/WinException.cpp
===================================================================
--- llvm/lib/CodeGen/AsmPrinter/WinException.cpp
+++ llvm/lib/CodeGen/AsmPrinter/WinException.cpp
@@ -626,6 +626,17 @@
     LastStartLabel = StateChange.NewStartLabel;
     LastEHState = StateChange.NewState;
   }
+  for (auto Entry : FuncInfo.SEHUnwindMap) {
+    if (!Entry.IsFinally && Entry.ToState != -1) {
+      // Mark up the destination of _local_unwind so it doesn't unwind
+      // too far.
+      //
+      // FIXME: Can this overlap with the EH_LABEL for an invoke?
+      auto *Handler = Entry.Handler.get<MachineBasicBlock *>();
+      const MCSymbol *Begin = Handler->getSymbol();
+      emitSEHActionsForRange(FuncInfo, Begin, Begin, Entry.ToState);
+    }
+  }
 
   OS.emitLabel(TableEnd);
 }
Index: llvm/include/llvm/IR/Intrinsics.td
===================================================================
--- llvm/include/llvm/IR/Intrinsics.td
+++ llvm/include/llvm/IR/Intrinsics.td
@@ -541,6 +541,9 @@
 def int_seh_scope_begin : Intrinsic<[], [], [IntrNoMem]>;
 def int_seh_scope_end : Intrinsic<[], [], [IntrNoMem]>;
 
+// Call _local_unwind to unwind to a local catchpad.
+def int_seh_localunwind : Intrinsic<[], [], [IntrNoReturn]>;
+
 // Note: we treat stacksave/stackrestore as writemem because we don't otherwise
 // model their dependencies on allocas.
 def int_stacksave     : DefaultAttrsIntrinsic<[llvm_ptr_ty]>,
Index: clang/lib/CodeGen/CodeGenFunction.h
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.h
+++ clang/lib/CodeGen/CodeGenFunction.h
@@ -672,6 +672,16 @@
   /// a value from the top of the stack.
   SmallVector<Address, 1> SEHCodeSlotStack;
 
+  /// Variable that indicates abnormal termination from the a child finally
+  /// block.
+  SmallVector<Address, 1> SEHRetNowStack;
+
+  /// Ponter to the parent function's SEHRetNow variable.
+  Address SEHRetNowParent = Address::invalid();
+
+  /// Ponter to the root function's ReturnValue variable.
+  Address SEHReturnValue = Address::invalid();
+
   /// Value returned by __exception_info intrinsic.
   llvm::Value *SEHInfo = nullptr;
 
@@ -3281,11 +3291,14 @@
   void EmitCXXTryStmt(const CXXTryStmt &S);
   void EmitSEHTryStmt(const SEHTryStmt &S);
   void EmitSEHLeaveStmt(const SEHLeaveStmt &S);
-  void EnterSEHTryStmt(const SEHTryStmt &S);
-  void ExitSEHTryStmt(const SEHTryStmt &S);
+  void EnterSEHTryStmt(const SEHTryStmt &S, bool &ContainsRetStmt);
+  void ExitSEHTryStmt(const SEHTryStmt &S, bool ContainsRetStmt);
   void VolatilizeTryBlocks(llvm::BasicBlock *BB,
                            llvm::SmallPtrSet<llvm::BasicBlock *, 10> &V);
 
+  void EmitSEHLocalUnwind();
+  llvm::Function *GenerateSEHIsLocalUnwindFunction();
+
   void pushSEHCleanup(CleanupKind kind,
                       llvm::Function *FinallyFunc);
   void startOutlinedSEHHelper(CodeGenFunction &ParentCGF, bool IsFilter,
Index: clang/lib/CodeGen/CGStmt.cpp
===================================================================
--- clang/lib/CodeGen/CGStmt.cpp
+++ clang/lib/CodeGen/CGStmt.cpp
@@ -1269,10 +1269,10 @@
                         ReturnLocation);
   }
 
-  // Returning from an outlined SEH helper is UB, and we already warn on it.
+  Address ReturnValue = this->ReturnValue;
   if (IsOutlinedSEHHelper) {
-    Builder.CreateUnreachable();
-    Builder.ClearInsertionPoint();
+    Builder.CreateStore(Builder.getInt8(1), SEHRetNowParent);
+    ReturnValue = SEHReturnValue;
   }
 
   // Emit the result value, even if unused, to evaluate the side effects.
Index: clang/lib/CodeGen/CGException.cpp
===================================================================
--- clang/lib/CodeGen/CGException.cpp
+++ clang/lib/CodeGen/CGException.cpp
@@ -1638,7 +1638,8 @@
 }
 
 void CodeGenFunction::EmitSEHTryStmt(const SEHTryStmt &S) {
-  EnterSEHTryStmt(S);
+  bool ContainsRetStmt = false;
+  EnterSEHTryStmt(S, ContainsRetStmt);
   {
     JumpDest TryExit = getJumpDestInCurrentScope("__try.__leave");
 
@@ -1667,7 +1668,7 @@
     else
       delete TryExit.getBlock();
   }
-  ExitSEHTryStmt(S);
+  ExitSEHTryStmt(S, ContainsRetStmt);
 }
 
 //  Recursively walk through blocks in a _try
@@ -1702,8 +1703,9 @@
 namespace {
 struct PerformSEHFinally final : EHScopeStack::Cleanup {
   llvm::Function *OutlinedFinally;
-  PerformSEHFinally(llvm::Function *OutlinedFinally)
-      : OutlinedFinally(OutlinedFinally) {}
+  bool RetFromFinally;
+  PerformSEHFinally(llvm::Function *OutlinedFinally, bool RetFromFinally)
+      : OutlinedFinally(OutlinedFinally), RetFromFinally(RetFromFinally) {}
 
   void Emit(CodeGenFunction &CGF, Flags F) override {
     ASTContext &Context = CGF.getContext();
@@ -1747,6 +1749,21 @@
 
     auto Callee = CGCallee::forDirect(OutlinedFinally);
     CGF.EmitCall(FnInfo, Callee, ReturnValueSlot(), Args);
+
+    if (F.isForEHCleanup() && RetFromFinally) {
+      llvm::BasicBlock *AbnormalCont = CGF.createBasicBlock("if.then");
+      llvm::BasicBlock *NormalCont = CGF.createBasicBlock("if.end");
+      llvm::Value *ShouldRetLoad =
+          CGF.Builder.CreateLoad(CGF.SEHRetNowStack.back());
+      llvm::Value *ShouldRet = CGF.Builder.CreateIsNotNull(ShouldRetLoad);
+
+      CGF.Builder.CreateCondBr(ShouldRet, AbnormalCont, NormalCont);
+      CGF.EmitBlock(AbnormalCont);
+      CGF.EmitSEHLocalUnwind();
+      CGF.Builder.CreateUnreachable();
+
+      CGF.EmitBlock(NormalCont);
+    }
   }
 };
 } // end anonymous namespace
@@ -1758,12 +1775,13 @@
   const VarDecl *ParentThis;
   llvm::SmallSetVector<const VarDecl *, 4> Captures;
   Address SEHCodeSlot = Address::invalid();
+  bool ContainsRetStmt = false;
   CaptureFinder(CodeGenFunction &ParentCGF, const VarDecl *ParentThis)
       : ParentCGF(ParentCGF), ParentThis(ParentThis) {}
 
   // Return true if we need to do any capturing work.
   bool foundCaptures() {
-    return !Captures.empty() || SEHCodeSlot.isValid();
+    return !Captures.empty() || SEHCodeSlot.isValid() || ContainsRetStmt;
   }
 
   void Visit(const Stmt *S) {
@@ -1805,6 +1823,25 @@
       break;
     }
   }
+
+  void VisitReturnStmt(const ReturnStmt *) { ContainsRetStmt = true; }
+};
+} // end anonymous namespace
+
+namespace {
+/// Find all local variable captures in the statement.
+struct ReturnStmtFinder : ConstStmtVisitor<ReturnStmtFinder> {
+  bool ContainsRetStmt = false;
+
+  void Visit(const Stmt *S) {
+    // See if this is a capture, then recurse.
+    ConstStmtVisitor::Visit(S);
+    for (const Stmt *Child : S->children())
+      if (Child)
+        Visit(Child);
+  }
+
+  void VisitReturnStmt(const ReturnStmt *) { ContainsRetStmt = true; }
 };
 } // end anonymous namespace
 
@@ -1853,7 +1890,8 @@
                                          bool IsFilter) {
   // Find all captures in the Stmt.
   CaptureFinder Finder(ParentCGF, ParentCGF.CXXABIThisDecl);
-  Finder.Visit(OutlinedStmt);
+  if (OutlinedStmt)
+    Finder.Visit(OutlinedStmt);
 
   // We can exit early on x86_64 when there are no captures. We just have to
   // save the exception code in filters so that __exception_code() works.
@@ -1991,6 +2029,16 @@
 
   if (IsFilter)
     EmitSEHExceptionCodeSave(ParentCGF, ParentFP, EntryFP);
+
+  if (Finder.ContainsRetStmt) {
+    SEHRetNowParent = recoverAddrOfEscapedLocal(
+        ParentCGF, ParentCGF.SEHRetNowStack.back(), ParentFP);
+    Address ParentSEHRetVal =
+        ParentCGF.ParentCGF ? ParentCGF.SEHReturnValue : ParentCGF.ReturnValue;
+    if (ParentSEHRetVal.isValid())
+      SEHReturnValue =
+          recoverAddrOfEscapedLocal(ParentCGF, ParentSEHRetVal, ParentFP);
+  }
 }
 
 /// Arrange a function prototype that can be called by Windows exception
@@ -2150,19 +2198,93 @@
 
 void CodeGenFunction::pushSEHCleanup(CleanupKind Kind,
                                      llvm::Function *FinallyFunc) {
-  EHStack.pushCleanup<PerformSEHFinally>(Kind, FinallyFunc);
+  EHStack.pushCleanup<PerformSEHFinally>(Kind, FinallyFunc, false);
 }
 
-void CodeGenFunction::EnterSEHTryStmt(const SEHTryStmt &S) {
+void CodeGenFunction::EnterSEHTryStmt(const SEHTryStmt &S,
+                                      bool &ContainsRetStmt) {
   CodeGenFunction HelperCGF(CGM, /*suppressNewContext=*/true);
   HelperCGF.ParentCGF = this;
   if (const SEHFinallyStmt *Finally = S.getFinallyHandler()) {
+    ReturnStmtFinder Finder;
+    Finder.Visit(Finally);
+    ContainsRetStmt = Finder.ContainsRetStmt;
+    if (ContainsRetStmt) {
+      // Suppose we have something like:
+      // __try {
+      //   f1();
+      // } __finally {
+      //   f2();
+      //   if (z)
+      //     return;
+      //   f3();
+      // }
+      //
+      // We want to generate code something like this, where "StopUnwinding()"
+      // refers to the operation of aborting the unwind, and jupmping back
+      // to normal code.
+      //
+      //  int immediate_return = 0;
+      //  __try {
+      //    f1();
+      //  } __finally {
+      //    f2();
+      //    if (z) {
+      //      immediate_return = 1;
+      //      goto end_of_finally;
+      //    }
+      //    f3();
+      //    end_of_finally:
+      //    if (_abnormal_termination())
+      //      StopUnwinding();
+      //  }
+      //  if (immediate_return) {
+      //    return;
+      //  }
+      //
+      // To handle the non-unwind case, we need to synthesize the
+      // "immediate_return" variable, and use it to change control flow
+      // after the finally block.
+      //
+      // To make "StopUnwinding()" work, we use _local_unwind.  This function
+      // tells the SEH unwinder to recompute the unwind action: instead of
+      // using the __except handler that was already computed, stop unwinding
+      // when the unwinder reaches the current function.  (The mechanism used
+      // here is unofficially called a "collided unwind".)
+      //
+      // We represent the destination of _local_unwind with a fake CatchPad:
+      // when the backend sees a filter named "__IsLocalUnwind", it arranges
+      // the unwind tables so that _local_unwind stops at that CatchPad, but
+      // other unwinding ignores it.
+      //
+      // Note that this construct could itself be inside an __try or __finally
+      // block.
+      //
+      // If it's inside the __try of a __try/__finally, the outer __finally
+      // executes before the function returns.
+      //
+      // If it's inside a __finally, we need to jump out of that __finally
+      // in a similar way.
+
+      // Initialize the variable controlling the exception filter.
+      SEHRetNowStack.push_back(
+          CreateTempAlloca(CGM.Int8Ty, CharUnits::fromQuantity(1), "retnow"));
+      Builder.CreateStore(Builder.getInt8(0), SEHRetNowStack.back());
+
+      // Create the exception filter.
+      EHCatchScope *CatchScope = EHStack.pushCatch(1);
+      llvm::Function *FilterFunc = GenerateSEHIsLocalUnwindFunction();
+      llvm::Constant *OpaqueFunc =
+          llvm::ConstantExpr::getBitCast(FilterFunc, Int8PtrTy);
+      CatchScope->setHandler(0, OpaqueFunc, createBasicBlock("__except.ret"));
+    }
     // Outline the finally block.
     llvm::Function *FinallyFunc =
         HelperCGF.GenerateSEHFinallyFunction(*this, *Finally);
 
     // Push a cleanup for __finally blocks.
-    EHStack.pushCleanup<PerformSEHFinally>(NormalAndEHCleanup, FinallyFunc);
+    EHStack.pushCleanup<PerformSEHFinally>(NormalAndEHCleanup, FinallyFunc,
+                                           ContainsRetStmt);
     return;
   }
 
@@ -2194,10 +2316,72 @@
   CatchScope->setHandler(0, OpaqueFunc, createBasicBlock("__except.ret"));
 }
 
-void CodeGenFunction::ExitSEHTryStmt(const SEHTryStmt &S) {
+llvm::Function *CodeGenFunction::GenerateSEHIsLocalUnwindFunction() {
+  // IsLocalUnwind is a void dummy func just for readability.
+  if (llvm::Function *F = CGM.getModule().getFunction("__IsLocalUnwind"))
+    return F;
+
+  llvm::LLVMContext &Ctx = getLLVMContext();
+  llvm::Type *ArgTys[] = {llvm::Type::getInt8PtrTy(Ctx),
+                          llvm::Type::getInt8PtrTy(Ctx)};
+  return llvm::Function::Create(
+      llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), ArgTys, false),
+      llvm::GlobalVariable::ExternalWeakLinkage, "__IsLocalUnwind",
+      &CGM.getModule());
+}
+
+void CodeGenFunction::EmitSEHLocalUnwind() {
+  EmitRuntimeCallOrInvoke(CGM.getIntrinsic(llvm::Intrinsic::seh_localunwind));
+}
+
+void CodeGenFunction::ExitSEHTryStmt(const SEHTryStmt &S,
+                                     bool ContainsRetStmt) {
   // Just pop the cleanup if it's a __finally block.
   if (S.getFinallyHandler()) {
     PopCleanupBlock();
+    if (ContainsRetStmt) {
+      // Create __except block and control flow handling for return from
+      // __finally. See comment in EnterSEHTryStmt.
+      //
+      // First, create the point where we check for a return
+      // from the __finally.
+      llvm::BasicBlock *ContBB = createBasicBlock("__finally.cont");
+      if (HaveInsertPoint())
+        Builder.CreateBr(ContBB);
+
+      EmitBlock(ContBB);
+
+      // On the normal path, check if we have a return-from-finally.
+      llvm::BasicBlock *AbnormalCont = createBasicBlock("if.then");
+      llvm::BasicBlock *NormalCont = createBasicBlock("if.end");
+      llvm::Value *ShouldRetLoad = Builder.CreateLoad(SEHRetNowStack.back());
+      llvm::Value *ShouldRet = Builder.CreateIsNotNull(ShouldRetLoad);
+
+      Builder.CreateCondBr(ShouldRet, AbnormalCont, NormalCont);
+
+      // Check if our filter function returned true.
+      EHCatchScope &CatchScope = cast<EHCatchScope>(*EHStack.begin());
+      emitCatchDispatchBlock(*this, CatchScope);
+
+      // Grab the block before we pop the handler.
+      llvm::BasicBlock *CatchPadBB = CatchScope.getHandler(0).Block;
+      EHStack.popCatch();
+
+      // The catch block only catches return-from-finally.
+      EmitBlockAfterUses(CatchPadBB);
+      llvm::CatchPadInst *CPI =
+          cast<llvm::CatchPadInst>(CatchPadBB->getFirstNonPHI());
+      Builder.CreateCatchRet(CPI, AbnormalCont);
+      EmitBlock(AbnormalCont);
+
+      // If the try block is nested inside a finally block, forward the
+      // return from __finally to the parent function.
+      if (SEHRetNowParent.isValid())
+        Builder.CreateStore(Builder.getInt8(1), SEHRetNowParent);
+      EmitBranchThroughCleanup(ReturnBlock);
+
+      EmitBlock(NormalCont);
+    }
     return;
   }
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to