This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 9b54e282acf656748a83ec704db345c8a08a2ec3
Author: Gabriel <[email protected]>
AuthorDate: Thu Aug 22 20:47:04 2024 +0800

    [refactor](local exchange) Refactor logics (#39771)
---
 be/src/pipeline/local_exchange/local_exchanger.cpp | 56 +++++++++++++++-------
 be/src/pipeline/local_exchange/local_exchanger.h   | 23 +++------
 2 files changed, 45 insertions(+), 34 deletions(-)

diff --git a/be/src/pipeline/local_exchange/local_exchanger.cpp 
b/be/src/pipeline/local_exchange/local_exchanger.cpp
index 6d504666522..118ea29062a 100644
--- a/be/src/pipeline/local_exchange/local_exchanger.cpp
+++ b/be/src/pipeline/local_exchange/local_exchanger.cpp
@@ -43,7 +43,8 @@ void Exchanger<BlockType>::_enqueue_data_and_set_ready(int 
channel_id,
     }
     std::unique_lock l(_m);
     local_state._shared_state->add_mem_usage(channel_id, allocated_bytes,
-                                             !std::is_same_v<PartitionedBlock, 
BlockType>);
+                                             !std::is_same_v<PartitionedBlock, 
BlockType> &&
+                                                     
!std::is_same_v<BroadcastBlock, BlockType>);
     if (_data_queue[channel_id].enqueue(std::move(block))) {
         local_state._shared_state->set_ready_to_read(channel_id);
     } else {
@@ -54,6 +55,7 @@ void Exchanger<BlockType>::_enqueue_data_and_set_ready(int 
channel_id,
             block.first->unref(local_state._shared_state, allocated_bytes);
         } else {
             block->unref(local_state._shared_state, allocated_bytes);
+            DCHECK_EQ(block->ref_value(), 0);
         }
     }
 }
@@ -79,6 +81,7 @@ bool 
Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState& local_st
                                                      
block->data_block.allocated_bytes());
             data_block->swap(block->data_block);
             block->unref(local_state._shared_state, 
data_block->allocated_bytes());
+            DCHECK_EQ(block->ref_value(), 0);
         }
         return true;
     } else if (all_finished) {
@@ -94,6 +97,7 @@ bool 
Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState& local_st
                                                          
block->data_block.allocated_bytes());
                 data_block->swap(block->data_block);
                 block->unref(local_state._shared_state, 
data_block->allocated_bytes());
+                DCHECK_EQ(block->ref_value(), 0);
             }
             return true;
         }
@@ -105,6 +109,9 @@ bool 
Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState& local_st
 
 Status ShuffleExchanger::sink(RuntimeState* state, vectorized::Block* 
in_block, bool eos,
                               LocalExchangeSinkLocalState& local_state) {
+    if (in_block->empty()) {
+        return Status::OK();
+    }
     {
         SCOPED_TIMER(local_state._compute_hash_value_timer);
         RETURN_IF_ERROR(local_state._partitioner->do_partitioning(state, 
in_block,
@@ -114,7 +121,7 @@ Status ShuffleExchanger::sink(RuntimeState* state, 
vectorized::Block* in_block,
         SCOPED_TIMER(local_state._distribute_timer);
         RETURN_IF_ERROR(_split_rows(state,
                                     
local_state._partitioner->get_channel_ids().get<uint32_t>(),
-                                    in_block, eos, local_state));
+                                    in_block, local_state));
     }
 
     return Status::OK();
@@ -158,7 +165,7 @@ Status ShuffleExchanger::get_block(RuntimeState* state, 
vectorized::Block* block
 }
 
 Status ShuffleExchanger::_split_rows(RuntimeState* state, const uint32_t* 
__restrict channel_ids,
-                                     vectorized::Block* block, bool eos,
+                                     vectorized::Block* block,
                                      LocalExchangeSinkLocalState& local_state) 
{
     const auto rows = block->rows();
     auto row_idx = std::make_shared<vectorized::PODArray<uint32_t>>(rows);
@@ -267,6 +274,9 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, 
const uint32_t* __rest
 
 Status PassthroughExchanger::sink(RuntimeState* state, vectorized::Block* 
in_block, bool eos,
                                   LocalExchangeSinkLocalState& local_state) {
+    if (in_block->empty()) {
+        return Status::OK();
+    }
     vectorized::Block new_block;
     BlockWrapperSPtr wrapper;
     if (!_free_blocks.try_dequeue(new_block)) {
@@ -309,7 +319,13 @@ Status PassthroughExchanger::get_block(RuntimeState* 
state, vectorized::Block* b
 
 Status PassToOneExchanger::sink(RuntimeState* state, vectorized::Block* 
in_block, bool eos,
                                 LocalExchangeSinkLocalState& local_state) {
-    vectorized::Block new_block(in_block->clone_empty());
+    if (in_block->empty()) {
+        return Status::OK();
+    }
+    vectorized::Block new_block;
+    if (!_free_blocks.try_dequeue(new_block)) {
+        new_block = {in_block->clone_empty()};
+    }
     new_block.swap(*in_block);
 
     BlockWrapperSPtr wrapper = 
BlockWrapper::create_shared(std::move(new_block));
@@ -331,15 +347,17 @@ Status PassToOneExchanger::get_block(RuntimeState* state, 
vectorized::Block* blo
 
 Status LocalMergeSortExchanger::sink(RuntimeState* state, vectorized::Block* 
in_block, bool eos,
                                      LocalExchangeSinkLocalState& local_state) 
{
-    vectorized::Block new_block;
-    if (!_free_blocks.try_dequeue(new_block)) {
-        new_block = {in_block->clone_empty()};
-    }
-    DCHECK_LE(local_state._channel_id, _data_queue.size());
+    if (!in_block->empty()) {
+        vectorized::Block new_block;
+        if (!_free_blocks.try_dequeue(new_block)) {
+            new_block = {in_block->clone_empty()};
+        }
+        DCHECK_LE(local_state._channel_id, _data_queue.size());
 
-    new_block.swap(*in_block);
-    _enqueue_data_and_set_ready(local_state._channel_id, local_state,
-                                
BlockWrapper::create_shared(std::move(new_block)));
+        new_block.swap(*in_block);
+        _enqueue_data_and_set_ready(local_state._channel_id, local_state,
+                                    
BlockWrapper::create_shared(std::move(new_block)));
+    }
     if (eos) {
         
local_state._shared_state->source_deps[local_state._channel_id]->set_always_ready();
     }
@@ -440,7 +458,7 @@ Status BroadcastExchanger::get_block(RuntimeState* state, 
vectorized::Block* blo
 }
 
 Status AdaptivePassthroughExchanger::_passthrough_sink(RuntimeState* state,
-                                                       vectorized::Block* 
in_block, bool eos,
+                                                       vectorized::Block* 
in_block,
                                                        
LocalExchangeSinkLocalState& local_state) {
     vectorized::Block new_block;
     if (!_free_blocks.try_dequeue(new_block)) {
@@ -455,7 +473,6 @@ Status 
AdaptivePassthroughExchanger::_passthrough_sink(RuntimeState* state,
 }
 
 Status AdaptivePassthroughExchanger::_shuffle_sink(RuntimeState* state, 
vectorized::Block* block,
-                                                   bool eos,
                                                    
LocalExchangeSinkLocalState& local_state) {
     std::vector<uint32_t> channel_ids;
     const auto num_rows = block->rows();
@@ -471,12 +488,12 @@ Status 
AdaptivePassthroughExchanger::_shuffle_sink(RuntimeState* state, vectoriz
             std::iota(channel_ids.begin() + i, channel_ids.end(), 0);
         }
     }
-    return _split_rows(state, channel_ids.data(), block, eos, local_state);
+    return _split_rows(state, channel_ids.data(), block, local_state);
 }
 
 Status AdaptivePassthroughExchanger::_split_rows(RuntimeState* state,
                                                  const uint32_t* __restrict 
channel_ids,
-                                                 vectorized::Block* block, 
bool eos,
+                                                 vectorized::Block* block,
                                                  LocalExchangeSinkLocalState& 
local_state) {
     const auto rows = block->rows();
     auto row_idx = std::make_shared<std::vector<uint32_t>>(rows);
@@ -513,13 +530,16 @@ Status 
AdaptivePassthroughExchanger::_split_rows(RuntimeState* state,
 
 Status AdaptivePassthroughExchanger::sink(RuntimeState* state, 
vectorized::Block* in_block,
                                           bool eos, 
LocalExchangeSinkLocalState& local_state) {
+    if (in_block->empty()) {
+        return Status::OK();
+    }
     if (_is_pass_through) {
-        return _passthrough_sink(state, in_block, eos, local_state);
+        return _passthrough_sink(state, in_block, local_state);
     } else {
         if (_total_block++ > _num_partitions) {
             _is_pass_through = true;
         }
-        return _shuffle_sink(state, in_block, eos, local_state);
+        return _shuffle_sink(state, in_block, local_state);
     }
 }
 
diff --git a/be/src/pipeline/local_exchange/local_exchanger.h 
b/be/src/pipeline/local_exchange/local_exchanger.h
index 8832aca02c2..834f74d216d 100644
--- a/be/src/pipeline/local_exchange/local_exchanger.h
+++ b/be/src/pipeline/local_exchange/local_exchanger.h
@@ -172,6 +172,7 @@ struct BlockWrapper {
     void ref(int delta) { ref_count += delta; }
     void unref(LocalExchangeSharedState* shared_state, size_t allocated_bytes) 
{
         if (ref_count.fetch_sub(1) == 1) {
+            DCHECK_GT(allocated_bytes, 0);
             shared_state->sub_total_mem_usage(allocated_bytes);
             if (shared_state->exchanger->_free_block_limit == 0 ||
                 shared_state->exchanger->_free_blocks.size_approx() <
@@ -183,17 +184,9 @@ struct BlockWrapper {
         }
     }
     void unref(LocalExchangeSharedState* shared_state) {
-        if (ref_count.fetch_sub(1) == 1) {
-            shared_state->sub_total_mem_usage(data_block.allocated_bytes());
-            if (shared_state->exchanger->_free_block_limit == 0 ||
-                shared_state->exchanger->_free_blocks.size_approx() <
-                        shared_state->exchanger->_free_block_limit *
-                                shared_state->exchanger->_num_sources) {
-                data_block.clear_column_data();
-                
shared_state->exchanger->_free_blocks.enqueue(std::move(data_block));
-            }
-        }
+        unref(shared_state, data_block.allocated_bytes());
     }
+    int ref_value() const { return ref_count.load(); }
     std::atomic<int> ref_count = 0;
     vectorized::Block data_block;
 };
@@ -224,8 +217,7 @@ protected:
         _data_queue.resize(num_partitions);
     }
     Status _split_rows(RuntimeState* state, const uint32_t* __restrict 
channel_ids,
-                       vectorized::Block* block, bool eos,
-                       LocalExchangeSinkLocalState& local_state);
+                       vectorized::Block* block, LocalExchangeSinkLocalState& 
local_state);
 
     const bool _ignore_source_data_distribution = false;
 };
@@ -343,13 +335,12 @@ public:
     void close(LocalExchangeSourceLocalState& local_state) override;
 
 private:
-    Status _passthrough_sink(RuntimeState* state, vectorized::Block* in_block, 
bool eos,
+    Status _passthrough_sink(RuntimeState* state, vectorized::Block* in_block,
                              LocalExchangeSinkLocalState& local_state);
-    Status _shuffle_sink(RuntimeState* state, vectorized::Block* in_block, 
bool eos,
+    Status _shuffle_sink(RuntimeState* state, vectorized::Block* in_block,
                          LocalExchangeSinkLocalState& local_state);
     Status _split_rows(RuntimeState* state, const uint32_t* __restrict 
channel_ids,
-                       vectorized::Block* block, bool eos,
-                       LocalExchangeSinkLocalState& local_state);
+                       vectorized::Block* block, LocalExchangeSinkLocalState& 
local_state);
 
     std::atomic_bool _is_pass_through = false;
     std::atomic_int32_t _total_block = 0;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to