This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new b7a62a06ec1 [branch-3.0](runtime filter) pick #43885 (#44130)
b7a62a06ec1 is described below

commit b7a62a06ec1e0b885dbc8030caeb3589d71d79cc
Author: Gabriel <liwenqi...@selectdb.com>
AuthorDate: Mon Nov 18 14:21:40 2024 +0800

    [branch-3.0](runtime filter) pick #43885 (#44130)
---
 be/src/exprs/runtime_filter.cpp                    | 83 ++++++++++++----------
 be/src/exprs/runtime_filter.h                      |  6 +-
 be/src/pipeline/exec/hashjoin_build_sink.cpp       |  9 +--
 be/src/pipeline/exec/hashjoin_build_sink.h         |  4 +-
 .../exec/nested_loop_join_build_operator.cpp       |  8 +--
 .../exec/nested_loop_join_build_operator.h         |  3 +-
 .../exec/partitioned_hash_join_sink_operator.cpp   |  8 ++-
 .../exec/partitioned_hash_join_sink_operator.h     |  3 +-
 be/src/pipeline/pipeline_fragment_context.cpp      | 48 ++++---------
 be/src/pipeline/pipeline_fragment_context.h        |  4 +-
 be/src/runtime/query_context.cpp                   |  2 +-
 be/src/runtime/runtime_filter_mgr.cpp              | 73 +++++++++++--------
 be/src/runtime/runtime_filter_mgr.h                | 44 +++++++-----
 be/src/runtime/runtime_state.cpp                   | 15 ++--
 be/src/runtime/runtime_state.h                     |  1 -
 15 files changed, 153 insertions(+), 158 deletions(-)

diff --git a/be/src/exprs/runtime_filter.cpp b/be/src/exprs/runtime_filter.cpp
index b9bf837cf56..f28cc53dcb8 100644
--- a/be/src/exprs/runtime_filter.cpp
+++ b/be/src/exprs/runtime_filter.cpp
@@ -994,17 +994,17 @@ void IRuntimeFilter::insert_batch(const 
vectorized::ColumnPtr column, size_t sta
 Status IRuntimeFilter::publish(bool publish_local) {
     DCHECK(is_producer());
 
-    auto send_to_remote = [&](IRuntimeFilter* filter) {
+    auto send_to_remote_targets = [&](IRuntimeFilter* filter) {
         TNetworkAddress addr;
         DCHECK(_state != nullptr);
-        RETURN_IF_ERROR(_state->runtime_filter_mgr->get_merge_addr(&addr));
+        
RETURN_IF_ERROR(_state->global_runtime_filter_mgr()->get_merge_addr(&addr));
         return filter->push_to_remote(&addr);
     };
-    auto send_to_local = [&](std::shared_ptr<RuntimePredicateWrapper> wrapper) 
{
-        std::vector<std::shared_ptr<IRuntimeFilter>> filters;
-        
RETURN_IF_ERROR(_state->runtime_filter_mgr->get_consume_filters(_filter_id, 
filters));
-        DCHECK(!filters.empty());
-        // push down
+    auto send_to_local_targets = [&](std::shared_ptr<RuntimePredicateWrapper> 
wrapper,
+                                     bool global) {
+        std::vector<std::shared_ptr<IRuntimeFilter>> filters =
+                global ? 
_state->global_runtime_filter_mgr()->get_consume_filters(_filter_id)
+                       : 
_state->local_runtime_filter_mgr()->get_consume_filters(_filter_id);
         for (auto filter : filters) {
             filter->_wrapper = wrapper;
             filter->update_runtime_filter_type_to_profile();
@@ -1012,32 +1012,36 @@ Status IRuntimeFilter::publish(bool publish_local) {
         }
         return Status::OK();
     };
-    auto do_local_merge = [&]() {
-        LocalMergeFilters* local_merge_filters = nullptr;
-        
RETURN_IF_ERROR(_state->runtime_filter_mgr->get_local_merge_producer_filters(
-                _filter_id, &local_merge_filters));
-        std::lock_guard l(*local_merge_filters->lock);
-        
RETURN_IF_ERROR(local_merge_filters->filters[0]->merge_from(_wrapper.get()));
-        local_merge_filters->merge_time--;
-        if (local_merge_filters->merge_time == 0) {
-            if (_has_local_target) {
-                
RETURN_IF_ERROR(send_to_local(local_merge_filters->filters[0]->_wrapper));
-            } else {
-                
RETURN_IF_ERROR(send_to_remote(local_merge_filters->filters[0].get()));
+    auto do_merge = [&]() {
+        if 
(!_state->global_runtime_filter_mgr()->get_consume_filters(_filter_id).empty()) 
{
+            LocalMergeFilters* local_merge_filters = nullptr;
+            
RETURN_IF_ERROR(_state->global_runtime_filter_mgr()->get_local_merge_producer_filters(
+                    _filter_id, &local_merge_filters));
+            std::lock_guard l(*local_merge_filters->lock);
+            
RETURN_IF_ERROR(local_merge_filters->filters[0]->merge_from(_wrapper.get()));
+            local_merge_filters->merge_time--;
+            if (local_merge_filters->merge_time == 0) {
+                if (_has_local_target) {
+                    RETURN_IF_ERROR(
+                            
send_to_local_targets(local_merge_filters->filters[0]->_wrapper, true));
+                } else {
+                    
RETURN_IF_ERROR(send_to_remote_targets(local_merge_filters->filters[0].get()));
+                }
             }
         }
         return Status::OK();
     };
 
-    if (_need_local_merge && _has_local_target) {
-        RETURN_IF_ERROR(do_local_merge());
-    } else if (_has_local_target) {
-        RETURN_IF_ERROR(send_to_local(_wrapper));
+    if (_has_local_target) {
+        // A runtime filter may have multiple targets and some of those are 
local-merge RF and others are not.
+        // So for all runtime filters' producers, `publish` should notify all 
consumers in global RF mgr which manages local-merge RF and local RF mgr which 
manages others.
+        RETURN_IF_ERROR(do_merge());
+        RETURN_IF_ERROR(send_to_local_targets(_wrapper, false));
     } else if (!publish_local) {
-        if (_is_broadcast_join || _state->be_exec_version < USE_NEW_SERDE) {
-            RETURN_IF_ERROR(send_to_remote(this));
+        if (_is_broadcast_join || _state->get_query_ctx()->be_exec_version() < 
USE_NEW_SERDE) {
+            RETURN_IF_ERROR(send_to_remote_targets(this));
         } else {
-            RETURN_IF_ERROR(do_local_merge());
+            RETURN_IF_ERROR(do_merge());
         }
     } else {
         // remote broadcast join only push onetime in build shared hash table
@@ -1096,13 +1100,16 @@ public:
 Status IRuntimeFilter::send_filter_size(RuntimeState* state, uint64_t 
local_filter_size) {
     DCHECK(is_producer());
 
-    if (_need_local_merge) {
+    if 
(!_state->global_runtime_filter_mgr()->get_consume_filters(_filter_id).empty()) 
{
         LocalMergeFilters* local_merge_filters = nullptr;
-        
RETURN_IF_ERROR(_state->runtime_filter_mgr->get_local_merge_producer_filters(
+        
RETURN_IF_ERROR(_state->global_runtime_filter_mgr()->get_local_merge_producer_filters(
                 _filter_id, &local_merge_filters));
         std::lock_guard l(*local_merge_filters->lock);
         local_merge_filters->merge_size_times--;
         local_merge_filters->local_merged_size += local_filter_size;
+        if (_has_local_target) {
+            set_synced_size(local_filter_size);
+        }
         if (local_merge_filters->merge_size_times) {
             return Status::OK();
         } else {
@@ -1122,9 +1129,9 @@ Status IRuntimeFilter::send_filter_size(RuntimeState* 
state, uint64_t local_filt
 
     TNetworkAddress addr;
     DCHECK(_state != nullptr);
-    RETURN_IF_ERROR(_state->runtime_filter_mgr->get_merge_addr(&addr));
+    
RETURN_IF_ERROR(_state->global_runtime_filter_mgr()->get_merge_addr(&addr));
     std::shared_ptr<PBackendService_Stub> stub(
-            _state->exec_env->brpc_internal_client_cache()->get_client(addr));
+            
_state->get_query_ctx()->exec_env()->brpc_internal_client_cache()->get_client(addr));
     if (!stub) {
         return Status::InternalError("Get rpc stub failed, host={}, port={}", 
addr.hostname,
                                      addr.port);
@@ -1137,8 +1144,8 @@ Status IRuntimeFilter::send_filter_size(RuntimeState* 
state, uint64_t local_filt
     auto closure = SyncSizeClosure::create_unique(request, callback, 
_dependency,
                                                   _wrapper->_context, 
this->debug_string());
     auto* pquery_id = request->mutable_query_id();
-    pquery_id->set_hi(_state->query_id.hi());
-    pquery_id->set_lo(_state->query_id.lo());
+    pquery_id->set_hi(_state->get_query_ctx()->query_id().hi);
+    pquery_id->set_lo(_state->get_query_ctx()->query_id().lo);
 
     auto* source_addr = request->mutable_source_addr();
     source_addr->set_hostname(BackendOptions::get_local_backend().host);
@@ -1157,7 +1164,7 @@ Status IRuntimeFilter::send_filter_size(RuntimeState* 
state, uint64_t local_filt
 Status IRuntimeFilter::push_to_remote(const TNetworkAddress* addr) {
     DCHECK(is_producer());
     std::shared_ptr<PBackendService_Stub> stub(
-            _state->exec_env->brpc_internal_client_cache()->get_client(*addr));
+            
_state->get_query_ctx()->exec_env()->brpc_internal_client_cache()->get_client(*addr));
     if (!stub) {
         return Status::InternalError(
                 fmt::format("Get rpc stub failed, host={}, port={}", 
addr->hostname, addr->port));
@@ -1172,8 +1179,8 @@ Status IRuntimeFilter::push_to_remote(const 
TNetworkAddress* addr) {
     int len = 0;
 
     auto* pquery_id = merge_filter_request->mutable_query_id();
-    pquery_id->set_hi(_state->query_id.hi());
-    pquery_id->set_lo(_state->query_id.lo());
+    pquery_id->set_hi(_state->get_query_ctx()->query_id().hi);
+    pquery_id->set_lo(_state->get_query_ctx()->query_id().lo);
 
     auto* pfragment_instance_id = 
merge_filter_request->mutable_fragment_instance_id();
     pfragment_instance_id->set_hi(BackendOptions::get_local_backend().id);
@@ -1183,7 +1190,7 @@ Status IRuntimeFilter::push_to_remote(const 
TNetworkAddress* addr) {
     merge_filter_request->set_is_pipeline(true);
     auto column_type = _wrapper->column_type();
     
RETURN_IF_CATCH_EXCEPTION(merge_filter_request->set_column_type(to_proto(column_type)));
-    merge_filter_callback->cntl_->set_timeout_ms(wait_time_ms());
+    
merge_filter_callback->cntl_->set_timeout_ms(_state->get_query_ctx()->execution_timeout());
 
     if (get_ignored()) {
         merge_filter_request->set_filter_type(PFilterType::UNKNOW_FILTER);
@@ -1227,8 +1234,8 @@ Status 
IRuntimeFilter::get_push_expr_ctxs(std::list<vectorized::VExprContextSPtr
 
 void IRuntimeFilter::update_state() {
     DCHECK(is_consumer());
-    auto execution_timeout = _state->execution_timeout * 1000;
-    auto runtime_filter_wait_time_ms = _state->runtime_filter_wait_time_ms;
+    auto execution_timeout = _state->get_query_ctx()->execution_timeout() * 
1000;
+    auto runtime_filter_wait_time_ms = 
_state->get_query_ctx()->runtime_filter_wait_time_ms();
     // bitmap filter is precise filter and only filter once, so it must be 
applied.
     int64_t wait_times_ms = _runtime_filter_type == 
RuntimeFilterType::BITMAP_FILTER
                                     ? execution_timeout
diff --git a/be/src/exprs/runtime_filter.h b/be/src/exprs/runtime_filter.h
index 629e5fa2550..9e0e93433d5 100644
--- a/be/src/exprs/runtime_filter.h
+++ b/be/src/exprs/runtime_filter.h
@@ -201,8 +201,8 @@ public:
               _role(RuntimeFilterRole::PRODUCER),
               _expr_order(-1),
               registration_time_(MonotonicMillis()),
-              _wait_infinitely(_state->runtime_filter_wait_infinitely),
-              _rf_wait_time_ms(_state->runtime_filter_wait_time_ms),
+              
_wait_infinitely(_state->get_query_ctx()->runtime_filter_wait_infinitely()),
+              
_rf_wait_time_ms(_state->get_query_ctx()->runtime_filter_wait_time_ms()),
               _runtime_filter_type(get_runtime_filter_type(desc)),
               _profile(
                       new RuntimeProfile(fmt::format("RuntimeFilter: (id = {}, 
type = {})",
@@ -333,7 +333,7 @@ public:
     int32_t wait_time_ms() const {
         int32_t res = 0;
         if (wait_infinitely()) {
-            res = _state->execution_timeout;
+            res = _state->get_query_ctx()->execution_timeout();
             // Convert to ms
             res *= 1000;
         } else {
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index 7efeb7692d4..0e498661c47 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -92,8 +92,7 @@ Status HashJoinBuildSinkLocalState::init(RuntimeState* state, 
LocalSinkStateInfo
     _runtime_filters.resize(p._runtime_filter_descs.size());
     for (size_t i = 0; i < p._runtime_filter_descs.size(); i++) {
         RETURN_IF_ERROR(state->register_producer_runtime_filter(
-                p._runtime_filter_descs[i], p._need_local_merge, 
&_runtime_filters[i],
-                _build_expr_ctxs.size() == 1));
+                p._runtime_filter_descs[i], &_runtime_filters[i], 
_build_expr_ctxs.size() == 1));
     }
 
     _runtime_filter_slots =
@@ -423,8 +422,7 @@ void 
HashJoinBuildSinkLocalState::_hash_table_init(RuntimeState* state) {
 
 HashJoinBuildSinkOperatorX::HashJoinBuildSinkOperatorX(ObjectPool* pool, int 
operator_id,
                                                        const TPlanNode& tnode,
-                                                       const DescriptorTbl& 
descs,
-                                                       bool need_local_merge)
+                                                       const DescriptorTbl& 
descs)
         : JoinBuildSinkOperatorX(pool, operator_id, tnode, descs),
           _join_distribution(tnode.hash_join_node.__isset.dist_type ? 
tnode.hash_join_node.dist_type
                                                                     : 
TJoinDistributionType::NONE),
@@ -432,8 +430,7 @@ 
HashJoinBuildSinkOperatorX::HashJoinBuildSinkOperatorX(ObjectPool* pool, int ope
                              tnode.hash_join_node.is_broadcast_join),
           _partition_exprs(tnode.__isset.distribute_expr_lists && 
!_is_broadcast_join
                                    ? tnode.distribute_expr_lists[1]
-                                   : std::vector<TExpr> {}),
-          _need_local_merge(need_local_merge) {}
+                                   : std::vector<TExpr> {}) {}
 
 Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* 
state) {
     RETURN_IF_ERROR(JoinBuildSinkOperatorX::init(tnode, state));
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index 930d3761791..4833bee5488 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -108,7 +108,7 @@ class HashJoinBuildSinkOperatorX final
         : public JoinBuildSinkOperatorX<HashJoinBuildSinkLocalState> {
 public:
     HashJoinBuildSinkOperatorX(ObjectPool* pool, int operator_id, const 
TPlanNode& tnode,
-                               const DescriptorTbl& descs, bool use_global_rf);
+                               const DescriptorTbl& descs);
     Status init(const TDataSink& tsink) override {
         return Status::InternalError("{} should not init with TDataSink",
                                      
JoinBuildSinkOperatorX<HashJoinBuildSinkLocalState>::_name);
@@ -166,8 +166,6 @@ private:
 
     vectorized::SharedHashTableContextPtr _shared_hash_table_context = nullptr;
     const std::vector<TExpr> _partition_exprs;
-
-    const bool _need_local_merge;
 };
 
 template <class HashTableContext>
diff --git a/be/src/pipeline/exec/nested_loop_join_build_operator.cpp 
b/be/src/pipeline/exec/nested_loop_join_build_operator.cpp
index 6c164417822..1b1723fe2c1 100644
--- a/be/src/pipeline/exec/nested_loop_join_build_operator.cpp
+++ b/be/src/pipeline/exec/nested_loop_join_build_operator.cpp
@@ -66,8 +66,8 @@ Status NestedLoopJoinBuildSinkLocalState::init(RuntimeState* 
state, LocalSinkSta
     _shared_state->join_op_variants = p._join_op_variants;
     _runtime_filters.resize(p._runtime_filter_descs.size());
     for (size_t i = 0; i < p._runtime_filter_descs.size(); i++) {
-        RETURN_IF_ERROR(state->register_producer_runtime_filter(
-                p._runtime_filter_descs[i], p._need_local_merge, 
&_runtime_filters[i], false));
+        
RETURN_IF_ERROR(state->register_producer_runtime_filter(p._runtime_filter_descs[i],
+                                                                
&_runtime_filters[i], false));
     }
     return Status::OK();
 }
@@ -87,11 +87,9 @@ Status NestedLoopJoinBuildSinkLocalState::open(RuntimeState* 
state) {
 NestedLoopJoinBuildSinkOperatorX::NestedLoopJoinBuildSinkOperatorX(ObjectPool* 
pool,
                                                                    int 
operator_id,
                                                                    const 
TPlanNode& tnode,
-                                                                   const 
DescriptorTbl& descs,
-                                                                   bool 
need_local_merge)
+                                                                   const 
DescriptorTbl& descs)
         : JoinBuildSinkOperatorX<NestedLoopJoinBuildSinkLocalState>(pool, 
operator_id, tnode,
                                                                     descs),
-          _need_local_merge(need_local_merge),
           
_is_output_left_side_only(tnode.nested_loop_join_node.__isset.is_output_left_side_only
 &&
                                     
tnode.nested_loop_join_node.is_output_left_side_only),
           _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {}
diff --git a/be/src/pipeline/exec/nested_loop_join_build_operator.h 
b/be/src/pipeline/exec/nested_loop_join_build_operator.h
index d6e72799f97..5c41088a705 100644
--- a/be/src/pipeline/exec/nested_loop_join_build_operator.h
+++ b/be/src/pipeline/exec/nested_loop_join_build_operator.h
@@ -59,7 +59,7 @@ class NestedLoopJoinBuildSinkOperatorX final
         : public JoinBuildSinkOperatorX<NestedLoopJoinBuildSinkLocalState> {
 public:
     NestedLoopJoinBuildSinkOperatorX(ObjectPool* pool, int operator_id, const 
TPlanNode& tnode,
-                                     const DescriptorTbl& descs, bool 
need_local_merge);
+                                     const DescriptorTbl& descs);
     Status init(const TDataSink& tsink) override {
         return Status::InternalError(
                 "{} should not init with TDataSink",
@@ -85,7 +85,6 @@ private:
 
     vectorized::VExprContextSPtrs _filter_src_expr_ctxs;
 
-    bool _need_local_merge;
     const bool _is_output_left_side_only;
     RowDescriptor _row_descriptor;
 };
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
index a7297be493f..32af13ba548 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
@@ -393,9 +393,11 @@ void PartitionedHashJoinSinkLocalState::_spill_to_disk(
     }
 }
 
-PartitionedHashJoinSinkOperatorX::PartitionedHashJoinSinkOperatorX(
-        ObjectPool* pool, int operator_id, const TPlanNode& tnode, const 
DescriptorTbl& descs,
-        bool use_global_rf, uint32_t partition_count)
+PartitionedHashJoinSinkOperatorX::PartitionedHashJoinSinkOperatorX(ObjectPool* 
pool,
+                                                                   int 
operator_id,
+                                                                   const 
TPlanNode& tnode,
+                                                                   const 
DescriptorTbl& descs,
+                                                                   uint32_t 
partition_count)
         : JoinBuildSinkOperatorX<PartitionedHashJoinSinkLocalState>(pool, 
operator_id, tnode,
                                                                     descs),
           _join_distribution(tnode.hash_join_node.__isset.dist_type ? 
tnode.hash_join_node.dist_type
diff --git a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h 
b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
index 8e89763b50a..d1fe30e06f2 100644
--- a/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
+++ b/be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
@@ -82,8 +82,7 @@ class PartitionedHashJoinSinkOperatorX
         : public JoinBuildSinkOperatorX<PartitionedHashJoinSinkLocalState> {
 public:
     PartitionedHashJoinSinkOperatorX(ObjectPool* pool, int operator_id, const 
TPlanNode& tnode,
-                                     const DescriptorTbl& descs, bool 
use_global_rf,
-                                     uint32_t partition_count);
+                                     const DescriptorTbl& descs, uint32_t 
partition_count);
 
     Status init(const TDataSink& tsink) override {
         return Status::InternalError("{} should not init with TDataSink",
diff --git a/be/src/pipeline/pipeline_fragment_context.cpp 
b/be/src/pipeline/pipeline_fragment_context.cpp
index 75623b8f3e9..b7aa8518422 100644
--- a/be/src/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_fragment_context.cpp
@@ -296,8 +296,6 @@ Status PipelineFragmentContext::prepare(const 
doris::TPipelineFragmentParams& re
         if (local_params.__isset.topn_filter_descs) {
             
_query_ctx->init_runtime_predicates(local_params.topn_filter_descs);
         }
-
-        _need_local_merge = request.__isset.parallel_instances;
     }
 
     {
@@ -369,29 +367,9 @@ Status 
PipelineFragmentContext::_build_pipeline_tasks(const doris::TPipelineFrag
         auto fragment_instance_id = local_params.fragment_instance_id;
         _fragment_instance_ids[i] = fragment_instance_id;
 
-        auto filterparams = std::make_unique<RuntimeFilterParamsContext>();
-
-        {
-            filterparams->runtime_filter_wait_infinitely =
-                    _runtime_state->runtime_filter_wait_infinitely();
-            filterparams->runtime_filter_wait_time_ms =
-                    _runtime_state->runtime_filter_wait_time_ms();
-            filterparams->execution_timeout = 
_runtime_state->execution_timeout();
-
-            filterparams->exec_env = ExecEnv::GetInstance();
-            filterparams->query_id.set_hi(_runtime_state->query_id().hi);
-            filterparams->query_id.set_lo(_runtime_state->query_id().lo);
-
-            filterparams->be_exec_version = _runtime_state->be_exec_version();
-            filterparams->query_ctx = _query_ctx.get();
-        }
-
-        auto runtime_filter_mgr = std::make_unique<RuntimeFilterMgr>(
-                request.query_id, filterparams.get(), 
_query_ctx->query_mem_tracker);
-
-        filterparams->runtime_filter_mgr = runtime_filter_mgr.get();
-
-        _runtime_filter_states[i] = std::move(filterparams);
+        _runtime_filter_states[i] = 
RuntimeFilterParamsContext::create(_query_ctx.get());
+        std::unique_ptr<RuntimeFilterMgr> runtime_filter_mgr = 
std::make_unique<RuntimeFilterMgr>(
+                request.query_id, _runtime_filter_states[i], 
_query_ctx->query_mem_tracker, false);
         std::map<PipelineId, PipelineTask*> pipeline_id_to_task;
         auto get_local_exchange_state = [&](PipelinePtr pipeline)
                 -> std::map<int, 
std::pair<std::shared_ptr<LocalExchangeSharedState>,
@@ -423,6 +401,7 @@ Status PipelineFragmentContext::_build_pipeline_tasks(const 
doris::TPipelineFrag
                         request.fragment_id, request.query_options, 
_query_ctx->query_globals,
                         _exec_env, _query_ctx.get());
                 auto& task_runtime_state = _task_runtime_states[pip_idx][i];
+                _runtime_filter_states[i]->set_state(task_runtime_state.get());
                 {
                     // Initialize runtime state for this task
                     
task_runtime_state->set_query_mem_tracker(_query_ctx->query_mem_tracker);
@@ -454,9 +433,8 @@ Status PipelineFragmentContext::_build_pipeline_tasks(const 
doris::TPipelineFrag
                     
task_runtime_state->set_load_stream_per_node(request.load_stream_per_node);
                     
task_runtime_state->set_total_load_streams(request.total_load_streams);
                     
task_runtime_state->set_num_local_sink(request.num_local_sink);
-                    DCHECK(_runtime_filter_states[i]->runtime_filter_mgr);
-                    task_runtime_state->set_runtime_filter_mgr(
-                            _runtime_filter_states[i]->runtime_filter_mgr);
+
+                    
task_runtime_state->set_runtime_filter_mgr(runtime_filter_mgr.get());
                 }
                 auto cur_task_id = _total_tasks++;
                 task_runtime_state->set_task_id(cur_task_id);
@@ -1377,8 +1355,8 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
             const uint32_t partition_count = 32;
             auto inner_probe_operator =
                     std::make_shared<HashJoinProbeOperatorX>(pool, tnode_, 0, 
descs);
-            auto inner_sink_operator = 
std::make_shared<HashJoinBuildSinkOperatorX>(
-                    pool, 0, tnode_, descs, _need_local_merge);
+            auto inner_sink_operator =
+                    std::make_shared<HashJoinBuildSinkOperatorX>(pool, 0, 
tnode_, descs);
 
             RETURN_IF_ERROR(inner_probe_operator->init(tnode_, 
_runtime_state.get()));
             RETURN_IF_ERROR(inner_sink_operator->init(tnode_, 
_runtime_state.get()));
@@ -1398,8 +1376,7 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
             _dag[downstream_pipeline_id].push_back(build_side_pipe->id());
 
             auto sink_operator = 
std::make_shared<PartitionedHashJoinSinkOperatorX>(
-                    pool, next_sink_operator_id(), tnode_, descs, 
_need_local_merge,
-                    partition_count);
+                    pool, next_sink_operator_id(), tnode_, descs, 
partition_count);
             sink_operator->set_inner_operators(inner_sink_operator, 
inner_probe_operator);
             DataSinkOperatorPtr sink = std::move(sink_operator);
             sink->set_dests_id({op->operator_id()});
@@ -1423,8 +1400,7 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
             _dag[downstream_pipeline_id].push_back(build_side_pipe->id());
 
             DataSinkOperatorPtr sink;
-            sink.reset(new HashJoinBuildSinkOperatorX(pool, 
next_sink_operator_id(), tnode, descs,
-                                                      _need_local_merge));
+            sink.reset(new HashJoinBuildSinkOperatorX(pool, 
next_sink_operator_id(), tnode, descs));
             sink->set_dests_id({op->operator_id()});
             RETURN_IF_ERROR(build_side_pipe->set_sink(sink));
             RETURN_IF_ERROR(build_side_pipe->sink()->init(tnode, 
_runtime_state.get()));
@@ -1451,8 +1427,8 @@ Status 
PipelineFragmentContext::_create_operator(ObjectPool* pool, const TPlanNo
         _dag[downstream_pipeline_id].push_back(build_side_pipe->id());
 
         DataSinkOperatorPtr sink;
-        sink.reset(new NestedLoopJoinBuildSinkOperatorX(pool, 
next_sink_operator_id(), tnode, descs,
-                                                        _need_local_merge));
+        sink.reset(
+                new NestedLoopJoinBuildSinkOperatorX(pool, 
next_sink_operator_id(), tnode, descs));
         sink->set_dests_id({op->operator_id()});
         RETURN_IF_ERROR(build_side_pipe->set_sink(sink));
         RETURN_IF_ERROR(build_side_pipe->sink()->init(tnode, 
_runtime_state.get()));
diff --git a/be/src/pipeline/pipeline_fragment_context.h 
b/be/src/pipeline/pipeline_fragment_context.h
index 2e75aeb414e..08ef05da5fa 100644
--- a/be/src/pipeline/pipeline_fragment_context.h
+++ b/be/src/pipeline/pipeline_fragment_context.h
@@ -228,8 +228,6 @@ private:
     // this is a [n * m] matrix. n is parallelism of pipeline engine and m is 
the number of pipelines.
     std::vector<std::vector<std::unique_ptr<PipelineTask>>> _tasks;
 
-    bool _need_local_merge = false;
-
     // TODO: remove the _sink and _multi_cast_stream_sink_senders to set both
     // of it in pipeline task not the fragment_context
 #ifdef __clang__
@@ -301,7 +299,7 @@ private:
      */
     std::vector<std::vector<std::unique_ptr<RuntimeState>>> 
_task_runtime_states;
 
-    std::vector<std::unique_ptr<RuntimeFilterParamsContext>> 
_runtime_filter_states;
+    std::vector<RuntimeFilterParamsContext*> _runtime_filter_states;
 
     // Total instance num running on all BEs
     int _total_instances = -1;
diff --git a/be/src/runtime/query_context.cpp b/be/src/runtime/query_context.cpp
index ece4c76a17a..f4d4256e66f 100644
--- a/be/src/runtime/query_context.cpp
+++ b/be/src/runtime/query_context.cpp
@@ -87,7 +87,7 @@ QueryContext::QueryContext(TUniqueId query_id, ExecEnv* 
exec_env,
     _shared_hash_table_controller.reset(new 
vectorized::SharedHashTableController());
     _execution_dependency = pipeline::Dependency::create_unique(-1, -1, 
"ExecutionDependency");
     _runtime_filter_mgr = std::make_unique<RuntimeFilterMgr>(
-            TUniqueId(), RuntimeFilterParamsContext::create(this), 
query_mem_tracker);
+            TUniqueId(), RuntimeFilterParamsContext::create(this), 
query_mem_tracker, true);
 
     _timeout_second = query_options.execution_timeout;
 
diff --git a/be/src/runtime/runtime_filter_mgr.cpp 
b/be/src/runtime/runtime_filter_mgr.cpp
index a4631cfaba7..31b9ec3b0c2 100644
--- a/be/src/runtime/runtime_filter_mgr.cpp
+++ b/be/src/runtime/runtime_filter_mgr.cpp
@@ -44,12 +44,12 @@
 namespace doris {
 
 RuntimeFilterMgr::RuntimeFilterMgr(const UniqueId& query_id, 
RuntimeFilterParamsContext* state,
-                                   const std::shared_ptr<MemTrackerLimiter>& 
query_mem_tracker) {
-    _state = state;
-    _state->runtime_filter_mgr = this;
-    _query_mem_tracker = query_mem_tracker;
-    _tracker = std::make_unique<MemTracker>("RuntimeFilterMgr(experimental)");
-}
+                                   const std::shared_ptr<MemTrackerLimiter>& 
query_mem_tracker,
+                                   const bool is_global)
+        : _is_global(is_global),
+          _state(state),
+          
_tracker(std::make_unique<MemTracker>("RuntimeFilterMgr(experimental)")),
+          _query_mem_tracker(query_mem_tracker) {}
 
 RuntimeFilterMgr::~RuntimeFilterMgr() {
     CHECK(_query_mem_tracker != nullptr);
@@ -59,6 +59,7 @@ RuntimeFilterMgr::~RuntimeFilterMgr() {
 
 Status RuntimeFilterMgr::get_consume_filters(
         const int filter_id, std::vector<std::shared_ptr<IRuntimeFilter>>& 
consumer_filters) {
+    DCHECK(_is_global);
     std::lock_guard<std::mutex> l(_lock);
     auto iter = _consumer_map.find(filter_id);
     if (iter == _consumer_map.end()) {
@@ -71,6 +72,20 @@ Status RuntimeFilterMgr::get_consume_filters(
     return Status::OK();
 }
 
+std::vector<std::shared_ptr<IRuntimeFilter>> 
RuntimeFilterMgr::get_consume_filters(
+        const int filter_id) {
+    std::lock_guard<std::mutex> l(_lock);
+    auto iter = _consumer_map.find(filter_id);
+    if (iter == _consumer_map.end()) {
+        return {};
+    }
+    std::vector<std::shared_ptr<IRuntimeFilter>> consumer_filters;
+    for (auto& holder : iter->second) {
+        consumer_filters.emplace_back(holder.filter);
+    }
+    return consumer_filters;
+}
+
 Status RuntimeFilterMgr::register_consumer_filter(const TRuntimeFilterDesc& 
desc,
                                                   const TQueryOptions& 
options, int node_id,
                                                   
std::shared_ptr<IRuntimeFilter>* consumer_filter,
@@ -89,6 +104,8 @@ Status RuntimeFilterMgr::register_consumer_filter(const 
TRuntimeFilterDesc& desc
         }
     }
 
+    DCHECK(!(_is_global xor need_local_merge))
+            << " _is_global: " << _is_global << " need_local_merge: " << 
need_local_merge;
     if (!has_exist) {
         std::shared_ptr<IRuntimeFilter> filter;
         RETURN_IF_ERROR(IRuntimeFilter::create(_state, &desc, &options, 
RuntimeFilterRole::CONSUMER,
@@ -106,6 +123,7 @@ Status RuntimeFilterMgr::register_consumer_filter(const 
TRuntimeFilterDesc& desc
 Status RuntimeFilterMgr::register_local_merge_producer_filter(
         const doris::TRuntimeFilterDesc& desc, const doris::TQueryOptions& 
options,
         std::shared_ptr<IRuntimeFilter>* producer_filter, bool 
build_bf_exactly) {
+    DCHECK(_is_global);
     SCOPED_CONSUME_MEM_TRACKER(_tracker.get());
     int32_t key = desc.filter_id;
 
@@ -141,6 +159,7 @@ Status 
RuntimeFilterMgr::register_local_merge_producer_filter(
 
 Status RuntimeFilterMgr::get_local_merge_producer_filters(
         int filter_id, doris::LocalMergeFilters** local_merge_filters) {
+    DCHECK(_is_global);
     std::lock_guard<std::mutex> l(_lock);
     auto iter = _local_merge_producer_map.find(filter_id);
     if (iter == _local_merge_producer_map.end()) {
@@ -158,6 +177,7 @@ Status RuntimeFilterMgr::register_producer_filter(const 
TRuntimeFilterDesc& desc
                                                   const TQueryOptions& options,
                                                   
std::shared_ptr<IRuntimeFilter>* producer_filter,
                                                   bool build_bf_exactly) {
+    DCHECK(!_is_global);
     SCOPED_CONSUME_MEM_TRACKER(_tracker.get());
     int32_t key = desc.filter_id;
     std::lock_guard<std::mutex> l(_lock);
@@ -341,9 +361,10 @@ Status 
RuntimeFilterMergeControllerEntity::send_filter_size(const PSendFilterSiz
                                   
DummyBrpcCallback<PSyncFilterSizeResponse>::create_shared());
 
             auto* pquery_id = closure->request_->mutable_query_id();
-            pquery_id->set_hi(_state->query_id.hi());
-            pquery_id->set_lo(_state->query_id.lo());
-            closure->cntl_->set_timeout_ms(std::min(3600, 
_state->execution_timeout) * 1000);
+            pquery_id->set_hi(_state->get_query_ctx()->query_id().hi);
+            pquery_id->set_lo(_state->get_query_ctx()->query_id().lo);
+            closure->cntl_->set_timeout_ms(
+                    std::min(3600, 
_state->get_query_ctx()->execution_timeout()) * 1000);
 
             closure->request_->set_filter_id(filter_id);
             closure->request_->set_filter_size(cnt_val->global_size);
@@ -455,7 +476,8 @@ Status RuntimeFilterMergeControllerEntity::merge(const 
PMergeFilterRequest* requ
             if (has_attachment) {
                 
closure->cntl_->request_attachment().append(request_attachment);
             }
-            closure->cntl_->set_timeout_ms(std::min(3600, 
_state->execution_timeout) * 1000);
+            closure->cntl_->set_timeout_ms(
+                    std::min(3600, 
_state->get_query_ctx()->execution_timeout()) * 1000);
             // set fragment-id
             if (target.__isset.target_fragment_ids) {
                 for (auto& target_fragment_id : target.target_fragment_ids) {
@@ -514,31 +536,22 @@ void RuntimeFilterMergeController::remove_entity(UniqueId 
query_id) {
 RuntimeFilterParamsContext* RuntimeFilterParamsContext::create(RuntimeState* 
state) {
     RuntimeFilterParamsContext* params =
             state->get_query_ctx()->obj_pool.add(new 
RuntimeFilterParamsContext());
-    params->runtime_filter_wait_infinitely = 
state->runtime_filter_wait_infinitely();
-    params->runtime_filter_wait_time_ms = state->runtime_filter_wait_time_ms();
-    params->execution_timeout = state->execution_timeout();
-    params->runtime_filter_mgr = state->local_runtime_filter_mgr();
-    params->exec_env = state->exec_env();
-    params->query_id.set_hi(state->query_id().hi);
-    params->query_id.set_lo(state->query_id().lo);
-
-    params->be_exec_version = state->be_exec_version();
-    params->query_ctx = state->get_query_ctx();
+    params->_query_ctx = state->get_query_ctx();
+    params->_state = state;
     return params;
 }
 
+RuntimeFilterMgr* RuntimeFilterParamsContext::global_runtime_filter_mgr() {
+    return _query_ctx->runtime_filter_mgr();
+}
+
+RuntimeFilterMgr* RuntimeFilterParamsContext::local_runtime_filter_mgr() {
+    return _state->local_runtime_filter_mgr();
+}
+
 RuntimeFilterParamsContext* RuntimeFilterParamsContext::create(QueryContext* 
query_ctx) {
     RuntimeFilterParamsContext* params = query_ctx->obj_pool.add(new 
RuntimeFilterParamsContext());
-    params->runtime_filter_wait_infinitely = 
query_ctx->runtime_filter_wait_infinitely();
-    params->runtime_filter_wait_time_ms = 
query_ctx->runtime_filter_wait_time_ms();
-    params->execution_timeout = query_ctx->execution_timeout();
-    params->runtime_filter_mgr = query_ctx->runtime_filter_mgr();
-    params->exec_env = query_ctx->exec_env();
-    params->query_id.set_hi(query_ctx->query_id().hi);
-    params->query_id.set_lo(query_ctx->query_id().lo);
-
-    params->be_exec_version = query_ctx->be_exec_version();
-    params->query_ctx = query_ctx;
+    params->_query_ctx = query_ctx;
     return params;
 }
 
diff --git a/be/src/runtime/runtime_filter_mgr.h 
b/be/src/runtime/runtime_filter_mgr.h
index b0aea7568cf..53520e43a55 100644
--- a/be/src/runtime/runtime_filter_mgr.h
+++ b/be/src/runtime/runtime_filter_mgr.h
@@ -77,12 +77,14 @@ struct LocalMergeFilters {
 class RuntimeFilterMgr {
 public:
     RuntimeFilterMgr(const UniqueId& query_id, RuntimeFilterParamsContext* 
state,
-                     const std::shared_ptr<MemTrackerLimiter>& 
query_mem_tracker);
+                     const std::shared_ptr<MemTrackerLimiter>& 
query_mem_tracker,
+                     const bool is_global);
 
     ~RuntimeFilterMgr();
 
     Status get_consume_filters(const int filter_id,
                                std::vector<std::shared_ptr<IRuntimeFilter>>& 
consumer_filters);
+    std::vector<std::shared_ptr<IRuntimeFilter>> get_consume_filters(const int 
filter_id);
 
     std::shared_ptr<IRuntimeFilter> try_get_product_filter(const int 
filter_id) {
         std::lock_guard<std::mutex> l(_lock);
@@ -124,6 +126,18 @@ private:
         int node_id;
         std::shared_ptr<IRuntimeFilter> filter;
     };
+    /**
+     * `_is_global = true` means this runtime filter manager menages 
query-level runtime filters.
+     * If so, all consumers in this query shared the same RF with the same ID. 
For producers, all
+     * RFs produced should be merged.
+     *
+     * If `_is_global` is false, a RF will be produced and consumed by a 
single-producer-single-consumer mode.
+     * This is usually happened in a co-located join and scan operators are 
not serial.
+     *
+     * `_local_merge_producer_map` is used only if `_is_global` is true. It is 
said, RFs produced by
+     * different producers need to be merged only if it is a global RF.
+     */
+    const bool _is_global;
     // RuntimeFilterMgr is owned by RuntimeState, so we only
     // use filter_id as key
     // key: "filter-id"
@@ -267,24 +281,22 @@ private:
     FilterControllerMap _filter_controller_map[kShardNum];
 };
 
-//There are two types of runtime filters:
-// one is global, originating from QueryContext,
-// and the other is local, originating from RuntimeState.
-// In practice, we have already distinguished between them through 
UpdateRuntimeFilterParamsV2/V1.
-// RuntimeState/QueryContext is only used to store 
runtime_filter_wait_time_ms...
+// There are two types of runtime filters:
+// 1. Global runtime filter. Managed by QueryContext's RuntimeFilterMgr which 
is produced by multiple producers and shared by multiple consumers.
+// 2. Local runtime filter. Managed by RuntimeState's RuntimeFilterMgr which 
is 1-producer-1-consumer mode.
 struct RuntimeFilterParamsContext {
-    RuntimeFilterParamsContext() = default;
     static RuntimeFilterParamsContext* create(RuntimeState* state);
     static RuntimeFilterParamsContext* create(QueryContext* query_ctx);
 
-    bool runtime_filter_wait_infinitely;
-    int32_t runtime_filter_wait_time_ms;
-    int32_t execution_timeout;
-    RuntimeFilterMgr* runtime_filter_mgr;
-    ExecEnv* exec_env;
-    PUniqueId query_id;
-    int be_exec_version;
-    QueryContext* query_ctx;
-    QueryContext* get_query_ctx() const { return query_ctx; }
+    QueryContext* get_query_ctx() const { return _query_ctx; }
+    void set_state(RuntimeState* state) { _state = state; }
+    RuntimeFilterMgr* global_runtime_filter_mgr();
+    RuntimeFilterMgr* local_runtime_filter_mgr();
+
+private:
+    RuntimeFilterParamsContext() = default;
+
+    QueryContext* _query_ctx;
+    RuntimeState* _state;
 };
 } // namespace doris
diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp
index d4e3cba36cd..34b3866febf 100644
--- a/be/src/runtime/runtime_state.cpp
+++ b/be/src/runtime/runtime_state.cpp
@@ -516,15 +516,12 @@ RuntimeFilterMgr* 
RuntimeState::global_runtime_filter_mgr() {
 }
 
 Status RuntimeState::register_producer_runtime_filter(
-        const TRuntimeFilterDesc& desc, bool need_local_merge,
-        std::shared_ptr<IRuntimeFilter>* producer_filter, bool 
build_bf_exactly) {
-    if (desc.has_remote_targets || need_local_merge) {
-        return 
global_runtime_filter_mgr()->register_local_merge_producer_filter(
-                desc, query_options(), producer_filter, build_bf_exactly);
-    } else {
-        return local_runtime_filter_mgr()->register_producer_filter(
-                desc, query_options(), producer_filter, build_bf_exactly);
-    }
+        const TRuntimeFilterDesc& desc, std::shared_ptr<IRuntimeFilter>* 
producer_filter,
+        bool build_bf_exactly) {
+    
RETURN_IF_ERROR(global_runtime_filter_mgr()->register_local_merge_producer_filter(
+            desc, query_options(), producer_filter, build_bf_exactly));
+    return local_runtime_filter_mgr()->register_producer_filter(desc, 
query_options(),
+                                                                
producer_filter, build_bf_exactly);
 }
 
 Status RuntimeState::register_consumer_runtime_filter(
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index b44aba5e731..782008ec075 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -554,7 +554,6 @@ public:
     }
 
     Status register_producer_runtime_filter(const doris::TRuntimeFilterDesc& 
desc,
-                                            bool need_local_merge,
                                             std::shared_ptr<IRuntimeFilter>* 
producer_filter,
                                             bool build_bf_exactly);
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to