This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bc011f33a7 [feature](exchange) enable shared exchange sink buffer to 
reduce RPC concurrency. (#44850)
2bc011f33a7 is described below

commit 2bc011f33a7c350832c6fc266d0dfbd81a3a64b6
Author: Mryange <yanxuech...@selectdb.com>
AuthorDate: Mon Dec 9 12:40:18 2024 +0800

    [feature](exchange) enable shared exchange sink buffer to reduce RPC 
concurrency. (#44850)
    
    ### What problem does this PR solve?
    
    In the past, each exchange sink had its own sink buffer.
    If the query concurrency is n, there would be n * n RPCs running
    concurrently
    in a typical shuffle scenario (each sender instance can send data to all
    downstream instances).
    Here, we introduce support for shared sink buffers.
    This does not reduce the total number of RPCs but can limit the number
    of concurrent RPCs.
---
 be/src/pipeline/exec/exchange_sink_buffer.cpp      | 144 ++++++++-------
 be/src/pipeline/exec/exchange_sink_buffer.h        | 115 +++++++++---
 be/src/pipeline/exec/exchange_sink_operator.cpp    |  85 +++++++--
 be/src/pipeline/exec/exchange_sink_operator.h      |  34 +++-
 be/src/pipeline/pipeline_fragment_context.cpp      |  24 ++-
 be/src/runtime/runtime_state.h                     |   5 +
 be/src/vec/sink/vdata_stream_sender.h              |  10 +-
 be/test/vec/exec/exchange_sink_test.cpp            | 196 +++++++++++++++++++++
 be/test/vec/exec/exchange_sink_test.h              | 180 +++++++++++++++++++
 .../java/org/apache/doris/qe/SessionVariable.java  |   7 +
 gensrc/thrift/PaloInternalService.thrift           |   2 +-
 11 files changed, 683 insertions(+), 119 deletions(-)

diff --git a/be/src/pipeline/exec/exchange_sink_buffer.cpp 
b/be/src/pipeline/exec/exchange_sink_buffer.cpp
index 6e6108d13a9..65e76987370 100644
--- a/be/src/pipeline/exec/exchange_sink_buffer.cpp
+++ b/be/src/pipeline/exec/exchange_sink_buffer.cpp
@@ -87,19 +87,22 @@ void BroadcastPBlockHolderMemLimiter::release(const 
BroadcastPBlockHolder& holde
 } // namespace vectorized
 
 namespace pipeline {
-ExchangeSinkBuffer::ExchangeSinkBuffer(PUniqueId query_id, PlanNodeId 
dest_node_id, int send_id,
-                                       int be_number, RuntimeState* state,
-                                       ExchangeSinkLocalState* parent)
+ExchangeSinkBuffer::ExchangeSinkBuffer(PUniqueId query_id, PlanNodeId 
dest_node_id,
+                                       RuntimeState* state,
+                                       const std::vector<InstanceLoId>& 
sender_ins_ids)
         : HasTaskExecutionCtx(state),
           _queue_capacity(0),
-          _is_finishing(false),
+          _is_failed(false),
           _query_id(std::move(query_id)),
           _dest_node_id(dest_node_id),
-          _sender_id(send_id),
-          _be_number(be_number),
           _state(state),
           _context(state->get_query_ctx()),
-          _parent(parent) {}
+          _exchange_sink_num(sender_ins_ids.size()) {
+    for (auto sender_ins_id : sender_ins_ids) {
+        _queue_deps.emplace(sender_ins_id, nullptr);
+        _parents.emplace(sender_ins_id, nullptr);
+    }
+}
 
 void ExchangeSinkBuffer::close() {
     // Could not clear the queue here, because there maybe a running rpc want 
to
@@ -110,8 +113,8 @@ void ExchangeSinkBuffer::close() {
     //_instance_to_request.clear();
 }
 
-void ExchangeSinkBuffer::register_sink(TUniqueId fragment_instance_id) {
-    if (_is_finishing) {
+void ExchangeSinkBuffer::construct_request(TUniqueId fragment_instance_id) {
+    if (_is_failed) {
         return;
     }
     auto low_id = fragment_instance_id.lo;
@@ -129,22 +132,27 @@ void ExchangeSinkBuffer::register_sink(TUniqueId 
fragment_instance_id) {
     finst_id.set_hi(fragment_instance_id.hi);
     finst_id.set_lo(fragment_instance_id.lo);
     _rpc_channel_is_idle[low_id] = true;
-    _instance_to_receiver_eof[low_id] = false;
+    _rpc_channel_is_turn_off[low_id] = false;
     
_instance_to_rpc_stats_vec.emplace_back(std::make_shared<RpcInstanceStatistics>(low_id));
     _instance_to_rpc_stats[low_id] = _instance_to_rpc_stats_vec.back().get();
-    _construct_request(low_id, finst_id);
+    _instance_to_request[low_id] = std::make_shared<PTransmitDataParams>();
+    _instance_to_request[low_id]->mutable_finst_id()->CopyFrom(finst_id);
+    _instance_to_request[low_id]->mutable_query_id()->CopyFrom(_query_id);
+
+    _instance_to_request[low_id]->set_node_id(_dest_node_id);
+    _running_sink_count[low_id] = _exchange_sink_num;
 }
 
 Status ExchangeSinkBuffer::add_block(TransmitInfo&& request) {
-    if (_is_finishing) {
+    if (_is_failed) {
         return Status::OK();
     }
-    auto ins_id = request.channel->_fragment_instance_id.lo;
+    auto ins_id = request.channel->dest_ins_id();
     if (!_instance_to_package_queue_mutex.contains(ins_id)) {
         return Status::InternalError("fragment_instance_id {} not do 
register_sink",
                                      
print_id(request.channel->_fragment_instance_id));
     }
-    if (_is_receiver_eof(ins_id)) {
+    if (_rpc_channel_is_turn_off[ins_id]) {
         return Status::EndOfFile("receiver eof");
     }
     bool send_now = false;
@@ -158,12 +166,15 @@ Status ExchangeSinkBuffer::add_block(TransmitInfo&& 
request) {
         if (request.block) {
             RETURN_IF_ERROR(
                     
BeExecVersionManager::check_be_exec_version(request.block->be_exec_version()));
-            COUNTER_UPDATE(_parent->memory_used_counter(), 
request.block->ByteSizeLong());
+            COUNTER_UPDATE(request.channel->_parent->memory_used_counter(),
+                           request.block->ByteSizeLong());
         }
         _instance_to_package_queue[ins_id].emplace(std::move(request));
         _total_queue_size++;
-        if (_queue_dependency && _total_queue_size > _queue_capacity) {
-            _queue_dependency->block();
+        if (_total_queue_size > _queue_capacity) {
+            for (auto& [_, dep] : _queue_deps) {
+                dep->block();
+            }
         }
     }
     if (send_now) {
@@ -174,15 +185,15 @@ Status ExchangeSinkBuffer::add_block(TransmitInfo&& 
request) {
 }
 
 Status ExchangeSinkBuffer::add_block(BroadcastTransmitInfo&& request) {
-    if (_is_finishing) {
+    if (_is_failed) {
         return Status::OK();
     }
-    auto ins_id = request.channel->_fragment_instance_id.lo;
+    auto ins_id = request.channel->dest_ins_id();
     if (!_instance_to_package_queue_mutex.contains(ins_id)) {
         return Status::InternalError("fragment_instance_id {} not do 
register_sink",
                                      
print_id(request.channel->_fragment_instance_id));
     }
-    if (_is_receiver_eof(ins_id)) {
+    if (_rpc_channel_is_turn_off[ins_id]) {
         return Status::EndOfFile("receiver eof");
     }
     bool send_now = false;
@@ -209,16 +220,17 @@ Status 
ExchangeSinkBuffer::add_block(BroadcastTransmitInfo&& request) {
 Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
     std::unique_lock<std::mutex> lock(*_instance_to_package_queue_mutex[id]);
 
-    DCHECK(_rpc_channel_is_idle[id] == false);
-
     std::queue<TransmitInfo, std::list<TransmitInfo>>& q = 
_instance_to_package_queue[id];
     std::queue<BroadcastTransmitInfo, std::list<BroadcastTransmitInfo>>& 
broadcast_q =
             _instance_to_broadcast_package_queue[id];
 
-    if (_is_finishing) {
+    if (_is_failed) {
         _turn_off_channel(id, lock);
         return Status::OK();
     }
+    if (_rpc_channel_is_turn_off[id]) {
+        return Status::OK();
+    }
 
     if (!q.empty()) {
         // If we have data to shuffle which is not broadcasted
@@ -226,6 +238,8 @@ Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
         auto& brpc_request = _instance_to_request[id];
         brpc_request->set_eos(request.eos);
         brpc_request->set_packet_seq(_instance_to_seq[id]++);
+        brpc_request->set_sender_id(request.channel->_parent->sender_id());
+        brpc_request->set_be_number(request.channel->_parent->be_number());
         if (request.block && !request.block->column_metas().empty()) {
             brpc_request->set_allocated_block(request.block.get());
         }
@@ -271,14 +285,16 @@ Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
             } else if (!s.ok()) {
                 _failed(id,
                         fmt::format("exchange req success but status isn't ok: 
{}", s.to_string()));
+                return;
             } else if (eos) {
                 _ended(id);
-            } else {
-                s = _send_rpc(id);
-                if (!s) {
-                    _failed(id, fmt::format("exchange req success but status 
isn't ok: {}",
-                                            s.to_string()));
-                }
+            }
+            // The eos here only indicates that the current exchange sink has 
reached eos.
+            // However, the queue still contains data from other exchange 
sinks, so RPCs need to continue being sent.
+            s = _send_rpc(id);
+            if (!s) {
+                _failed(id,
+                        fmt::format("exchange req success but status isn't ok: 
{}", s.to_string()));
             }
         });
         {
@@ -296,13 +312,16 @@ Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
             }
         }
         if (request.block) {
-            COUNTER_UPDATE(_parent->memory_used_counter(), 
-request.block->ByteSizeLong());
+            COUNTER_UPDATE(request.channel->_parent->memory_used_counter(),
+                           -request.block->ByteSizeLong());
             static_cast<void>(brpc_request->release_block());
         }
         q.pop();
         _total_queue_size--;
-        if (_queue_dependency && _total_queue_size <= _queue_capacity) {
-            _queue_dependency->set_ready();
+        if (_total_queue_size <= _queue_capacity) {
+            for (auto& [_, dep] : _queue_deps) {
+                dep->set_ready();
+            }
         }
     } else if (!broadcast_q.empty()) {
         // If we have data to shuffle which is broadcasted
@@ -310,6 +329,8 @@ Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
         auto& brpc_request = _instance_to_request[id];
         brpc_request->set_eos(request.eos);
         brpc_request->set_packet_seq(_instance_to_seq[id]++);
+        brpc_request->set_sender_id(request.channel->_parent->sender_id());
+        brpc_request->set_be_number(request.channel->_parent->be_number());
         if (request.block_holder->get_block() &&
             !request.block_holder->get_block()->column_metas().empty()) {
             
brpc_request->set_allocated_block(request.block_holder->get_block());
@@ -352,14 +373,17 @@ Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
             } else if (!s.ok()) {
                 _failed(id,
                         fmt::format("exchange req success but status isn't ok: 
{}", s.to_string()));
+                return;
             } else if (eos) {
                 _ended(id);
-            } else {
-                s = _send_rpc(id);
-                if (!s) {
-                    _failed(id, fmt::format("exchange req success but status 
isn't ok: {}",
-                                            s.to_string()));
-                }
+            }
+
+            // The eos here only indicates that the current exchange sink has 
reached eos.
+            // However, the queue still contains data from other exchange 
sinks, so RPCs need to continue being sent.
+            s = _send_rpc(id);
+            if (!s) {
+                _failed(id,
+                        fmt::format("exchange req success but status isn't ok: 
{}", s.to_string()));
             }
         });
         {
@@ -387,16 +411,6 @@ Status ExchangeSinkBuffer::_send_rpc(InstanceLoId id) {
     return Status::OK();
 }
 
-void ExchangeSinkBuffer::_construct_request(InstanceLoId id, PUniqueId 
finst_id) {
-    _instance_to_request[id] = std::make_shared<PTransmitDataParams>();
-    _instance_to_request[id]->mutable_finst_id()->CopyFrom(finst_id);
-    _instance_to_request[id]->mutable_query_id()->CopyFrom(_query_id);
-
-    _instance_to_request[id]->set_node_id(_dest_node_id);
-    _instance_to_request[id]->set_sender_id(_sender_id);
-    _instance_to_request[id]->set_be_number(_be_number);
-}
-
 void ExchangeSinkBuffer::_ended(InstanceLoId id) {
     if (!_instance_to_package_queue_mutex.template contains(id)) {
         std::stringstream ss;
@@ -411,24 +425,29 @@ void ExchangeSinkBuffer::_ended(InstanceLoId id) {
         __builtin_unreachable();
     } else {
         std::unique_lock<std::mutex> 
lock(*_instance_to_package_queue_mutex[id]);
-        _turn_off_channel(id, lock);
+        _running_sink_count[id]--;
+        if (_running_sink_count[id] == 0) {
+            _turn_off_channel(id, lock);
+        }
     }
 }
 
 void ExchangeSinkBuffer::_failed(InstanceLoId id, const std::string& err) {
-    _is_finishing = true;
+    _is_failed = true;
     _context->cancel(Status::Cancelled(err));
 }
 
 void ExchangeSinkBuffer::_set_receiver_eof(InstanceLoId id) {
     std::unique_lock<std::mutex> lock(*_instance_to_package_queue_mutex[id]);
-    _instance_to_receiver_eof[id] = true;
+    // When the receiving side reaches eof, it means the receiver has finished 
early.
+    // The remaining data in the current rpc_channel does not need to be sent,
+    // and the rpc_channel should be turned off immediately.
     _turn_off_channel(id, lock);
     std::queue<BroadcastTransmitInfo, std::list<BroadcastTransmitInfo>>& 
broadcast_q =
             _instance_to_broadcast_package_queue[id];
     for (; !broadcast_q.empty(); broadcast_q.pop()) {
         if (broadcast_q.front().block_holder->get_block()) {
-            COUNTER_UPDATE(_parent->memory_used_counter(),
+            
COUNTER_UPDATE(broadcast_q.front().channel->_parent->memory_used_counter(),
                            
-broadcast_q.front().block_holder->get_block()->ByteSizeLong());
         }
     }
@@ -440,7 +459,8 @@ void ExchangeSinkBuffer::_set_receiver_eof(InstanceLoId id) 
{
     std::queue<TransmitInfo, std::list<TransmitInfo>>& q = 
_instance_to_package_queue[id];
     for (; !q.empty(); q.pop()) {
         if (q.front().block) {
-            COUNTER_UPDATE(_parent->memory_used_counter(), 
-q.front().block->ByteSizeLong());
+            COUNTER_UPDATE(q.front().channel->_parent->memory_used_counter(),
+                           -q.front().block->ByteSizeLong());
         }
     }
 
@@ -450,22 +470,22 @@ void ExchangeSinkBuffer::_set_receiver_eof(InstanceLoId 
id) {
     }
 }
 
-bool ExchangeSinkBuffer::_is_receiver_eof(InstanceLoId id) {
-    std::unique_lock<std::mutex> lock(*_instance_to_package_queue_mutex[id]);
-    return _instance_to_receiver_eof[id];
-}
-
 // The unused parameter `with_lock` is to ensure that the function is called 
when the lock is held.
 void ExchangeSinkBuffer::_turn_off_channel(InstanceLoId id,
                                            std::unique_lock<std::mutex>& 
/*with_lock*/) {
     if (!_rpc_channel_is_idle[id]) {
         _rpc_channel_is_idle[id] = true;
     }
-    _instance_to_receiver_eof[id] = true;
-
+    // Ensure that each RPC is turned off only once.
+    if (_rpc_channel_is_turn_off[id]) {
+        return;
+    }
+    _rpc_channel_is_turn_off[id] = true;
     auto weak_task_ctx = weak_task_exec_ctx();
     if (auto pip_ctx = weak_task_ctx.lock()) {
-        _parent->on_channel_finished(id);
+        for (auto& [_, parent] : _parents) {
+            parent->on_channel_finished(id);
+        }
     }
 }
 
@@ -509,7 +529,7 @@ void ExchangeSinkBuffer::update_profile(RuntimeProfile* 
profile) {
     auto* _max_rpc_timer = ADD_TIMER_WITH_LEVEL(profile, "RpcMaxTime", 1);
     auto* _min_rpc_timer = ADD_TIMER(profile, "RpcMinTime");
     auto* _sum_rpc_timer = ADD_TIMER(profile, "RpcSumTime");
-    auto* _count_rpc = ADD_COUNTER_WITH_LEVEL(profile, "RpcCount", 
TUnit::UNIT, 1);
+    auto* _count_rpc = ADD_COUNTER(profile, "RpcCount", TUnit::UNIT);
     auto* _avg_rpc_timer = ADD_TIMER(profile, "RpcAvgTime");
 
     int64_t max_rpc_time = 0, min_rpc_time = 0;
diff --git a/be/src/pipeline/exec/exchange_sink_buffer.h 
b/be/src/pipeline/exec/exchange_sink_buffer.h
index 22a1452f8d5..b2eb32414fe 100644
--- a/be/src/pipeline/exec/exchange_sink_buffer.h
+++ b/be/src/pipeline/exec/exchange_sink_buffer.h
@@ -169,13 +169,61 @@ private:
     bool _eos;
 };
 
-// Each ExchangeSinkOperator have one ExchangeSinkBuffer
-class ExchangeSinkBuffer final : public HasTaskExecutionCtx {
+// ExchangeSinkBuffer can either be shared among multiple 
ExchangeSinkLocalState instances
+// or be individually owned by each ExchangeSinkLocalState.
+// The following describes the scenario where ExchangeSinkBuffer is shared 
among multiple ExchangeSinkLocalState instances.
+// Of course, individual ownership can be seen as a special case where only 
one ExchangeSinkLocalState shares the buffer.
+
+// A sink buffer contains multiple rpc_channels.
+// Each rpc_channel corresponds to a target instance on the receiving side.
+// Data is sent using a ping-pong mode within each rpc_channel,
+// meaning that at most one RPC can exist in a single rpc_channel at a time.
+// The next RPC can only be sent after the previous one has completed.
+//
+// Each exchange sink sends data to all target instances on the receiving side.
+// If the concurrency is 3, a single rpc_channel will be used simultaneously 
by three exchange sinks.
+
+/*                                                                             
                                                                                
                                                                                
                                                                             
+                          +-----------+          +-----------+        
+-----------+      
+                          |dest ins id|          |dest ins id|        |dest 
ins id|      
+                          |           |          |           |        |        
   |      
+                          +----+------+          +-----+-----+        
+------+----+      
+                               |                       |                     | 
          
+                               |                       |                     | 
          
+                      +----------------+      +----------------+     
+----------------+  
+                      |                |      |                |     |         
       |  
+ sink buffer -------- |   rpc_channel  |      |  rpc_channel   |     |  
rpc_channel   |  
+                      |                |      |                |     |         
       |  
+                      +-------+--------+      +----------------+     
+----------------+  
+                              |                        |                      
|          
+                              
|------------------------+----------------------+          
+                              |                        |                      
|          
+                              |                        |                      
|          
+                     +-----------------+       +-------+---------+    
+-------+---------+
+                     |                 |       |                 |    |        
         |
+                     |  exchange sink  |       |  exchange sink  |    |  
exchange sink  |
+                     |                 |       |                 |    |        
         |
+                     +-----------------+       +-----------------+    
+-----------------+
+*/
+
+#ifdef BE_TEST
+void transmit_blockv2(PBackendService_Stub& stub,
+                      std::unique_ptr<AutoReleaseClosure<PTransmitDataParams,
+                                                         
ExchangeSendCallback<PTransmitDataResult>>>
+                              closure);
+#endif
+class ExchangeSinkBuffer : public HasTaskExecutionCtx {
 public:
-    ExchangeSinkBuffer(PUniqueId query_id, PlanNodeId dest_node_id, int 
send_id, int be_number,
-                       RuntimeState* state, ExchangeSinkLocalState* parent);
+    ExchangeSinkBuffer(PUniqueId query_id, PlanNodeId dest_node_id, 
RuntimeState* state,
+                       const std::vector<InstanceLoId>& sender_ins_ids);
+
+#ifdef BE_TEST
+    ExchangeSinkBuffer(RuntimeState* state, int64_t sinknum)
+            : HasTaskExecutionCtx(state), _exchange_sink_num(sinknum) {};
+#endif
     ~ExchangeSinkBuffer() override = default;
-    void register_sink(TUniqueId);
+
+    void construct_request(TUniqueId);
 
     Status add_block(TransmitInfo&& request);
     Status add_block(BroadcastTransmitInfo&& request);
@@ -183,17 +231,18 @@ public:
     void update_rpc_time(InstanceLoId id, int64_t start_rpc_time, int64_t 
receive_rpc_time);
     void update_profile(RuntimeProfile* profile);
 
-    void set_dependency(std::shared_ptr<Dependency> queue_dependency,
-                        std::shared_ptr<Dependency> finish_dependency) {
-        _queue_dependency = queue_dependency;
-        _finish_dependency = finish_dependency;
-    }
-
-    void set_broadcast_dependency(std::shared_ptr<Dependency> 
broadcast_dependency) {
-        _broadcast_dependency = broadcast_dependency;
+    void set_dependency(InstanceLoId sender_ins_id, 
std::shared_ptr<Dependency> queue_dependency,
+                        ExchangeSinkLocalState* local_state) {
+        DCHECK(_queue_deps.contains(sender_ins_id));
+        DCHECK(_parents.contains(sender_ins_id));
+        _queue_deps[sender_ins_id] = queue_dependency;
+        _parents[sender_ins_id] = local_state;
     }
-
+#ifdef BE_TEST
+public:
+#else
 private:
+#endif
     friend class ExchangeSinkLocalState;
 
     phmap::flat_hash_map<InstanceLoId, std::unique_ptr<std::mutex>>
@@ -214,7 +263,10 @@ private:
     // One channel is corresponding to a downstream instance.
     phmap::flat_hash_map<InstanceLoId, bool> _rpc_channel_is_idle;
 
-    phmap::flat_hash_map<InstanceLoId, bool> _instance_to_receiver_eof;
+    // There could be multiple situations that cause an rpc_channel to be 
turned off,
+    // such as receiving the eof, manual cancellation by the user, or all 
sinks reaching eos.
+    // Therefore, it is necessary to prevent an rpc_channel from being turned 
off multiple times.
+    phmap::flat_hash_map<InstanceLoId, bool> _rpc_channel_is_turn_off;
     struct RpcInstanceStatistics {
         RpcInstanceStatistics(InstanceLoId id) : inst_lo_id(id) {}
         InstanceLoId inst_lo_id;
@@ -226,32 +278,43 @@ private:
     std::vector<std::shared_ptr<RpcInstanceStatistics>> 
_instance_to_rpc_stats_vec;
     phmap::flat_hash_map<InstanceLoId, RpcInstanceStatistics*> 
_instance_to_rpc_stats;
 
-    std::atomic<bool> _is_finishing;
+    // It is set to true only when an RPC fails. Currently, we do not have an 
error retry mechanism.
+    // If an RPC error occurs, the query will be canceled.
+    std::atomic<bool> _is_failed;
     PUniqueId _query_id;
     PlanNodeId _dest_node_id;
-    // Sender instance id, unique within a fragment. StreamSender save the 
variable
-    int _sender_id;
-    int _be_number;
     std::atomic<int64_t> _rpc_count = 0;
     RuntimeState* _state = nullptr;
     QueryContext* _context = nullptr;
 
     Status _send_rpc(InstanceLoId);
-    // must hold the _instance_to_package_queue_mutex[id] mutex to opera
-    void _construct_request(InstanceLoId id, PUniqueId);
+
+#ifndef BE_TEST
     inline void _ended(InstanceLoId id);
     inline void _failed(InstanceLoId id, const std::string& err);
     inline void _set_receiver_eof(InstanceLoId id);
-    inline bool _is_receiver_eof(InstanceLoId id);
     inline void _turn_off_channel(InstanceLoId id, 
std::unique_lock<std::mutex>& with_lock);
+
+#else
+    virtual void _ended(InstanceLoId id);
+    virtual void _failed(InstanceLoId id, const std::string& err);
+    virtual void _set_receiver_eof(InstanceLoId id);
+    virtual void _turn_off_channel(InstanceLoId id, 
std::unique_lock<std::mutex>& with_lock);
+#endif
+
     void get_max_min_rpc_time(int64_t* max_time, int64_t* min_time);
     int64_t get_sum_rpc_time();
 
     std::atomic<int> _total_queue_size = 0;
-    std::shared_ptr<Dependency> _queue_dependency = nullptr;
-    std::shared_ptr<Dependency> _finish_dependency = nullptr;
-    std::shared_ptr<Dependency> _broadcast_dependency = nullptr;
-    ExchangeSinkLocalState* _parent = nullptr;
+
+    // _running_sink_count is used to track how many sinks have not finished 
yet.
+    // It is only decremented when eos is reached.
+    phmap::flat_hash_map<InstanceLoId, int64_t> _running_sink_count;
+    // _queue_deps is used for memory control.
+    phmap::flat_hash_map<InstanceLoId, std::shared_ptr<Dependency>> 
_queue_deps;
+    // The ExchangeSinkLocalState in _parents is only used in 
_turn_off_channel.
+    phmap::flat_hash_map<InstanceLoId, ExchangeSinkLocalState*> _parents;
+    const int64_t _exchange_sink_num;
 };
 
 } // namespace pipeline
diff --git a/be/src/pipeline/exec/exchange_sink_operator.cpp 
b/be/src/pipeline/exec/exchange_sink_operator.cpp
index dfa6df392b7..04b9653e9c8 100644
--- a/be/src/pipeline/exec/exchange_sink_operator.cpp
+++ b/be/src/pipeline/exec/exchange_sink_operator.cpp
@@ -32,6 +32,8 @@
 #include "pipeline/exec/operator.h"
 #include "pipeline/exec/sort_source_operator.h"
 #include "pipeline/local_exchange/local_exchange_sink_operator.h"
+#include "pipeline/local_exchange/local_exchange_source_operator.h"
+#include "pipeline/pipeline_fragment_context.h"
 #include "util/runtime_profile.h"
 #include "util/uid_util.h"
 #include "vec/columns/column_const.h"
@@ -100,6 +102,24 @@ Status ExchangeSinkLocalState::init(RuntimeState* state, 
LocalSinkStateInfo& inf
                 fmt::format("WaitForLocalExchangeBuffer{}", i), TUnit 
::TIME_NS, timer_name, 1));
     }
     _wait_broadcast_buffer_timer = ADD_CHILD_TIMER(_profile, 
"WaitForBroadcastBuffer", timer_name);
+
+    size_t local_size = 0;
+    for (int i = 0; i < channels.size(); ++i) {
+        if (channels[i]->is_local()) {
+            local_size++;
+            _last_local_channel_idx = i;
+        }
+    }
+    only_local_exchange = local_size == channels.size();
+
+    if (!only_local_exchange) {
+        _sink_buffer = p.get_sink_buffer(state->fragment_instance_id().lo);
+        register_channels(_sink_buffer.get());
+        _queue_dependency = Dependency::create_shared(_parent->operator_id(), 
_parent->node_id(),
+                                                      
"ExchangeSinkQueueDependency", true);
+        _sink_buffer->set_dependency(state->fragment_instance_id().lo, 
_queue_dependency, this);
+    }
+
     return Status::OK();
 }
 
@@ -149,20 +169,10 @@ Status ExchangeSinkLocalState::open(RuntimeState* state) {
     id.set_hi(_state->query_id().hi);
     id.set_lo(_state->query_id().lo);
 
-    if (!only_local_exchange) {
-        _sink_buffer = std::make_unique<ExchangeSinkBuffer>(id, 
p._dest_node_id, _sender_id,
-                                                            
_state->be_number(), state, this);
-        register_channels(_sink_buffer.get());
-        _queue_dependency = Dependency::create_shared(_parent->operator_id(), 
_parent->node_id(),
-                                                      
"ExchangeSinkQueueDependency", true);
-        _sink_buffer->set_dependency(_queue_dependency, _finish_dependency);
-    }
-
     if ((_part_type == TPartitionType::UNPARTITIONED || channels.size() == 1) 
&&
         !only_local_exchange) {
         _broadcast_dependency = Dependency::create_shared(
                 _parent->operator_id(), _parent->node_id(), 
"BroadcastDependency", true);
-        _sink_buffer->set_broadcast_dependency(_broadcast_dependency);
         _broadcast_pb_mem_limiter =
                 
vectorized::BroadcastPBlockHolderMemLimiter::create_shared(_broadcast_dependency);
     } else if (local_size > 0) {
@@ -301,7 +311,8 @@ segment_v2::CompressionTypePB 
ExchangeSinkLocalState::compression_type() const {
 
 ExchangeSinkOperatorX::ExchangeSinkOperatorX(
         RuntimeState* state, const RowDescriptor& row_desc, int operator_id,
-        const TDataStreamSink& sink, const 
std::vector<TPlanFragmentDestination>& destinations)
+        const TDataStreamSink& sink, const 
std::vector<TPlanFragmentDestination>& destinations,
+        const std::vector<TUniqueId>& fragment_instance_ids)
         : DataSinkOperatorX(operator_id, sink.dest_node_id),
           _texprs(sink.output_partition.partition_exprs),
           _row_desc(row_desc),
@@ -315,7 +326,8 @@ ExchangeSinkOperatorX::ExchangeSinkOperatorX(
           _tablet_sink_tuple_id(sink.tablet_sink_tuple_id),
           _tablet_sink_txn_id(sink.tablet_sink_txn_id),
           _t_tablet_sink_exprs(&sink.tablet_sink_exprs),
-          _enable_local_merge_sort(state->enable_local_merge_sort()) {
+          _enable_local_merge_sort(state->enable_local_merge_sort()),
+          _fragment_instance_ids(fragment_instance_ids) {
     DCHECK_GT(destinations.size(), 0);
     DCHECK(sink.output_partition.type == TPartitionType::UNPARTITIONED ||
            sink.output_partition.type == TPartitionType::HASH_PARTITIONED ||
@@ -360,6 +372,11 @@ Status ExchangeSinkOperatorX::open(RuntimeState* state) {
         }
         RETURN_IF_ERROR(vectorized::VExpr::open(_tablet_sink_expr_ctxs, 
state));
     }
+    std::vector<InstanceLoId> ins_ids;
+    for (auto fragment_instance_id : _fragment_instance_ids) {
+        ins_ids.push_back(fragment_instance_id.lo);
+    }
+    _sink_buffer = _create_buffer(ins_ids);
     return Status::OK();
 }
 
@@ -620,7 +637,7 @@ Status ExchangeSinkOperatorX::sink(RuntimeState* state, 
vectorized::Block* block
 
 void ExchangeSinkLocalState::register_channels(pipeline::ExchangeSinkBuffer* 
buffer) {
     for (auto& channel : channels) {
-        channel->register_exchange_buffer(buffer);
+        channel->set_exchange_buffer(buffer);
     }
 }
 
@@ -669,8 +686,8 @@ std::string ExchangeSinkLocalState::debug_string(int 
indentation_level) const {
         fmt::format_to(debug_string_buffer,
                        ", Sink Buffer: (_is_finishing = {}, blocks in queue: 
{}, queue capacity: "
                        "{}, queue dep: {}), _reach_limit: {}, working 
channels: {}",
-                       _sink_buffer->_is_finishing.load(), 
_sink_buffer->_total_queue_size,
-                       _sink_buffer->_queue_capacity, 
(void*)_sink_buffer->_queue_dependency.get(),
+                       _sink_buffer->_is_failed.load(), 
_sink_buffer->_total_queue_size,
+                       _sink_buffer->_queue_capacity, 
(void*)_queue_dependency.get(),
                        _reach_limit.load(), _working_channels_count.load());
     }
     return fmt::to_string(debug_string_buffer);
@@ -724,4 +741,42 @@ DataDistribution 
ExchangeSinkOperatorX::required_data_distribution() const {
     return 
DataSinkOperatorX<ExchangeSinkLocalState>::required_data_distribution();
 }
 
+std::shared_ptr<ExchangeSinkBuffer> ExchangeSinkOperatorX::_create_buffer(
+        const std::vector<InstanceLoId>& sender_ins_ids) {
+    PUniqueId id;
+    id.set_hi(_state->query_id().hi);
+    id.set_lo(_state->query_id().lo);
+    auto sink_buffer =
+            std::make_unique<ExchangeSinkBuffer>(id, _dest_node_id, state(), 
sender_ins_ids);
+    for (const auto& _dest : _dests) {
+        sink_buffer->construct_request(_dest.fragment_instance_id);
+    }
+    return sink_buffer;
+}
+
+// For a normal shuffle scenario, if the concurrency is n,
+// there can be up to n * n RPCs in the current fragment.
+// Therefore, a shared sink buffer is used here to limit the number of 
concurrent RPCs.
+// (Note: This does not reduce the total number of RPCs.)
+// In a merge sort scenario, there are only n RPCs, so a shared sink buffer is 
not needed.
+/// TODO: Modify this to let FE handle the judgment instead of BE.
+std::shared_ptr<ExchangeSinkBuffer> ExchangeSinkOperatorX::get_sink_buffer(
+        InstanceLoId sender_ins_id) {
+    if (!_child) {
+        throw doris::Exception(ErrorCode::INTERNAL_ERROR,
+                               "ExchangeSinkOperatorX did not correctly set 
the child.");
+    }
+    // When the child is SortSourceOperatorX or LocalExchangeSourceOperatorX,
+    // it is an order-by scenario.
+    // In this case, there is only one target instance, and no n * n RPC 
concurrency will occur.
+    // Therefore, sharing a sink buffer is not necessary.
+    if (std::dynamic_pointer_cast<SortSourceOperatorX>(_child) ||
+        std::dynamic_pointer_cast<LocalExchangeSourceOperatorX>(_child)) {
+        return _create_buffer({sender_ins_id});
+    }
+    if (_state->enable_shared_exchange_sink_buffer()) {
+        return _sink_buffer;
+    }
+    return _create_buffer({sender_ins_id});
+}
 } // namespace doris::pipeline
diff --git a/be/src/pipeline/exec/exchange_sink_operator.h 
b/be/src/pipeline/exec/exchange_sink_operator.h
index 91ee1bd27a6..8d094b43f61 100644
--- a/be/src/pipeline/exec/exchange_sink_operator.h
+++ b/be/src/pipeline/exec/exchange_sink_operator.h
@@ -61,6 +61,14 @@ public:
                                              parent->get_name() + 
"_FINISH_DEPENDENCY", false);
     }
 
+#ifdef BE_TEST
+    ExchangeSinkLocalState(RuntimeState* state) : Base(nullptr, state) {
+        _profile = state->obj_pool()->add(new RuntimeProfile("mock"));
+        _memory_used_counter =
+                _profile->AddHighWaterMarkCounter("MemoryUsage", TUnit::BYTES, 
"", 1);
+    }
+#endif
+
     std::vector<Dependency*> dependencies() const override {
         std::vector<Dependency*> dep_vec;
         if (_queue_dependency) {
@@ -88,7 +96,12 @@ public:
     bool is_finished() const override { return _reach_limit.load(); }
     void set_reach_limit() { _reach_limit = true; };
 
+    // sender_id indicates which instance within a fragment, while be_number 
indicates which instance
+    // across all fragments. For example, with 3 BEs and 8 instances, the 
range of sender_id would be 0 to 24,
+    // and the range of be_number would be from n + 0 to n + 24.
+    // Since be_number is a required field, it still needs to be set for 
compatibility with older code.
     [[nodiscard]] int sender_id() const { return _sender_id; }
+    [[nodiscard]] int be_number() const { return _state->be_number(); }
 
     std::string name_suffix() override;
     segment_v2::CompressionTypePB compression_type() const;
@@ -112,7 +125,7 @@ private:
     friend class vectorized::Channel;
     friend class vectorized::BlockSerializer;
 
-    std::unique_ptr<ExchangeSinkBuffer> _sink_buffer = nullptr;
+    std::shared_ptr<ExchangeSinkBuffer> _sink_buffer = nullptr;
     RuntimeProfile::Counter* _serialize_batch_timer = nullptr;
     RuntimeProfile::Counter* _compress_timer = nullptr;
     RuntimeProfile::Counter* _bytes_sent_counter = nullptr;
@@ -197,7 +210,8 @@ class ExchangeSinkOperatorX final : public 
DataSinkOperatorX<ExchangeSinkLocalSt
 public:
     ExchangeSinkOperatorX(RuntimeState* state, const RowDescriptor& row_desc, 
int operator_id,
                           const TDataStreamSink& sink,
-                          const std::vector<TPlanFragmentDestination>& 
destinations);
+                          const std::vector<TPlanFragmentDestination>& 
destinations,
+                          const std::vector<TUniqueId>& fragment_instance_ids);
     Status init(const TDataSink& tsink) override;
 
     RuntimeState* state() { return _state; }
@@ -209,6 +223,14 @@ public:
     DataDistribution required_data_distribution() const override;
     bool is_serial_operator() const override { return true; }
 
+    // For a normal shuffle scenario, if the concurrency is n,
+    // there can be up to n * n RPCs in the current fragment.
+    // Therefore, a shared sink buffer is used here to limit the number of 
concurrent RPCs.
+    // (Note: This does not reduce the total number of RPCs.)
+    // In a merge sort scenario, there are only n RPCs, so a shared sink 
buffer is not needed.
+    /// TODO: Modify this to let FE handle the judgment instead of BE.
+    std::shared_ptr<ExchangeSinkBuffer> get_sink_buffer(InstanceLoId 
sender_ins_id);
+
 private:
     friend class ExchangeSinkLocalState;
 
@@ -225,6 +247,13 @@ private:
                                      size_t num_channels,
                                      std::vector<std::vector<uint32_t>>& 
channel2rows,
                                      vectorized::Block* block, bool eos);
+
+    // Use ExchangeSinkOperatorX to create a sink buffer.
+    // The sink buffer can be shared among multiple ExchangeSinkLocalState 
instances,
+    // or each ExchangeSinkLocalState can have its own sink buffer.
+    std::shared_ptr<ExchangeSinkBuffer> _create_buffer(
+            const std::vector<InstanceLoId>& sender_ins_ids);
+    std::shared_ptr<ExchangeSinkBuffer> _sink_buffer = nullptr;
     RuntimeState* _state = nullptr;
 
     const std::vector<TExpr> _texprs;
@@ -264,6 +293,7 @@ private:
     size_t _data_processed = 0;
     int _writer_count = 1;
     const bool _enable_local_merge_sort;
+    const std::vector<TUniqueId>& _fragment_instance_ids;
 };
 
 } // namespace pipeline
diff --git a/be/src/pipeline/pipeline_fragment_context.cpp 
b/be/src/pipeline/pipeline_fragment_context.cpp
index 8ceb63eb993..8ab0f1d1515 100644
--- a/be/src/pipeline/pipeline_fragment_context.cpp
+++ b/be/src/pipeline/pipeline_fragment_context.cpp
@@ -260,7 +260,7 @@ Status PipelineFragmentContext::prepare(const 
doris::TPipelineFragmentParams& re
         _runtime_state = RuntimeState::create_unique(
                 request.query_id, request.fragment_id, request.query_options,
                 _query_ctx->query_globals, _exec_env, _query_ctx.get());
-
+        _runtime_state->set_task_execution_context(shared_from_this());
         
SCOPED_SWITCH_THREAD_MEM_TRACKER_LIMITER(_runtime_state->query_mem_tracker());
         if (request.__isset.backend_id) {
             _runtime_state->set_backend_id(request.backend_id);
@@ -296,6 +296,14 @@ Status PipelineFragmentContext::prepare(const 
doris::TPipelineFragmentParams& re
         if (local_params.__isset.topn_filter_descs) {
             
_query_ctx->init_runtime_predicates(local_params.topn_filter_descs);
         }
+
+        // init fragment_instance_ids
+        const auto target_size = request.local_params.size();
+        _fragment_instance_ids.resize(target_size);
+        for (size_t i = 0; i < request.local_params.size(); i++) {
+            auto fragment_instance_id = 
request.local_params[i].fragment_instance_id;
+            _fragment_instance_ids[i] = fragment_instance_id;
+        }
     }
 
     {
@@ -353,7 +361,6 @@ Status PipelineFragmentContext::_build_pipeline_tasks(const 
doris::TPipelineFrag
     _total_tasks = 0;
     const auto target_size = request.local_params.size();
     _tasks.resize(target_size);
-    _fragment_instance_ids.resize(target_size);
     _runtime_filter_states.resize(target_size);
     _task_runtime_states.resize(_pipelines.size());
     for (size_t pip_idx = 0; pip_idx < _pipelines.size(); pip_idx++) {
@@ -365,8 +372,6 @@ Status PipelineFragmentContext::_build_pipeline_tasks(const 
doris::TPipelineFrag
     auto pre_and_submit = [&](int i, PipelineFragmentContext* ctx) {
         const auto& local_params = request.local_params[i];
         auto fragment_instance_id = local_params.fragment_instance_id;
-        _fragment_instance_ids[i] = fragment_instance_id;
-
         _runtime_filter_states[i] = 
RuntimeFilterParamsContext::create(_query_ctx.get());
         std::unique_ptr<RuntimeFilterMgr> runtime_filter_mgr = 
std::make_unique<RuntimeFilterMgr>(
                 request.query_id, _runtime_filter_states[i], 
_query_ctx->query_mem_tracker, false);
@@ -1007,7 +1012,8 @@ Status 
PipelineFragmentContext::_create_data_sink(ObjectPool* pool, const TDataS
             return Status::InternalError("Missing data stream sink.");
         }
         _sink.reset(new ExchangeSinkOperatorX(state, row_desc, 
next_sink_operator_id(),
-                                              thrift_sink.stream_sink, 
params.destinations));
+                                              thrift_sink.stream_sink, 
params.destinations,
+                                              _fragment_instance_ids));
         break;
     }
     case TDataSinkType::RESULT_SINK: {
@@ -1134,10 +1140,10 @@ Status 
PipelineFragmentContext::_create_data_sink(ObjectPool* pool, const TDataS
             // 2. create and set sink operator of data stream sender for new 
pipeline
 
             DataSinkOperatorPtr sink_op;
-            sink_op.reset(
-                    new ExchangeSinkOperatorX(state, *_row_desc, 
next_sink_operator_id(),
-                                              
thrift_sink.multi_cast_stream_sink.sinks[i],
-                                              
thrift_sink.multi_cast_stream_sink.destinations[i]));
+            sink_op.reset(new ExchangeSinkOperatorX(
+                    state, *_row_desc, next_sink_operator_id(),
+                    thrift_sink.multi_cast_stream_sink.sinks[i],
+                    thrift_sink.multi_cast_stream_sink.destinations[i], 
_fragment_instance_ids));
 
             RETURN_IF_ERROR(new_pipeline->set_sink(sink_op));
             {
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index a49567109a3..1e7c1e579f7 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -587,6 +587,11 @@ public:
                _query_options.enable_local_merge_sort;
     }
 
+    bool enable_shared_exchange_sink_buffer() const {
+        return _query_options.__isset.enable_shared_exchange_sink_buffer &&
+               _query_options.enable_shared_exchange_sink_buffer;
+    }
+
     int64_t min_revocable_mem() const {
         if (_query_options.__isset.min_revocable_mem) {
             return std::max(_query_options.min_revocable_mem, (int64_t)1);
diff --git a/be/src/vec/sink/vdata_stream_sender.h 
b/be/src/vec/sink/vdata_stream_sender.h
index 4999602fdf4..16ea49e443c 100644
--- a/be/src/vec/sink/vdata_stream_sender.h
+++ b/be/src/vec/sink/vdata_stream_sender.h
@@ -76,6 +76,9 @@ namespace vectorized {
 class BlockSerializer {
 public:
     BlockSerializer(pipeline::ExchangeSinkLocalState* parent, bool is_local = 
true);
+#ifdef BE_TEST
+    BlockSerializer() : _batch_size(0) {};
+#endif
     Status next_serialized_block(Block* src, PBlock* dest, size_t 
num_receivers, bool* serialized,
                                  bool eos, const std::vector<uint32_t>* rows = 
nullptr);
     Status serialize_block(PBlock* dest, size_t num_receivers = 1);
@@ -165,10 +168,9 @@ public:
         return Status::OK();
     }
 
-    void register_exchange_buffer(pipeline::ExchangeSinkBuffer* buffer) {
-        _buffer = buffer;
-        _buffer->register_sink(_fragment_instance_id);
-    }
+    void set_exchange_buffer(pipeline::ExchangeSinkBuffer* buffer) { _buffer = 
buffer; }
+
+    InstanceLoId dest_ins_id() const { return _fragment_instance_id.lo; }
 
     std::shared_ptr<pipeline::ExchangeSendCallback<PTransmitDataResult>> 
get_send_callback(
             InstanceLoId id, bool eos) {
diff --git a/be/test/vec/exec/exchange_sink_test.cpp 
b/be/test/vec/exec/exchange_sink_test.cpp
new file mode 100644
index 00000000000..9576ed71ee2
--- /dev/null
+++ b/be/test/vec/exec/exchange_sink_test.cpp
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "exchange_sink_test.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "pipeline/exec/exchange_sink_buffer.h"
+
+namespace doris::vectorized {
+using namespace pipeline;
+TEST_F(ExchangeSInkTest, test_normal_end) {
+    {
+        auto state = create_runtime_state();
+        auto buffer = create_buffer(state);
+
+        auto sink1 = create_sink(state, buffer);
+        auto sink2 = create_sink(state, buffer);
+        auto sink3 = create_sink(state, buffer);
+
+        EXPECT_EQ(sink1.add_block(dest_ins_id_1, true), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_2, true), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_3, true), Status::OK());
+
+        EXPECT_EQ(sink2.add_block(dest_ins_id_1, true), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_2, true), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_3, true), Status::OK());
+
+        EXPECT_EQ(sink3.add_block(dest_ins_id_1, true), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_2, true), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_3, true), Status::OK());
+
+        for (auto [id, count] : buffer->_running_sink_count) {
+            EXPECT_EQ(count, 3) << "id : " << id;
+        }
+
+        for (auto [id, is_turn_off] : buffer->_rpc_channel_is_turn_off) {
+            EXPECT_EQ(is_turn_off, false) << "id : " << id;
+        }
+
+        pop_block(dest_ins_id_1, PopState::accept);
+        pop_block(dest_ins_id_1, PopState::accept);
+        pop_block(dest_ins_id_1, PopState::accept);
+
+        pop_block(dest_ins_id_2, PopState::accept);
+        pop_block(dest_ins_id_2, PopState::accept);
+        pop_block(dest_ins_id_2, PopState::accept);
+
+        pop_block(dest_ins_id_3, PopState::accept);
+        pop_block(dest_ins_id_3, PopState::accept);
+        pop_block(dest_ins_id_3, PopState::accept);
+
+        for (auto [id, count] : buffer->_running_sink_count) {
+            EXPECT_EQ(count, 0) << "id : " << id;
+        }
+
+        for (auto [id, is_turn_off] : buffer->_rpc_channel_is_turn_off) {
+            EXPECT_EQ(is_turn_off, true) << "id : " << id;
+        }
+        clear_all_done();
+    }
+}
+
+TEST_F(ExchangeSInkTest, test_eof_end) {
+    {
+        auto state = create_runtime_state();
+        auto buffer = create_buffer(state);
+
+        auto sink1 = create_sink(state, buffer);
+        auto sink2 = create_sink(state, buffer);
+        auto sink3 = create_sink(state, buffer);
+
+        EXPECT_EQ(sink1.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_3, false), Status::OK());
+
+        EXPECT_EQ(sink2.add_block(dest_ins_id_1, true), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_2, true), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_3, true), Status::OK());
+
+        EXPECT_EQ(sink3.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_2, true), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_3, false), Status::OK());
+
+        for (auto [id, count] : buffer->_running_sink_count) {
+            EXPECT_EQ(count, 3) << "id : " << id;
+        }
+
+        for (auto [id, is_turn_off] : buffer->_rpc_channel_is_turn_off) {
+            EXPECT_EQ(is_turn_off, false) << "id : " << id;
+        }
+
+        pop_block(dest_ins_id_1, PopState::eof);
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_1], true);
+        EXPECT_TRUE(buffer->_instance_to_package_queue[dest_ins_id_1].empty());
+
+        pop_block(dest_ins_id_2, PopState::accept);
+        pop_block(dest_ins_id_2, PopState::accept);
+        pop_block(dest_ins_id_2, PopState::accept);
+
+        pop_block(dest_ins_id_3, PopState::accept);
+        pop_block(dest_ins_id_3, PopState::accept);
+        pop_block(dest_ins_id_3, PopState::accept);
+
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_1], true);
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_2], false) << 
"not all eos";
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_3], false) << " 
not all eos";
+
+        EXPECT_TRUE(sink1.add_block(dest_ins_id_1, 
true).is<ErrorCode::END_OF_FILE>());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_2, true), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_3, true), Status::OK());
+        pop_block(dest_ins_id_2, PopState::accept);
+        pop_block(dest_ins_id_3, PopState::accept);
+
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_1], true);
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_2], true);
+        EXPECT_EQ(buffer->_rpc_channel_is_turn_off[dest_ins_id_3], false);
+        EXPECT_EQ(buffer->_running_sink_count[dest_ins_id_3], 1);
+
+        clear_all_done();
+    }
+}
+
+TEST_F(ExchangeSInkTest, test_error_end) {
+    {
+        auto state = create_runtime_state();
+        auto buffer = create_buffer(state);
+
+        auto sink1 = create_sink(state, buffer);
+        auto sink2 = create_sink(state, buffer);
+        auto sink3 = create_sink(state, buffer);
+
+        EXPECT_EQ(sink1.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_3, false), Status::OK());
+
+        EXPECT_EQ(sink2.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_3, false), Status::OK());
+
+        EXPECT_EQ(sink3.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_3, false), Status::OK());
+
+        for (auto [id, count] : buffer->_running_sink_count) {
+            EXPECT_EQ(count, 3) << "id : " << id;
+        }
+
+        for (auto [id, is_turn_off] : buffer->_rpc_channel_is_turn_off) {
+            EXPECT_EQ(is_turn_off, false) << "id : " << id;
+        }
+
+        pop_block(dest_ins_id_2, PopState::error);
+
+        auto orgin_queue_1_size = done_map[dest_ins_id_1].size();
+        auto orgin_queue_2_size = done_map[dest_ins_id_2].size();
+        auto orgin_queue_3_size = done_map[dest_ins_id_3].size();
+
+        EXPECT_EQ(sink1.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink1.add_block(dest_ins_id_3, false), Status::OK());
+
+        EXPECT_EQ(sink2.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink2.add_block(dest_ins_id_3, false), Status::OK());
+
+        EXPECT_EQ(sink3.add_block(dest_ins_id_1, false), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_2, false), Status::OK());
+        EXPECT_EQ(sink3.add_block(dest_ins_id_3, false), Status::OK());
+
+        EXPECT_EQ(orgin_queue_1_size, done_map[dest_ins_id_1].size());
+        EXPECT_EQ(orgin_queue_2_size, done_map[dest_ins_id_2].size());
+        EXPECT_EQ(orgin_queue_3_size, done_map[dest_ins_id_3].size());
+
+        clear_all_done();
+    }
+}
+
+} // namespace doris::vectorized
diff --git a/be/test/vec/exec/exchange_sink_test.h 
b/be/test/vec/exec/exchange_sink_test.h
new file mode 100644
index 00000000000..253d7b267f9
--- /dev/null
+++ b/be/test/vec/exec/exchange_sink_test.h
@@ -0,0 +1,180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "common/status.h"
+#include "pipeline/exec/exchange_sink_buffer.h"
+#include "pipeline/exec/exchange_sink_operator.h"
+#include "runtime/runtime_state.h"
+#include "udf/udf.h"
+#include "vec/sink/writer/vhive_utils.h"
+
+namespace doris::pipeline {
+
+std::map<int64_t, std::queue<AutoReleaseClosure<PTransmitDataParams,
+                                                
ExchangeSendCallback<PTransmitDataResult>>*>>
+
+        done_map;
+
+void add_request(int64_t id, auto* done) {
+    done_map[id].push(done);
+}
+
+void clear_all_done() {
+    for (auto& [id, dones] : done_map) {
+        while (!dones.empty()) {
+            dones.front()->Run();
+            dones.pop();
+        }
+    }
+}
+
+enum PopState : int {
+    eof,
+    error,
+    accept,
+};
+
+void pop_block(int64_t id, PopState state) {
+    if (done_map[id].empty()) {
+        return;
+    }
+    auto* done = done_map[id].front();
+    done_map[id].pop();
+    switch (state) {
+    case PopState::eof: {
+        Status st = Status::EndOfFile("Mock eof");
+        st.to_protobuf(done->response_->mutable_status());
+        done->Run();
+        break;
+    }
+    case error: {
+        done->cntl_->SetFailed("Mock error");
+        done->Run();
+        break;
+    }
+    case accept: {
+        done->Run();
+        break;
+    }
+    }
+}
+void transmit_blockv2(PBackendService_Stub& stub,
+                      std::unique_ptr<AutoReleaseClosure<PTransmitDataParams,
+                                                         
ExchangeSendCallback<PTransmitDataResult>>>
+                              closure) {
+    std::cout << "mock transmit_blockv2 dest ins id :" << 
closure->request_->finst_id().lo()
+              << "\n";
+    add_request(closure->request_->finst_id().lo(), closure.release());
+}
+}; // namespace doris::pipeline
+
+namespace doris::vectorized {
+
+using namespace pipeline;
+class ExchangeSInkTest : public testing::Test {
+public:
+    ExchangeSInkTest() = default;
+    ~ExchangeSInkTest() override = default;
+};
+
+class MockContext : public TaskExecutionContext {};
+
+std::shared_ptr<MockContext> _mock_context = std::make_shared<MockContext>();
+
+auto create_runtime_state() {
+    auto state = RuntimeState::create_shared();
+
+    state->set_task_execution_context(_mock_context);
+    return state;
+}
+constexpr int64_t recvr_fragment_id = 2;
+constexpr int64_t sender_fragment_id = 2;
+
+TUniqueId create_TUniqueId(int64_t hi, int64_t lo) {
+    TUniqueId t {};
+    t.hi = hi;
+    t.lo = lo;
+    return t;
+}
+
+const auto dest_fragment_ins_id_1 = create_TUniqueId(recvr_fragment_id, 1);
+const auto dest_fragment_ins_id_2 = create_TUniqueId(recvr_fragment_id, 2);
+const auto dest_fragment_ins_id_3 = create_TUniqueId(recvr_fragment_id, 3);
+const auto dest_ins_id_1 = dest_fragment_ins_id_1.lo;
+const auto dest_ins_id_2 = dest_fragment_ins_id_2.lo;
+const auto dest_ins_id_3 = dest_fragment_ins_id_3.lo;
+
+class MockSinkBuffer : public ExchangeSinkBuffer {
+public:
+    MockSinkBuffer(RuntimeState* state, int64_t sinknum) : 
ExchangeSinkBuffer(state, sinknum) {};
+    void _failed(InstanceLoId id, const std::string& err) override {
+        _is_failed = true;
+        std::cout << "_failed\n";
+    }
+};
+
+struct SinkWithChannel {
+    std::shared_ptr<ExchangeSinkLocalState> sink;
+    std::shared_ptr<MockSinkBuffer> buffer;
+    std::map<int64_t, std::shared_ptr<Channel>> channels;
+    Status add_block(int64_t id, bool eos) {
+        auto channel = channels[id];
+        TransmitInfo transmitInfo {.channel = channel.get(),
+                                   .block = std::make_unique<PBlock>(),
+                                   .eos = eos,
+                                   .exec_status = Status::OK()};
+        return buffer->add_block(std::move(transmitInfo));
+    }
+};
+
+auto create_buffer(std::shared_ptr<RuntimeState> state) {
+    auto sink_buffer = std::make_shared<MockSinkBuffer>(state.get(), 3);
+
+    sink_buffer->construct_request(dest_fragment_ins_id_1);
+    sink_buffer->construct_request(dest_fragment_ins_id_2);
+    sink_buffer->construct_request(dest_fragment_ins_id_3);
+    return sink_buffer;
+}
+
+auto create_sink(std::shared_ptr<RuntimeState> state, 
std::shared_ptr<MockSinkBuffer> sink_buffer) {
+    SinkWithChannel sink_with_channel;
+    sink_with_channel.sink = 
ExchangeSinkLocalState::create_shared(state.get());
+    sink_with_channel.buffer = sink_buffer;
+    {
+        auto channel = std::make_shared<vectorized::Channel>(
+                sink_with_channel.sink.get(), TNetworkAddress {}, 
dest_fragment_ins_id_1, 0);
+        sink_with_channel.channels[channel->dest_ins_id()] = channel;
+    }
+    {
+        auto channel = std::make_shared<vectorized::Channel>(
+                sink_with_channel.sink.get(), TNetworkAddress {}, 
dest_fragment_ins_id_2, 0);
+        sink_with_channel.channels[channel->dest_ins_id()] = channel;
+    }
+    {
+        auto channel = std::make_shared<vectorized::Channel>(
+                sink_with_channel.sink.get(), TNetworkAddress {}, 
dest_fragment_ins_id_3, 0);
+        sink_with_channel.channels[channel->dest_ins_id()] = channel;
+    }
+    return sink_with_channel;
+}
+
+} // namespace doris::vectorized
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index d07a8f022d7..380c758e575 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -284,6 +284,8 @@ public class SessionVariable implements Serializable, 
Writable {
 
     public static final String ENABLE_LOCAL_MERGE_SORT = 
"enable_local_merge_sort";
 
+    public static final String ENABLE_SHARED_EXCHANGE_SINK_BUFFER = 
"enable_shared_exchange_sink_buffer";
+
     public static final String ENABLE_AGG_STATE = "enable_agg_state";
 
     public static final String ENABLE_RPC_OPT_FOR_PIPELINE = 
"enable_rpc_opt_for_pipeline";
@@ -1135,6 +1137,9 @@ public class SessionVariable implements Serializable, 
Writable {
     @VariableMgr.VarAttr(name = ENABLE_LOCAL_MERGE_SORT)
     private boolean enableLocalMergeSort = true;
 
+    @VariableMgr.VarAttr(name = ENABLE_SHARED_EXCHANGE_SINK_BUFFER, fuzzy = 
true)
+    private boolean enableSharedExchangeSinkBuffer = true;
+
     @VariableMgr.VarAttr(name = ENABLE_AGG_STATE, fuzzy = false, varType = 
VariableAnnotation.EXPERIMENTAL,
             needForward = true)
     public boolean enableAggState = false;
@@ -2370,6 +2375,7 @@ public class SessionVariable implements Serializable, 
Writable {
         this.parallelPrepareThreshold = random.nextInt(32) + 1;
         this.enableCommonExprPushdown = random.nextBoolean();
         this.enableLocalExchange = random.nextBoolean();
+        this.enableSharedExchangeSinkBuffer = random.nextBoolean();
         this.useSerialExchange = random.nextBoolean();
         // This will cause be dead loop, disable it first
         // this.disableJoinReorder = random.nextBoolean();
@@ -3965,6 +3971,7 @@ public class SessionVariable implements Serializable, 
Writable {
         tResult.setDataQueueMaxBlocks(dataQueueMaxBlocks);
 
         tResult.setEnableLocalMergeSort(enableLocalMergeSort);
+        
tResult.setEnableSharedExchangeSinkBuffer(enableSharedExchangeSinkBuffer);
         tResult.setEnableParallelResultSink(enableParallelResultSink);
         tResult.setEnableParallelOutfile(enableParallelOutfile);
         
tResult.setEnableShortCircuitQueryAccessColumnStore(enableShortCircuitQueryAcessColumnStore);
diff --git a/gensrc/thrift/PaloInternalService.thrift 
b/gensrc/thrift/PaloInternalService.thrift
index 745cb8f21fb..85abbf9b66d 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -359,7 +359,7 @@ struct TQueryOptions {
 
   141: optional bool ignore_runtime_filter_error = false;
   142: optional bool enable_fixed_len_to_uint32_v2 = false;
-
+  143: optional bool enable_shared_exchange_sink_buffer = true;
   // For cloud, to control if the content would be written into file cache
   // In write path, to control if the content would be written into file cache.
   // In read path, read from file cache or remote storage when execute query.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to