Author: Eugene Zhulenev Date: 2021-01-20T13:23:39-08:00 New Revision: a2223b09b10a4cc87b5e9c4a36ab9401c46610f6
URL: https://github.com/llvm/llvm-project/commit/a2223b09b10a4cc87b5e9c4a36ab9401c46610f6 DIFF: https://github.com/llvm/llvm-project/commit/a2223b09b10a4cc87b5e9c4a36ab9401c46610f6.diff LOG: [mlir:async] Fix data races in AsyncRuntime Resumed coroutine potentially can deallocate the token/value/group and destroy the mutex before the std::unique_ptr destructor. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D95037 Added: Modified: mlir/lib/ExecutionEngine/AsyncRuntime.cpp Removed: ################################################################################ diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp index a20bd6d1e996..e38ebf92cd84 100644 --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -136,13 +136,14 @@ struct AsyncToken : public RefCounted { // asynchronously executed task. If the caller immediately will drop its // reference we must ensure that the token will be alive until the // asynchronous operation is completed. - AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} + AsyncToken(AsyncRuntime *runtime) + : RefCounted(runtime, /*count=*/2), ready(false) {} - // Internal state below guarded by a mutex. + std::atomic<bool> ready; + + // Pending awaiters are guarded by a mutex. std::mutex mu; std::condition_variable cv; - - bool ready = false; std::vector<std::function<void()>> awaiters; }; @@ -152,17 +153,17 @@ struct AsyncToken : public RefCounted { struct AsyncValue : public RefCounted { // AsyncValue similar to an AsyncToken created with a reference count of 2. AsyncValue(AsyncRuntime *runtime, int32_t size) - : RefCounted(runtime, /*count=*/2), storage(size) {} - - // Internal state below guarded by a mutex. - std::mutex mu; - std::condition_variable cv; + : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {} - bool ready = false; - std::vector<std::function<void()>> awaiters; + std::atomic<bool> ready; // Use vector of bytes to store async value payload. std::vector<int8_t> storage; + + // Pending awaiters are guarded by a mutex. + std::mutex mu; + std::condition_variable cv; + std::vector<std::function<void()>> awaiters; }; // Async group provides a mechanism to group together multiple async tokens or @@ -175,10 +176,9 @@ struct AsyncGroup : public RefCounted { std::atomic<int> pendingTokens; std::atomic<int> rank; - // Internal state below guarded by a mutex. + // Pending awaiters are guarded by a mutex. std::mutex mu; std::condition_variable cv; - std::vector<std::function<void()>> awaiters; }; @@ -291,13 +291,13 @@ extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) { extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { std::unique_lock<std::mutex> lock(token->mu); if (!token->ready) - token->cv.wait(lock, [token] { return token->ready; }); + token->cv.wait(lock, [token] { return token->ready.load(); }); } extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) { std::unique_lock<std::mutex> lock(value->mu); if (!value->ready) - value->cv.wait(lock, [value] { return value->ready; }); + value->cv.wait(lock, [value] { return value->ready.load(); }); } extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) { @@ -319,34 +319,37 @@ extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) { extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle, CoroResume resume) { - std::unique_lock<std::mutex> lock(token->mu); auto execute = [handle, resume]() { (*resume)(handle); }; - if (token->ready) + if (token->ready) { execute(); - else + } else { + std::unique_lock<std::mutex> lock(token->mu); token->awaiters.push_back([execute]() { execute(); }); + } } extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, CoroHandle handle, CoroResume resume) { - std::unique_lock<std::mutex> lock(value->mu); auto execute = [handle, resume]() { (*resume)(handle); }; - if (value->ready) + if (value->ready) { execute(); - else + } else { + std::unique_lock<std::mutex> lock(value->mu); value->awaiters.push_back([execute]() { execute(); }); + } } extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, CoroResume resume) { - std::unique_lock<std::mutex> lock(group->mu); auto execute = [handle, resume]() { (*resume)(handle); }; - if (group->pendingTokens == 0) + if (group->pendingTokens == 0) { execute(); - else + } else { + std::unique_lock<std::mutex> lock(group->mu); group->awaiters.push_back([execute]() { execute(); }); + } } //===----------------------------------------------------------------------===// _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits