This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new 8ce8887b75c [branch-2.1](memory) Refactor refresh workload groups 
weighted memory ratio and record refresh interval memory growth (#39760)
8ce8887b75c is described below

commit 8ce8887b75cd52698d42e1d29c8db4ef30c0e3a6
Author: Xinyi Zou <zouxiny...@gmail.com>
AuthorDate: Thu Aug 22 17:33:11 2024 +0800

    [branch-2.1](memory) Refactor refresh workload groups weighted memory ratio 
and record refresh interval memory growth (#39760)
    
    pick #38168
    overwrites changes in #37221 on workload_group_manager.cpp. If need to
    pick 37221, ignore it.
---
 be/src/common/config.cpp                           |   3 +-
 be/src/common/config.h                             |   4 +-
 be/src/common/daemon.cpp                           |  11 +-
 be/src/common/daemon.h                             |   2 +-
 .../exec/partitioned_aggregation_sink_operator.cpp |   3 +-
 .../partitioned_aggregation_source_operator.cpp    |   3 +-
 .../exec/partitioned_hash_join_probe_operator.cpp  |   9 +-
 .../exec/partitioned_hash_join_sink_operator.cpp   |   8 +-
 be/src/pipeline/exec/spill_sort_sink_operator.cpp  |   3 +-
 .../pipeline/exec/spill_sort_source_operator.cpp   |   3 +-
 be/src/pipeline/pipeline_fragment_context.cpp      |   3 +-
 be/src/pipeline/pipeline_fragment_context.h        |   2 -
 .../pipeline_x/pipeline_x_fragment_context.cpp     |   2 +-
 be/src/pipeline/pipeline_x/pipeline_x_task.cpp     |   2 +-
 be/src/runtime/fragment_mgr.cpp                    |   9 +-
 be/src/runtime/load_channel.cpp                    |   4 +-
 be/src/runtime/load_stream.cpp                     |   4 +-
 be/src/runtime/memory/mem_tracker_limiter.h        |   2 +-
 be/src/runtime/memory/thread_mem_tracker_mgr.h     |  20 +++-
 be/src/runtime/query_context.h                     |  11 +-
 be/src/runtime/thread_context.cpp                  |  36 +++---
 be/src/runtime/thread_context.h                    |  34 ++++--
 be/src/runtime/workload_group/workload_group.cpp   |  34 ++++--
 be/src/runtime/workload_group/workload_group.h     |  40 ++++++-
 .../workload_group/workload_group_manager.cpp      | 121 ++++++++-------------
 .../workload_group/workload_group_manager.h        |   2 +-
 be/src/vec/exec/scan/scanner_context.cpp           |   3 +-
 be/src/vec/runtime/vdata_stream_recvr.cpp          |   8 +-
 be/src/vec/runtime/vdata_stream_recvr.h            |   6 +-
 .../runtime/memory/thread_mem_tracker_mgr_test.cpp |  39 ++++---
 30 files changed, 258 insertions(+), 173 deletions(-)

diff --git a/be/src/common/config.cpp b/be/src/common/config.cpp
index 656c5de3a98..973c52e6787 100644
--- a/be/src/common/config.cpp
+++ b/be/src/common/config.cpp
@@ -588,8 +588,7 @@ DEFINE_mInt32(memory_gc_sleep_time_ms, "1000");
 // Sleep time in milliseconds between memtbale flush mgr refresh iterations
 DEFINE_mInt64(memtable_mem_tracker_refresh_interval_ms, "5");
 
-// Sleep time in milliseconds between refresh iterations of workload group 
memory statistics
-DEFINE_mInt64(wg_mem_refresh_interval_ms, "50");
+DEFINE_mInt64(wg_weighted_memory_ratio_refresh_interval_ms, "50");
 
 // percent of (active memtables size / all memtables size) when reach hard 
limit
 DEFINE_mInt32(memtable_hard_limit_active_percent, "50");
diff --git a/be/src/common/config.h b/be/src/common/config.h
index d6080fd78b2..69c8c42563b 100644
--- a/be/src/common/config.h
+++ b/be/src/common/config.h
@@ -645,8 +645,8 @@ DECLARE_mInt32(memory_gc_sleep_time_ms);
 // Sleep time in milliseconds between memtbale flush mgr memory refresh 
iterations
 DECLARE_mInt64(memtable_mem_tracker_refresh_interval_ms);
 
-// Sleep time in milliseconds between refresh iterations of workload group 
memory statistics
-DECLARE_mInt64(wg_mem_refresh_interval_ms);
+// Sleep time in milliseconds between refresh iterations of workload group 
weighted memory ratio
+DECLARE_mInt64(wg_weighted_memory_ratio_refresh_interval_ms);
 
 // percent of (active memtables size / all memtables size) when reach hard 
limit
 DECLARE_mInt32(memtable_hard_limit_active_percent);
diff --git a/be/src/common/daemon.cpp b/be/src/common/daemon.cpp
index 00d9caa4155..61845db775a 100644
--- a/be/src/common/daemon.cpp
+++ b/be/src/common/daemon.cpp
@@ -377,11 +377,11 @@ void Daemon::je_purge_dirty_pages_thread() const {
     } while (true);
 }
 
-void Daemon::wg_mem_used_refresh_thread() {
-    // Refresh memory usage and limit of workload groups
+void Daemon::wg_weighted_memory_ratio_refresh_thread() {
+    // Refresh weighted memory ratio of workload groups
     while (!_stop_background_threads_latch.wait_for(
-            std::chrono::milliseconds(config::wg_mem_refresh_interval_ms))) {
-        
doris::ExecEnv::GetInstance()->workload_group_mgr()->refresh_wg_memory_info();
+            
std::chrono::milliseconds(config::wg_weighted_memory_ratio_refresh_interval_ms)))
 {
+        
doris::ExecEnv::GetInstance()->workload_group_mgr()->refresh_wg_weighted_memory_ratio();
     }
 }
 
@@ -420,7 +420,8 @@ void Daemon::start() {
     CHECK(st.ok()) << st;
 
     st = Thread::create(
-            "Daemon", "wg_mem_refresh_thread", [this]() { 
this->wg_mem_used_refresh_thread(); },
+            "Daemon", "wg_weighted_memory_ratio_refresh_thread",
+            [this]() { this->wg_weighted_memory_ratio_refresh_thread(); },
             &_threads.emplace_back());
     CHECK(st.ok()) << st;
 }
diff --git a/be/src/common/daemon.h b/be/src/common/daemon.h
index 28f63067896..25d842c4f9d 100644
--- a/be/src/common/daemon.h
+++ b/be/src/common/daemon.h
@@ -44,7 +44,7 @@ private:
     void calculate_metrics_thread();
     void je_purge_dirty_pages_thread() const;
     void report_runtime_query_statistics_thread();
-    void wg_mem_used_refresh_thread();
+    void wg_weighted_memory_ratio_refresh_thread();
 
     CountDownLatch _stop_background_threads_latch;
     std::vector<scoped_refptr<Thread>> _threads;
diff --git a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp 
b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp
index 83de348dbdb..053e6dee0cb 100644
--- a/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_aggregation_sink_operator.cpp
@@ -263,7 +263,8 @@ Status 
PartitionedAggSinkLocalState::revoke_memory(RuntimeState* state) {
     status = 
ExecEnv::GetInstance()->spill_stream_mgr()->get_spill_io_thread_pool()->submit_func(
             [this, &parent, state, query_id, mem_tracker, shared_state_holder, 
execution_context,
              submit_timer] {
-                SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+                QueryThreadContext query_thread_context {query_id, 
mem_tracker};
+                SCOPED_ATTACH_TASK(query_thread_context);
                 std::shared_ptr<TaskExecutionContext> execution_context_lock;
                 auto shared_state_sptr = shared_state_holder.lock();
                 if (shared_state_sptr) {
diff --git a/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp 
b/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp
index 6d871451bfd..67101f98ff8 100644
--- a/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_aggregation_source_operator.cpp
@@ -269,7 +269,8 @@ Status 
PartitionedAggLocalState::initiate_merge_spill_partition_agg_data(Runtime
 
     auto exception_catch_func = [spill_func, query_id, mem_tracker, 
shared_state_holder,
                                  execution_context, this]() {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
diff --git a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
index 3cc3c3a9d0b..d98b8cea98c 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_probe_operator.cpp
@@ -210,7 +210,8 @@ Status 
PartitionedHashJoinProbeLocalState::spill_probe_blocks(RuntimeState* stat
 
     auto exception_catch_func = [query_id, mem_tracker, shared_state_holder, 
execution_context,
                                  spill_func, this]() {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
@@ -338,7 +339,8 @@ Status 
PartitionedHashJoinProbeLocalState::recovery_build_blocks_from_disk(Runti
 
     auto exception_catch_func = [read_func, query_id, mem_tracker, 
shared_state_holder,
                                  execution_context, state, this]() {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
@@ -426,7 +428,8 @@ Status 
PartitionedHashJoinProbeLocalState::recovery_probe_blocks_from_disk(Runti
 
     auto exception_catch_func = [read_func, mem_tracker, shared_state_holder, 
execution_context,
                                  query_id, this]() {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
index 45ca975a88c..65f641f0860 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
@@ -127,7 +127,7 @@ Status 
PartitionedHashJoinSinkLocalState::_revoke_unpartitioned_block(RuntimeSta
     auto spill_func = [build_blocks = std::move(build_blocks), state, 
num_slots, this]() mutable {
         Defer defer {[&]() {
             // need to reset build_block here, or else build_block will be 
destructed
-            // after SCOPED_ATTACH_TASK_WITH_ID and will trigger 
memory_orphan_check failure
+            // after SCOPED_ATTACH_TASK and will trigger memory_orphan_check 
failure
             build_blocks.clear();
         }};
 
@@ -216,7 +216,8 @@ Status 
PartitionedHashJoinSinkLocalState::_revoke_unpartitioned_block(RuntimeSta
 
     auto exception_catch_func = [spill_func, shared_state_holder, 
execution_context, state,
                                  query_id, mem_tracker, this]() mutable {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
@@ -289,7 +290,8 @@ Status 
PartitionedHashJoinSinkLocalState::revoke_memory(RuntimeState* state) {
 
         auto st = spill_io_pool->submit_func([this, query_id, mem_tracker, 
shared_state_holder,
                                               execution_context, 
spilling_stream, i, submit_timer] {
-            SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+            QueryThreadContext query_thread_context {query_id, mem_tracker};
+            SCOPED_ATTACH_TASK(query_thread_context);
             std::shared_ptr<TaskExecutionContext> execution_context_lock;
             auto shared_state_sptr = shared_state_holder.lock();
             if (shared_state_sptr) {
diff --git a/be/src/pipeline/exec/spill_sort_sink_operator.cpp 
b/be/src/pipeline/exec/spill_sort_sink_operator.cpp
index c945d16cf57..dfda2ff61e1 100644
--- a/be/src/pipeline/exec/spill_sort_sink_operator.cpp
+++ b/be/src/pipeline/exec/spill_sort_sink_operator.cpp
@@ -296,7 +296,8 @@ Status SpillSortSinkLocalState::revoke_memory(RuntimeState* 
state) {
 
     auto exception_catch_func = [this, query_id, mem_tracker, 
shared_state_holder,
                                  execution_context, spill_func]() {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
diff --git a/be/src/pipeline/exec/spill_sort_source_operator.cpp 
b/be/src/pipeline/exec/spill_sort_source_operator.cpp
index 18a3d4310fd..ab871669d3e 100644
--- a/be/src/pipeline/exec/spill_sort_source_operator.cpp
+++ b/be/src/pipeline/exec/spill_sort_source_operator.cpp
@@ -175,7 +175,8 @@ Status 
SpillSortLocalState::initiate_merge_sort_spill_streams(RuntimeState* stat
 
     auto exception_catch_func = [this, query_id, mem_tracker, 
shared_state_holder,
                                  execution_context, spill_func]() {
-        SCOPED_ATTACH_TASK_WITH_ID(mem_tracker, query_id);
+        QueryThreadContext query_thread_context {query_id, mem_tracker};
+        SCOPED_ATTACH_TASK(query_thread_context);
         std::shared_ptr<TaskExecutionContext> execution_context_lock;
         auto shared_state_sptr = shared_state_holder.lock();
         if (shared_state_sptr) {
diff --git a/be/src/pipeline/pipeline_fragment_context.cpp 
b/be/src/pipeline/pipeline_fragment_context.cpp
index 4c677216e6a..dab359ed040 100644
--- a/be/src/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_fragment_context.cpp
@@ -136,12 +136,11 @@ PipelineFragmentContext::PipelineFragmentContext(
           _create_time(MonotonicNanos()) {
     _fragment_watcher.start();
     _start_time = VecDateTimeValue::local_time();
-    _query_thread_context = {query_id, _query_ctx->query_mem_tracker};
 }
 
 PipelineFragmentContext::~PipelineFragmentContext() {
     // The memory released by the query end is recorded in the query mem 
tracker.
-    
SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_query_thread_context.query_mem_tracker);
+    SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_query_ctx->query_mem_tracker);
     auto st = _query_ctx->exec_status();
     _query_ctx.reset();
     _tasks.clear();
diff --git a/be/src/pipeline/pipeline_fragment_context.h 
b/be/src/pipeline/pipeline_fragment_context.h
index 7eabb13b772..b8d192ac096 100644
--- a/be/src/pipeline/pipeline_fragment_context.h
+++ b/be/src/pipeline/pipeline_fragment_context.h
@@ -187,8 +187,6 @@ protected:
 
     std::shared_ptr<QueryContext> _query_ctx;
 
-    QueryThreadContext _query_thread_context;
-
     MonotonicStopWatch _fragment_watcher;
     RuntimeProfile::Counter* _start_timer = nullptr;
     RuntimeProfile::Counter* _prepare_timer = nullptr;
diff --git a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp 
b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
index d736879f0eb..7d90cebc8d2 100644
--- a/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_x/pipeline_x_fragment_context.cpp
@@ -112,7 +112,7 @@ PipelineXFragmentContext::PipelineXFragmentContext(
 
 PipelineXFragmentContext::~PipelineXFragmentContext() {
     // The memory released by the query end is recorded in the query mem 
tracker.
-    
SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_query_thread_context.query_mem_tracker);
+    SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_query_ctx->query_mem_tracker);
     auto st = _query_ctx->exec_status();
     _tasks.clear();
     if (!_task_runtime_states.empty()) {
diff --git a/be/src/pipeline/pipeline_x/pipeline_x_task.cpp 
b/be/src/pipeline/pipeline_x/pipeline_x_task.cpp
index f05b491d50b..580c425884f 100644
--- a/be/src/pipeline/pipeline_x/pipeline_x_task.cpp
+++ b/be/src/pipeline/pipeline_x/pipeline_x_task.cpp
@@ -370,7 +370,7 @@ bool PipelineXTask::should_revoke_memory(RuntimeState* 
state, int64_t revocable_
     } else if (is_wg_mem_low_water_mark) {
         int64_t query_weighted_limit = 0;
         int64_t query_weighted_consumption = 0;
-        query_ctx->get_weighted_mem_info(query_weighted_limit, 
query_weighted_consumption);
+        query_ctx->get_weighted_memory(query_weighted_limit, 
query_weighted_consumption);
         if (query_weighted_consumption < query_weighted_limit) {
             return false;
         }
diff --git a/be/src/runtime/fragment_mgr.cpp b/be/src/runtime/fragment_mgr.cpp
index b68839a0d62..d793a7f19e9 100644
--- a/be/src/runtime/fragment_mgr.cpp
+++ b/be/src/runtime/fragment_mgr.cpp
@@ -947,7 +947,7 @@ Status FragmentMgr::exec_plan_fragment(const 
TPipelineFragmentParams& params,
 
     std::shared_ptr<QueryContext> query_ctx;
     RETURN_IF_ERROR(_get_query_ctx(params, params.query_id, true, 
query_source, query_ctx));
-    SCOPED_ATTACH_TASK_WITH_ID(query_ctx->query_mem_tracker, params.query_id);
+    SCOPED_ATTACH_TASK(query_ctx.get());
     const bool enable_pipeline_x = 
params.query_options.__isset.enable_pipeline_x_engine &&
                                    
params.query_options.enable_pipeline_x_engine;
     if (enable_pipeline_x) {
@@ -1093,7 +1093,7 @@ Status FragmentMgr::exec_plan_fragment(const 
TPipelineFragmentParams& params,
 
             for (size_t i = 0; i < target_size; i++) {
                 RETURN_IF_ERROR(_thread_pool->submit_func([&, i]() {
-                    SCOPED_ATTACH_TASK_WITH_ID(query_ctx->query_mem_tracker, 
query_ctx->query_id());
+                    SCOPED_ATTACH_TASK(query_ctx.get());
                     prepare_status[i] = pre_and_submit(i);
                     std::unique_lock<std::mutex> lock(m);
                     prepare_done++;
@@ -1611,7 +1611,8 @@ Status FragmentMgr::apply_filterv2(const 
PPublishFilterRequestV2* request,
                 runtime_filter_mgr = 
pip_context->get_query_ctx()->runtime_filter_mgr();
                 pool = &pip_context->get_query_ctx()->obj_pool;
                 query_thread_context = 
{pip_context->get_query_ctx()->query_id(),
-                                        
pip_context->get_query_ctx()->query_mem_tracker};
+                                        
pip_context->get_query_ctx()->query_mem_tracker,
+                                        
pip_context->get_query_ctx()->workload_group()};
             } else {
                 auto iter = _fragment_instance_map.find(tfragment_instance_id);
                 if (iter == _fragment_instance_map.end()) {
@@ -1717,7 +1718,7 @@ Status FragmentMgr::merge_filter(const 
PMergeFilterRequest* request,
         // when filter_controller->merge is still in progress
         query_ctx = iter->second;
     }
-    SCOPED_ATTACH_TASK_WITH_ID(query_ctx->query_mem_tracker, 
query_ctx->query_id());
+    SCOPED_ATTACH_TASK(query_ctx.get());
     auto merge_status = filter_controller->merge(request, attach_data, 
opt_remote_rf);
     return merge_status;
 }
diff --git a/be/src/runtime/load_channel.cpp b/be/src/runtime/load_channel.cpp
index 3d8c8e1dbf3..cd86705c831 100644
--- a/be/src/runtime/load_channel.cpp
+++ b/be/src/runtime/load_channel.cpp
@@ -27,6 +27,7 @@
 #include "runtime/memory/mem_tracker.h"
 #include "runtime/tablets_channel.h"
 #include "runtime/thread_context.h"
+#include "runtime/workload_group/workload_group_manager.h"
 
 namespace doris {
 
@@ -43,7 +44,8 @@ LoadChannel::LoadChannel(const UniqueId& load_id, int64_t 
timeout_s, bool is_hig
     std::shared_ptr<QueryContext> query_context =
             
ExecEnv::GetInstance()->fragment_mgr()->get_query_context(_load_id.to_thrift());
     if (query_context != nullptr) {
-        _query_thread_context = {_load_id.to_thrift(), 
query_context->query_mem_tracker};
+        _query_thread_context = {_load_id.to_thrift(), 
query_context->query_mem_tracker,
+                                 query_context->workload_group()};
     } else {
         _query_thread_context = {
                 _load_id.to_thrift(),
diff --git a/be/src/runtime/load_stream.cpp b/be/src/runtime/load_stream.cpp
index 1f8c33995b3..619499cebdc 100644
--- a/be/src/runtime/load_stream.cpp
+++ b/be/src/runtime/load_stream.cpp
@@ -40,6 +40,7 @@
 #include "runtime/load_channel.h"
 #include "runtime/load_stream_mgr.h"
 #include "runtime/load_stream_writer.h"
+#include "runtime/workload_group/workload_group_manager.h"
 #include "util/debug_points.h"
 #include "util/runtime_profile.h"
 #include "util/thrift_util.h"
@@ -352,7 +353,8 @@ LoadStream::LoadStream(PUniqueId load_id, LoadStreamMgr* 
load_stream_mgr, bool e
     std::shared_ptr<QueryContext> query_context =
             
ExecEnv::GetInstance()->fragment_mgr()->get_query_context(load_tid);
     if (query_context != nullptr) {
-        _query_thread_context = {load_tid, query_context->query_mem_tracker};
+        _query_thread_context = {load_tid, query_context->query_mem_tracker,
+                                 query_context->workload_group()};
     } else {
         _query_thread_context = {load_tid, MemTrackerLimiter::create_shared(
                                                    
MemTrackerLimiter::Type::LOAD,
diff --git a/be/src/runtime/memory/mem_tracker_limiter.h 
b/be/src/runtime/memory/mem_tracker_limiter.h
index e6cf8410c30..67c40e1f6c5 100644
--- a/be/src/runtime/memory/mem_tracker_limiter.h
+++ b/be/src/runtime/memory/mem_tracker_limiter.h
@@ -215,7 +215,7 @@ public:
     }
 
     // Iterator into mem_tracker_limiter_pool for this object. Stored to have 
O(1) remove.
-    std::list<std::weak_ptr<MemTrackerLimiter>>::iterator 
tg_tracker_limiter_group_it;
+    std::list<std::weak_ptr<MemTrackerLimiter>>::iterator 
wg_tracker_limiter_group_it;
 
 private:
     friend class ThreadMemTrackerMgr;
diff --git a/be/src/runtime/memory/thread_mem_tracker_mgr.h 
b/be/src/runtime/memory/thread_mem_tracker_mgr.h
index 9d36cd2d807..39e896d0f18 100644
--- a/be/src/runtime/memory/thread_mem_tracker_mgr.h
+++ b/be/src/runtime/memory/thread_mem_tracker_mgr.h
@@ -33,6 +33,7 @@
 #include "runtime/memory/global_memory_arbitrator.h"
 #include "runtime/memory/mem_tracker.h"
 #include "runtime/memory/mem_tracker_limiter.h"
+#include "runtime/workload_group/workload_group.h"
 #include "util/stack_util.h"
 #include "util/uid_util.h"
 
@@ -71,6 +72,10 @@ public:
 
     TUniqueId query_id() { return _query_id; }
 
+    void set_wg_wptr(const std::weak_ptr<WorkloadGroup>& wg_wptr) { _wg_wptr = 
wg_wptr; }
+
+    void reset_wg_wptr() { _wg_wptr.reset(); }
+
     void start_count_scope_mem() {
         CHECK(init());
         _scope_mem = _reserved_mem; // consume in advance
@@ -151,6 +156,7 @@ private:
     std::shared_ptr<MemTrackerLimiter> _limiter_tracker;
     MemTrackerLimiter* _limiter_tracker_raw = nullptr;
     std::vector<MemTracker*> _consumer_tracker_stack;
+    std::weak_ptr<WorkloadGroup> _wg_wptr;
 
     // If there is a memory new/delete operation in the consume method, it may 
enter infinite recursion.
     bool _stop_consume = false;
@@ -287,8 +293,16 @@ inline bool ThreadMemTrackerMgr::try_reserve(int64_t size) 
{
     if (!_limiter_tracker_raw->try_consume(size)) {
         return false;
     }
+    auto wg_ptr = _wg_wptr.lock();
+    if (!wg_ptr) {
+        if (!wg_ptr->add_wg_refresh_interval_memory_growth(size)) {
+            _limiter_tracker_raw->release(size); // rollback
+            return false;
+        }
+    }
     if (!doris::GlobalMemoryArbitrator::try_reserve_process_memory(size)) {
-        _limiter_tracker_raw->release(size); // rollback
+        _limiter_tracker_raw->release(size);                 // rollback
+        wg_ptr->sub_wg_refresh_interval_memory_growth(size); // rollback
         return false;
     }
     if (_count_scope_mem) {
@@ -306,6 +320,10 @@ inline void ThreadMemTrackerMgr::release_reserved() {
         
doris::GlobalMemoryArbitrator::release_process_reserved_memory(_reserved_mem +
                                                                        
_untracked_mem);
         _limiter_tracker_raw->release(_reserved_mem);
+        auto wg_ptr = _wg_wptr.lock();
+        if (!wg_ptr) {
+            wg_ptr->sub_wg_refresh_interval_memory_growth(_reserved_mem);
+        }
         if (_count_scope_mem) {
             _scope_mem -= _reserved_mem;
         }
diff --git a/be/src/runtime/query_context.h b/be/src/runtime/query_context.h
index 3d523522337..3d0b2289baa 100644
--- a/be/src/runtime/query_context.h
+++ b/be/src/runtime/query_context.h
@@ -269,15 +269,16 @@ public:
         return _running_big_mem_op_num.load(std::memory_order_relaxed);
     }
 
-    void set_weighted_mem(int64_t weighted_limit, int64_t 
weighted_consumption) {
+    void set_weighted_memory(int64_t weighted_limit, double weighted_ratio) {
         std::lock_guard<std::mutex> l(_weighted_mem_lock);
-        _weighted_consumption = weighted_consumption;
         _weighted_limit = weighted_limit;
+        _weighted_ratio = weighted_ratio;
     }
-    void get_weighted_mem_info(int64_t& weighted_limit, int64_t& 
weighted_consumption) {
+
+    void get_weighted_memory(int64_t& weighted_limit, int64_t& 
weighted_consumption) {
         std::lock_guard<std::mutex> l(_weighted_mem_lock);
         weighted_limit = _weighted_limit;
-        weighted_consumption = _weighted_consumption;
+        weighted_consumption = int64_t(query_mem_tracker->consumption() * 
_weighted_ratio);
     }
 
     DescriptorTbl* desc_tbl = nullptr;
@@ -360,7 +361,7 @@ private:
     std::mutex _pipeline_map_write_lock;
 
     std::mutex _weighted_mem_lock;
-    int64_t _weighted_consumption = 0;
+    double _weighted_ratio = 0;
     int64_t _weighted_limit = 0;
     timespec _query_arrival_timestamp;
     // Distinguish the query source, for query that comes from fe, we will 
have some memory structure on FE to
diff --git a/be/src/runtime/thread_context.cpp 
b/be/src/runtime/thread_context.cpp
index 6f69eb9e134..c89f532e592 100644
--- a/be/src/runtime/thread_context.cpp
+++ b/be/src/runtime/thread_context.cpp
@@ -18,7 +18,9 @@
 #include "runtime/thread_context.h"
 
 #include "common/signal_handler.h"
+#include "runtime/query_context.h"
 #include "runtime/runtime_state.h"
+#include "runtime/workload_group/workload_group_manager.h"
 
 namespace doris {
 class MemTracker;
@@ -26,34 +28,38 @@ class MemTracker;
 QueryThreadContext ThreadContext::query_thread_context() {
     DCHECK(doris::pthread_context_ptr_init);
     ORPHAN_TRACKER_CHECK();
-    return {_task_id, thread_mem_tracker_mgr->limiter_mem_tracker()};
+    return {_task_id, thread_mem_tracker_mgr->limiter_mem_tracker(), _wg_wptr};
 }
 
-AttachTask::AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
-                       const TUniqueId& task_id) {
+void AttachTask::init(const QueryThreadContext& query_thread_context) {
     ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(task_id);
-    thread_context()->attach_task(task_id, mem_tracker);
+    signal::set_signal_task_id(query_thread_context.query_id);
+    thread_context()->attach_task(query_thread_context.query_id,
+                                  query_thread_context.query_mem_tracker,
+                                  query_thread_context.wg_wptr);
 }
 
 AttachTask::AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker) {
-    ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(TUniqueId());
-    thread_context()->attach_task(TUniqueId(), mem_tracker);
+    QueryThreadContext query_thread_context = {TUniqueId(), mem_tracker};
+    init(query_thread_context);
 }
 
 AttachTask::AttachTask(RuntimeState* runtime_state) {
-    ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(runtime_state->query_id());
     signal::set_signal_is_nereids(runtime_state->is_nereids());
-    thread_context()->attach_task(runtime_state->query_id(), 
runtime_state->query_mem_tracker());
+    QueryThreadContext query_thread_context = {runtime_state->query_id(),
+                                               
runtime_state->query_mem_tracker(),
+                                               
runtime_state->get_query_ctx()->workload_group()};
+    init(query_thread_context);
 }
 
 AttachTask::AttachTask(const QueryThreadContext& query_thread_context) {
-    ThreadLocalHandle::create_thread_local_if_not_exits();
-    signal::set_signal_task_id(query_thread_context.query_id);
-    thread_context()->attach_task(query_thread_context.query_id,
-                                  query_thread_context.query_mem_tracker);
+    init(query_thread_context);
+}
+
+AttachTask::AttachTask(QueryContext* query_ctx) {
+    QueryThreadContext query_thread_context = {query_ctx->query_id(), 
query_ctx->query_mem_tracker,
+                                               query_ctx->workload_group()};
+    init(query_thread_context);
 }
 
 AttachTask::~AttachTask() {
diff --git a/be/src/runtime/thread_context.h b/be/src/runtime/thread_context.h
index 7a4695a4e98..40b3985dec7 100644
--- a/be/src/runtime/thread_context.h
+++ b/be/src/runtime/thread_context.h
@@ -45,8 +45,6 @@
 // This will save some info about a working thread in the thread context.
 // Looking forward to tracking memory during thread execution into 
MemTrackerLimiter.
 #define SCOPED_ATTACH_TASK(arg1) auto VARNAME_LINENUM(attach_task) = 
AttachTask(arg1)
-#define SCOPED_ATTACH_TASK_WITH_ID(arg1, arg2) \
-    auto VARNAME_LINENUM(attach_task) = AttachTask(arg1, arg2)
 
 // Switch MemTrackerLimiter for count memory during thread execution.
 // Used after SCOPED_ATTACH_TASK, in order to count the memory into another
@@ -86,8 +84,6 @@
 // thread context need to be initialized, required by Allocator and elsewhere.
 #define SCOPED_ATTACH_TASK(arg1, ...) \
     auto VARNAME_LINENUM(scoped_tls_at) = doris::ScopedInitThreadContext()
-#define SCOPED_ATTACH_TASK_WITH_ID(arg1, arg2) \
-    auto VARNAME_LINENUM(scoped_tls_atwi) = doris::ScopedInitThreadContext()
 #define SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(arg1) \
     auto VARNAME_LINENUM(scoped_tls_stmtl) = doris::ScopedInitThreadContext()
 #define SCOPED_CONSUME_MEM_TRACKER(mem_tracker) \
@@ -121,6 +117,7 @@ class ThreadContext;
 class MemTracker;
 class RuntimeState;
 class QueryThreadContext;
+class WorkloadGroup;
 
 extern bthread_key_t btls_key;
 
@@ -155,7 +152,8 @@ public:
     ~ThreadContext() = default;
 
     void attach_task(const TUniqueId& task_id,
-                     const std::shared_ptr<MemTrackerLimiter>& mem_tracker) {
+                     const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
+                     const std::weak_ptr<WorkloadGroup>& wg_wptr) {
         // will only attach_task at the beginning of the thread function, 
there should be no duplicate attach_task.
         DCHECK(mem_tracker);
         // Orphan is thread default tracker.
@@ -163,16 +161,20 @@ public:
                 << ", thread mem tracker label: " << 
thread_mem_tracker()->label()
                 << ", attach mem tracker label: " << mem_tracker->label();
         _task_id = task_id;
+        _wg_wptr = wg_wptr;
         thread_mem_tracker_mgr->attach_limiter_tracker(mem_tracker);
         thread_mem_tracker_mgr->set_query_id(_task_id);
+        thread_mem_tracker_mgr->set_wg_wptr(_wg_wptr);
         thread_mem_tracker_mgr->enable_wait_gc();
         thread_mem_tracker_mgr->reset_query_cancelled_flag(false);
     }
 
     void detach_task() {
         _task_id = TUniqueId();
+        _wg_wptr.reset();
         thread_mem_tracker_mgr->detach_limiter_tracker();
         thread_mem_tracker_mgr->set_query_id(TUniqueId());
+        thread_mem_tracker_mgr->reset_wg_wptr();
         thread_mem_tracker_mgr->disable_wait_gc();
     }
 
@@ -223,12 +225,15 @@ public:
         thread_mem_tracker_mgr->release_reserved();
     }
 
+    std::weak_ptr<WorkloadGroup> workload_group() { return _wg_wptr; }
+
     int thread_local_handle_count = 0;
     int skip_memory_check = 0;
     int skip_large_memory_check = 0;
 
 private:
     TUniqueId _task_id;
+    std::weak_ptr<WorkloadGroup> _wg_wptr;
 };
 
 class ThreadLocalHandle {
@@ -309,6 +314,11 @@ static ThreadContext* thread_context() {
 class QueryThreadContext {
 public:
     QueryThreadContext() = default;
+    QueryThreadContext(const TUniqueId& query_id,
+                       const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
+                       const std::weak_ptr<WorkloadGroup>& wg_wptr)
+            : query_id(query_id), query_mem_tracker(mem_tracker), 
wg_wptr(wg_wptr) {}
+    // If use WorkloadGroup and can get WorkloadGroup ptr, must as a parameter.
     QueryThreadContext(const TUniqueId& query_id,
                        const std::shared_ptr<MemTrackerLimiter>& mem_tracker)
             : query_id(query_id), query_mem_tracker(mem_tracker) {}
@@ -318,6 +328,7 @@ public:
         ORPHAN_TRACKER_CHECK();
         query_id = doris::thread_context()->task_id();
         query_mem_tracker = 
doris::thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker();
+        wg_wptr = doris::thread_context()->workload_group();
 #else
         query_id = TUniqueId();
         query_mem_tracker = 
doris::ExecEnv::GetInstance()->orphan_mem_tracker();
@@ -326,6 +337,7 @@ public:
 
     TUniqueId query_id;
     std::shared_ptr<MemTrackerLimiter> query_mem_tracker;
+    std::weak_ptr<WorkloadGroup> wg_wptr;
 };
 
 class ScopeMemCountByHook {
@@ -357,15 +369,18 @@ public:
 
 class AttachTask {
 public:
-    explicit AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker,
-                        const TUniqueId& task_id);
-
+    // not query or load, initialize with memory tracker, empty query id and 
default normal workload group.
     explicit AttachTask(const std::shared_ptr<MemTrackerLimiter>& mem_tracker);
 
+    // is query or load, initialize with memory tracker, query id and workload 
group wptr.
     explicit AttachTask(RuntimeState* runtime_state);
 
+    explicit AttachTask(QueryContext* query_ctx);
+
     explicit AttachTask(const QueryThreadContext& query_thread_context);
 
+    void init(const QueryThreadContext& query_thread_context);
+
     ~AttachTask();
 };
 
@@ -380,7 +395,8 @@ public:
 
     explicit SwitchThreadMemTrackerLimiter(const QueryThreadContext& 
query_thread_context) {
         ThreadLocalHandle::create_thread_local_if_not_exits();
-        DCHECK(thread_context()->task_id() == query_thread_context.query_id);
+        DCHECK(thread_context()->task_id() ==
+               query_thread_context.query_id); // workload group alse not 
change
         DCHECK(query_thread_context.query_mem_tracker);
         _old_mem_tracker = 
thread_context()->thread_mem_tracker_mgr->limiter_mem_tracker();
         thread_context()->thread_mem_tracker_mgr->attach_limiter_tracker(
diff --git a/be/src/runtime/workload_group/workload_group.cpp 
b/be/src/runtime/workload_group/workload_group.cpp
index b68b1765a52..0a34ada5c70 100644
--- a/be/src/runtime/workload_group/workload_group.cpp
+++ b/be/src/runtime/workload_group/workload_group.cpp
@@ -107,39 +107,59 @@ void WorkloadGroup::check_and_update(const 
WorkloadGroupInfo& tg_info) {
     }
 }
 
-int64_t WorkloadGroup::memory_used() {
+int64_t WorkloadGroup::make_memory_tracker_snapshots(
+        std::list<std::shared_ptr<MemTrackerLimiter>>* tracker_snapshots) {
     int64_t used_memory = 0;
     for (auto& mem_tracker_group : _mem_tracker_limiter_pool) {
         std::lock_guard<std::mutex> l(mem_tracker_group.group_lock);
         for (const auto& trackerWptr : mem_tracker_group.trackers) {
             auto tracker = trackerWptr.lock();
             CHECK(tracker != nullptr);
+            if (tracker_snapshots != nullptr) {
+                tracker_snapshots->insert(tracker_snapshots->end(), tracker);
+            }
             used_memory += tracker->consumption();
         }
     }
+    refresh_memory(used_memory);
     return used_memory;
 }
 
-void WorkloadGroup::set_weighted_memory_used(int64_t wg_total_mem_used, double 
ratio) {
-    _weighted_mem_used.store(int64_t(wg_total_mem_used * ratio), 
std::memory_order_relaxed);
+int64_t WorkloadGroup::memory_used() {
+    return make_memory_tracker_snapshots(nullptr);
+}
+
+void WorkloadGroup::refresh_memory(int64_t used_memory) {
+    // refresh total memory used.
+    _total_mem_used = used_memory;
+    // reserve memory is recorded in the query mem tracker
+    // and _total_mem_used already contains all the current reserve memory.
+    // so after refreshing _total_mem_used, reset 
_wg_refresh_interval_memory_growth.
+    _wg_refresh_interval_memory_growth.store(0.0);
+}
+
+void WorkloadGroup::set_weighted_memory_ratio(double ratio) {
+    _weighted_mem_ratio = ratio;
 }
 
 void WorkloadGroup::add_mem_tracker_limiter(std::shared_ptr<MemTrackerLimiter> 
mem_tracker_ptr) {
+    std::unique_lock<std::shared_mutex> wlock(_mutex);
     auto group_num = mem_tracker_ptr->group_num();
     std::lock_guard<std::mutex> 
l(_mem_tracker_limiter_pool[group_num].group_lock);
-    mem_tracker_ptr->tg_tracker_limiter_group_it =
+    mem_tracker_ptr->wg_tracker_limiter_group_it =
             _mem_tracker_limiter_pool[group_num].trackers.insert(
                     _mem_tracker_limiter_pool[group_num].trackers.end(), 
mem_tracker_ptr);
 }
 
 void 
WorkloadGroup::remove_mem_tracker_limiter(std::shared_ptr<MemTrackerLimiter> 
mem_tracker_ptr) {
+    std::unique_lock<std::shared_mutex> wlock(_mutex);
     auto group_num = mem_tracker_ptr->group_num();
     std::lock_guard<std::mutex> 
l(_mem_tracker_limiter_pool[group_num].group_lock);
-    if (mem_tracker_ptr->tg_tracker_limiter_group_it !=
+    if (mem_tracker_ptr->wg_tracker_limiter_group_it !=
         _mem_tracker_limiter_pool[group_num].trackers.end()) {
         _mem_tracker_limiter_pool[group_num].trackers.erase(
-                mem_tracker_ptr->tg_tracker_limiter_group_it);
-        mem_tracker_ptr->tg_tracker_limiter_group_it =
+                mem_tracker_ptr->wg_tracker_limiter_group_it);
+        mem_tracker_ptr->wg_tracker_limiter_group_it =
                 _mem_tracker_limiter_pool[group_num].trackers.end();
     }
 }
diff --git a/be/src/runtime/workload_group/workload_group.h 
b/be/src/runtime/workload_group/workload_group.h
index b57e5736eb2..a53b7ac6579 100644
--- a/be/src/runtime/workload_group/workload_group.h
+++ b/be/src/runtime/workload_group/workload_group.h
@@ -77,7 +77,12 @@ public:
         return _memory_limit;
     };
 
+    // make memory snapshots and refresh total memory used at the same time.
+    int64_t make_memory_tracker_snapshots(
+            std::list<std::shared_ptr<MemTrackerLimiter>>* tracker_snapshots);
+    // call make_memory_tracker_snapshots, so also refresh total memory used.
     int64_t memory_used();
+    void refresh_memory(int64_t used_memory);
 
     int spill_threshold_low_water_mark() const {
         return _spill_low_watermark.load(std::memory_order_relaxed);
@@ -86,10 +91,31 @@ public:
         return _spill_high_watermark.load(std::memory_order_relaxed);
     }
 
-    void set_weighted_memory_used(int64_t wg_total_mem_used, double ratio);
+    void set_weighted_memory_ratio(double ratio);
+    bool add_wg_refresh_interval_memory_growth(int64_t size) {
+        // `weighted_mem_used` is a rough memory usage in this group,
+        // because we can only get a precise memory usage by MemTracker which 
is not include page cache.
+        auto weighted_mem_used =
+                int64_t((_total_mem_used + 
_wg_refresh_interval_memory_growth.load() + size) *
+                        _weighted_mem_ratio);
+        if ((weighted_mem_used > ((double)_memory_limit *
+                                  
_spill_high_watermark.load(std::memory_order_relaxed) / 100))) {
+            return false;
+        } else {
+            _wg_refresh_interval_memory_growth.fetch_add(size);
+            return true;
+        }
+    }
+    void sub_wg_refresh_interval_memory_growth(int64_t size) {
+        _wg_refresh_interval_memory_growth.fetch_sub(size);
+    }
 
     void check_mem_used(bool* is_low_wartermark, bool* is_high_wartermark) 
const {
-        auto weighted_mem_used = 
_weighted_mem_used.load(std::memory_order_relaxed);
+        // `weighted_mem_used` is a rough memory usage in this group,
+        // because we can only get a precise memory usage by MemTracker which 
is not include page cache.
+        auto weighted_mem_used =
+                int64_t((_total_mem_used + 
_wg_refresh_interval_memory_growth.load()) *
+                        _weighted_mem_ratio);
         *is_low_wartermark =
                 (weighted_mem_used > ((double)_memory_limit *
                                       
_spill_low_watermark.load(std::memory_order_relaxed) / 100));
@@ -138,7 +164,7 @@ public:
 
     bool can_be_dropped() {
         std::shared_lock<std::shared_mutex> r_lock(_mutex);
-        return _is_shutdown && _query_ctxs.size() == 0;
+        return _is_shutdown && _query_ctxs.empty();
     }
 
     int query_num() {
@@ -169,8 +195,12 @@ private:
     const uint64_t _id;
     std::string _name;
     int64_t _version;
-    int64_t _memory_limit;                      // bytes
-    std::atomic_int64_t _weighted_mem_used = 0; // bytes
+    int64_t _memory_limit; // bytes
+    // last value of make_memory_tracker_snapshots, refresh every time 
make_memory_tracker_snapshots is called.
+    std::atomic_int64_t _total_mem_used = 0; // bytes
+    // last value of refresh_wg_weighted_memory_ratio.
+    std::atomic<double> _weighted_mem_ratio = 0.0;
+    std::atomic_int64_t _wg_refresh_interval_memory_growth;
     bool _enable_memory_overcommit;
     std::atomic<uint64_t> _cpu_share;
     std::vector<TrackerLimiterGroup> _mem_tracker_limiter_pool;
diff --git a/be/src/runtime/workload_group/workload_group_manager.cpp 
b/be/src/runtime/workload_group/workload_group_manager.cpp
index 7a93015030f..e9221f67db5 100644
--- a/be/src/runtime/workload_group/workload_group_manager.cpp
+++ b/be/src/runtime/workload_group/workload_group_manager.cpp
@@ -148,50 +148,34 @@ void 
WorkloadGroupMgr::delete_workload_group_by_ids(std::set<uint64_t> used_wg_i
 
 struct WorkloadGroupMemInfo {
     int64_t total_mem_used = 0;
-    int64_t weighted_mem_used = 0;
-    bool is_low_wartermark = false;
-    bool is_high_wartermark = false;
-    double mem_used_ratio = 0;
+    std::list<std::shared_ptr<MemTrackerLimiter>> tracker_snapshots =
+            std::list<std::shared_ptr<MemTrackerLimiter>>();
 };
-void WorkloadGroupMgr::refresh_wg_memory_info() {
+
+void WorkloadGroupMgr::refresh_wg_weighted_memory_ratio() {
     std::shared_lock<std::shared_mutex> r_lock(_group_mutex);
-    // workload group id -> workload group queries
-    std::unordered_map<uint64_t, std::unordered_map<TUniqueId, 
std::weak_ptr<QueryContext>>>
-            all_wg_queries;
-    for (auto& [wg_id, wg] : _workload_groups) {
-        all_wg_queries.insert({wg_id, wg->queries()});
-    }
 
+    // 1. make all workload groups memory snapshots(refresh workload groups 
total memory used at the same time)
+    // and calculate total memory used of all queries.
     int64_t all_queries_mem_used = 0;
-
-    // calculate total memory used of each workload group and total memory 
used of all queries
     std::unordered_map<uint64_t, WorkloadGroupMemInfo> wgs_mem_info;
-    for (auto& [wg_id, wg_queries] : all_wg_queries) {
-        int64_t wg_total_mem_used = 0;
-        for (const auto& [query_id, query_ctx_ptr] : wg_queries) {
-            if (auto query_ctx = query_ctx_ptr.lock()) {
-                wg_total_mem_used += 
query_ctx->query_mem_tracker->consumption();
-            }
-        }
-        all_queries_mem_used += wg_total_mem_used;
-        wgs_mem_info[wg_id] = {wg_total_mem_used};
+    for (auto& [wg_id, wg] : _workload_groups) {
+        wgs_mem_info[wg_id].total_mem_used =
+                
wg->make_memory_tracker_snapshots(&wgs_mem_info[wg_id].tracker_snapshots);
+        all_queries_mem_used += wgs_mem_info[wg_id].total_mem_used;
     }
-
-    // *TODO*, modify to use 
doris::GlobalMemoryArbitrator::process_memory_usage().
-    auto proc_vm_rss = PerfCounters::get_vm_rss();
     if (all_queries_mem_used <= 0) {
         return;
     }
 
-    if (proc_vm_rss < all_queries_mem_used) {
-        all_queries_mem_used = proc_vm_rss;
-    }
-
+    // 2. calculate weighted ratio.
     // process memory used is actually bigger than all_queries_mem_used,
     // because memory of page cache, allocator cache, segment cache etc. are 
included
     // in proc_vm_rss.
     // we count these cache memories equally on workload groups.
-    double ratio = (double)proc_vm_rss / (double)all_queries_mem_used;
+    auto process_memory_usage = GlobalMemoryArbitrator::process_memory_usage();
+    all_queries_mem_used = std::min(process_memory_usage, 
all_queries_mem_used);
+    double ratio = (double)process_memory_usage / (double)all_queries_mem_used;
     if (ratio <= 1.25) {
         std::string debug_msg =
                 fmt::format("\nProcess Memory Summary: {}, {}, all quries mem: 
{}",
@@ -202,66 +186,57 @@ void WorkloadGroupMgr::refresh_wg_memory_info() {
     }
 
     for (auto& wg : _workload_groups) {
+        // 3.1 calculate query weighted memory limit of task group
         auto wg_mem_limit = wg.second->memory_limit();
-        auto& wg_mem_info = wgs_mem_info[wg.first];
-        wg_mem_info.weighted_mem_used = int64_t(wg_mem_info.total_mem_used * 
ratio);
-        wg_mem_info.mem_used_ratio = (double)wg_mem_info.weighted_mem_used / 
wg_mem_limit;
-
-        wg.second->set_weighted_memory_used(wg_mem_info.total_mem_used, ratio);
-
-        auto spill_low_water_mark = 
wg.second->spill_threshold_low_water_mark();
-        auto spill_high_water_mark = 
wg.second->spill_threashold_high_water_mark();
-        wg_mem_info.is_high_wartermark = (wg_mem_info.weighted_mem_used >
-                                          ((double)wg_mem_limit * 
spill_high_water_mark / 100));
-        wg_mem_info.is_low_wartermark = (wg_mem_info.weighted_mem_used >
-                                         ((double)wg_mem_limit * 
spill_low_water_mark / 100));
-
-        // calculate query weighted memory limit of task group
-        const auto& wg_queries = all_wg_queries[wg.first];
-        auto wg_query_count = wg_queries.size();
+        auto wg_query_count = wgs_mem_info[wg.first].tracker_snapshots.size();
         int64_t query_weighted_mem_limit =
                 wg_query_count ? (wg_mem_limit + wg_query_count) / 
wg_query_count : wg_mem_limit;
 
+        // 3.2 set all workload groups weighted memory ratio and all query 
weighted memory limit and ratio.
+        wg.second->set_weighted_memory_ratio(ratio);
+        for (const auto& query : wg.second->queries()) {
+            auto query_ctx = query.second.lock();
+            if (!query_ctx) {
+                continue;
+            }
+            query_ctx->set_weighted_memory(query_weighted_mem_limit, ratio);
+        }
+
+        // 3.3 only print debug logs, if workload groups is_high_wartermark or 
is_low_wartermark.
+        auto weighted_mem_used = int64_t(wgs_mem_info[wg.first].total_mem_used 
* ratio);
+        bool is_high_wartermark =
+                (weighted_mem_used >
+                 ((double)wg_mem_limit * 
wg.second->spill_threashold_high_water_mark() / 100));
+        bool is_low_wartermark =
+                (weighted_mem_used >
+                 ((double)wg_mem_limit * 
wg.second->spill_threshold_low_water_mark() / 100));
         std::string debug_msg;
-        if (wg_mem_info.is_high_wartermark || wg_mem_info.is_low_wartermark) {
+        if (is_high_wartermark || is_low_wartermark) {
             debug_msg = fmt::format(
                     "\nWorkload Group {}: mem limit: {}, mem used: {}, 
weighted mem used: {}, used "
                     "ratio: {}, query "
                     "count: {}, query_weighted_mem_limit: {}",
                     wg.second->name(), PrettyPrinter::print(wg_mem_limit, 
TUnit::BYTES),
-                    PrettyPrinter::print(wg_mem_info.total_mem_used, 
TUnit::BYTES),
-                    PrettyPrinter::print(wg_mem_info.weighted_mem_used, 
TUnit::BYTES),
-                    wg_mem_info.mem_used_ratio, wg_query_count,
+                    
PrettyPrinter::print(wgs_mem_info[wg.first].total_mem_used, TUnit::BYTES),
+                    PrettyPrinter::print(weighted_mem_used, TUnit::BYTES),
+                    (double)weighted_mem_used / wg_mem_limit, wg_query_count,
                     PrettyPrinter::print(query_weighted_mem_limit, 
TUnit::BYTES));
 
             debug_msg += "\n  Query Memory Summary:";
-        } else {
-            continue;
-        }
-        // check whether queries need to revoke memory for task group
-        for (const auto& query : wg_queries) {
-            auto query_ctx = query.second.lock();
-            if (!query_ctx) {
-                continue;
-            }
-            auto query_consumption = 
query_ctx->query_mem_tracker->consumption();
-            auto query_weighted_consumption = int64_t(query_consumption * 
ratio);
-            query_ctx->set_weighted_mem(query_weighted_mem_limit, 
query_weighted_consumption);
-
-            if (wg_mem_info.is_high_wartermark || 
wg_mem_info.is_low_wartermark) {
+            // check whether queries need to revoke memory for task group
+            for (const auto& query_mem_tracker : 
wgs_mem_info[wg.first].tracker_snapshots) {
                 debug_msg += fmt::format(
                         "\n    MemTracker Label={}, Parent Label={}, Used={}, 
WeightedUsed={}, "
                         "Peak={}",
-                        query_ctx->query_mem_tracker->label(),
-                        query_ctx->query_mem_tracker->parent_label(),
-                        PrettyPrinter::print(query_consumption, TUnit::BYTES),
-                        PrettyPrinter::print(query_weighted_consumption, 
TUnit::BYTES),
-                        
PrettyPrinter::print(query_ctx->query_mem_tracker->peak_consumption(),
-                                             TUnit::BYTES));
+                        query_mem_tracker->label(), 
query_mem_tracker->parent_label(),
+                        PrettyPrinter::print(query_mem_tracker->consumption(), 
TUnit::BYTES),
+                        
PrettyPrinter::print(int64_t(query_mem_tracker->consumption() * ratio),
+                                             TUnit::BYTES),
+                        
PrettyPrinter::print(query_mem_tracker->peak_consumption(), TUnit::BYTES));
             }
-        }
-        if (wg_mem_info.is_high_wartermark || wg_mem_info.is_low_wartermark) {
-            LOG_EVERY_T(INFO, 10) << debug_msg;
+            LOG_EVERY_T(INFO, 1) << debug_msg;
+        } else {
+            continue;
         }
     }
 }
diff --git a/be/src/runtime/workload_group/workload_group_manager.h 
b/be/src/runtime/workload_group/workload_group_manager.h
index 8aeb8f988a3..37539ada8d8 100644
--- a/be/src/runtime/workload_group/workload_group_manager.h
+++ b/be/src/runtime/workload_group/workload_group_manager.h
@@ -54,7 +54,7 @@ public:
 
     bool enable_cpu_hard_limit() { return _enable_cpu_hard_limit.load(); }
 
-    void refresh_wg_memory_info();
+    void refresh_wg_weighted_memory_ratio();
 
 private:
     std::shared_mutex _group_mutex;
diff --git a/be/src/vec/exec/scan/scanner_context.cpp 
b/be/src/vec/exec/scan/scanner_context.cpp
index 19a645c1a95..7a9d7c3550b 100644
--- a/be/src/vec/exec/scan/scanner_context.cpp
+++ b/be/src/vec/exec/scan/scanner_context.cpp
@@ -112,7 +112,8 @@ ScannerContext::ScannerContext(RuntimeState* state, const 
TupleDescriptor* outpu
         (_local_state && _local_state->should_run_serial())) {
         _max_thread_num = 1;
     }
-    _query_thread_context = {_query_id, _state->query_mem_tracker()};
+    _query_thread_context = {_query_id, _state->query_mem_tracker(),
+                             _state->get_query_ctx()->workload_group()};
 }
 
 ScannerContext::ScannerContext(doris::RuntimeState* state, 
doris::vectorized::VScanNode* parent,
diff --git a/be/src/vec/runtime/vdata_stream_recvr.cpp 
b/be/src/vec/runtime/vdata_stream_recvr.cpp
index ac90e277080..cb483e986c8 100644
--- a/be/src/vec/runtime/vdata_stream_recvr.cpp
+++ b/be/src/vec/runtime/vdata_stream_recvr.cpp
@@ -338,10 +338,8 @@ VDataStreamRecvr::VDataStreamRecvr(VDataStreamMgr* 
stream_mgr, RuntimeState* sta
                                    int num_senders, bool is_merging, 
RuntimeProfile* profile)
         : HasTaskExecutionCtx(state),
           _mgr(stream_mgr),
-#ifdef USE_MEM_TRACKER
-          _query_mem_tracker(state->query_mem_tracker()),
-          _query_id(state->query_id()),
-#endif
+          _query_thread_context(state->query_id(), state->query_mem_tracker(),
+                                state->get_query_ctx()->workload_group()),
           _fragment_instance_id(fragment_instance_id),
           _dest_node_id(dest_node_id),
           _row_desc(row_desc),
@@ -424,7 +422,7 @@ Status VDataStreamRecvr::create_merger(const 
VExprContextSPtrs& ordering_expr,
 
 Status VDataStreamRecvr::add_block(const PBlock& pblock, int sender_id, int 
be_number,
                                    int64_t packet_seq, 
::google::protobuf::Closure** done) {
-    SCOPED_ATTACH_TASK_WITH_ID(_query_mem_tracker, _query_id);
+    SCOPED_ATTACH_TASK(_query_thread_context);
     int use_sender_id = _is_merging ? sender_id : 0;
     return _sender_queues[use_sender_id]->add_block(pblock, be_number, 
packet_seq, done);
 }
diff --git a/be/src/vec/runtime/vdata_stream_recvr.h 
b/be/src/vec/runtime/vdata_stream_recvr.h
index 3832a10c4f2..cb44565e8c2 100644
--- a/be/src/vec/runtime/vdata_stream_recvr.h
+++ b/be/src/vec/runtime/vdata_stream_recvr.h
@@ -43,6 +43,7 @@
 #include "common/status.h"
 #include "runtime/descriptors.h"
 #include "runtime/task_execution_context.h"
+#include "runtime/thread_context.h"
 #include "util/runtime_profile.h"
 #include "util/stopwatch.hpp"
 #include "vec/core/block.h"
@@ -128,10 +129,7 @@ private:
     // DataStreamMgr instance used to create this recvr. (Not owned)
     VDataStreamMgr* _mgr = nullptr;
 
-#ifdef USE_MEM_TRACKER
-    std::shared_ptr<MemTrackerLimiter> _query_mem_tracker = nullptr;
-    TUniqueId _query_id;
-#endif
+    QueryThreadContext _query_thread_context;
 
     // Fragment and node id of the destination exchange node this receiver is 
used by.
     TUniqueId _fragment_instance_id;
diff --git a/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp 
b/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
index 29c2759fcb7..ab15fce05a7 100644
--- a/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
+++ b/be/test/runtime/memory/thread_mem_tracker_mgr_test.cpp
@@ -26,7 +26,18 @@
 
 namespace doris {
 
-TEST(ThreadMemTrackerMgrTest, ConsumeMemory) {
+class ThreadMemTrackerMgrTest : public testing::Test {
+public:
+    ThreadMemTrackerMgrTest() = default;
+    ~ThreadMemTrackerMgrTest() override = default;
+
+    void SetUp() override {}
+
+protected:
+    std::shared_ptr<WorkloadGroup> workload_group;
+};
+
+TEST_F(ThreadMemTrackerMgrTest, ConsumeMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-ConsumeMemory");
@@ -34,7 +45,7 @@ TEST(ThreadMemTrackerMgrTest, ConsumeMemory) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t);
+    thread_context->attach_task(TUniqueId(), t, workload_group);
     thread_context->consume_memory(size1);
     // size1 < config::mem_tracker_consume_min_size_bytes, not consume mem 
tracker.
     EXPECT_EQ(t->consumption(), 0);
@@ -80,7 +91,7 @@ TEST(ThreadMemTrackerMgrTest, Boundary) {
     // TODO, Boundary check may not be necessary, add some `IF` maybe increase 
cost time.
 }
 
-TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
+TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 = MemTrackerLimiter::create_shared(
             MemTrackerLimiter::Type::OTHER, "UT-NestedSwitchMemTracker1");
@@ -92,7 +103,7 @@ TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
     EXPECT_EQ(t1->consumption(), size1 + size2);
@@ -152,7 +163,7 @@ TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTracker) {
     EXPECT_EQ(t1->consumption(), 0);
 }
 
-TEST(ThreadMemTrackerMgrTest, MultiMemTracker) {
+TEST_F(ThreadMemTrackerMgrTest, MultiMemTracker) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-MultiMemTracker1");
@@ -162,7 +173,7 @@ TEST(ThreadMemTrackerMgrTest, MultiMemTracker) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
     thread_context->consume_memory(size1);
@@ -213,7 +224,7 @@ TEST(ThreadMemTrackerMgrTest, MultiMemTracker) {
     EXPECT_EQ(t3->consumption(), size1 + size2 - size1);
 }
 
-TEST(ThreadMemTrackerMgrTest, ScopedCount) {
+TEST_F(ThreadMemTrackerMgrTest, ScopedCount) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-ScopedCount");
@@ -221,7 +232,7 @@ TEST(ThreadMemTrackerMgrTest, ScopedCount) {
     int64_t size1 = 4 * 1024;
     int64_t size2 = 4 * 1024 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->thread_mem_tracker_mgr->start_count_scope_mem();
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
@@ -239,7 +250,7 @@ TEST(ThreadMemTrackerMgrTest, ScopedCount) {
     EXPECT_EQ(scope_mem, size1 + size2 + size1 + size2 + size1);
 }
 
-TEST(ThreadMemTrackerMgrTest, ReserveMemory) {
+TEST_F(ThreadMemTrackerMgrTest, ReserveMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t =
             MemTrackerLimiter::create_shared(MemTrackerLimiter::Type::OTHER, 
"UT-ReserveMemory");
@@ -248,7 +259,7 @@ TEST(ThreadMemTrackerMgrTest, ReserveMemory) {
     int64_t size2 = 4 * 1024 * 1024;
     int64_t size3 = size2 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t);
+    thread_context->attach_task(TUniqueId(), t, workload_group);
     thread_context->consume_memory(size1);
     thread_context->consume_memory(size2);
     EXPECT_EQ(t->consumption(), size1 + size2);
@@ -338,7 +349,7 @@ TEST(ThreadMemTrackerMgrTest, ReserveMemory) {
     EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), 0);
 }
 
-TEST(ThreadMemTrackerMgrTest, NestedReserveMemory) {
+TEST_F(ThreadMemTrackerMgrTest, NestedReserveMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t = MemTrackerLimiter::create_shared(
             MemTrackerLimiter::Type::OTHER, "UT-NestedReserveMemory");
@@ -346,7 +357,7 @@ TEST(ThreadMemTrackerMgrTest, NestedReserveMemory) {
     int64_t size2 = 4 * 1024 * 1024;
     int64_t size3 = size2 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t);
+    thread_context->attach_task(TUniqueId(), t, workload_group);
     thread_context->try_reserve_memory(size3);
     EXPECT_EQ(t->consumption(), size3);
     EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), size3);
@@ -386,7 +397,7 @@ TEST(ThreadMemTrackerMgrTest, NestedReserveMemory) {
     EXPECT_EQ(doris::GlobalMemoryArbitrator::process_reserved_memory(), 0);
 }
 
-TEST(ThreadMemTrackerMgrTest, NestedSwitchMemTrackerReserveMemory) {
+TEST_F(ThreadMemTrackerMgrTest, NestedSwitchMemTrackerReserveMemory) {
     std::unique_ptr<ThreadContext> thread_context = 
std::make_unique<ThreadContext>();
     std::shared_ptr<MemTrackerLimiter> t1 = MemTrackerLimiter::create_shared(
             MemTrackerLimiter::Type::OTHER, 
"UT-NestedSwitchMemTrackerReserveMemory1");
@@ -399,7 +410,7 @@ TEST(ThreadMemTrackerMgrTest, 
NestedSwitchMemTrackerReserveMemory) {
     int64_t size2 = 4 * 1024 * 1024;
     int64_t size3 = size2 * 1024;
 
-    thread_context->attach_task(TUniqueId(), t1);
+    thread_context->attach_task(TUniqueId(), t1, workload_group);
     thread_context->try_reserve_memory(size3);
     thread_context->consume_memory(size2);
     EXPECT_EQ(t1->consumption(), size3);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to