This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 5e2e8c7ce5ad568641890d5790e38e9707e3b014
Author: Pxl <pxl...@qq.com>
AuthorDate: Thu Jun 13 23:00:27 2024 +0800

    [Bug](load) fix use after free on load channel in work load group scheduler 
(#36272)
    
    ## Proposed changes
    fix use after free on load channel in work load group scheduler
    introduced by #36176
---
 be/src/runtime/fragment_mgr.cpp | 37 ++++++++++++++++---------------------
 be/src/runtime/fragment_mgr.h   |  6 +++---
 be/src/runtime/load_channel.cpp |  3 ++-
 be/src/runtime/load_stream.cpp  |  2 +-
 4 files changed, 22 insertions(+), 26 deletions(-)

diff --git a/be/src/runtime/fragment_mgr.cpp b/be/src/runtime/fragment_mgr.cpp
index e8b25a96717..9451135da6e 100644
--- a/be/src/runtime/fragment_mgr.cpp
+++ b/be/src/runtime/fragment_mgr.cpp
@@ -529,7 +529,7 @@ Status FragmentMgr::start_query_execution(const 
PExecPlanFragmentStartRequest* r
     TUniqueId query_id;
     query_id.__set_hi(request->query_id().hi());
     query_id.__set_lo(request->query_id().lo());
-    if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+    if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
         q_ctx->set_ready_to_execute(Status::OK());
     } else {
         return Status::InternalError(
@@ -560,7 +560,7 @@ void FragmentMgr::remove_pipeline_context(
     }
 }
 
-std::shared_ptr<QueryContext> FragmentMgr::get_or_erase_query_ctx(TUniqueId 
query_id) {
+std::shared_ptr<QueryContext> FragmentMgr::_get_or_erase_query_ctx(const 
TUniqueId& query_id) {
     auto search = _query_ctx_map.find(query_id);
     if (search != _query_ctx_map.end()) {
         if (auto q_ctx = search->second.lock()) {
@@ -575,13 +575,19 @@ std::shared_ptr<QueryContext> 
FragmentMgr::get_or_erase_query_ctx(TUniqueId quer
     return nullptr;
 }
 
+std::shared_ptr<QueryContext> FragmentMgr::get_or_erase_query_ctx_with_lock(
+        const TUniqueId& query_id) {
+    std::unique_lock<std::mutex> lock(_lock);
+    return _get_or_erase_query_ctx(query_id);
+}
+
 template <typename Params>
 Status FragmentMgr::_get_query_ctx(const Params& params, TUniqueId query_id, 
bool pipeline,
                                    std::shared_ptr<QueryContext>& query_ctx) {
     if (params.is_simplified_param) {
         // Get common components from _query_ctx_map
         std::lock_guard<std::mutex> lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_ctx = q_ctx;
         } else {
             return Status::InternalError(
@@ -593,7 +599,7 @@ Status FragmentMgr::_get_query_ctx(const Params& params, 
TUniqueId query_id, boo
         // Find _query_ctx_map, in case some other request has already
         // create the query fragments context.
         std::lock_guard<std::mutex> lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_ctx = q_ctx;
             return Status::OK();
         }
@@ -691,7 +697,7 @@ std::string FragmentMgr::dump_pipeline_tasks(int64_t 
duration) {
 }
 
 std::string FragmentMgr::dump_pipeline_tasks(TUniqueId& query_id) {
-    if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+    if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
         return q_ctx->print_all_pipeline_context();
     } else {
         return fmt::format("Query context (query id = {}) not found. \n", 
print_id(query_id));
@@ -787,23 +793,12 @@ void FragmentMgr::_set_scan_concurrency(const Param& 
params, QueryContext* query
 #endif
 }
 
-Status FragmentMgr::get_query_context(const TUniqueId& query_id,
-                                      std::shared_ptr<QueryContext>* 
query_ctx) {
-    std::lock_guard<std::mutex> state_lock(_lock);
-    if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
-        *query_ctx = q_ctx;
-    } else {
-        return Status::InternalError("Query context not found for query {}", 
print_id(query_id));
-    }
-    return Status::OK();
-}
-
 void FragmentMgr::cancel_query(const TUniqueId query_id, const Status reason) {
     std::shared_ptr<QueryContext> query_ctx = nullptr;
     std::vector<TUniqueId> all_instance_ids;
     {
         std::lock_guard<std::mutex> state_lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_ctx = q_ctx;
             // Copy instanceids to avoid concurrent modification.
             // And to reduce the scope of lock.
@@ -1137,7 +1132,7 @@ Status FragmentMgr::send_filter_size(const 
PSendFilterSizeRequest* request) {
         query_id.__set_hi(queryid.hi);
         query_id.__set_lo(queryid.lo);
         std::lock_guard<std::mutex> lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_ctx = q_ctx;
         } else {
             return Status::InvalidArgument("Query context (query-id: {}) not 
found",
@@ -1156,7 +1151,7 @@ Status FragmentMgr::sync_filter_size(const 
PSyncFilterSizeRequest* request) {
         query_id.__set_hi(queryid.hi);
         query_id.__set_lo(queryid.lo);
         std::lock_guard<std::mutex> lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_ctx = q_ctx;
         } else {
             return Status::InvalidArgument("Query context (query-id: {}) not 
found",
@@ -1178,7 +1173,7 @@ Status FragmentMgr::merge_filter(const 
PMergeFilterRequest* request,
         query_id.__set_hi(queryid.hi);
         query_id.__set_lo(queryid.lo);
         std::lock_guard<std::mutex> lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_ctx = q_ctx;
         } else {
             return Status::InvalidArgument("Query context (query-id: {}) not 
found",
@@ -1219,7 +1214,7 @@ Status FragmentMgr::get_realtime_exec_status(const 
TUniqueId& query_id,
 
     {
         std::lock_guard<std::mutex> lock(_lock);
-        if (auto q_ctx = get_or_erase_query_ctx(query_id)) {
+        if (auto q_ctx = _get_or_erase_query_ctx(query_id)) {
             query_context = q_ctx;
         } else {
             return Status::NotFound("Query {} has been released", 
print_id(query_id));
diff --git a/be/src/runtime/fragment_mgr.h b/be/src/runtime/fragment_mgr.h
index 5355f51a217..dba9bcde398 100644
--- a/be/src/runtime/fragment_mgr.h
+++ b/be/src/runtime/fragment_mgr.h
@@ -135,8 +135,6 @@ public:
 
     ThreadPool* get_thread_pool() { return _thread_pool.get(); }
 
-    Status get_query_context(const TUniqueId& query_id, 
std::shared_ptr<QueryContext>* query_ctx);
-
     int32_t running_query_num() {
         std::unique_lock<std::mutex> ctx_lock(_lock);
         return _query_ctx_map.size();
@@ -150,9 +148,11 @@ public:
     Status get_realtime_exec_status(const TUniqueId& query_id,
                                     TReportExecStatusParams* exec_status);
 
-    std::shared_ptr<QueryContext> get_or_erase_query_ctx(TUniqueId query_id);
+    std::shared_ptr<QueryContext> get_or_erase_query_ctx_with_lock(const 
TUniqueId& query_id);
 
 private:
+    std::shared_ptr<QueryContext> _get_or_erase_query_ctx(const TUniqueId& 
query_id);
+
     template <typename Param>
     void _set_scan_concurrency(const Param& params, QueryContext* query_ctx);
 
diff --git a/be/src/runtime/load_channel.cpp b/be/src/runtime/load_channel.cpp
index f307a6e3545..cd3f3aa5af5 100644
--- a/be/src/runtime/load_channel.cpp
+++ b/be/src/runtime/load_channel.cpp
@@ -43,7 +43,8 @@ LoadChannel::LoadChannel(const UniqueId& load_id, int64_t 
timeout_s, bool is_hig
           _backend_id(backend_id),
           _enable_profile(enable_profile) {
     std::shared_ptr<QueryContext> query_context =
-            
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx(_load_id.to_thrift());
+            
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx_with_lock(
+                    _load_id.to_thrift());
     if (query_context != nullptr) {
         _query_thread_context = {_load_id.to_thrift(), 
query_context->query_mem_tracker};
     } else {
diff --git a/be/src/runtime/load_stream.cpp b/be/src/runtime/load_stream.cpp
index d4132cada9f..b896994b1ef 100644
--- a/be/src/runtime/load_stream.cpp
+++ b/be/src/runtime/load_stream.cpp
@@ -336,7 +336,7 @@ LoadStream::LoadStream(PUniqueId load_id, LoadStreamMgr* 
load_stream_mgr, bool e
     TUniqueId load_tid = ((UniqueId)load_id).to_thrift();
 #ifndef BE_TEST
     std::shared_ptr<QueryContext> query_context =
-            
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx(load_tid);
+            
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx_with_lock(load_tid);
     if (query_context != nullptr) {
         _query_thread_context = {load_tid, query_context->query_mem_tracker};
     } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to