llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Lang Hames (lhames)

<details>
<summary>Changes</summary>

This removes ThreadSafeContext::Lock, ThreadSafeContext::getLock, and 
ThreadSafeContext::getContext, and replaces them with a 
ThreadSafeContext::withContextDo method (and const override).

The new method can be used to access an existing ThreadSafeContext-wrapped 
LLVMContext in a safe way:

ThreadSafeContext TSCtx = ... ;
TSCtx.withContextDo([](LLVMContext *Ctx) {
  // this closure has exclusive access to Ctx.
});

The new API enforces correct locking, whereas the old APIs relied on manual 
locking (which almost no in-tree code preformed, relying instead on incidental 
exclusive access to the ThreadSafeContext).

---

Patch is 31.30 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/146819.diff


18 Files Affected:

- (modified) clang/lib/Interpreter/Interpreter.cpp (+8-5) 
- (modified) 
llvm/examples/OrcV2Examples/LLJITWithThinLTOSummaries/LLJITWithThinLTOSummaries.cpp
 (+3-1) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsBasicUsage/OrcV2CBindingsBasicUsage.c 
(+5-5) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsDumpObjects/OrcV2CBindingsDumpObjects.c
 (+3-2) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsIRTransforms/OrcV2CBindingsIRTransforms.c
 (+2-2) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsLazy/OrcV2CBindingsLazy.c (+6-4) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsMCJITLikeMemoryManager/OrcV2CBindingsMCJITLikeMemoryManager.c
 (+6-5) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsRemovableCode/OrcV2CBindingsRemovableCode.c
 (+6-5) 
- (modified) 
llvm/examples/OrcV2Examples/OrcV2CBindingsVeryLazy/OrcV2CBindingsVeryLazy.c 
(+5-5) 
- (modified) llvm/include/llvm-c/Orc.h (+19-7) 
- (modified) llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h (+28-44) 
- (modified) llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp (+3-3) 
- (modified) llvm/lib/ExecutionEngine/Orc/Speculation.cpp (-2) 
- (modified) llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp (+5-3) 
- (modified) llvm/tools/lli/lli.cpp (+2-1) 
- (modified) 
llvm/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp (+17-25) 
- (modified) llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp (+2-2) 
- (modified) llvm/unittests/ExecutionEngine/Orc/ThreadSafeModuleTest.cpp 
(+31-43) 


``````````diff
diff --git a/clang/lib/Interpreter/Interpreter.cpp 
b/clang/lib/Interpreter/Interpreter.cpp
index 2f110659d19a4..98fc0a5e383a3 100644
--- a/clang/lib/Interpreter/Interpreter.cpp
+++ b/clang/lib/Interpreter/Interpreter.cpp
@@ -373,8 +373,11 @@ Interpreter::Interpreter(std::unique_ptr<CompilerInstance> 
Instance,
   auto LLVMCtx = std::make_unique<llvm::LLVMContext>();
   TSCtx = std::make_unique<llvm::orc::ThreadSafeContext>(std::move(LLVMCtx));
 
-  Act = std::make_unique<IncrementalAction>(*CI, *TSCtx->getContext(), ErrOut,
-                                            *this, std::move(Consumer));
+  Act = TSCtx->withContextDo([&](llvm::LLVMContext *Ctx) {
+    return std::make_unique<IncrementalAction>(*CI, *Ctx, ErrOut, *this,
+                                               std::move(Consumer));
+  });
+
   if (ErrOut)
     return;
   CI->ExecuteAction(*Act);
@@ -495,10 +498,10 @@ 
Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
   std::unique_ptr<Interpreter> Interp = std::move(*InterpOrErr);
 
   llvm::Error Err = llvm::Error::success();
-  llvm::LLVMContext &LLVMCtx = *Interp->TSCtx->getContext();
 
-  auto DeviceAct =
-      std::make_unique<IncrementalAction>(*DCI, LLVMCtx, Err, *Interp);
+  auto DeviceAct = Interp->TSCtx->withContextDo([&](llvm::LLVMContext *Ctx) {
+    return std::make_unique<IncrementalAction>(*DCI, *Ctx, Err, *Interp);
+  });
 
   if (Err)
     return std::move(Err);
diff --git 
a/llvm/examples/OrcV2Examples/LLJITWithThinLTOSummaries/LLJITWithThinLTOSummaries.cpp
 
b/llvm/examples/OrcV2Examples/LLJITWithThinLTOSummaries/LLJITWithThinLTOSummaries.cpp
index c55aa73d50277..84f17871c03e4 100644
--- 
a/llvm/examples/OrcV2Examples/LLJITWithThinLTOSummaries/LLJITWithThinLTOSummaries.cpp
+++ 
b/llvm/examples/OrcV2Examples/LLJITWithThinLTOSummaries/LLJITWithThinLTOSummaries.cpp
@@ -169,7 +169,9 @@ Expected<ThreadSafeModule> loadModule(StringRef Path,
 
   MemoryBufferRef BitcodeBufferRef = (**BitcodeBuffer).getMemBufferRef();
   Expected<std::unique_ptr<Module>> M =
-      parseBitcodeFile(BitcodeBufferRef, *TSCtx.getContext());
+      TSCtx.withContextDo([&](LLVMContext *Ctx) {
+        return parseBitcodeFile(BitcodeBufferRef, *Ctx);
+      });
   if (!M)
     return M.takeError();
 
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsBasicUsage/OrcV2CBindingsBasicUsage.c
 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsBasicUsage/OrcV2CBindingsBasicUsage.c
index 286fa8baac4f8..b95462f340f2f 100644
--- 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsBasicUsage/OrcV2CBindingsBasicUsage.c
+++ 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsBasicUsage/OrcV2CBindingsBasicUsage.c
@@ -22,11 +22,8 @@ int handleError(LLVMErrorRef Err) {
 }
 
 LLVMOrcThreadSafeModuleRef createDemoModule(void) {
-  // Create a new ThreadSafeContext and underlying LLVMContext.
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
-
-  // Get a reference to the underlying LLVMContext.
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  // Create an LLVMContext.
+  LLVMContextRef Ctx = LLVMContextCreate();
 
   // Create a new LLVM module.
   LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
@@ -57,6 +54,9 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
   //  - Free the builder.
   LLVMDisposeBuilder(Builder);
 
+  // Create a new ThreadSafeContext to hold the context.
+  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
+
   // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
   // ThreadSafeModule.
   LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsDumpObjects/OrcV2CBindingsDumpObjects.c
 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsDumpObjects/OrcV2CBindingsDumpObjects.c
index 1b4102625fa1b..42a27c7054d47 100644
--- 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsDumpObjects/OrcV2CBindingsDumpObjects.c
+++ 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsDumpObjects/OrcV2CBindingsDumpObjects.c
@@ -31,8 +31,7 @@ int handleError(LLVMErrorRef Err) {
 }
 
 LLVMOrcThreadSafeModuleRef createDemoModule(void) {
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  LLVMContextRef Ctx = LLVMContextCreate();
   LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
   LLVMTypeRef ParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()};
   LLVMTypeRef SumFunctionType =
@@ -45,6 +44,8 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
   LLVMValueRef SumArg1 = LLVMGetParam(SumFunction, 1);
   LLVMValueRef Result = LLVMBuildAdd(Builder, SumArg0, SumArg1, "result");
   LLVMBuildRet(Builder, Result);
+  LLVMOrcThreadSafeContextRef TSCtx =
+      LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
   LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
   LLVMOrcDisposeThreadSafeContext(TSCtx);
   return TSM;
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsIRTransforms/OrcV2CBindingsIRTransforms.c
 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsIRTransforms/OrcV2CBindingsIRTransforms.c
index 41ae6e53db1d6..62904d006da61 100644
--- 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsIRTransforms/OrcV2CBindingsIRTransforms.c
+++ 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsIRTransforms/OrcV2CBindingsIRTransforms.c
@@ -32,8 +32,7 @@ int handleError(LLVMErrorRef Err) {
 }
 
 LLVMOrcThreadSafeModuleRef createDemoModule(void) {
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  LLVMContextRef Ctx = LLVMContextCreate();
   LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
   LLVMTypeRef ParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()};
   LLVMTypeRef SumFunctionType =
@@ -47,6 +46,7 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
   LLVMValueRef Result = LLVMBuildAdd(Builder, SumArg0, SumArg1, "result");
   LLVMBuildRet(Builder, Result);
   LLVMDisposeBuilder(Builder);
+  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
   LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
   LLVMOrcDisposeThreadSafeContext(TSCtx);
   return TSM;
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsLazy/OrcV2CBindingsLazy.c 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsLazy/OrcV2CBindingsLazy.c
index 33398c8cb9816..9c31f93899201 100644
--- a/llvm/examples/OrcV2Examples/OrcV2CBindingsLazy/OrcV2CBindingsLazy.c
+++ b/llvm/examples/OrcV2Examples/OrcV2CBindingsLazy/OrcV2CBindingsLazy.c
@@ -67,11 +67,9 @@ const char MainMod[] =
 LLVMErrorRef parseExampleModule(const char *Source, size_t Len,
                                 const char *Name,
                                 LLVMOrcThreadSafeModuleRef *TSM) {
-  // Create a new ThreadSafeContext and underlying LLVMContext.
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
 
-  // Get a reference to the underlying LLVMContext.
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  // Create an LLVMContext for the Module.
+  LLVMContextRef Ctx = LLVMContextCreate();
 
   // Wrap Source in a MemoryBuffer
   LLVMMemoryBufferRef MB =
@@ -85,6 +83,10 @@ LLVMErrorRef parseExampleModule(const char *Source, size_t 
Len,
     // TODO: LLVMDisposeMessage(ErrMsg);
   }
 
+  // Create a new ThreadSafeContext to hold the context.
+  LLVMOrcThreadSafeContextRef TSCtx =
+      LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
+
   // Our module is now complete. Wrap it and our ThreadSafeContext in a
   // ThreadSafeModule.
   *TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsMCJITLikeMemoryManager/OrcV2CBindingsMCJITLikeMemoryManager.c
 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsMCJITLikeMemoryManager/OrcV2CBindingsMCJITLikeMemoryManager.c
index f85430bcfda4a..6962c6980e787 100644
--- 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsMCJITLikeMemoryManager/OrcV2CBindingsMCJITLikeMemoryManager.c
+++ 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsMCJITLikeMemoryManager/OrcV2CBindingsMCJITLikeMemoryManager.c
@@ -150,11 +150,8 @@ int handleError(LLVMErrorRef Err) {
 }
 
 LLVMOrcThreadSafeModuleRef createDemoModule(void) {
-  // Create a new ThreadSafeContext and underlying LLVMContext.
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
-
-  // Get a reference to the underlying LLVMContext.
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  // Create an LLVMContext.
+  LLVMContextRef Ctx = LLVMContextCreate();
 
   // Create a new LLVM module.
   LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
@@ -182,6 +179,10 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
   //  - Build the return instruction.
   LLVMBuildRet(Builder, Result);
 
+  // Create a new ThreadSafeContext to hold the context.
+  LLVMOrcThreadSafeContextRef TSCtx =
+      LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
+
   // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
   // ThreadSafeModule.
   LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsRemovableCode/OrcV2CBindingsRemovableCode.c
 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsRemovableCode/OrcV2CBindingsRemovableCode.c
index 7f84a3d413435..7f8a9cd334c6b 100644
--- 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsRemovableCode/OrcV2CBindingsRemovableCode.c
+++ 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsRemovableCode/OrcV2CBindingsRemovableCode.c
@@ -22,11 +22,8 @@ int handleError(LLVMErrorRef Err) {
 }
 
 LLVMOrcThreadSafeModuleRef createDemoModule(void) {
-  // Create a new ThreadSafeContext and underlying LLVMContext.
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
-
-  // Get a reference to the underlying LLVMContext.
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  // Create an LLVMContext.
+  LLVMContextRef Ctx = LLVMContextCreate();
 
   // Create a new LLVM module.
   LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
@@ -57,6 +54,10 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
   //  - Free the builder.
   LLVMDisposeBuilder(Builder);
 
+  // Create a new ThreadSafeContext to hold the context.
+  LLVMOrcThreadSafeContextRef TSCtx =
+      LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
+
   // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
   // ThreadSafeModule.
   LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
diff --git 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsVeryLazy/OrcV2CBindingsVeryLazy.c 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsVeryLazy/OrcV2CBindingsVeryLazy.c
index 85651f728399e..3c1ff8392eff4 100644
--- 
a/llvm/examples/OrcV2Examples/OrcV2CBindingsVeryLazy/OrcV2CBindingsVeryLazy.c
+++ 
b/llvm/examples/OrcV2Examples/OrcV2CBindingsVeryLazy/OrcV2CBindingsVeryLazy.c
@@ -74,11 +74,8 @@ LLVMErrorRef applyDataLayout(void *Ctx, LLVMModuleRef M) {
 LLVMErrorRef parseExampleModule(const char *Source, size_t Len,
                                 const char *Name,
                                 LLVMOrcThreadSafeModuleRef *TSM) {
-  // Create a new ThreadSafeContext and underlying LLVMContext.
-  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
-
-  // Get a reference to the underlying LLVMContext.
-  LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
+  // Create an LLVMContext.
+  LLVMContextRef Ctx = LLVMContextCreate();
 
   // Wrap Source in a MemoryBuffer
   LLVMMemoryBufferRef MB =
@@ -93,6 +90,9 @@ LLVMErrorRef parseExampleModule(const char *Source, size_t 
Len,
     return Err;
   }
 
+  // Create a new ThreadSafeContext to hold the context.
+  LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
+
   // Our module is now complete. Wrap it and our ThreadSafeContext in a
   // ThreadSafeModule.
   *TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
diff --git a/llvm/include/llvm-c/Orc.h b/llvm/include/llvm-c/Orc.h
index 743ba1d581782..ee80f6f9f9892 100644
--- a/llvm/include/llvm-c/Orc.h
+++ b/llvm/include/llvm-c/Orc.h
@@ -1062,20 +1062,32 @@ LLVMErrorRef 
LLVMOrcCreateStaticLibrarySearchGeneratorForPath(
     const char *FileName);
 
 /**
- * Create a ThreadSafeContext containing a new LLVMContext.
+ * Create a ThreadSafeContextRef containing a new LLVMContext.
  *
  * Ownership of the underlying ThreadSafeContext data is shared: Clients
- * can and should dispose of their ThreadSafeContext as soon as they no longer
- * need to refer to it directly. Other references (e.g. from ThreadSafeModules)
- * will keep the data alive as long as it is needed.
+ * can and should dispose of their ThreadSafeContextRef as soon as they no
+ * longer need to refer to it directly. Other references (e.g. from
+ * ThreadSafeModules) will keep the underlying data alive as long as it is
+ * needed.
  */
 LLVMOrcThreadSafeContextRef LLVMOrcCreateNewThreadSafeContext(void);
 
 /**
- * Get a reference to the wrapped LLVMContext.
+ * Create a ThreadSafeContextRef from a given LLVMContext, which must not be
+ * associated with any existing ThreadSafeContext.
+ *
+ * The underlying ThreadSafeContext will take ownership of the LLVMContext
+ * object, so clients should not free the LLVMContext passed to this
+ * function.
+ *
+ * Ownership of the underlying ThreadSafeContext data is shared: Clients
+ * can and should dispose of their ThreadSafeContextRef as soon as they no
+ * longer need to refer to it directly. Other references (e.g. from
+ * ThreadSafeModules) will keep the underlying data alive as long as it is
+ * needed.
  */
-LLVMContextRef
-LLVMOrcThreadSafeContextGetContext(LLVMOrcThreadSafeContextRef TSCtx);
+LLVMOrcThreadSafeContextRef
+LLVMOrcCreateNewThreadSafeContextFromLLVMContext(LLVMContextRef Ctx);
 
 /**
  * Dispose of a ThreadSafeContext.
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h 
b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h
index b61c8b8563a1a..f1353777f6ce9 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h
@@ -36,16 +36,6 @@ class ThreadSafeContext {
   };
 
 public:
-  // RAII based lock for ThreadSafeContext.
-  class [[nodiscard]] Lock {
-  public:
-    Lock(std::shared_ptr<State> S) : S(std::move(S)), L(this->S->Mutex) {}
-
-  private:
-    std::shared_ptr<State> S;
-    std::unique_lock<std::recursive_mutex> L;
-  };
-
   /// Construct a null context.
   ThreadSafeContext() = default;
 
@@ -56,17 +46,20 @@ class ThreadSafeContext {
            "Can not construct a ThreadSafeContext from a nullptr");
   }
 
-  /// Returns a pointer to the LLVMContext that was used to construct this
-  /// instance, or null if the instance was default constructed.
-  LLVMContext *getContext() { return S ? S->Ctx.get() : nullptr; }
-
-  /// Returns a pointer to the LLVMContext that was used to construct this
-  /// instance, or null if the instance was default constructed.
-  const LLVMContext *getContext() const { return S ? S->Ctx.get() : nullptr; }
+  template <typename Func> decltype(auto) withContextDo(Func &&F) {
+    if (auto TmpS = S) {
+      std::lock_guard<std::recursive_mutex> Lock(TmpS->Mutex);
+      return F(TmpS->Ctx.get());
+    } else
+      return F((LLVMContext *)nullptr);
+  }
 
-  Lock getLock() const {
-    assert(S && "Can not lock an empty ThreadSafeContext");
-    return Lock(S);
+  template <typename Func> decltype(auto) withContextDo(Func &&F) const {
+    if (auto TmpS = S) {
+      std::lock_guard<std::recursive_mutex> Lock(TmpS->Mutex);
+      return F(const_cast<const LLVMContext *>(TmpS->Ctx.get()));
+    } else
+      return F((const LLVMContext *)nullptr);
   }
 
 private:
@@ -89,10 +82,7 @@ class ThreadSafeModule {
     // *before* the context that it depends on.
     // We also need to lock the context to make sure the module tear-down
     // does not overlap any other work on the context.
-    if (M) {
-      auto L = TSCtx.getLock();
-      M = nullptr;
-    }
+    TSCtx.withContextDo([this](LLVMContext *Ctx) { M = nullptr; });
     M = std::move(Other.M);
     TSCtx = std::move(Other.TSCtx);
     return *this;
@@ -111,45 +101,39 @@ class ThreadSafeModule {
 
   ~ThreadSafeModule() {
     // We need to lock the context while we destruct the module.
-    if (M) {
-      auto L = TSCtx.getLock();
-      M = nullptr;
-    }
+    TSCtx.withContextDo([this](LLVMContext *Ctx) { M = nullptr; });
   }
 
   /// Boolean conversion: This ThreadSafeModule will evaluate to true if it
   /// wraps a non-null module.
-  explicit operator bool() const {
-    if (M) {
-      assert(TSCtx.getContext() &&
-             "Non-null module must have non-null context");
-      return true;
-    }
-    return false;
-  }
+  explicit operator bool() const { return !!M; }
 
   /// Locks the associated ThreadSafeContext and calls the given function
   /// on the contained Module.
   template <typename Func> decltype(auto) withModuleDo(Func &&F) {
-    assert(M && "Can not call on null module");
-    auto Lock = TSCtx.getLock();
-    return F(*M);
+    return TSCtx.withContextDo([&](LLVMContext *) {
+      assert(M && "Can not call on null module");
+      return F(*M);
+    });
   }
 
   /// Locks the associated ThreadSafeContext and calls the given function
   /// on the contained Module.
   template <typename Func> decltype(auto) withModuleDo(Func &&F) const {
-    assert(M && "Can not call on null module");
-    auto Lock = TSCtx.getLock();
-    return F(*M);
+    return TSCtx.withContextDo([&](const LLVMContext *) {
+      assert(M && "Can not call on null module");
+      return F(*M);
+    });
   }
 
   /// Locks the associated ThreadSafeContext and calls the given function,
   /// passing the contained std::unique_ptr<Module>. The given function should
   /// consume the Module.
   template <typename Func> decltype(auto) consumingModuleDo(Func &&F) {
-    auto Lock = TSCtx.getLock();
-    return F(std::move(M));
+    return TSCtx.withContextDo([&](LLVMContext *) {
+      assert(M && "Can not call on null module");
+      return F(std::move(M));
+    });
   }
 
   /// Get a raw pointer to the contained module without locking the context.
diff --git a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp 
b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
index 9999e1ff3bf00..fd805fbf01fb7 100644
--- a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
@@ -729,9 +729,9 @@ LLVMOrcThreadSafeContextRef 
LLVMOrcCreateNewThreadSafeContext(void) {
   return wrap(new ThreadSafeContext(std::make_unique<LLVMContext>()));
 }
 
-LLVMContextRef
-LLVMOrcThreadSafeContextGetContext(LLVMOrcThreadSafeContextRef TSCtx) {
-  return wrap(unwrap(TSCtx)->getContext());
+LLVMOrcThreadSafeContextRef
+LLVMOrcCreateNewThreadSafeContextFromLLVMContext(LLVMContextRef Ctx) {
+  return wrap(new 
ThreadSafeContext(std::unique_ptr<LLVMContext>(unwrap(Ctx))));
 }
 
 void LLVMOrcDisposeThreadSafeContext(LLVMOrcThreadSafeContextRef TSCtx) {
diff --git a/llvm/lib/ExecutionEngine/Orc/Speculation.cpp 
b/llvm/lib/ExecutionEngine/Orc/Speculation.cpp
index 74b9eb29bdccf..fee94b96a9e8a 100644
--- a/llvm/lib/ExecutionEngine/Orc/Speculation.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/Speculation.cpp
@@ -60,8 +60,6 @@ void 
IRSpeculationLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
                               ThreadSafeModule TSM) {
 
   assert(TSM && "Speculation Layer received Null Module ?");
-  assert(TSM.getContext().getContext() != nullptr &&
-         "Module with null LLVMContext?");
 
   // Instrumentation of runtime calls, lock the Module
   TSM.withModuleDo([this, &R](Module &M) {
diff --git a/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp 
b/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp
index 2e128dd237443..c927f21494697 100644
--- a/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp
@@ -53,9 +53,11 @@ ThreadSafeModule cloneToNewContext(const ThreadSafeModule 
&TSM,
         "cloned module buffer");
     ThreadSafeContext NewTSCtx(std::make_unique<LLVMContext>());
 
-    auto ClonedModule = cantFail(
-        parseBitcodeFile(ClonedModuleBufferRef, *NewTS...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146819
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to