This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 5688c2836415bf8f6f25b6a9a6029b965c17d0b7
Author: Pxl <pxl...@qq.com>
AuthorDate: Thu Apr 11 11:22:10 2024 +0800

    [Bug](runtime-filter) try to fix heap use after free on runtime filter send 
filter size (#33465) (#33522)
---
 be/src/runtime/fragment_mgr.cpp                    | 42 ++++++++++-----------
 be/src/runtime/runtime_filter_mgr.cpp              |  6 +--
 be/src/runtime/runtime_state.cpp                   | 43 +++-------------------
 be/src/runtime/runtime_state.h                     |  5 ---
 be/test/exprs/runtime_filter_test.cpp              |  7 +---
 .../serde/data_type_serde_mysql_test.cpp           |  3 +-
 be/test/vec/exec/vtablet_sink_test.cpp             | 12 ++++--
 be/test/vec/exprs/vexpr_test.cpp                   |  6 +--
 be/test/vec/runtime/vdata_stream_test.cpp          |  3 +-
 9 files changed, 42 insertions(+), 85 deletions(-)

diff --git a/be/src/runtime/fragment_mgr.cpp b/be/src/runtime/fragment_mgr.cpp
index 9fff093b3bf..3cde0778f62 100644
--- a/be/src/runtime/fragment_mgr.cpp
+++ b/be/src/runtime/fragment_mgr.cpp
@@ -824,30 +824,30 @@ Status FragmentMgr::exec_plan_fragment(const 
TPipelineFragmentParams& params,
         }
         g_fragmentmgr_prepare_latency << (duration_ns / 1000);
 
-        for (size_t i = 0; i < params.local_params.size(); i++) {
-            std::shared_ptr<RuntimeFilterMergeControllerEntity> handler;
-            RETURN_IF_ERROR(_runtimefilter_controller.add_entity(
-                    params.local_params[i], params.query_id, 
params.query_options, &handler,
-                    
RuntimeFilterParamsContext::create(context->get_runtime_state())));
-            if (!i && handler) {
-                query_ctx->set_merge_controller_handler(handler);
-            }
-            const TUniqueId& fragment_instance_id = 
params.local_params[i].fragment_instance_id;
-            {
-                std::lock_guard<std::mutex> lock(_lock);
-                auto iter = _pipeline_map.find(fragment_instance_id);
-                if (iter != _pipeline_map.end()) {
-                    // Duplicated
-                    return Status::OK();
-                }
-                
query_ctx->fragment_instance_ids.push_back(fragment_instance_id);
-            }
+        std::shared_ptr<RuntimeFilterMergeControllerEntity> handler;
+        RETURN_IF_ERROR(_runtimefilter_controller.add_entity(
+                params.local_params[0], params.query_id, params.query_options, 
&handler,
+                
RuntimeFilterParamsContext::create(context->get_runtime_state())));
+        if (handler) {
+            query_ctx->set_merge_controller_handler(handler);
+        }
 
-            if (!params.__isset.need_wait_execution_trigger ||
-                !params.need_wait_execution_trigger) {
-                query_ctx->set_ready_to_execute_only();
+        for (const auto& local_param : params.local_params) {
+            const TUniqueId& fragment_instance_id = 
local_param.fragment_instance_id;
+            std::lock_guard<std::mutex> lock(_lock);
+            auto iter = _pipeline_map.find(fragment_instance_id);
+            if (iter != _pipeline_map.end()) {
+                return Status::InternalError(
+                        "exec_plan_fragment input duplicated 
fragment_instance_id({})",
+                        UniqueId(fragment_instance_id).to_string());
             }
+            query_ctx->fragment_instance_ids.push_back(fragment_instance_id);
         }
+
+        if (!params.__isset.need_wait_execution_trigger || 
!params.need_wait_execution_trigger) {
+            query_ctx->set_ready_to_execute_only();
+        }
+
         {
             std::lock_guard<std::mutex> lock(_lock);
             std::vector<TUniqueId> ins_ids;
diff --git a/be/src/runtime/runtime_filter_mgr.cpp 
b/be/src/runtime/runtime_filter_mgr.cpp
index f0407e42ce2..f1adb4ed198 100644
--- a/be/src/runtime/runtime_filter_mgr.cpp
+++ b/be/src/runtime/runtime_filter_mgr.cpp
@@ -574,8 +574,7 @@ Status RuntimeFilterMergeController::acquire(
     std::lock_guard<std::mutex> guard(_controller_mutex[shard]);
     auto iter = _filter_controller_map[shard].find(query_id);
     if (iter == _filter_controller_map[shard].end()) {
-        LOG(WARNING) << "not found entity, query-id:" << query_id.to_string();
-        return Status::InvalidArgument("not found entity");
+        return Status::InvalidArgument("not found entity, query-id:{}", 
query_id.to_string());
     }
     *handle = _filter_controller_map[shard][query_id].lock();
     if (*handle == nullptr) {
@@ -591,7 +590,8 @@ void RuntimeFilterMergeController::remove_entity(UniqueId 
query_id) {
 }
 
 RuntimeFilterParamsContext* RuntimeFilterParamsContext::create(RuntimeState* 
state) {
-    RuntimeFilterParamsContext* params = state->obj_pool()->add(new 
RuntimeFilterParamsContext());
+    RuntimeFilterParamsContext* params =
+            state->get_query_ctx()->obj_pool.add(new 
RuntimeFilterParamsContext());
     params->runtime_filter_wait_infinitely = 
state->runtime_filter_wait_infinitely();
     params->runtime_filter_wait_time_ms = state->runtime_filter_wait_time_ms();
     params->enable_pipeline_exec = state->enable_pipeline_exec();
diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp
index fcbf20c0f72..2713ee441dd 100644
--- a/be/src/runtime/runtime_state.cpp
+++ b/be/src/runtime/runtime_state.cpp
@@ -49,41 +49,6 @@
 namespace doris {
 using namespace ErrorCode;
 
-// for ut only
-RuntimeState::RuntimeState(const TUniqueId& fragment_instance_id,
-                           const TQueryOptions& query_options, const 
TQueryGlobals& query_globals,
-                           ExecEnv* exec_env,
-                           const std::shared_ptr<MemTrackerLimiter>& 
query_mem_tracker)
-        : _profile("Fragment " + print_id(fragment_instance_id)),
-          _load_channel_profile("<unnamed>"),
-          _obj_pool(new ObjectPool()),
-          _data_stream_recvrs_pool(new ObjectPool()),
-          _unreported_error_idx(0),
-          _is_cancelled(false),
-          _per_fragment_instance_idx(0),
-          _num_rows_load_total(0),
-          _num_rows_load_filtered(0),
-          _num_rows_load_unselected(0),
-          _num_print_error_rows(0),
-          _num_bytes_load_total(0),
-          _num_finished_scan_range(0),
-          _load_job_id(-1),
-          _normal_row_number(0),
-          _error_row_number(0),
-          _error_log_file(nullptr) {
-    Status status = init(fragment_instance_id, query_options, query_globals, 
exec_env);
-    DCHECK(status.ok());
-    _query_mem_tracker = query_mem_tracker;
-#ifdef BE_TEST
-    if (_query_mem_tracker == nullptr) {
-        init_mem_trackers();
-    }
-#endif
-    DCHECK(_query_mem_tracker != nullptr && _query_mem_tracker->label() != 
"Orphan");
-    _runtime_filter_mgr.reset(new RuntimeFilterMgr(
-            TUniqueId(), RuntimeFilterParamsContext::create(this), 
_query_mem_tracker));
-}
-
 RuntimeState::RuntimeState(const TPlanFragmentExecParams& fragment_exec_params,
                            const TQueryOptions& query_options, const 
TQueryGlobals& query_globals,
                            ExecEnv* exec_env, QueryContext* ctx,
@@ -121,9 +86,11 @@ RuntimeState::RuntimeState(const TPlanFragmentExecParams& 
fragment_exec_params,
     }
 #endif
     DCHECK(_query_mem_tracker != nullptr && _query_mem_tracker->label() != 
"Orphan");
-    _runtime_filter_mgr = std::make_unique<RuntimeFilterMgr>(
-            fragment_exec_params.query_id, 
RuntimeFilterParamsContext::create(this),
-            _query_mem_tracker);
+    if (ctx) {
+        _runtime_filter_mgr = std::make_unique<RuntimeFilterMgr>(
+                fragment_exec_params.query_id, 
RuntimeFilterParamsContext::create(this),
+                _query_mem_tracker);
+    }
     if (fragment_exec_params.__isset.runtime_filter_params) {
         _query_ctx->runtime_filter_mgr()->set_runtime_filter_params(
                 fragment_exec_params.runtime_filter_params);
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index 6fae242d53c..07655c71b6c 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -66,11 +66,6 @@ class RuntimeState {
     ENABLE_FACTORY_CREATOR(RuntimeState);
 
 public:
-    // for ut only
-    RuntimeState(const TUniqueId& fragment_instance_id, const TQueryOptions& 
query_options,
-                 const TQueryGlobals& query_globals, ExecEnv* exec_env,
-                 const std::shared_ptr<MemTrackerLimiter>& query_mem_tracker);
-
     RuntimeState(const TPlanFragmentExecParams& fragment_exec_params,
                  const TQueryOptions& query_options, const TQueryGlobals& 
query_globals,
                  ExecEnv* exec_env, QueryContext* ctx,
diff --git a/be/test/exprs/runtime_filter_test.cpp 
b/be/test/exprs/runtime_filter_test.cpp
index 64310e811f7..36d7cd885dd 100644
--- a/be/test/exprs/runtime_filter_test.cpp
+++ b/be/test/exprs/runtime_filter_test.cpp
@@ -35,12 +35,7 @@ TTypeDesc create_type_desc(PrimitiveType type, int 
precision, int scale);
 class RuntimeFilterTest : public testing::Test {
 public:
     RuntimeFilterTest() {}
-    virtual void SetUp() {
-        ExecEnv* exec_env = ExecEnv::GetInstance();
-        exec_env = nullptr;
-        _runtime_stat = RuntimeState::create_unique(_fragment_id, 
_query_options, _query_globals,
-                                                    exec_env, nullptr);
-    }
+    virtual void SetUp() {}
     virtual void TearDown() { _obj_pool.clear(); }
 
 private:
diff --git a/be/test/vec/data_types/serde/data_type_serde_mysql_test.cpp 
b/be/test/vec/data_types/serde/data_type_serde_mysql_test.cpp
index 35c169da7aa..5ba8af8b81f 100644
--- a/be/test/vec/data_types/serde/data_type_serde_mysql_test.cpp
+++ b/be/test/vec/data_types/serde/data_type_serde_mysql_test.cpp
@@ -87,8 +87,7 @@ void serialize_and_deserialize_mysql_test() {
     // make desc and generate block
     vectorized::VExprContextSPtrs _output_vexpr_ctxs;
     _output_vexpr_ctxs.resize(cols.size());
-    doris::RuntimeState runtime_stat(doris::TUniqueId(), 
doris::TQueryOptions(),
-                                     doris::TQueryGlobals(), nullptr, nullptr);
+    doris::RuntimeState runtime_stat;
     ObjectPool object_pool;
     int col_idx = 0;
     for (auto t : cols) {
diff --git a/be/test/vec/exec/vtablet_sink_test.cpp 
b/be/test/vec/exec/vtablet_sink_test.cpp
index cb6873b719c..e9011d73df8 100644
--- a/be/test/vec/exec/vtablet_sink_test.cpp
+++ b/be/test/vec/exec/vtablet_sink_test.cpp
@@ -403,7 +403,8 @@ public:
         TQueryOptions query_options;
         query_options.batch_size = 1;
         query_options.be_exec_version = be_exec_version;
-        RuntimeState state(fragment_id, query_options, TQueryGlobals(), _env, 
nullptr);
+        RuntimeState state;
+        state.set_query_options(query_options);
         std::shared_ptr<TaskExecutionContext> task_ctx_lock =
                 std::make_shared<TaskExecutionContext>();
         state.set_task_execution_context(task_ctx_lock);
@@ -524,7 +525,8 @@ TEST_F(VOlapTableSinkTest, convert) {
     TQueryOptions query_options;
     query_options.batch_size = 1024;
     query_options.be_exec_version = 1;
-    RuntimeState state(fragment_id, query_options, TQueryGlobals(), _env, 
nullptr);
+    RuntimeState state;
+    state.set_query_options(query_options);
     std::shared_ptr<TaskExecutionContext> task_ctx_lock = 
std::make_shared<TaskExecutionContext>();
     state.set_task_execution_context(task_ctx_lock);
 
@@ -655,7 +657,8 @@ TEST_F(VOlapTableSinkTest, add_block_failed) {
     TQueryOptions query_options;
     query_options.batch_size = 1;
     query_options.be_exec_version = 1;
-    RuntimeState state(fragment_id, query_options, TQueryGlobals(), _env, 
nullptr);
+    RuntimeState state;
+    state.set_query_options(query_options);
     std::shared_ptr<TaskExecutionContext> task_ctx_lock = 
std::make_shared<TaskExecutionContext>();
     state.set_task_execution_context(task_ctx_lock);
 
@@ -770,7 +773,8 @@ TEST_F(VOlapTableSinkTest, decimal) {
     TQueryOptions query_options;
     query_options.batch_size = 1;
     query_options.be_exec_version = 1;
-    RuntimeState state(fragment_id, query_options, TQueryGlobals(), _env, 
nullptr);
+    RuntimeState state;
+    state.set_query_options(query_options);
     std::shared_ptr<TaskExecutionContext> task_ctx_lock = 
std::make_shared<TaskExecutionContext>();
     state.set_task_execution_context(task_ctx_lock);
 
diff --git a/be/test/vec/exprs/vexpr_test.cpp b/be/test/vec/exprs/vexpr_test.cpp
index 079fbe5f377..a670c443c27 100644
--- a/be/test/vec/exprs/vexpr_test.cpp
+++ b/be/test/vec/exprs/vexpr_test.cpp
@@ -61,8 +61,7 @@ TEST(TEST_VEXPR, ABSTEST) {
     doris::vectorized::VExprContextSPtr context;
     static_cast<void>(doris::vectorized::VExpr::create_expr_tree(exprx, 
context));
 
-    doris::RuntimeState runtime_stat(doris::TUniqueId(), 
doris::TQueryOptions(),
-                                     doris::TQueryGlobals(), nullptr, nullptr);
+    doris::RuntimeState runtime_stat;
     runtime_stat.set_desc_tbl(desc_tbl);
     auto state = doris::Status::OK();
     state = context->prepare(&runtime_stat, row_desc);
@@ -153,8 +152,7 @@ TEST(TEST_VEXPR, ABSTEST2) {
     doris::vectorized::VExprContextSPtr context;
     static_cast<void>(doris::vectorized::VExpr::create_expr_tree(exprx, 
context));
 
-    doris::RuntimeState runtime_stat(doris::TUniqueId(), 
doris::TQueryOptions(),
-                                     doris::TQueryGlobals(), nullptr, nullptr);
+    doris::RuntimeState runtime_stat;
     DescriptorTbl desc_tbl;
     desc_tbl._slot_desc_map[0] = tuple_desc->slots()[0];
     runtime_stat.set_desc_tbl(&desc_tbl);
diff --git a/be/test/vec/runtime/vdata_stream_test.cpp 
b/be/test/vec/runtime/vdata_stream_test.cpp
index 19d4ac81afe..2a030d8e947 100644
--- a/be/test/vec/runtime/vdata_stream_test.cpp
+++ b/be/test/vec/runtime/vdata_stream_test.cpp
@@ -149,8 +149,7 @@ TEST_F(VDataStreamTest, BasicTest) {
     auto tuple_desc = 
const_cast<doris::TupleDescriptor*>(desc_tbl->get_tuple_descriptor(0));
     doris::RowDescriptor row_desc(tuple_desc, false);
 
-    doris::RuntimeState runtime_stat(doris::TUniqueId(), 
doris::TQueryOptions(),
-                                     doris::TQueryGlobals(), nullptr, nullptr);
+    doris::RuntimeState runtime_stat;
     std::shared_ptr<TaskExecutionContext> task_ctx_lock = 
std::make_shared<TaskExecutionContext>();
     runtime_stat.set_task_execution_context(task_ctx_lock);
     runtime_stat.set_desc_tbl(desc_tbl);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to