This is an automated email from the ASF dual-hosted git repository.

zouxinyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d1bf9eeb04 [fix](memory) Refactor refresh workload groups weighted 
memory ratio and record refresh interval memory growth (#38168)
8d1bf9eeb04 is described below

commit 8d1bf9eeb04905a43574d1931414aaa63b625b1b
Author: Xinyi Zou <zouxiny...@gmail.com>
AuthorDate: Mon Jul 22 21:11:48 2024 +0800

    [fix](memory) Refactor refresh workload groups weighted memory ratio and 
record refresh interval memory growth (#38168)
    
    1. Modify `refresh_wg_memory_info` to `refresh_wg_weighted_memory_ratio`,
    - Previously, `total memory used` and `weighted memory` of wordload groups 
were refreshed periodically, and they no longer change during the refresh 
period.
    - Now only `weighted memory ratio` is refreshed periodically, and `weighted 
memory` will use (`total memory used` + `refresh interval memory growth`) * 
`weighted memory ratio` in real time. Wordload groups and queries only save 
`weighted memory ratio`, which is more accurate.
    - wordload groups `total memory used` is refreshed when 
`make_memory_tracker_snapshots` is executed, which is called by `workload 
gorups memory GC` and `refresh_wg_weighted_memory_ratio`.
    
    2. `Thread Context` add workload group weak ptr, and record workload group 
memory ratio refresh interval memory growth.
    - if thread in query or load, initialize with memory tracker, query id and 
workload group wptr.
    - else not query or load, initialize with memory tracker, empty query id 
and default normal workload group.
---
 be/src/common/config.cpp                           |   3 +-
 be/src/common/config.h                             |   4 +-
 be/src/common/daemon.cpp                           |  11 +-
 be/src/common/daemon.h                             |   2 +-
 be/src/pipeline/exec/spill_utils.h                 |   6 +-
 be/src/pipeline/pipeline_fragment_context.cpp      |   3 +-
 be/src/pipeline/pipeline_fragment_context.h        |   2 -
 be/src/pipeline/pipeline_task.cpp                  |   2 +-
 be/src/runtime/fragment_mgr.cpp                    |   7 +-
 be/src/runtime/load_channel.cpp                    |   4 +-
 be/src/runtime/load_stream.cpp                     |   4 +-
 be/src/runtime/memory/mem_tracker_limiter.h        |   2 +-
 be/src/runtime/memory/thread_mem_tracker_mgr.h     |  20 +++-
 be/src/runtime/query_context.h                     |  11 +-
 be/src/runtime/thread_context.cpp                  |  36 ++++---
 be/src/runtime/thread_context.h                    |  34 ++++--
 be/src/runtime/workload_group/workload_group.cpp   |  34 ++++--
 be/src/runtime/workload_group/workload_group.h     |  40 ++++++--
 .../workload_group/workload_group_manager.cpp      | 114 +++++++++------------
 .../workload_group/workload_group_manager.h        |   2 +-
 be/src/vec/exec/scan/scanner_context.cpp           |   3 +-
 be/src/vec/runtime/vdata_stream_recvr.cpp          |   8 +-
 be/src/vec/runtime/vdata_stream_recvr.h            |   6 +-
 .../runtime/memory/thread_mem_tracker_mgr_test.cpp |  39 ++++---
 24 files changed, 235 insertions(+), 162 deletions(-)

diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp
index 5222100170e..f66a7dd17c5 100644
--- a/be/src/common/config.cpp
+++ b/be/src/common/config.cpp
@@ -601,8 +601,7 @@ DEFINE_mInt32(memory_gc_sleep_time_ms, "1000");
 // Sleep time in milliseconds between memtbale flush mgr refresh iterations
 DEFINE_mInt64(memtable_mem_tracker_refresh_interval_ms, "5");
 
-// Sleep time in milliseconds between refresh iterations of workload group 
memory statistics
-DEFINE_mInt64(wg_mem_refresh_interval_ms, "50");
+DEFINE_mInt64(wg_weighted_memory_ratio_refresh_interval_ms, "50");
 
 // percent of (active memtables size / all memtables size) when reach hard 
limit
 DEFINE_mInt32(memtable_hard_limit_active_percent, "50");
diff --git a/be/src/common/config.h b/be/src/common/config.h
index 53261ab2fb9..fd38924f47e 100644
--- a/be/src/common/config.h
+++ b/be/src/common/config.h
@@ -658,8 +658,8 @@ DECLARE_mInt32(memory_gc_sleep_time_ms);
 // Sleep time in milliseconds between memtbale flush mgr memory refresh 
iterations
 DECLARE_mInt64(memtable_mem_tracker_refresh_interval_ms);
 
-// Sleep time in milliseconds between refresh iterations of workload group 
memory statistics
-DECLARE_mInt64(wg_mem_refresh_interval_ms);
+// Sleep time in milliseconds between refresh iterations of workload group 
weighted memory ratio
+DECLARE_mInt64(wg_weighted_memory_ratio_refresh_interval_ms);
 
 // percent of (active memtables size / all memtables size) when reach hard 
limit
 DECLARE_mInt32(memtable_hard_limit_active_percent);
diff --git a/be/src/common/daemon.cpp b/be/src/common/daemon.cpp
index c97904f5677..7667820b83f 100644
--- a/be/src/common/daemon.cpp
+++ b/be/src/common/daemon.cpp
@@ -392,11 +392,11 @@ void Daemon::je_purge_dirty_pages_thread() const {
     } while (true);
 }
 
-void Daemon::wg_mem_used_refresh_thread() {
-    // Refresh memory usage and limit of workload groups
+void Daemon::wg_weighted_memory_ratio_refresh_thread() {
+    // Refresh weighted memory ratio of workload groups
     while (!_stop_background_threads_latch.wait_for(
-            std::chrono::milliseconds(config::wg_mem_refresh_interval_ms))) {
-        
doris::ExecEnv::GetInstance()->workload_group_mgr()->refresh_wg_memory_info();
+            
std::chrono::milliseconds(config::wg_weighted_memory_ratio_refresh_interval_ms)))
 {
+        
doris::ExecEnv::GetInstance()->workload_group_mgr()->refresh_wg_weighted_memory_ratio();
     }
 }
 
@@ -441,7 +441,8 @@ void Daemon::start() {
     CHECK(st.ok()) << st;
 
     st = Thread::create(
-            "Daemon", "wg_mem_refresh_thread", [this]() { 
this->wg_mem_used_refresh_thread(); },
+            "Daemon", "wg_weighted_memory_ratio_refresh_thread",
+            [this]() { this->wg_weighted_memory_ratio_refresh_thread(); },
             &_threads.emplace_back());
 
     if (config::enable_be_proc_monitor) {
diff --git a/be/src/common/daemon.h b/be/src/common/daemon.h
index 9dfb079b904..2a8adf20e46 100644
--- a/be/src/common/daemon.h
+++ b/be/src/common/daemon.h
@@ -44,7 +44,7 @@ private:
     void calculate_metrics_thread();
     void je_purge_dirty_pages_thread() const;
     void report_runtime_query_statistics_thread();
-    void wg_mem_used_refresh_thread();
+    void wg_weighted_memory_ratio_refresh_thread();
     void be_proc_monitor_thread();
 
     CountDownLatch _stop_background_threads_latch;
diff --git a/be/src/pipeline/exec/spill_utils.h 
b/be/src/pipeline/exec/spill_utils.h
index f2f19512cbd..635a6a6bbbc 100644
--- a/be/src/pipeline/exec/spill_utils.h
+++ b/be/src/pipeline/exec/spill_utils.h
@@ -33,8 +33,6 @@ public:
     SpillRunnable(RuntimeState* state, const 
std::shared_ptr<BasicSharedState>& shared_state,
                   std::function<void()> func)
             : _state(state),
-              _mem_tracker(state->get_query_ctx()->query_mem_tracker),
-              _task_id(state->query_id()),
               _task_context_holder(state->get_task_execution_context()),
               _shared_state_holder(shared_state),
               _func(std::move(func)) {}
@@ -42,7 +40,7 @@ public:
     ~SpillRunnable() override = default;
 
     void run() override {
-        SCOPED_ATTACH_TASK_WITH_ID(_mem_tracker, _task_id);
+        SCOPED_ATTACH_TASK(_state);
         Defer defer([&] {
             std::function<void()> tmp;
             std::swap(tmp, _func);
@@ -66,8 +64,6 @@ public:
 
 private:
     RuntimeState* _state;
-    std::shared_ptr<MemTrackerLimiter> _mem_tracker;
-    TUniqueId _task_id;
     std::weak_ptr<TaskExecutionContext> _task_context_holder;
     std::weak_ptr<BasicSharedState> _shared_state_holder;
     std::function<void()> _func;
diff --git a/be/src/pipeline/pipeline_fragment_context.cpp 
b/be/src/pipeline/pipeline_fragment_context.cpp
index 1ab40723641..45f49fc09b9 100644
--- a/be/src/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_fragment_context.cpp
@@ -124,12 +124,11 @@ PipelineFragmentContext::PipelineFragmentContext(
           _is_report_on_cancel(true),
           _report_status_cb(report_status_cb) {
     _fragment_watcher.start();
-    _query_thread_context = {query_id, _query_ctx->query_mem_tracker};
 }
 
 PipelineFragmentContext::~PipelineFragmentContext() {
     // The memory released by the query end is recorded in the query mem 
tracker.
-    
SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_query_thread_context.query_mem_tracker);
+    SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_query_ctx->query_mem_tracker);
     auto st = _query_ctx->exec_status();
     _query_ctx.reset();
     for (size_t i = 0; i < _tasks.size(); i++) {
diff --git a/be/src/pipeline/pipeline_fragment_context.h 
b/be/src/pipeline/pipeline_fragment_context.h
index 3b6c73dbef4..7597c3ce9b5 100644
--- a/be/src/pipeline/pipeline_fragment_context.h
+++ b/be/src/pipeline/pipeline_fragment_context.h
@@ -214,8 +214,6 @@ private:
 
     std::shared_ptr<QueryContext> _query_ctx;
 
-    QueryThreadContext _query_thread_context;
-
     MonotonicStopWatch _fragment_watcher;
     RuntimeProfile::Counter* _prepare_timer = nullptr;
     RuntimeProfile::Counter* _init_context_timer = nullptr;
diff --git a/be/src/pipeline/pipeline_task.cpp 
b/be/src/pipeline/pipeline_task.cpp
index 80b23d94011..dbe09e817eb 100644
--- a/be/src/pipeline/pipeline_task.cpp
+++ b/be/src/pipeline/pipeline_task.cpp
@@ -406,7 +406,7 @@ bool PipelineTask::should_revoke_memory(RuntimeState* 
state, int64_t revocable_m
     } else if (is_wg_mem_low_water_mark) {
         int64_t query_weighted_limit = 0;
         int64_t query_weighted_consumption = 0;
-        query_ctx->get_weighted_mem_info(query_weighted_limit, 
query_weighted_consumption);
+        query_ctx->get_weighted_memory(query_weighted_limit, 
query_weighted_consumption);
         if (query_weighted_limit == 0 || query_weighted_consumption < 
query_weighted_limit) {
             return false;
         }
diff --git a/be/src/runtime/fragment_mgr.cpp b/be/src/runtime/fragment_mgr.cpp
index 5389bf2b7ec..d6bbba016be 100644
--- a/be/src/runtime/fragment_mgr.cpp
+++ b/be/src/runtime/fragment_mgr.cpp
@@ -710,7 +710,7 @@ Status FragmentMgr::exec_plan_fragment(const 
TPipelineFragmentParams& params,
 
     std::shared_ptr<QueryContext> query_ctx;
     RETURN_IF_ERROR(_get_query_ctx(params, params.query_id, true, query_ctx));
-    SCOPED_ATTACH_TASK_WITH_ID(query_ctx->query_mem_tracker, params.query_id);
+    SCOPED_ATTACH_TASK(query_ctx.get());
     int64_t duration_ns = 0;
     std::shared_ptr<pipeline::PipelineFragmentContext> context =
             std::make_shared<pipeline::PipelineFragmentContext>(
@@ -1073,7 +1073,8 @@ Status FragmentMgr::apply_filterv2(const 
PPublishFilterRequestV2* request,
                 runtime_filter_mgr = 
pip_context->get_query_ctx()->runtime_filter_mgr();
                 pool = &pip_context->get_query_ctx()->obj_pool;
                 query_thread_context = 
{pip_context->get_query_ctx()->query_id(),
-                                        
pip_context->get_query_ctx()->query_mem_tracker};
+                                        
pip_context->get_query_ctx()->query_mem_tracker,
+                                        
pip_context->get_query_ctx()->workload_group()};
             } else {
                 return Status::InternalError("Non-pipeline is disabled!");
             }
@@ -1165,7 +1166,7 @@ Status FragmentMgr::merge_filter(const 
PMergeFilterRequest* request,
                                            queryid.to_string());
         }
     }
-    SCOPED_ATTACH_TASK_WITH_ID(query_ctx->query_mem_tracker, 
query_ctx->query_id());
+    SCOPED_ATTACH_TASK(query_ctx.get());
     auto merge_status = filter_controller->merge(request, attach_data);
     return merge_status;
 }
diff --git a/be/src/runtime/load_channel.cpp b/be/src/runtime/load_channel.cpp
index b7df43b65b1..99f0a0b3d5b 100644
--- a/be/src/runtime/load_channel.cpp
+++ b/be/src/runtime/load_channel.cpp
@@ -29,6 +29,7 @@
 #include "runtime/memory/mem_tracker.h"
 #include "runtime/tablets_channel.h"
 #include "runtime/thread_context.h"
+#include "runtime/workload_group/workload_group_manager.h"
 
 namespace doris {
 
@@ -46,7 +47,8 @@ LoadChannel::LoadChannel(const UniqueId& load_id, int64_t 
timeout_s, bool is_hig
             
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx_with_lock(
                     _load_id.to_thrift());
     if (query_context != nullptr) {
-        _query_thread_context = {_load_id.to_thrift(), 
query_context->query_mem_tracker};
+        _query_thread_context = {_load_id.to_thrift(), 
query_context->query_mem_tracker,
+                                 query_context->workload_group()};
     } else {
         _query_thread_context = {
                 _load_id.to_thrift(),
diff --git a/be/src/runtime/load_stream.cpp b/be/src/runtime/load_stream.cpp
index 8adc90364e1..c818c4664a0 100644
--- a/be/src/runtime/load_stream.cpp
+++ b/be/src/runtime/load_stream.cpp
@@ -40,6 +40,7 @@
 #include "runtime/load_channel.h"
 #include "runtime/load_stream_mgr.h"
 #include "runtime/load_stream_writer.h"
+#include "runtime/workload_group/workload_group_manager.h"
 #include "util/debug_points.h"
 #include "util/runtime_profile.h"
 #include "util/thrift_util.h"
@@ -361,7 +362,8 @@ LoadStream::LoadStream(PUniqueId load_id, LoadStreamMgr* 
load_stream_mgr, bool e
     std::shared_ptr<QueryContext> query_context =
             
ExecEnv::GetInstance()->fragment_mgr()->get_or_erase_query_ctx_with_lock(load_tid);
     if (query_context != nullptr) {
-        _query_thread_context = {load_tid, query_context->query_mem_tracker};
+        _query_thread_context = {load_tid, query_context->query_mem_tracker,
+                                 query_context->workload_group()};
     } else {
         _query_thread_context = {load_tid, MemTrackerLimiter::create_shared(
                                                    
MemTrackerLimiter::Type::LOAD,
diff --git a/be/src/runtime/memory/mem_tracker_limiter.h 
b/be/src/runtime/memory/mem_tracker_limiter.h
index dd2b89029cb..e5c5cb1bc03 100644
--- a/be/src/runtime/memory/mem_tracker_limiter.h
+++ b/be/src/runtime/memory/mem_tracker_limiter.h
@@ -221,7 +221,7 @@ public:
     }
 
     // Iterator into mem_tracker_limiter_pool for this object. Stored to have 
O(1) remove.
-    std::list<std::weak_ptr<MemTrackerLimiter>>::iterator 
tg_tracker_limiter_group_it;
+    std::list<std::weak_ptr<MemTrackerLimiter>>::iterator 
wg_tracker_limiter_group_it;
 
 private:
     friend class ThreadMemTrackerMgr;
diff --git a/be/src/runtime/memory/thread_mem_tracker_mgr.h 
b/be/src/runtime/memory/thread_mem_tracker_mgr.h
index 9d36cd2d807..39e896d0f18 100644
--- a/be/src/runtime/memory/thread_mem_tracker_mgr.h
+++ b/be/src/runtime/memory/thread_mem_tracker_mgr.h
@@ -33,6 +33,7 @@
 #include "runtime/memory/global_memory_arbitrator.h"
 #include "runtime/memory/mem_tracker.h"
 #include "runtime/memory/mem_tracker_limiter.h"
+#include "runtime/workload_group/workload_group.h"
 #include "util/stack_util.h"
 #include "util/uid_util.h"
 
@@ -71,6 +72,10 @@ public:
 
     TUniqueId query_id() { return _query_id; }
 
+    void set_wg_wptr(const std::weak_ptr<WorkloadGroup>& wg_wptr) { _wg_wptr = 
wg_wptr; }
+
+    void reset_wg_wptr() { _wg_wptr.reset(); }
+
     void start_count_scope_mem() {
         CHECK(init());
         _scope_mem = _reserved_mem; // consume in advance
@@ -151,6 +156,7 @@ private:
     std::shared_ptr<MemTrackerLimiter> _limiter_tracker;
     MemTrackerLimiter* _limiter_tracker_raw = nullptr;
     std::vector<MemTracker*> _consumer_tracker_stack;
+    std::weak_ptr<WorkloadGroup> _wg_wptr;
 
     // If there is a memory new/delete operation in the consume method, it may 
enter infinite recursion.
     bool _stop_consume = false;
@@ -287,8 +293,16 @@ inline bool ThreadMemTrackerMgr::try_reserve(int64_t size) 
{
     if (!_limiter_tracker_raw->try_consume(size)) {
         return false;
     }
+    auto wg_ptr = _wg_wptr.lock();
+    if (!wg_ptr) {
+        if (!wg_ptr->add_wg_refresh_interval_memory_growth(size)) {
+            _limiter_tracker_raw->release(size); // rollback
+            return false;
+        }
+    }
     if (!doris::GlobalMemoryArbitrator::try_reserve_process_memory(size)) {
-        _limiter_tracker_raw->release(size); // rollback
+        _limiter_tracker_raw->release(size);                 // rollback
+        wg_ptr->sub_wg_refresh_interval_memory_growth(size); // rollback
         return false;
     }
     if (_count_scope_mem) {
@@ -306,6 +320,10 @@ inline void ThreadMemTrackerMgr::release_reserved() {
         
doris::GlobalMemoryArbitrator::release_process_reserved_memory(_reserved_mem +
                                                                        
_untracked_mem);
         _limiter_tracker_raw->release(_reserved_mem);
+        auto wg_ptr = _wg_wptr.lock();
+        if (!wg_ptr) {
+            wg_ptr->sub_wg_refresh_interval_memory_growth(_reserved_mem);
+        }
         if (_count_scope_mem) {
             _scope_mem -= _reserved_mem;
         }
diff --git a/be/src/runtime/query_context.h b/be/src/runtime/query_context.h
index b565214ef22..0d7870a0e1d 100644
--- a/be/src/runtime/query_context.h
+++ b/be/src/runtime/query_context.h
@@ -230,15 +230,16 @@ public:
         return _running_big_mem_op_num.load(std::memory_order_relaxed);
     }
 
-    void set_weighted_mem(int64_t weighted_limit, int64_t 
weighted_consumption) {
+    void set_weighted_memory(int64_t weighted_limit, double weighted_ratio) {
         std::lock_guard<std::mutex> l(_weighted_mem_lock);
-        _weighted_consumption = weighted_consumption;
         _weighted_limit = weighted_limit;
+        _weighted_ratio = weighted_ratio;
     }
-    void get_weighted_mem_info(int64_t& weighted_limit, int64_t& 
weighted_consumption) {
+
+    void get_weighted_memory(int64_t& weighted_limit, int64_t& 
weighted_consumption) {
         std::lock_guard<std::mutex> l(_weighted_mem_lock);
         weighted_limit = _weighted_limit;
-        weighted_consumption = _weighted_consumption;
+        weighted_consumption = int64_t(query_mem_tracker->consumption() * 
_weighted_ratio);
     }
 
     DescriptorTbl* desc_tbl = nullptr;
@@ -311,7 +312,7 @@ private:
     std::mutex _pipeline_map_write_lock;
 
     std::mutex _weighted_mem_lock;
-    int64_t _weighted_consumption = 0;
+    double _weighted_ratio = 0;
     int64_t _weighted_limit = 0;
 
     std::mutex _profile_mutex;
diff --git a/be/src/runtime/thread_context.cpp 
b/be/src/runtime/thread_context.cpp
index 6f69eb9e134..c89f532e592 100644
--- a/be/src/runtime/thread_context.cpp
+++ b/be/src/runtime/thread_context.cpp
@@ -18,7 +18,9 @@
 #include "runtime/thread_context.h"
 
 #include "common/signal_handler.h"
+#include "runtime/query_context.h"
 #include "runtime/runtime_state.h"
+#include "runtime/workload_group/workload_group_manager.h"
 
 namespace doris {
 class MemTracker;
@@ -26,34 +28,38 @@ class MemTracker;
 QueryThreadContext ThreadContext::query_thread_context() {
     DCHECK(doris::pthread_context_ptr_init);
     ORPHAN_TRACKER_CHECK();
-    return {_task_id, thread_mem_tracker_mgr->limiter_mem_tracker()};
+    return {_task_id, thread_mem_tracker_mgr->limiter_mem_tracker(), _wg_wptr};
 }
 
-AttachTask::AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
-                       const TUniqueId& task_id) {
+void AttachTask::init(const QueryThreadContext& query_thread_context) {
     ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(task_id);
-    thread_context()->attach_task(task_id, mem_tracker);
+    signal::set_signal_task_id(query_thread_context.query_id);
+    thread_context()->attach_task(query_thread_context.query_id,
+                                  query_thread_context.query_mem_tracker,
+                                  query_thread_context.wg_wptr);
 }
 
 AttachTask::AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker) {
-    ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(TUniqueId());
-    thread_context()->attach_task(TUniqueId(), mem_tracker);
+    QueryThreadContext query_thread_context = {TUniqueId(), mem_tracker};
+    init(query_thread_context);
 }
 
 AttachTask::AttachTask(RuntimeState* runtime_state) {
-    ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(runtime_state->query_id());
     signal::set_signal_is_nereids(runtime_state->is_nereids());
-    thread_context()->attach_task(runtime_state->query_id(), 
runtime_state->query_mem_tracker());
+    QueryThreadContext query_thread_context = {runtime_state->query_id(),
+                                               
runtime_state->query_mem_tracker(),
+                                               
runtime_state->get_query_ctx()->workload_group()};
+    init(query_thread_context);
 }
 
 AttachTask::AttachTask(const QueryThreadContext& query_thread_context) {
-    ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(query_thread_context.query_id);
-    thread_context()->attach_task(query_thread_context.query_id,
-                                  query_thread_context.query_mem_tracker);
+    init(query_thread_context);
+}
+
+AttachTask::AttachTask(QueryContext* query_ctx) {
+    QueryThreadContext query_thread_context = {query_ctx->query_id(), 
query_ctx->query_mem_tracker,
+                                               query_ctx->workload_group()};
+    init(query_thread_context);
 }
 
 AttachTask::~AttachTask() {
diff --git a/be/src/runtime/thread_context.h b/be/src/runtime/thread_context.h
index 45f64a3739a..4a731424b9e 100644
--- a/be/src/runtime/thread_context.h
+++ b/be/src/runtime/thread_context.h
@@ -45,8 +45,6 @@
 // This will save some info about a working thread in the thread context.
 // Looking forward to tracking memory during thread execution into 
MemTrackerLimiter.
 #define SCOPED_ATTACH_TASK(arg1) auto VARNAME_LINENUM(attach_task) = 
AttachTask(arg1)
-#define SCOPED_ATTACH_TASK_WITH_ID(arg1, arg2) \
-    auto VARNAME_LINENUM(attach_task) = AttachTask(arg1, arg2)
 
 // Switch MemTrackerLimiter for count memory during thread execution.
 // Used after SCOPED_ATTACH_TASK, in order to count the memory into another
@@ -74,8 +72,6 @@
 // thread context need to be initialized, required by Allocator and elsewhere.
 #define SCOPED_ATTACH_TASK(arg1, ...) \
     auto VARNAME_LINENUM(scoped_tls_at) = doris::ScopedInitThreadContext()
-#define SCOPED_ATTACH_TASK_WITH_ID(arg1, arg2) \
-    auto VARNAME_LINENUM(scoped_tls_atwi) = doris::ScopedInitThreadContext()
 #define SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(arg1) \
     auto VARNAME_LINENUM(scoped_tls_stmtl) = doris::ScopedInitThreadContext()
 #define SCOPED_CONSUME_MEM_TRACKER(mem_tracker) \
@@ -124,6 +120,7 @@ class ThreadContext;
 class MemTracker;
 class RuntimeState;
 class QueryThreadContext;
+class WorkloadGroup;
 
 extern bthread_key_t btls_key;
 
@@ -158,7 +155,8 @@ public:
     ~ThreadContext() = default;
 
     void attach_task(const TUniqueId& task_id,
-                     const std::shared_ptr<MemTrackerLimiter>& mem_tracker) {
+                     const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
+                     const std::weak_ptr<WorkloadGroup>& wg_wptr) {
         // will only attach_task at the beginning of the thread function, 
there should be no duplicate attach_task.
         DCHECK(mem_tracker);
         // Orphan is thread default tracker.
@@ -166,16 +164,20 @@ public:
                 << ", thread mem tracker label: " << 
thread_mem_tracker()->label()
                 << ", attach mem tracker label: " << mem_tracker->label();
         _task_id = task_id;
+        _wg_wptr = wg_wptr;
         thread_mem_tracker_mgr->attach_limiter_tracker(mem_tracker);
         thread_mem_tracker_mgr->set_query_id(_task_id);
+        thread_mem_tracker_mgr->set_wg_wptr(_wg_wptr);
         thread_mem_tracker_mgr->enable_wait_gc();
         thread_mem_tracker_mgr->reset_query_cancelled_flag(false);
     }
 
     void detach_task() {
         _task_id = TUniqueId();
+        _wg_wptr.reset();
         thread_mem_tracker_mgr->detach_limiter_tracker();
         thread_mem_tracker_mgr->set_query_id(TUniqueId());
+        thread_mem_tracker_mgr->reset_wg_wptr();
         thread_mem_tracker_mgr->disable_wait_gc();
     }
 
@@ -226,12 +228,15 @@ public:
         thread_mem_tracker_mgr->release_reserved();
     }
 
+    std::weak_ptr<WorkloadGroup> workload_group() { return _wg_wptr; }
+
     int thread_local_handle_count = 0;
     int skip_memory_check = 0;
     int skip_large_memory_check = 0;
 
 private:
     TUniqueId _task_id;
+    std::weak_ptr<WorkloadGroup> _wg_wptr;
 };
 
 class ThreadLocalHandle {
@@ -313,6 +318,11 @@ static ThreadContext* thread_context() {
 class QueryThreadContext {
 public:
     QueryThreadContext() = default;
+    QueryThreadContext(const TUniqueId& query_id,
+                       const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
+                       const std::weak_ptr<WorkloadGroup>& wg_wptr)
+            : query_id(query_id), query_mem_tracker(mem_tracker), 
wg_wptr(wg_wptr) {}
+    // If use WorkloadGroup and can get WorkloadGroup ptr, must as a parameter.
     QueryThreadContext(const TUniqueId& query_id,
                        const std::shared_ptr<MemTrackerLimiter>& mem_tracker)
             : query_id(query_id), query_mem_tracker(mem_tracker) {}
@@ -324,6 +334,7 @@ public:
         ORPHAN_TRACKER_CHECK();
         query_id = doris::thread_context()->task_id();
         query_mem_tracker = 
doris::thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker();
+        wg_wptr = doris::thread_context()->workload_group();
 #else
         query_id = TUniqueId();
         query_mem_tracker = 
doris::ExecEnv::GetInstance()->orphan_mem_tracker();
@@ -332,6 +343,7 @@ public:
 
     TUniqueId query_id;
     std::shared_ptr<MemTrackerLimiter> query_mem_tracker;
+    std::weak_ptr<WorkloadGroup> wg_wptr;
 };
 
 class ScopeMemCountByHook {
@@ -363,15 +375,18 @@ public:
 
 class AttachTask {
 public:
-    explicit AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
-                        const TUniqueId& task_id);
-
+    // not query or load, initialize with memory tracker, empty query id and 
default normal workload group.
     explicit AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker);
 
+    // is query or load, initialize with memory tracker, query id and workload 
group wptr.
     explicit AttachTask(RuntimeState* runtime_state);
 
+    explicit AttachTask(QueryContext* query_ctx);
+
     explicit AttachTask(const QueryThreadContext& query_thread_context);
 
+    void init(const QueryThreadContext& query_thread_context);
+
     ~AttachTask();
 };
 
@@ -386,7 +401,8 @@ public:
 
     explicit SwitchThreadMemTrackerLimiter(const QueryThreadContext& 
query_thread_context) {
         ThreadLocalHandle::create_thread_local_if_not_exits();
-        DCHECK(thread_context()->task_id() == query_thread_context.query_id);
+        DCHECK(thread_context()->task_id() ==
+               query_thread_context.query_id); // workload group alse not 
change
         DCHECK(query_thread_context.query_mem_tracker);
         _old_mem_tracker = 
thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker();
         thread_context()->thread_mem_tracker_mgr->attach_limiter_tracker(
diff --git a/be/src/runtime/workload_group/workload_group.cpp 
b/be/src/runtime/workload_group/workload_group.cpp
index f4d1e0d4f7e..131daf1a4f4 100644
--- a/be/src/runtime/workload_group/workload_group.cpp
+++ b/be/src/runtime/workload_group/workload_group.cpp
@@ -107,39 +107,59 @@ void WorkloadGroup::check_and_update(const 
WorkloadGroupInfo& tg_info) {
     }
 }
 
-int64_t WorkloadGroup::memory_used() {
+int64_t WorkloadGroup::make_memory_tracker_snapshots(
+        std::list<std::shared_ptr<MemTrackerLimiter>>* tracker_snapshots) {
     int64_t used_memory = 0;
     for (auto& mem_tracker_group : _mem_tracker_limiter_pool) {
         std::lock_guard<std::mutex> l(mem_tracker_group.group_lock);
         for (const auto& trackerWptr : mem_tracker_group.trackers) {
             auto tracker = trackerWptr.lock();
             CHECK(tracker != nullptr);
+            if (tracker_snapshots != nullptr) {
+                tracker_snapshots->insert(tracker_snapshots->end(), tracker);
+            }
             used_memory += tracker->consumption();
         }
     }
+    refresh_memory(used_memory);
     return used_memory;
 }
 
-void WorkloadGroup::set_weighted_memory_used(int64_t wg_total_mem_used, double 
ratio) {
-    _weighted_mem_used.store(int64_t(wg_total_mem_used * ratio), 
std::memory_order_relaxed);
+int64_t WorkloadGroup::memory_used() {
+    return make_memory_tracker_snapshots(nullptr);
+}
+
+void WorkloadGroup::refresh_memory(int64_t used_memory) {
+    // refresh total memory used.
+    _total_mem_used = used_memory;
+    // reserve memory is recorded in the query mem tracker
+    // and _total_mem_used already contains all the current reserve memory.
+    // so after refreshing _total_mem_used, reset 
_wg_refresh_interval_memory_growth.
+    _wg_refresh_interval_memory_growth.store(0.0);
+}
+
+void WorkloadGroup::set_weighted_memory_ratio(double ratio) {
+    _weighted_mem_ratio = ratio;
 }
 
 void WorkloadGroup::add_mem_tracker_limiter(std::shared_ptr<MemTrackerLimiter> 
mem_tracker_ptr) {
+    std::unique_lock<std::shared_mutex> wlock(_mutex);
     auto group_num = mem_tracker_ptr->group_num();
     std::lock_guard<std::mutex> 
l(_mem_tracker_limiter_pool[group_num].group_lock);
-    mem_tracker_ptr->tg_tracker_limiter_group_it =
+    mem_tracker_ptr->wg_tracker_limiter_group_it =
             _mem_tracker_limiter_pool[group_num].trackers.insert(
                     _mem_tracker_limiter_pool[group_num].trackers.end(), 
mem_tracker_ptr);
 }
 
 void 
WorkloadGroup::remove_mem_tracker_limiter(std::shared_ptr<MemTrackerLimiter> 
mem_tracker_ptr) {
+    std::unique_lock<std::shared_mutex> wlock(_mutex);
     auto group_num = mem_tracker_ptr->group_num();
     std::lock_guard<std::mutex> 
l(_mem_tracker_limiter_pool[group_num].group_lock);
-    if (mem_tracker_ptr->tg_tracker_limiter_group_it !=
+    if (mem_tracker_ptr->wg_tracker_limiter_group_it !=
         _mem_tracker_limiter_pool[group_num].trackers.end()) {
         _mem_tracker_limiter_pool[group_num].trackers.erase(
-                mem_tracker_ptr->tg_tracker_limiter_group_it);
-        mem_tracker_ptr->tg_tracker_limiter_group_it =
+                mem_tracker_ptr->wg_tracker_limiter_group_it);
+        mem_tracker_ptr->wg_tracker_limiter_group_it =
                 _mem_tracker_limiter_pool[group_num].trackers.end();
     }
 }
diff --git a/be/src/runtime/workload_group/workload_group.h 
b/be/src/runtime/workload_group/workload_group.h
index a82efab0904..a19ac53474f 100644
--- a/be/src/runtime/workload_group/workload_group.h
+++ b/be/src/runtime/workload_group/workload_group.h
@@ -76,7 +76,12 @@ public:
         return _memory_limit;
     };
 
+    // make memory snapshots and refresh total memory used at the same time.
+    int64_t make_memory_tracker_snapshots(
+            std::list<std::shared_ptr<MemTrackerLimiter>>* tracker_snapshots);
+    // call make_memory_tracker_snapshots, so also refresh total memory used.
     int64_t memory_used();
+    void refresh_memory(int64_t used_memory);
 
     int spill_threshold_low_water_mark() const {
         return _spill_low_watermark.load(std::memory_order_relaxed);
@@ -85,10 +90,31 @@ public:
         return _spill_high_watermark.load(std::memory_order_relaxed);
     }
 
-    void set_weighted_memory_used(int64_t wg_total_mem_used, double ratio);
+    void set_weighted_memory_ratio(double ratio);
+    bool add_wg_refresh_interval_memory_growth(int64_t size) {
+        // `weighted_mem_used` is a rough memory usage in this group,
+        // because we can only get a precise memory usage by MemTracker which 
is not include page cache.
+        auto weighted_mem_used =
+                int64_t((_total_mem_used + 
_wg_refresh_interval_memory_growth.load() + size) *
+                        _weighted_mem_ratio);
+        if ((weighted_mem_used > ((double)_memory_limit *
+                                  
_spill_high_watermark.load(std::memory_order_relaxed) / 100))) {
+            return false;
+        } else {
+            _wg_refresh_interval_memory_growth.fetch_add(size);
+            return true;
+        }
+    }
+    void sub_wg_refresh_interval_memory_growth(int64_t size) {
+        _wg_refresh_interval_memory_growth.fetch_sub(size);
+    }
 
     void check_mem_used(bool* is_low_wartermark, bool* is_high_wartermark) 
const {
-        auto weighted_mem_used = 
_weighted_mem_used.load(std::memory_order_relaxed);
+        // `weighted_mem_used` is a rough memory usage in this group,
+        // because we can only get a precise memory usage by MemTracker which 
is not include page cache.
+        auto weighted_mem_used =
+                int64_t((_total_mem_used + 
_wg_refresh_interval_memory_growth.load()) *
+                        _weighted_mem_ratio);
         *is_low_wartermark =
                 (weighted_mem_used > ((double)_memory_limit *
                                       
_spill_low_watermark.load(std::memory_order_relaxed) / 100));
@@ -137,7 +163,7 @@ public:
 
     bool can_be_dropped() {
         std::shared_lock<std::shared_mutex> r_lock(_mutex);
-        return _is_shutdown && _query_ctxs.size() == 0;
+        return _is_shutdown && _query_ctxs.empty();
     }
 
     int query_num() {
@@ -169,9 +195,11 @@ private:
     std::string _name;
     int64_t _version;
     int64_t _memory_limit; // bytes
-    // `_weighted_mem_used` is a rough memory usage in this group,
-    // because we can only get a precise memory usage by MemTracker which is 
not include page cache.
-    std::atomic_int64_t _weighted_mem_used = 0; // bytes
+    // last value of make_memory_tracker_snapshots, refresh every time 
make_memory_tracker_snapshots is called.
+    std::atomic_int64_t _total_mem_used = 0; // bytes
+    // last value of refresh_wg_weighted_memory_ratio.
+    std::atomic<double> _weighted_mem_ratio = 0.0;
+    std::atomic_int64_t _wg_refresh_interval_memory_growth;
     bool _enable_memory_overcommit;
     std::atomic<uint64_t> _cpu_share;
     std::vector<TrackerLimiterGroup> _mem_tracker_limiter_pool;
diff --git a/be/src/runtime/workload_group/workload_group_manager.cpp 
b/be/src/runtime/workload_group/workload_group_manager.cpp
index 6813bfd3b75..9e595841c67 100644
--- a/be/src/runtime/workload_group/workload_group_manager.cpp
+++ b/be/src/runtime/workload_group/workload_group_manager.cpp
@@ -149,46 +149,33 @@ void 
WorkloadGroupMgr::delete_workload_group_by_ids(std::set<uint64_t> used_wg_i
 
 struct WorkloadGroupMemInfo {
     int64_t total_mem_used = 0;
-    int64_t weighted_mem_used = 0;
-    bool is_low_wartermark = false;
-    bool is_high_wartermark = false;
-    double mem_used_ratio = 0;
+    std::list<std::shared_ptr<MemTrackerLimiter>> tracker_snapshots =
+            std::list<std::shared_ptr<MemTrackerLimiter>>();
 };
-void WorkloadGroupMgr::refresh_wg_memory_info() {
+
+void WorkloadGroupMgr::refresh_wg_weighted_memory_ratio() {
     std::shared_lock<std::shared_mutex> r_lock(_group_mutex);
-    // workload group id -> workload group queries
-    std::unordered_map<uint64_t, std::unordered_map<TUniqueId, 
std::weak_ptr<QueryContext>>>
-            all_wg_queries;
-    for (auto& [wg_id, wg] : _workload_groups) {
-        all_wg_queries.insert({wg_id, wg->queries()});
-    }
 
+    // 1. make all workload groups memory snapshots(refresh workload groups 
total memory used at the same time)
+    // and calculate total memory used of all queries.
     int64_t all_queries_mem_used = 0;
-
-    // calculate total memory used of each workload group and total memory 
used of all queries
     std::unordered_map<uint64_t, WorkloadGroupMemInfo> wgs_mem_info;
-    for (auto& [wg_id, wg_queries] : all_wg_queries) {
-        int64_t wg_total_mem_used = 0;
-        for (const auto& [query_id, query_ctx_ptr] : wg_queries) {
-            if (auto query_ctx = query_ctx_ptr.lock()) {
-                wg_total_mem_used += 
query_ctx->query_mem_tracker->consumption();
-            }
-        }
-        all_queries_mem_used += wg_total_mem_used;
-        wgs_mem_info[wg_id] = {wg_total_mem_used};
+    for (auto& [wg_id, wg] : _workload_groups) {
+        wgs_mem_info[wg_id].total_mem_used =
+                
wg->make_memory_tracker_snapshots(&wgs_mem_info[wg_id].tracker_snapshots);
+        all_queries_mem_used += wgs_mem_info[wg_id].total_mem_used;
     }
-
-    auto process_memory_usage = GlobalMemoryArbitrator::process_memory_usage();
     if (all_queries_mem_used <= 0) {
         return;
     }
 
-    all_queries_mem_used = std::min(process_memory_usage, 
all_queries_mem_used);
-
+    // 2. calculate weighted ratio.
     // process memory used is actually bigger than all_queries_mem_used,
     // because memory of page cache, allocator cache, segment cache etc. are 
included
     // in proc_vm_rss.
     // we count these cache memories equally on workload groups.
+    auto process_memory_usage = GlobalMemoryArbitrator::process_memory_usage();
+    all_queries_mem_used = std::min(process_memory_usage, 
all_queries_mem_used);
     double ratio = (double)process_memory_usage / (double)all_queries_mem_used;
     if (ratio <= 1.25) {
         std::string debug_msg =
@@ -200,66 +187,57 @@ void WorkloadGroupMgr::refresh_wg_memory_info() {
     }
 
     for (auto& wg : _workload_groups) {
+        // 3.1 calculate query weighted memory limit of task group
         auto wg_mem_limit = wg.second->memory_limit();
-        auto& wg_mem_info = wgs_mem_info[wg.first];
-        wg_mem_info.weighted_mem_used = int64_t(wg_mem_info.total_mem_used * 
ratio);
-        wg_mem_info.mem_used_ratio = (double)wg_mem_info.weighted_mem_used / 
wg_mem_limit;
-
-        wg.second->set_weighted_memory_used(wg_mem_info.total_mem_used, ratio);
-
-        auto spill_low_water_mark = 
wg.second->spill_threshold_low_water_mark();
-        auto spill_high_water_mark = 
wg.second->spill_threashold_high_water_mark();
-        wg_mem_info.is_high_wartermark = (wg_mem_info.weighted_mem_used >
-                                          ((double)wg_mem_limit * 
spill_high_water_mark / 100));
-        wg_mem_info.is_low_wartermark = (wg_mem_info.weighted_mem_used >
-                                         ((double)wg_mem_limit * 
spill_low_water_mark / 100));
-
-        // calculate query weighted memory limit of task group
-        const auto& wg_queries = all_wg_queries[wg.first];
-        auto wg_query_count = wg_queries.size();
+        auto wg_query_count = wgs_mem_info[wg.first].tracker_snapshots.size();
         int64_t query_weighted_mem_limit =
                 wg_query_count ? (wg_mem_limit + wg_query_count) / 
wg_query_count : wg_mem_limit;
 
+        // 3.2 set all workload groups weighted memory ratio and all query 
weighted memory limit and ratio.
+        wg.second->set_weighted_memory_ratio(ratio);
+        for (const auto& query : wg.second->queries()) {
+            auto query_ctx = query.second.lock();
+            if (!query_ctx) {
+                continue;
+            }
+            query_ctx->set_weighted_memory(query_weighted_mem_limit, ratio);
+        }
+
+        // 3.3 only print debug logs, if workload groups is_high_wartermark or 
is_low_wartermark.
+        auto weighted_mem_used = int64_t(wgs_mem_info[wg.first].total_mem_used 
* ratio);
+        bool is_high_wartermark =
+                (weighted_mem_used >
+                 ((double)wg_mem_limit * 
wg.second->spill_threashold_high_water_mark() / 100));
+        bool is_low_wartermark =
+                (weighted_mem_used >
+                 ((double)wg_mem_limit * 
wg.second->spill_threshold_low_water_mark() / 100));
         std::string debug_msg;
-        if (wg_mem_info.is_high_wartermark || wg_mem_info.is_low_wartermark) {
+        if (is_high_wartermark || is_low_wartermark) {
             debug_msg = fmt::format(
                     "\nWorkload Group {}: mem limit: {}, mem used: {}, 
weighted mem used: {}, used "
                     "ratio: {}, query "
                     "count: {}, query_weighted_mem_limit: {}",
                     wg.second->name(), PrettyPrinter::print(wg_mem_limit, 
TUnit::BYTES),
-                    PrettyPrinter::print(wg_mem_info.total_mem_used, 
TUnit::BYTES),
-                    PrettyPrinter::print(wg_mem_info.weighted_mem_used, 
TUnit::BYTES),
-                    wg_mem_info.mem_used_ratio, wg_query_count,
+                    
PrettyPrinter::print(wgs_mem_info[wg.first].total_mem_used, TUnit::BYTES),
+                    PrettyPrinter::print(weighted_mem_used, TUnit::BYTES),
+                    (double)weighted_mem_used / wg_mem_limit, wg_query_count,
                     PrettyPrinter::print(query_weighted_mem_limit, 
TUnit::BYTES));
 
             debug_msg += "\n  Query Memory Summary:";
-        } else {
-            continue;
-        }
-        // check whether queries need to revoke memory for task group
-        for (const auto& query : wg_queries) {
-            auto query_ctx = query.second.lock();
-            if (!query_ctx) {
-                continue;
-            }
-            auto query_consumption = 
query_ctx->query_mem_tracker->consumption();
-            auto query_weighted_consumption = int64_t(query_consumption * 
ratio);
-            query_ctx->set_weighted_mem(query_weighted_mem_limit, 
query_weighted_consumption);
-
-            if (wg_mem_info.is_high_wartermark || 
wg_mem_info.is_low_wartermark) {
+            // check whether queries need to revoke memory for task group
+            for (const auto& query_mem_tracker : 
wgs_mem_info[wg.first].tracker_snapshots) {
                 debug_msg += fmt::format(
                         "\n    MemTracker Label={}, Parent Label={}, Used={}, 
WeightedUsed={}, "
                         "Peak={}",
-                        query_ctx->query_mem_tracker->label(),
-                        query_ctx->query_mem_tracker->parent_label(),
-                        PrettyPrinter::print(query_consumption, TUnit::BYTES),
-                        PrettyPrinter::print(query_weighted_consumption, 
TUnit::BYTES),
-                        
PrettyPrinter::print(query_ctx->query_mem_tracker->peak_consumption(),
-                                             TUnit::BYTES));
+                        query_mem_tracker->label(), 
query_mem_tracker->parent_label(),
+                        PrettyPrinter::print(query_mem_tracker->consumption(), 
TUnit::BYTES),
+                        
PrettyPrinter::print(int64_t(query_mem_tracker->consumption() * ratio),
+                                             TUnit::BYTES),
+                        
PrettyPrinter::print(query_mem_tracker->peak_consumption(), TUnit::BYTES));
             }
-        }
-        if (wg_mem_info.is_high_wartermark || wg_mem_info.is_low_wartermark) {
             LOG_EVERY_T(INFO, 1) << debug_msg;
+        } else {
+            continue;
         }
     }
 }
diff --git a/be/src/runtime/workload_group/workload_group_manager.h 
b/be/src/runtime/workload_group/workload_group_manager.h
index 8aeb8f988a3..37539ada8d8 100644
--- a/be/src/runtime/workload_group/workload_group_manager.h
+++ b/be/src/runtime/workload_group/workload_group_manager.h
@@ -54,7 +54,7 @@ public:
 
     bool enable_cpu_hard_limit() { return _enable_cpu_hard_limit.load(); }
 
-    void refresh_wg_memory_info();
+    void refresh_wg_weighted_memory_ratio();
 
 private:
     std::shared_mutex _group_mutex;
diff --git a/be/src/vec/exec/scan/scanner_context.cpp 
b/be/src/vec/exec/scan/scanner_context.cpp
index 1de8e7a7950..4b09260f69a 100644
--- a/be/src/vec/exec/scan/scanner_context.cpp
+++ b/be/src/vec/exec/scan/scanner_context.cpp
@@ -106,7 +106,8 @@ ScannerContext::ScannerContext(
         }
     }
 
-    _query_thread_context = {_query_id, _state->query_mem_tracker()};
+    _query_thread_context = {_query_id, _state->query_mem_tracker(),
+                             _state->get_query_ctx()->workload_group()};
     _dependency = dependency;
 }
 
diff --git a/be/src/vec/runtime/vdata_stream_recvr.cpp 
b/be/src/vec/runtime/vdata_stream_recvr.cpp
index 562a4239545..cae099c31bd 100644
--- a/be/src/vec/runtime/vdata_stream_recvr.cpp
+++ b/be/src/vec/runtime/vdata_stream_recvr.cpp
@@ -328,10 +328,8 @@ VDataStreamRecvr::VDataStreamRecvr(VDataStreamMgr* 
stream_mgr, RuntimeState* sta
                                    int num_senders, bool is_merging, 
RuntimeProfile* profile)
         : HasTaskExecutionCtx(state),
           _mgr(stream_mgr),
-#ifdef USE_MEM_TRACKER
-          _query_mem_tracker(state->query_mem_tracker()),
-          _query_id(state->query_id()),
-#endif
+          _query_thread_context(state->query_id(), state->query_mem_tracker(),
+                                state->get_query_ctx()->workload_group()),
           _fragment_instance_id(fragment_instance_id),
           _dest_node_id(dest_node_id),
           _row_desc(row_desc),
@@ -403,7 +401,7 @@ Status VDataStreamRecvr::create_merger(const 
VExprContextSPtrs& ordering_expr,
 
 Status VDataStreamRecvr::add_block(const PBlock& pblock, int sender_id, int 
be_number,
                                    int64_t packet_seq, 
::google::protobuf::Closure** done) {
-    SCOPED_ATTACH_TASK_WITH_ID(_query_mem_tracker, _query_id);
+    SCOPED_ATTACH_TASK(_query_thread_context);
     int use_sender_id = _is_merging ? sender_id : 0;
     return _sender_queues[use_sender_id]->add_block(pblock, be_number, 
packet_seq, done);
 }
diff --git a/be/src/vec/runtime/vdata_stream_recvr.h 
b/be/src/vec/runtime/vdata_stream_recvr.h
index 1983b7309c0..68fc2eac347 100644
--- a/be/src/vec/runtime/vdata_stream_recvr.h
+++ b/be/src/vec/runtime/vdata_stream_recvr.h
@@ -43,6 +43,7 @@
 #include "common/status.h"
 #include "runtime/descriptors.h"
 #include "runtime/task_execution_context.h"
+#include "runtime/thread_context.h"
 #include "util/runtime_profile.h"
 #include "util/stopwatch.hpp"
 #include "vec/core/block.h"
@@ -120,10 +121,7 @@ private:
     // DataStreamMgr instance used to create this recvr. (Not owned)
     VDataStreamMgr* _mgr = nullptr;
 
-#ifdef USE_MEM_TRACKER
-    std::shared_ptr<MemTrackerLimiter> _query_mem_tracker = nullptr;
-    TUniqueId _query_id;
-#endif
+    QueryThreadContext _query_thread_context;
 
     // Fragment and node id of the destination exchange node this receiver is 
used by.
     TUniqueId _fragment_instance_id;
diff --git a/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp 
b/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
index 29c2759fcb7..ab15fce05a7 100644
--- a/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
+++ b/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
@@ -26,7 +26,18 @@
 
 namespace doris {
 
-TEST(ThreadMemTrackerMgrTest, ConsumeMemory) {
+class ThreadMemTrackerMgrTest : public testing::Test {
+public:
+    ThreadMemTrackerMgrTest() = default;
+    ~ThreadMemTrackerMgrTest() override = default;
+
+    void SetUp() override {}
+
+protected:
+    std::shared_ptr<WorkloadGroup> workload_group;
+};
+
+TEST_F(ThreadMemTrackerMgrTest, ConsumeMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-ConsumeMemory");
@@ -34,7 +45,7 @@ TEST(ThreadMemTrackerMgrTest, ConsumeMemory) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t);
+    thread_context->attach_task(TUniqueId(), t, workload_group);
     thread_context->consume_memory(size1);
     // size1 < config::mem_tracker_consume_min_size_bytes, not consume mem 
tracker.
     EXPECT_EQ(t->consumption(), 0);
@@ -80,7 +91,7 @@ TEST(ThreadMemTrackerMgrTest, Boundary) {
     // TODO, Boundary check may not be necessary, add some `IF` maybe increase 
cost time.
 }
 
-TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
+TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 = MemTrackerLimiter::create_shared(
             MemTrackerLimiter::Type::OTHER, "UT-NestedSwitchMemTracker1");
@@ -92,7 +103,7 @@ TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
     EXPECT_EQ(t1->consumption(), size1 + size2);
@@ -152,7 +163,7 @@ TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
     EXPECT_EQ(t1->consumption(), 0);
 }
 
-TEST(ThreadMemTrackerMgrTest, MultiMemTracker) {
+TEST_F(ThreadMemTrackerMgrTest, MultiMemTracker) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-MultiMemTracker1");
@@ -162,7 +173,7 @@ TEST(ThreadMemTrackerMgrTest, MultiMemTracker) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
     thread_context->consume_memory(size1);
@@ -213,7 +224,7 @@ TEST(ThreadMemTrackerMgrTest, MultiMemTracker) {
     EXPECT_EQ(t3->consumption(), size1 + size2 - size1);
 }
 
-TEST(ThreadMemTrackerMgrTest, ScopedCount) {
+TEST_F(ThreadMemTrackerMgrTest, ScopedCount) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-ScopedCount");
@@ -221,7 +232,7 @@ TEST(ThreadMemTrackerMgrTest, ScopedCount) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->thread_mem_tracker_mgr->start_count_scope_mem();
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
@@ -239,7 +250,7 @@ TEST(ThreadMemTrackerMgrTest, ScopedCount) {
     EXPECT_EQ(scope_mem, size1 + size2 + size1 + size2 + size1);
 }
 
-TEST(ThreadMemTrackerMgrTest, ReserveMemory) {
+TEST_F(ThreadMemTrackerMgrTest, ReserveMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-ReserveMemory");
@@ -248,7 +259,7 @@ TEST(ThreadMemTrackerMgrTest, ReserveMemory) {
     int64_t size2 = 4 * 1024 * 1024;
     int64_t size3 = size2 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t);
+    thread_context->attach_task(TUniqueId(), t, workload_group);
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
     EXPECT_EQ(t->consumption(), size1 + size2);
@@ -338,7 +349,7 @@ TEST(ThreadMemTrackerMgrTest, ReserveMemory) {
     EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), 0);
 }
 
-TEST(ThreadMemTrackerMgrTest, NestedReserveMemory) {
+TEST_F(ThreadMemTrackerMgrTest, NestedReserveMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t = MemTrackerLimiter::create_shared(
             MemTrackerLimiter::Type::OTHER, "UT-NestedReserveMemory");
@@ -346,7 +357,7 @@ TEST(ThreadMemTrackerMgrTest, NestedReserveMemory) {
     int64_t size2 = 4 * 1024 * 1024;
     int64_t size3 = size2 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t);
+    thread_context->attach_task(TUniqueId(), t, workload_group);
     thread_context->try_reserve_memory(size3);
     EXPECT_EQ(t->consumption(), size3);
     EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), size3);
@@ -386,7 +397,7 @@ TEST(ThreadMemTrackerMgrTest, NestedReserveMemory) {
     EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), 0);
 }
 
-TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTrackerReserveMemory) {
+TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTrackerReserveMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 = MemTrackerLimiter::create_shared(
             MemTrackerLimiter::Type::OTHER, 
"UT-NestedSwitchMemTrackerReserveMemory1");
@@ -399,7 +410,7 @@ TEST(ThreadMemTrackerMgrTest, 
NestedSwitchMemTrackerReserveMemory) {
     int64_t size2 = 4 * 1024 * 1024;
     int64_t size3 = size2 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->try_reserve_memory(size3);
     thread_context->consume_memory(size2);
     EXPECT_EQ(t1->consumption(), size3);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to