This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 255d80cf981 [Feature](exec) Support group by limit opt in BE code 
(#29641)
255d80cf981 is described below

commit 255d80cf981a573fe3c6753c5f2da089f6ba2479
Author: HappenLee <happen...@hotmail.com>
AuthorDate: Mon Jun 3 20:19:37 2024 +0800

    [Feature](exec) Support group by limit opt in BE code (#29641)
    
    ## Proposed changes
    
    Do group by limit, do topn in opt in BE
---
 be/src/pipeline/dependency.cpp                     |  76 ++++++
 be/src/pipeline/dependency.h                       |  68 +++++
 be/src/pipeline/exec/aggregation_sink_operator.cpp | 288 ++++++++++++++++-----
 be/src/pipeline/exec/aggregation_sink_operator.h   |  11 +-
 .../pipeline/exec/aggregation_source_operator.cpp  |  12 +-
 be/src/pipeline/exec/aggregation_source_operator.h |   1 +
 be/src/pipeline/exec/operator.cpp                  |   3 +-
 be/src/vec/columns/column_nullable.cpp             |   1 +
 be/src/vec/columns/column_string.cpp               |  10 +-
 be/src/vec/columns/column_string.h                 |   1 +
 be/src/vec/columns/column_vector.cpp               |  12 +-
 be/src/vec/core/block.cpp                          |  13 +
 be/src/vec/core/block.h                            |   3 +-
 be/src/vec/exec/vaggregation_node.cpp              | 243 ++++++++++++++++-
 be/src/vec/exec/vaggregation_node.h                | 205 +++++++++------
 15 files changed, 773 insertions(+), 174 deletions(-)

diff --git a/be/src/pipeline/dependency.cpp b/be/src/pipeline/dependency.cpp
index 8cf025274af..e7159b2df35 100644
--- a/be/src/pipeline/dependency.cpp
+++ b/be/src/pipeline/dependency.cpp
@@ -196,6 +196,82 @@ LocalExchangeSharedState::LocalExchangeSharedState(int 
num_instances) {
     mem_trackers.resize(num_instances, nullptr);
 }
 
+vectorized::MutableColumns AggSharedState::_get_keys_hash_table() {
+    return std::visit(
+            vectorized::Overload {
+                    [&](std::monostate& arg) {
+                        throw doris::Exception(ErrorCode::INTERNAL_ERROR, 
"uninited hash table");
+                        return vectorized::MutableColumns();
+                    },
+                    [&](auto&& agg_method) -> vectorized::MutableColumns {
+                        vectorized::MutableColumns key_columns;
+                        for (int i = 0; i < probe_expr_ctxs.size(); ++i) {
+                            key_columns.emplace_back(
+                                    
probe_expr_ctxs[i]->root()->data_type()->create_column());
+                        }
+                        auto& data = *agg_method.hash_table;
+                        bool has_null_key = data.has_null_key_data();
+                        const auto size = data.size() - has_null_key;
+                        using KeyType = 
std::decay_t<decltype(agg_method.iterator->get_first())>;
+                        std::vector<KeyType> keys(size);
+
+                        size_t num_rows = 0;
+                        auto iter = aggregate_data_container->begin();
+                        {
+                            while (iter != aggregate_data_container->end()) {
+                                keys[num_rows] = iter.get_key<KeyType>();
+                                ++iter;
+                                ++num_rows;
+                            }
+                        }
+                        agg_method.insert_keys_into_columns(keys, key_columns, 
num_rows);
+                        if (has_null_key) {
+                            key_columns[0]->insert_data(nullptr, 0);
+                        }
+                        return key_columns;
+                    }},
+            agg_data->method_variant);
+}
+
+void AggSharedState::build_limit_heap(size_t hash_table_size) {
+    limit_columns = _get_keys_hash_table();
+    for (size_t i = 0; i < hash_table_size; ++i) {
+        limit_heap.emplace(i, limit_columns, order_directions, 
null_directions);
+    }
+    while (hash_table_size > limit) {
+        limit_heap.pop();
+        hash_table_size--;
+    }
+    limit_columns_min = limit_heap.top()._row_id;
+}
+
+bool AggSharedState::do_limit_filter(vectorized::Block* block, size_t 
num_rows) {
+    if (num_rows) {
+        cmp_res.resize(num_rows);
+        need_computes.resize(num_rows);
+        memset(need_computes.data(), 0, need_computes.size());
+        memset(cmp_res.data(), 0, cmp_res.size());
+
+        const auto key_size = null_directions.size();
+        for (int i = 0; i < key_size; i++) {
+            block->get_by_position(i).column->compare_internal(
+                    limit_columns_min, *limit_columns[i], null_directions[i], 
order_directions[i],
+                    cmp_res, need_computes.data());
+        }
+
+        auto set_computes_arr = [](auto* __restrict res, auto* __restrict 
computes, int rows) {
+            for (int i = 0; i < rows; ++i) {
+                computes[i] = computes[i] == res[i];
+            }
+        };
+        set_computes_arr(cmp_res.data(), need_computes.data(), num_rows);
+
+        return std::find(need_computes.begin(), need_computes.end(), 0) != 
need_computes.end();
+    }
+
+    return false;
+}
+
 Status AggSharedState::reset_hash_table() {
     return std::visit(
             vectorized::Overload {
diff --git a/be/src/pipeline/dependency.h b/be/src/pipeline/dependency.h
index d7084f85d5d..e32f5a1c0d6 100644
--- a/be/src/pipeline/dependency.h
+++ b/be/src/pipeline/dependency.h
@@ -311,6 +311,9 @@ public:
 
     Status reset_hash_table();
 
+    bool do_limit_filter(vectorized::Block* block, size_t num_rows);
+    void build_limit_heap(size_t hash_table_size);
+
     // We should call this function only at 1st phase.
     // 1st phase: is_merge=true, only have one SlotRef.
     // 2nd phase: is_merge=false, maybe have multiple exprs.
@@ -346,8 +349,73 @@ public:
     MemoryRecord mem_usage_record;
     bool agg_data_created_without_key = false;
     bool enable_spill = false;
+    bool reach_limit = false;
+
+    int64_t limit = -1;
+    bool do_sort_limit = false;
+    vectorized::MutableColumns limit_columns;
+    int limit_columns_min = -1;
+    vectorized::PaddedPODArray<uint8_t> need_computes;
+    std::vector<uint8_t> cmp_res;
+    std::vector<int> order_directions;
+    std::vector<int> null_directions;
+
+    struct HeapLimitCursor {
+        HeapLimitCursor(int row_id, vectorized::MutableColumns& limit_columns,
+                        std::vector<int>& order_directions, std::vector<int>& 
null_directions)
+                : _row_id(row_id),
+                  _limit_columns(limit_columns),
+                  _order_directions(order_directions),
+                  _null_directions(null_directions) {}
+
+        HeapLimitCursor(const HeapLimitCursor& other) noexcept
+                : _row_id(other._row_id),
+                  _limit_columns(other._limit_columns),
+                  _order_directions(other._order_directions),
+                  _null_directions(other._null_directions) {}
+
+        HeapLimitCursor(HeapLimitCursor&& other) noexcept
+                : _row_id(other._row_id),
+                  _limit_columns(other._limit_columns),
+                  _order_directions(other._order_directions),
+                  _null_directions(other._null_directions) {}
+
+        HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept {
+            _row_id = other._row_id;
+            return *this;
+        }
+
+        HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept {
+            _row_id = other._row_id;
+            return *this;
+        }
+
+        bool operator<(const HeapLimitCursor& rhs) const {
+            for (int i = 0; i < _limit_columns.size(); ++i) {
+                const auto& _limit_column = _limit_columns[i];
+                auto res = _limit_column->compare_at(_row_id, rhs._row_id, 
*_limit_column,
+                                                     _null_directions[i]) *
+                           _order_directions[i];
+                if (res < 0) {
+                    return true;
+                } else if (res > 0) {
+                    return false;
+                }
+            }
+            return false;
+        }
+
+        int _row_id;
+        vectorized::MutableColumns& _limit_columns;
+        std::vector<int>& _order_directions;
+        std::vector<int>& _null_directions;
+    };
+
+    std::priority_queue<HeapLimitCursor> limit_heap;
 
 private:
+    vectorized::MutableColumns _get_keys_hash_table();
+
     void _close_with_serialized_key() {
         std::visit(vectorized::Overload {[&](std::monostate& arg) -> void {
                                              // Do nothing
diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp 
b/be/src/pipeline/exec/aggregation_sink_operator.cpp
index 79f5b5af083..a3ac73a5d85 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.cpp
+++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp
@@ -69,6 +69,7 @@ Status AggSinkLocalState::init(RuntimeState* state, 
LocalSinkStateInfo& info) {
     _serialize_data_timer = ADD_TIMER(Base::profile(), "SerializeDataTime");
     _deserialize_data_timer = ADD_TIMER(Base::profile(), 
"DeserializeAndMergeTime");
     _hash_table_compute_timer = ADD_TIMER(Base::profile(), 
"HashTableComputeTime");
+    _hash_table_limit_compute_timer = ADD_TIMER(Base::profile(), 
"DoLimitComputeTime");
     _hash_table_emplace_timer = ADD_TIMER(Base::profile(), 
"HashTableEmplaceTime");
     _hash_table_input_counter = ADD_COUNTER(Base::profile(), 
"HashTableInputCount", TUnit::UNIT);
     _max_row_size_counter = ADD_COUNTER(Base::profile(), "MaxRowSizeInBytes", 
TUnit::UNIT);
@@ -86,6 +87,11 @@ Status AggSinkLocalState::open(RuntimeState* state) {
     Base::_shared_state->offsets_of_aggregate_states = 
p._offsets_of_aggregate_states;
     Base::_shared_state->make_nullable_keys = p._make_nullable_keys;
     Base::_shared_state->probe_expr_ctxs.resize(p._probe_expr_ctxs.size());
+
+    Base::_shared_state->limit = p._limit;
+    Base::_shared_state->do_sort_limit = p._do_sort_limit;
+    Base::_shared_state->null_directions = p._null_directions;
+    Base::_shared_state->order_directions = p._order_directions;
     for (size_t i = 0; i < Base::_shared_state->probe_expr_ctxs.size(); i++) {
         RETURN_IF_ERROR(
                 p._probe_expr_ctxs[i]->clone(state, 
Base::_shared_state->probe_expr_ctxs[i]));
@@ -132,7 +138,6 @@ Status AggSinkLocalState::open(RuntimeState* state) {
 
         _should_limit_output = p._limit != -1 &&       // has limit
                                (!p._have_conjuncts) && // no having conjunct
-                               p._needs_finalize &&    // agg's finalize step
                                !Base::_shared_state->enable_spill;
     }
     for (auto& evaluator : p._aggregate_evaluators) {
@@ -183,7 +188,7 @@ Status 
AggSinkLocalState::_execute_without_key(vectorized::Block* block) {
 }
 
 Status AggSinkLocalState::_merge_with_serialized_key(vectorized::Block* block) 
{
-    if (_reach_limit) {
+    if (_shared_state->reach_limit) {
         return _merge_with_serialized_key_helper<true, false>(block);
     } else {
         return _merge_with_serialized_key_helper<false, false>(block);
@@ -260,12 +265,14 @@ Status 
AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b
 
     size_t key_size = Base::_shared_state->probe_expr_ctxs.size();
     vectorized::ColumnRawPtrs key_columns(key_size);
+    std::vector<int> key_locs(key_size);
 
     for (size_t i = 0; i < key_size; ++i) {
         if constexpr (for_spill) {
             key_columns[i] = block->get_by_position(i).column.get();
+            key_locs[i] = i;
         } else {
-            int result_column_id = -1;
+            int& result_column_id = key_locs[i];
             RETURN_IF_ERROR(
                     Base::_shared_state->probe_expr_ctxs[i]->execute(block, 
&result_column_id));
             block->replace_by_position_if_const(result_column_id);
@@ -278,7 +285,7 @@ Status 
AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b
         _places.resize(rows);
     }
 
-    if constexpr (limit) {
+    if (limit && !_shared_state->do_sort_limit) {
         _find_in_hash_table(_places.data(), key_columns, rows);
 
         for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); 
++i) {
@@ -318,52 +325,66 @@ Status 
AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b
             }
         }
     } else {
-        _emplace_into_hash_table(_places.data(), key_columns, rows);
+        bool need_do_agg = true;
+        if (limit) {
+            need_do_agg = _emplace_into_hash_table_limit(_places.data(), 
block, key_locs,
+                                                         key_columns, rows);
+        } else {
+            _emplace_into_hash_table(_places.data(), key_columns, rows);
+        }
 
-        for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); 
++i) {
-            if (Base::_shared_state->aggregate_evaluators[i]->is_merge() || 
for_spill) {
-                int col_id = 0;
-                if constexpr (for_spill) {
-                    col_id = Base::_shared_state->probe_expr_ctxs.size() + i;
+        if (need_do_agg) {
+            for (int i = 0; i < 
Base::_shared_state->aggregate_evaluators.size(); ++i) {
+                if (Base::_shared_state->aggregate_evaluators[i]->is_merge() 
|| for_spill) {
+                    int col_id = 0;
+                    if constexpr (for_spill) {
+                        col_id = Base::_shared_state->probe_expr_ctxs.size() + 
i;
+                    } else {
+                        col_id = AggSharedState::get_slot_column_id(
+                                Base::_shared_state->aggregate_evaluators[i]);
+                    }
+                    auto column = block->get_by_position(col_id).column;
+                    if (column->is_nullable()) {
+                        column = ((vectorized::ColumnNullable*)column.get())
+                                         ->get_nested_column_ptr();
+                    }
+
+                    size_t buffer_size = 
Base::_shared_state->aggregate_evaluators[i]
+                                                 ->function()
+                                                 ->size_of_data() *
+                                         rows;
+                    if (_deserialize_buffer.size() < buffer_size) {
+                        _deserialize_buffer.resize(buffer_size);
+                    }
+
+                    {
+                        SCOPED_TIMER(_deserialize_data_timer);
+                        Base::_shared_state->aggregate_evaluators[i]
+                                ->function()
+                                ->deserialize_and_merge_vec(
+                                        _places.data(),
+                                        Base::_parent->template 
cast<AggSinkOperatorX>()
+                                                
._offsets_of_aggregate_states[i],
+                                        _deserialize_buffer.data(), 
column.get(), _agg_arena_pool,
+                                        rows);
+                    }
                 } else {
-                    col_id = AggSharedState::get_slot_column_id(
-                            Base::_shared_state->aggregate_evaluators[i]);
-                }
-                auto column = block->get_by_position(col_id).column;
-                if (column->is_nullable()) {
-                    column = 
((vectorized::ColumnNullable*)column.get())->get_nested_column_ptr();
-                }
-
-                size_t buffer_size =
-                        
Base::_shared_state->aggregate_evaluators[i]->function()->size_of_data() *
-                        rows;
-                if (_deserialize_buffer.size() < buffer_size) {
-                    _deserialize_buffer.resize(buffer_size);
-                }
-
-                {
-                    SCOPED_TIMER(_deserialize_data_timer);
-                    Base::_shared_state->aggregate_evaluators[i]
-                            ->function()
-                            ->deserialize_and_merge_vec(
-                                    _places.data(),
-                                    Base::_parent->template 
cast<AggSinkOperatorX>()
-                                            ._offsets_of_aggregate_states[i],
-                                    _deserialize_buffer.data(), column.get(), 
_agg_arena_pool,
-                                    rows);
+                    
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
+                            block,
+                            Base::_parent->template cast<AggSinkOperatorX>()
+                                    ._offsets_of_aggregate_states[i],
+                            _places.data(), _agg_arena_pool));
                 }
-            } else {
-                
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
-                        block,
-                        Base::_parent->template cast<AggSinkOperatorX>()
-                                ._offsets_of_aggregate_states[i],
-                        _places.data(), _agg_arena_pool));
             }
         }
 
-        if (_should_limit_output) {
-            _reach_limit = _get_hash_table_size() >=
-                           Base::_parent->template 
cast<AggSinkOperatorX>()._limit;
+        if (!limit && _should_limit_output) {
+            const size_t hash_table_size = _get_hash_table_size();
+            _shared_state->reach_limit =
+                    hash_table_size >= Base::_parent->template 
cast<AggSinkOperatorX>()._limit;
+            if (_shared_state->do_sort_limit && _shared_state->reach_limit) {
+                _shared_state->build_limit_heap(hash_table_size);
+            }
         }
     }
 
@@ -410,7 +431,7 @@ void AggSinkLocalState::_update_memusage_without_key() {
 }
 
 Status AggSinkLocalState::_execute_with_serialized_key(vectorized::Block* 
block) {
-    if (_reach_limit) {
+    if (_shared_state->reach_limit) {
         return _execute_with_serialized_key_helper<true>(block);
     } else {
         return _execute_with_serialized_key_helper<false>(block);
@@ -424,10 +445,11 @@ Status 
AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block*
 
     size_t key_size = Base::_shared_state->probe_expr_ctxs.size();
     vectorized::ColumnRawPtrs key_columns(key_size);
+    std::vector<int> key_locs(key_size);
     {
         SCOPED_TIMER(_expr_timer);
         for (size_t i = 0; i < key_size; ++i) {
-            int result_column_id = -1;
+            int& result_column_id = key_locs[i];
             RETURN_IF_ERROR(
                     Base::_shared_state->probe_expr_ctxs[i]->execute(block, 
&result_column_id));
             block->get_by_position(result_column_id).column =
@@ -442,7 +464,7 @@ Status 
AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block*
         _places.resize(rows);
     }
 
-    if constexpr (limit) {
+    if (limit && !_shared_state->do_sort_limit) {
         _find_in_hash_table(_places.data(), key_columns, rows);
 
         for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); 
++i) {
@@ -454,27 +476,48 @@ Status 
AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block*
                             _places.data(), _agg_arena_pool));
         }
     } else {
-        _emplace_into_hash_table(_places.data(), key_columns, rows);
-
-        for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); 
++i) {
-            
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
-                    block,
-                    Base::_parent->template cast<AggSinkOperatorX>()
-                            ._offsets_of_aggregate_states[i],
-                    _places.data(), _agg_arena_pool));
-        }
+        auto do_aggregate_evaluators = [&] {
+            for (int i = 0; i < 
Base::_shared_state->aggregate_evaluators.size(); ++i) {
+                
RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add(
+                        block,
+                        Base::_parent->template cast<AggSinkOperatorX>()
+                                ._offsets_of_aggregate_states[i],
+                        _places.data(), _agg_arena_pool));
+            }
+            return Status::OK();
+        };
 
-        if (_should_limit_output) {
-            _reach_limit = _get_hash_table_size() >=
-                           Base::_parent->template 
cast<AggSinkOperatorX>()._limit;
-            if (_reach_limit &&
-                Base::_parent->template 
cast<AggSinkOperatorX>()._can_short_circuit) {
-                Base::_dependency->set_ready_to_read();
-                return Status::Error<ErrorCode::END_OF_FILE>("");
+        if constexpr (limit) {
+            if (_emplace_into_hash_table_limit(_places.data(), block, 
key_locs, key_columns,
+                                               rows)) {
+                RETURN_IF_ERROR(do_aggregate_evaluators());
+            }
+        } else {
+            _emplace_into_hash_table(_places.data(), key_columns, rows);
+            RETURN_IF_ERROR(do_aggregate_evaluators());
+
+            if (_should_limit_output && !Base::_shared_state->enable_spill) {
+                const size_t hash_table_size = _get_hash_table_size();
+                if (Base::_parent->template 
cast<AggSinkOperatorX>()._can_short_circuit) {
+                    _shared_state->reach_limit =
+                            hash_table_size >=
+                            Base::_parent->template 
cast<AggSinkOperatorX>()._limit;
+                    if (_shared_state->reach_limit) {
+                        Base::_dependency->set_ready_to_read();
+                        return Status::Error<ErrorCode::END_OF_FILE>("");
+                    }
+                } else {
+                    _shared_state->reach_limit =
+                            hash_table_size >= _shared_state->do_sort_limit
+                                    ? Base::_parent->template 
cast<AggSinkOperatorX>()._limit * 5
+                                    : Base::_parent->template 
cast<AggSinkOperatorX>()._limit;
+                    if (_shared_state->reach_limit && 
_shared_state->do_sort_limit) {
+                        _shared_state->build_limit_heap(hash_table_size);
+                    }
+                }
             }
         }
     }
-
     return Status::OK();
 }
 
@@ -535,6 +578,108 @@ void 
AggSinkLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* p
                _agg_data->method_variant);
 }
 
+bool 
AggSinkLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* 
places,
+                                                       vectorized::Block* 
block,
+                                                       const std::vector<int>& 
key_locs,
+                                                       
vectorized::ColumnRawPtrs& key_columns,
+                                                       size_t num_rows) {
+    return std::visit(
+            vectorized::Overload {
+                    [&](std::monostate& arg) {
+                        throw doris::Exception(ErrorCode::INTERNAL_ERROR, 
"uninited hash table");
+                        return true;
+                    },
+                    [&](auto&& agg_method) -> bool {
+                        SCOPED_TIMER(_hash_table_compute_timer);
+                        using HashMethodType = 
std::decay_t<decltype(agg_method)>;
+                        using AggState = typename HashMethodType::State;
+
+                        bool need_filter = false;
+                        {
+                            SCOPED_TIMER(_hash_table_limit_compute_timer);
+                            need_filter = 
_shared_state->do_limit_filter(block, num_rows);
+                        }
+
+                        auto& need_computes = _shared_state->need_computes;
+                        if (auto need_agg =
+                                    std::find(need_computes.begin(), 
need_computes.end(), 1);
+                            need_agg != need_computes.end()) {
+                            if (need_filter) {
+                                
vectorized::Block::filter_block_internal(block, need_computes);
+                                for (int i = 0; i < key_locs.size(); ++i) {
+                                    key_columns[i] =
+                                            
block->get_by_position(key_locs[i]).column.get();
+                                }
+                                num_rows = block->rows();
+                            }
+
+                            AggState state(key_columns);
+                            agg_method.init_serialized_keys(key_columns, 
num_rows);
+                            size_t i = 0;
+
+                            auto refresh_top_limit = [&, this]() {
+                                _shared_state->limit_heap.pop();
+                                for (int j = 0; j < key_columns.size(); ++j) {
+                                    
_shared_state->limit_columns[j]->insert_from(*key_columns[j],
+                                                                               
  i);
+                                }
+                                _shared_state->limit_heap.emplace(
+                                        
_shared_state->limit_columns[0]->size() - 1,
+                                        _shared_state->limit_columns,
+                                        _shared_state->order_directions,
+                                        _shared_state->null_directions);
+                                _shared_state->limit_columns_min =
+                                        
_shared_state->limit_heap.top()._row_id;
+                            };
+
+                            auto creator = [this, refresh_top_limit](const 
auto& ctor, auto& key,
+                                                                     auto& 
origin) {
+                                try {
+                                    
HashMethodType::try_presis_key_and_origin(key, origin,
+                                                                              
*_agg_arena_pool);
+                                    auto mapped =
+                                            
_shared_state->aggregate_data_container->append_data(
+                                                    origin);
+                                    auto st = _create_agg_status(mapped);
+                                    if (!st) {
+                                        throw Exception(st.code(), 
st.to_string());
+                                    }
+                                    ctor(key, mapped);
+                                    refresh_top_limit();
+                                } catch (...) {
+                                    // Exception-safety - if it can not 
allocate memory or create status,
+                                    // the destructors will not be called.
+                                    ctor(key, nullptr);
+                                    throw;
+                                }
+                            };
+
+                            auto creator_for_null_key = [this, 
refresh_top_limit](auto& mapped) {
+                                mapped = _agg_arena_pool->aligned_alloc(
+                                        Base::_parent->template 
cast<AggSinkOperatorX>()
+                                                
._total_size_of_aggregate_states,
+                                        Base::_parent->template 
cast<AggSinkOperatorX>()
+                                                ._align_aggregate_states);
+                                auto st = _create_agg_status(mapped);
+                                if (!st) {
+                                    throw Exception(st.code(), st.to_string());
+                                }
+                                refresh_top_limit();
+                            };
+
+                            SCOPED_TIMER(_hash_table_emplace_timer);
+                            for (i = 0; i < num_rows; ++i) {
+                                places[i] = agg_method.lazy_emplace(state, i, 
creator,
+                                                                    
creator_for_null_key);
+                            }
+                            COUNTER_UPDATE(_hash_table_input_counter, 
num_rows);
+                            return true;
+                        }
+                        return false;
+                    }},
+            _agg_data->method_variant);
+}
+
 void AggSinkLocalState::_find_in_hash_table(vectorized::AggregateDataPtr* 
places,
                                             vectorized::ColumnRawPtrs& 
key_columns,
                                             size_t num_rows) {
@@ -616,6 +761,21 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode, 
RuntimeState* state) {
     _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
                             [](const auto& e) { return 
e.nodes[0].agg_expr.is_merge_agg; });
 
+    if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
+        _do_sort_limit = true;
+        const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
+        DCHECK_EQ(agg_sort_info.nulls_first.size(), 
agg_sort_info.is_asc_order.size());
+
+        const int order_by_key_size = agg_sort_info.is_asc_order.size();
+        _order_directions.resize(order_by_key_size);
+        _null_directions.resize(order_by_key_size);
+        for (int i = 0; i < order_by_key_size; ++i) {
+            _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
+            _null_directions[i] =
+                    agg_sort_info.nulls_first[i] ? -_order_directions[i] : 
_order_directions[i];
+        }
+    }
+
     return Status::OK();
 }
 
diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h 
b/be/src/pipeline/exec/aggregation_sink_operator.h
index d48debc2d83..39fee1707e4 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.h
+++ b/be/src/pipeline/exec/aggregation_sink_operator.h
@@ -85,6 +85,9 @@ protected:
                              vectorized::ColumnRawPtrs& key_columns, size_t 
num_rows);
     void _emplace_into_hash_table(vectorized::AggregateDataPtr* places,
                                   vectorized::ColumnRawPtrs& key_columns, 
size_t num_rows);
+    bool _emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places,
+                                        vectorized::Block* block, const 
std::vector<int>& key_locs,
+                                        vectorized::ColumnRawPtrs& 
key_columns, size_t num_rows);
     size_t _get_hash_table_size() const;
 
     template <bool limit, bool for_spill = false>
@@ -96,6 +99,7 @@ protected:
 
     RuntimeProfile::Counter* _hash_table_compute_timer = nullptr;
     RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr;
+    RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr;
     RuntimeProfile::Counter* _hash_table_input_counter = nullptr;
     RuntimeProfile::Counter* _build_timer = nullptr;
     RuntimeProfile::Counter* _expr_timer = nullptr;
@@ -109,7 +113,6 @@ protected:
     RuntimeProfile::HighWaterMarkCounter* _serialize_key_arena_memory_usage = 
nullptr;
 
     bool _should_limit_output = false;
-    bool _reach_limit = false;
 
     vectorized::PODArray<vectorized::AggregateDataPtr> _places;
     std::vector<char> _deserialize_buffer;
@@ -191,8 +194,12 @@ protected:
     ObjectPool* _pool = nullptr;
     std::vector<size_t> _make_nullable_keys;
     int64_t _limit; // -1: no limit
-    bool _have_conjuncts;
+    // do sort limit and directions
+    bool _do_sort_limit = false;
+    std::vector<int> _order_directions;
+    std::vector<int> _null_directions;
 
+    bool _have_conjuncts;
     const std::vector<TExpr> _partition_exprs;
     const bool _is_colocate;
 
diff --git a/be/src/pipeline/exec/aggregation_source_operator.cpp 
b/be/src/pipeline/exec/aggregation_source_operator.cpp
index b94f076bdbf..cca9fefbdb2 100644
--- a/be/src/pipeline/exec/aggregation_source_operator.cpp
+++ b/be/src/pipeline/exec/aggregation_source_operator.cpp
@@ -22,7 +22,6 @@
 
 #include "common/exception.h"
 #include "pipeline/exec/operator.h"
-#include "vec//utils/util.hpp"
 
 namespace doris::pipeline {
 
@@ -444,10 +443,19 @@ Status AggSourceOperatorX::get_block(RuntimeState* state, 
vectorized::Block* blo
     local_state.make_nullable_output_key(block);
     // dispose the having clause, should not be execute in prestreaming agg
     RETURN_IF_ERROR(vectorized::VExprContext::filter_block(_conjuncts, block, 
block->columns()));
-    local_state.reached_limit(block, eos);
+    local_state.do_agg_limit(block, eos);
     return Status::OK();
 }
 
+void AggLocalState::do_agg_limit(vectorized::Block* block, bool* eos) {
+    if (_shared_state->reach_limit) {
+        if (_shared_state->do_sort_limit && 
_shared_state->do_limit_filter(block, block->rows())) {
+            vectorized::Block::filter_block_internal(block, 
_shared_state->need_computes);
+        }
+        reached_limit(block, eos);
+    }
+}
+
 void AggLocalState::make_nullable_output_key(vectorized::Block* block) {
     if (block->rows() != 0) {
         for (auto cid : _shared_state->make_nullable_keys) {
diff --git a/be/src/pipeline/exec/aggregation_source_operator.h 
b/be/src/pipeline/exec/aggregation_source_operator.h
index c4ea6c6ccde..a3824a381eb 100644
--- a/be/src/pipeline/exec/aggregation_source_operator.h
+++ b/be/src/pipeline/exec/aggregation_source_operator.h
@@ -41,6 +41,7 @@ public:
     void make_nullable_output_key(vectorized::Block* block);
     template <bool limit>
     Status merge_with_serialized_key_helper(vectorized::Block* block);
+    void do_agg_limit(vectorized::Block* block, bool* eos);
 
 protected:
     friend class AggSourceOperatorX;
diff --git a/be/src/pipeline/exec/operator.cpp 
b/be/src/pipeline/exec/operator.cpp
index 938eb22f253..455f11fa9f1 100644
--- a/be/src/pipeline/exec/operator.cpp
+++ b/be/src/pipeline/exec/operator.cpp
@@ -389,8 +389,7 @@ std::shared_ptr<BasicSharedState> 
DataSinkOperatorX<LocalStateType>::create_shar
         LOG(FATAL) << "should not reach here!";
         return nullptr;
     } else {
-        std::shared_ptr<BasicSharedState> ss = nullptr;
-        ss = LocalStateType::SharedStateType::create_shared();
+        auto ss = LocalStateType::SharedStateType::create_shared();
         ss->id = operator_id();
         for (auto& dest : dests_id()) {
             ss->related_op_ids.insert(dest);
diff --git a/be/src/vec/columns/column_nullable.cpp 
b/be/src/vec/columns/column_nullable.cpp
index 6efa690d7db..c516b96b72f 100644
--- a/be/src/vec/columns/column_nullable.cpp
+++ b/be/src/vec/columns/column_nullable.cpp
@@ -422,6 +422,7 @@ int ColumnNullable::compare_at(size_t n, size_t m, const 
IColumn& rhs_,
     return get_nested_column().compare_at(n, m, 
nullable_rhs.get_nested_column(),
                                           null_direction_hint);
 }
+
 void ColumnNullable::compare_internal(size_t rhs_row_id, const IColumn& rhs, 
int nan_direction_hint,
                                       int direction, std::vector<uint8>& 
cmp_res,
                                       uint8* __restrict filter) const {
diff --git a/be/src/vec/columns/column_string.cpp 
b/be/src/vec/columns/column_string.cpp
index 446fd283b1c..919854a42d9 100644
--- a/be/src/vec/columns/column_string.cpp
+++ b/be/src/vec/columns/column_string.cpp
@@ -544,7 +544,7 @@ template <typename T>
 void ColumnStr<T>::compare_internal(size_t rhs_row_id, const IColumn& rhs, int 
nan_direction_hint,
                                     int direction, std::vector<uint8>& cmp_res,
                                     uint8* __restrict filter) const {
-    auto sz = this->size();
+    auto sz = offsets.size();
     DCHECK(cmp_res.size() == sz);
     const auto& cmp_base = assert_cast<const 
ColumnStr<T>&>(rhs).get_data_at(rhs_row_id);
     size_t begin = simd::find_zero(cmp_res, 0);
@@ -554,12 +554,8 @@ void ColumnStr<T>::compare_internal(size_t rhs_row_id, 
const IColumn& rhs, int n
             auto value_a = get_data_at(row_id);
             int res = memcmp_small_allow_overflow15(value_a.data, 
value_a.size, cmp_base.data,
                                                     cmp_base.size);
-            if (res * direction < 0) {
-                filter[row_id] = 1;
-                cmp_res[row_id] = 1;
-            } else if (res * direction > 0) {
-                cmp_res[row_id] = 1;
-            }
+            cmp_res[row_id] = res != 0;
+            filter[row_id] = res * direction < 0;
         }
         begin = simd::find_zero(cmp_res, end + 1);
     }
diff --git a/be/src/vec/columns/column_string.h 
b/be/src/vec/columns/column_string.h
index d0994607a46..22dcd612d3a 100644
--- a/be/src/vec/columns/column_string.h
+++ b/be/src/vec/columns/column_string.h
@@ -549,6 +549,7 @@ public:
     void compare_internal(size_t rhs_row_id, const IColumn& rhs, int 
nan_direction_hint,
                           int direction, std::vector<uint8>& cmp_res,
                           uint8* __restrict filter) const override;
+
     MutableColumnPtr get_shinked_column() const {
         auto shrinked_column = ColumnStr<T>::create();
         for (int i = 0; i < size(); i++) {
diff --git a/be/src/vec/columns/column_vector.cpp 
b/be/src/vec/columns/column_vector.cpp
index 60a75420405..14d52045943 100644
--- a/be/src/vec/columns/column_vector.cpp
+++ b/be/src/vec/columns/column_vector.cpp
@@ -141,21 +141,17 @@ void ColumnVector<T>::compare_internal(size_t rhs_row_id, 
const IColumn& rhs,
                                        int nan_direction_hint, int direction,
                                        std::vector<uint8>& cmp_res,
                                        uint8* __restrict filter) const {
-    auto sz = this->size();
+    const auto sz = data.size();
     DCHECK(cmp_res.size() == sz);
     const auto& cmp_base = assert_cast<const 
ColumnVector<T>&>(rhs).get_data()[rhs_row_id];
     size_t begin = simd::find_zero(cmp_res, 0);
     while (begin < sz) {
         size_t end = simd::find_one(cmp_res, begin + 1);
         for (size_t row_id = begin; row_id < end; row_id++) {
-            auto value_a = get_data()[row_id];
+            auto value_a = data[row_id];
             int res = value_a > cmp_base ? 1 : (value_a < cmp_base ? -1 : 0);
-            if (res * direction < 0) {
-                filter[row_id] = 1;
-                cmp_res[row_id] = 1;
-            } else if (res * direction > 0) {
-                cmp_res[row_id] = 1;
-            }
+            cmp_res[row_id] = (res != 0);
+            filter[row_id] = (res * direction < 0);
         }
         begin = simd::find_zero(cmp_res, end + 1);
     }
diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp
index 7595ffb6620..95af060dfc7 100644
--- a/be/src/vec/core/block.cpp
+++ b/be/src/vec/core/block.cpp
@@ -797,6 +797,19 @@ void Block::filter_block_internal(Block* block, const 
IColumn::Filter& filter,
     filter_block_internal(block, columns_to_filter, filter);
 }
 
+void Block::filter_block_internal(Block* block, const IColumn::Filter& filter) 
{
+    const size_t count =
+            filter.size() - simd::count_zero_num((int8_t*)filter.data(), 
filter.size());
+    for (int i = 0; i < block->columns(); ++i) {
+        auto& column = block->get_by_position(i).column;
+        if (column->is_exclusive()) {
+            column->assume_mutable()->filter(filter);
+        } else {
+            column = column->filter(filter, count);
+        }
+    }
+}
+
 Block Block::copy_block(const std::vector<int>& column_offset) const {
     ColumnsWithTypeAndName columns_with_type_and_name;
     for (auto offset : column_offset) {
diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h
index 593d37f7ff2..3611252ea59 100644
--- a/be/src/vec/core/block.h
+++ b/be/src/vec/core/block.h
@@ -281,10 +281,11 @@ public:
     // need exception safety
     static void filter_block_internal(Block* block, const 
std::vector<uint32_t>& columns_to_filter,
                                       const IColumn::Filter& filter);
-
     // need exception safety
     static void filter_block_internal(Block* block, const IColumn::Filter& 
filter,
                                       uint32_t column_to_keep);
+    // need exception safety
+    static void filter_block_internal(Block* block, const IColumn::Filter& 
filter);
 
     static Status filter_block(Block* block, const std::vector<uint32_t>& 
columns_to_filter,
                                int filter_column_id, int column_to_keep);
diff --git a/be/src/vec/exec/vaggregation_node.cpp 
b/be/src/vec/exec/vaggregation_node.cpp
index 1845382a2b4..f009802d5dd 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -158,6 +158,21 @@ Status AggregationNode::init(const TPlanNode& tnode, 
RuntimeState* state) {
 
     _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
                             [](const auto& e) { return 
e.nodes[0].agg_expr.is_merge_agg; });
+
+    if (tnode.agg_node.__isset.agg_sort_info_by_group_key) {
+        _do_sort_limit = true;
+        const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key;
+        DCHECK_EQ(agg_sort_info.nulls_first.size(), 
agg_sort_info.is_asc_order.size());
+
+        const int order_by_key_size = agg_sort_info.is_asc_order.size();
+        _order_directions.resize(order_by_key_size);
+        _null_directions.resize(order_by_key_size);
+        for (int i = 0; i < order_by_key_size; ++i) {
+            _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1;
+            _null_directions[i] =
+                    agg_sort_info.nulls_first[i] ? -_order_directions[i] : 
_order_directions[i];
+        }
+    }
     return Status::OK();
 }
 
@@ -183,6 +198,7 @@ Status AggregationNode::prepare_profile(RuntimeState* 
state) {
     _deserialize_data_timer = ADD_TIMER(runtime_profile(), 
"DeserializeAndMergeTime");
     _hash_table_compute_timer = ADD_TIMER(runtime_profile(), 
"HashTableComputeTime");
     _hash_table_emplace_timer = ADD_TIMER(runtime_profile(), 
"HashTableEmplaceTime");
+    _hash_table_limit_compute_timer = ADD_TIMER(runtime_profile(), 
"DoLimitComputeTime");
     _hash_table_iterate_timer = ADD_TIMER(runtime_profile(), 
"HashTableIterateTime");
     _insert_keys_to_column_timer = ADD_TIMER(runtime_profile(), 
"InsertKeysToColumnTime");
     _streaming_agg_timer = ADD_TIMER(runtime_profile(), "StreamingAggTime");
@@ -315,9 +331,8 @@ Status AggregationNode::prepare_profile(RuntimeState* 
state) {
                 
std::bind<void>(&AggregationNode::_update_memusage_with_serialized_key, this);
         _executor.close = 
std::bind<void>(&AggregationNode::_close_with_serialized_key, this);
 
-        _should_limit_output = _limit != -1 &&       // has limit
-                               _conjuncts.empty() && // no having conjunct
-                               _needs_finalize;      // agg's finalize step
+        _should_limit_output = _limit != -1 && // has limit
+                               _conjuncts.empty();
     }
 
     fmt::memory_buffer msg;
@@ -436,8 +451,12 @@ Status AggregationNode::pull(doris::RuntimeState* state, 
vectorized::Block* bloc
     _make_nullable_output_key(block);
     // dispose the having clause, should not be execute in prestreaming agg
     RETURN_IF_ERROR(VExprContext::filter_block(_conjuncts, block, 
block->columns()));
-    reached_limit(block, eos);
-
+    if (_reach_limit) {
+        if (_do_sort_limit && _do_limit_filter(block, 
_order_directions.size(), block->rows())) {
+            Block::filter_block_internal(block, _need_computes);
+        }
+        reached_limit(block, eos);
+    }
     return Status::OK();
 }
 
@@ -775,6 +794,158 @@ size_t AggregationNode::_get_hash_table_size() {
                       _agg_data->method_variant);
 }
 
+template <bool limit>
+Status AggregationNode::_execute_with_serialized_key_helper(Block* block) {
+    DCHECK(!_probe_expr_ctxs.empty());
+
+    size_t key_size = _probe_expr_ctxs.size();
+    ColumnRawPtrs key_columns(key_size);
+    std::vector<int> key_locs(key_size);
+    {
+        SCOPED_TIMER(_expr_timer);
+        for (size_t i = 0; i < key_size; ++i) {
+            auto& result_column_id = key_locs[i];
+            RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
+            block->get_by_position(result_column_id).column =
+                    block->get_by_position(result_column_id)
+                            .column->convert_to_full_column_if_const();
+            key_columns[i] = 
block->get_by_position(result_column_id).column.get();
+        }
+    }
+
+    int rows = block->rows();
+    if (_places.size() < rows) {
+        _places.resize(rows);
+    }
+
+    if constexpr (limit) {
+        if (_emplace_into_hash_table_limit(_places.data(), block, key_locs, 
key_columns, rows)) {
+            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+                RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
+                        block, _offsets_of_aggregate_states[i], _places.data(),
+                        _agg_arena_pool.get()));
+            }
+        }
+    } else {
+        _emplace_into_hash_table(_places.data(), key_columns, rows);
+
+        for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+            RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
+                    block, _offsets_of_aggregate_states[i], _places.data(), 
_agg_arena_pool.get()));
+        }
+
+        if (_should_limit_output && !_reach_limit) {
+            auto size = _get_hash_table_size();
+            _reach_limit = size >= _limit * 5;
+            if (_reach_limit) {
+                _build_limit_heap(size);
+            }
+        }
+    }
+
+    return Status::OK();
+}
+
+void AggregationNode::_build_limit_heap(size_t hash_table_size) {
+    _limit_columns = _get_keys_hash_table();
+    for (size_t i = 0; i < hash_table_size; ++i) {
+        _limit_heap.emplace(i, _limit_columns, _order_directions, 
_null_directions);
+    }
+    while (hash_table_size > _limit) {
+        _limit_heap.pop();
+        hash_table_size--;
+    }
+    _limit_columns_min = _limit_heap.top()._row_id;
+}
+
+bool AggregationNode::_emplace_into_hash_table_limit(AggregateDataPtr* places, 
Block* block,
+                                                     const std::vector<int>& 
key_locs,
+                                                     ColumnRawPtrs& 
key_columns, size_t num_rows) {
+    return std::visit(
+            Overload {[&](std::monostate& arg) {
+                          throw doris::Exception(ErrorCode::INTERNAL_ERROR, 
"uninited hash table");
+                          return true;
+                      },
+                      [&](auto&& agg_method) -> bool {
+                          SCOPED_TIMER(_hash_table_compute_timer);
+                          using HashMethodType = 
std::decay_t<decltype(agg_method)>;
+                          using AggState = typename HashMethodType::State;
+
+                          bool need_filter = false;
+                          {
+                              SCOPED_TIMER(_hash_table_limit_compute_timer);
+                              need_filter = _do_limit_filter(block, 
key_columns.size(), num_rows);
+                          }
+
+                          if (auto need_agg =
+                                      std::find(_need_computes.begin(), 
_need_computes.end(), 1);
+                              need_agg != _need_computes.end()) {
+                              if (need_filter) {
+                                  Block::filter_block_internal(block, 
_need_computes);
+                                  for (int i = 0; i < key_locs.size(); ++i) {
+                                      key_columns[i] =
+                                              
block->get_by_position(key_locs[i]).column.get();
+                                  }
+                                  num_rows = block->rows();
+                              }
+
+                              AggState state(key_columns);
+                              agg_method.init_serialized_keys(key_columns, 
num_rows);
+                              size_t i = 0;
+
+                              auto refresh_top_limit = [&, this] {
+                                  _limit_heap.pop();
+                                  for (int j = 0; j < key_columns.size(); ++j) 
{
+                                      
_limit_columns[j]->insert_from(*key_columns[j], i);
+                                  }
+                                  
_limit_heap.emplace(_limit_columns[0]->size() - 1, _limit_columns,
+                                                      _order_directions, 
_null_directions);
+                                  _limit_columns_min = 
_limit_heap.top()._row_id;
+                              };
+
+                              auto creator = [this, refresh_top_limit](const 
auto& ctor, auto& key,
+                                                                       auto& 
origin) {
+                                  try {
+                                      
HashMethodType::try_presis_key_and_origin(key, origin,
+                                                                               
 *_agg_arena_pool);
+                                      auto mapped = 
_aggregate_data_container->append_data(origin);
+                                      auto st = _create_agg_status(mapped);
+                                      if (!st) {
+                                          throw Exception(st.code(), 
st.to_string());
+                                      }
+                                      ctor(key, mapped);
+                                      refresh_top_limit();
+                                  } catch (...) {
+                                      // Exception-safety - if it can not 
allocate memory or create status,
+                                      // the destructors will not be called.
+                                      ctor(key, nullptr);
+                                      throw;
+                                  }
+                              };
+
+                              auto creator_for_null_key = [this, 
refresh_top_limit](auto& mapped) {
+                                  mapped = _agg_arena_pool->aligned_alloc(
+                                          _total_size_of_aggregate_states, 
_align_aggregate_states);
+                                  auto st = _create_agg_status(mapped);
+                                  if (!st) {
+                                      throw Exception(st.code(), 
st.to_string());
+                                  }
+                                  refresh_top_limit();
+                              };
+
+                              SCOPED_TIMER(_hash_table_emplace_timer);
+                              for (i = 0; i < num_rows; ++i) {
+                                  places[i] = agg_method.lazy_emplace(state, 
i, creator,
+                                                                      
creator_for_null_key);
+                              }
+                              COUNTER_UPDATE(_hash_table_input_counter, 
num_rows);
+                              return true;
+                          }
+                          return false;
+                      }},
+            _agg_data->method_variant);
+}
+
 void AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, 
ColumnRawPtrs& key_columns,
                                                const size_t num_rows) {
     std::visit(Overload {[&](std::monostate& arg) {
@@ -785,6 +956,7 @@ void 
AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, ColumnR
                              SCOPED_TIMER(_hash_table_compute_timer);
                              using HashMethodType = 
std::decay_t<decltype(agg_method)>;
                              using AggState = typename HashMethodType::State;
+
                              AggState state(key_columns);
                              agg_method.init_serialized_keys(key_columns, 
num_rows);
 
@@ -850,6 +1022,42 @@ void 
AggregationNode::_find_in_hash_table(AggregateDataPtr* places, ColumnRawPtr
                _agg_data->method_variant);
 }
 
+MutableColumns AggregationNode::_get_keys_hash_table() {
+    return std::visit(
+            Overload {[&](std::monostate& arg) {
+                          throw doris::Exception(ErrorCode::INTERNAL_ERROR, 
"uninited hash table");
+                          return MutableColumns();
+                      },
+                      [&](auto&& agg_method) -> MutableColumns {
+                          MutableColumns key_columns;
+                          for (int i = 0; i < _probe_expr_ctxs.size(); ++i) {
+                              key_columns.emplace_back(
+                                      
_probe_expr_ctxs[i]->root()->data_type()->create_column());
+                          }
+                          auto& data = *agg_method.hash_table;
+                          bool has_null_key = data.has_null_key_data();
+                          const auto size = data.size() - has_null_key;
+                          using KeyType = 
std::decay_t<decltype(agg_method.iterator->get_first())>;
+                          std::vector<KeyType> keys(size);
+
+                          size_t num_rows = 0;
+                          auto iter = _aggregate_data_container->begin();
+                          {
+                              while (iter != _aggregate_data_container->end()) 
{
+                                  keys[num_rows] = iter.get_key<KeyType>();
+                                  ++iter;
+                                  ++num_rows;
+                              }
+                          }
+                          agg_method.insert_keys_into_columns(keys, 
key_columns, num_rows);
+                          if (has_null_key) {
+                              key_columns[0]->insert_data(nullptr, 0);
+                          }
+                          return key_columns;
+                      }},
+            _agg_data->method_variant);
+}
+
 Status AggregationNode::_pre_agg_with_serialized_key(doris::vectorized::Block* 
in_block,
                                                      doris::vectorized::Block* 
out_block) {
     DCHECK(!_probe_expr_ctxs.empty());
@@ -1455,14 +1663,37 @@ Status 
AggregationNode::_serialize_with_serialized_key_result_non_spill(RuntimeS
                                              
_probe_expr_ctxs[i]->root()->data_type(),
                                              
_probe_expr_ctxs[i]->root()->expr_name());
         }
+
         for (int i = 0; i < agg_size; ++i) {
             columns_with_schema.emplace_back(std::move(value_columns[i]), 
value_data_types[i], "");
+            *block = Block(columns_with_schema);
         }
-        *block = Block(columns_with_schema);
     }
+
     return Status::OK();
 }
 
+bool AggregationNode::_do_limit_filter(Block* block, int key_size, size_t 
num_rows) {
+    if (num_rows) {
+        _cmp_res.resize(num_rows);
+        _need_computes.resize(num_rows);
+        memset(_need_computes.data(), 0, _need_computes.size());
+        memset(_cmp_res.data(), 0, _cmp_res.size());
+
+        for (int i = 0; i < key_size; i++) {
+            block->get_by_position(i).column->compare_internal(
+                    _limit_columns_min, *_limit_columns[i], 
_null_directions[i],
+                    _order_directions[i], _cmp_res, _need_computes.data());
+        }
+
+        for (int i = 0; i < num_rows; ++i) {
+            _need_computes[i] = _need_computes[i] == _cmp_res[i];
+        }
+        return std::find(_need_computes.begin(), _need_computes.end(), 0) != 
_need_computes.end();
+    }
+    return false;
+}
+
 Status AggregationNode::_merge_with_serialized_key(Block* block) {
     if (_reach_limit) {
         return _merge_with_serialized_key_helper<true, false>(block);
diff --git a/be/src/vec/exec/vaggregation_node.h 
b/be/src/vec/exec/vaggregation_node.h
index 70222f9ccf7..cd7aedebead 100644
--- a/be/src/vec/exec/vaggregation_node.h
+++ b/be/src/vec/exec/vaggregation_node.h
@@ -438,6 +438,7 @@ protected:
     // nullable diff. so we need make nullable of it.
     std::vector<size_t> _make_nullable_keys;
     RuntimeProfile::Counter* _hash_table_compute_timer = nullptr;
+    RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr;
     RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr;
     RuntimeProfile::Counter* _hash_table_input_counter = nullptr;
     RuntimeProfile::Counter* _expr_timer = nullptr;
@@ -491,8 +492,70 @@ private:
     RuntimeProfile::HighWaterMarkCounter* _serialize_key_arena_memory_usage = 
nullptr;
 
     bool _should_expand_hash_table = true;
+
     bool _should_limit_output = false;
     bool _reach_limit = false;
+    bool _do_sort_limit = false;
+    MutableColumns _limit_columns;
+    int _limit_columns_min = -1;
+    PaddedPODArray<uint8_t> _need_computes;
+    std::vector<uint8_t> _cmp_res;
+    std::vector<int> _order_directions;
+    std::vector<int> _null_directions;
+
+    struct HeapLimitCursor {
+        HeapLimitCursor(int row_id, MutableColumns& limit_columns,
+                        std::vector<int>& order_directions, std::vector<int>& 
null_directions)
+                : _row_id(row_id),
+                  _limit_columns(limit_columns),
+                  _order_directions(order_directions),
+                  _null_directions(null_directions) {}
+
+        HeapLimitCursor(const HeapLimitCursor& other) noexcept
+                : _row_id(other._row_id),
+                  _limit_columns(other._limit_columns),
+                  _order_directions(other._order_directions),
+                  _null_directions(other._null_directions) {}
+
+        HeapLimitCursor(HeapLimitCursor&& other) noexcept
+                : _row_id(other._row_id),
+                  _limit_columns(other._limit_columns),
+                  _order_directions(other._order_directions),
+                  _null_directions(other._null_directions) {}
+
+        HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept {
+            _row_id = other._row_id;
+            return *this;
+        }
+
+        HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept {
+            _row_id = other._row_id;
+            return *this;
+        }
+
+        bool operator<(const HeapLimitCursor& rhs) const {
+            for (int i = 0; i < _limit_columns.size(); ++i) {
+                const auto& _limit_column = _limit_columns[i];
+                auto res = _limit_column->compare_at(_row_id, rhs._row_id, 
*_limit_column,
+                                                     _null_directions[i]) *
+                           _order_directions[i];
+                if (res < 0) {
+                    return true;
+                } else if (res > 0) {
+                    return false;
+                }
+            }
+            return false;
+        }
+
+        int _row_id;
+        MutableColumns& _limit_columns;
+        std::vector<int>& _order_directions;
+        std::vector<int>& _null_directions;
+    };
+
+    std::priority_queue<HeapLimitCursor> _limit_heap;
+
     bool _agg_data_created_without_key = false;
 
     PODArray<AggregateDataPtr> _places;
@@ -535,52 +598,7 @@ private:
     Status _init_hash_method(const VExprContextSPtrs& probe_exprs);
 
     template <bool limit>
-    Status _execute_with_serialized_key_helper(Block* block) {
-        DCHECK(!_probe_expr_ctxs.empty());
-
-        size_t key_size = _probe_expr_ctxs.size();
-        ColumnRawPtrs key_columns(key_size);
-        {
-            SCOPED_TIMER(_expr_timer);
-            for (size_t i = 0; i < key_size; ++i) {
-                int result_column_id = -1;
-                RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
-                block->get_by_position(result_column_id).column =
-                        block->get_by_position(result_column_id)
-                                .column->convert_to_full_column_if_const();
-                key_columns[i] = 
block->get_by_position(result_column_id).column.get();
-            }
-        }
-
-        int rows = block->rows();
-        if (_places.size() < rows) {
-            _places.resize(rows);
-        }
-
-        if constexpr (limit) {
-            _find_in_hash_table(_places.data(), key_columns, rows);
-
-            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-                
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add_selected(
-                        block, _offsets_of_aggregate_states[i], _places.data(),
-                        _agg_arena_pool.get()));
-            }
-        } else {
-            _emplace_into_hash_table(_places.data(), key_columns, rows);
-
-            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-                RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
-                        block, _offsets_of_aggregate_states[i], _places.data(),
-                        _agg_arena_pool.get()));
-            }
-
-            if (_should_limit_output) {
-                _reach_limit = _get_hash_table_size() >= _limit;
-            }
-        }
-
-        return Status::OK();
-    }
+    Status _execute_with_serialized_key_helper(Block* block);
 
     // We should call this function only at 1st phase.
     // 1st phase: is_merge=true, only have one SlotRef.
@@ -599,15 +617,16 @@ private:
 
         size_t key_size = _probe_expr_ctxs.size();
         ColumnRawPtrs key_columns(key_size);
+        std::vector<int> key_locs(key_size);
 
         for (size_t i = 0; i < key_size; ++i) {
             if constexpr (for_spill) {
                 key_columns[i] = block->get_by_position(i).column.get();
+                key_locs[i] = i;
             } else {
-                int result_column_id = -1;
-                RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
-                block->replace_by_position_if_const(result_column_id);
-                key_columns[i] = 
block->get_by_position(result_column_id).column.get();
+                RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&key_locs[i]));
+                block->replace_by_position_if_const(key_locs[i]);
+                key_columns[i] = 
block->get_by_position(key_locs[i]).column.get();
             }
         }
 
@@ -616,7 +635,7 @@ private:
             _places.resize(rows);
         }
 
-        if constexpr (limit) {
+        if (limit && !_do_sort_limit) {
             _find_in_hash_table(_places.data(), key_columns, rows);
 
             for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
@@ -647,43 +666,55 @@ private:
                 }
             }
         } else {
-            _emplace_into_hash_table(_places.data(), key_columns, rows);
+            bool need_do_agg = true;
+            if (limit) {
+                need_do_agg = _emplace_into_hash_table_limit(_places.data(), 
block, key_locs,
+                                                             key_columns, 
rows);
+            } else {
+                _emplace_into_hash_table(_places.data(), key_columns, rows);
+            }
 
-            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-                if (_aggregate_evaluators[i]->is_merge() || for_spill) {
-                    int col_id;
-                    if constexpr (for_spill) {
-                        col_id = _probe_expr_ctxs.size() + i;
+            if (need_do_agg) {
+                for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+                    if (_aggregate_evaluators[i]->is_merge() || for_spill) {
+                        int col_id;
+                        if constexpr (for_spill) {
+                            col_id = _probe_expr_ctxs.size() + i;
+                        } else {
+                            col_id = 
_get_slot_column_id(_aggregate_evaluators[i]);
+                        }
+                        auto column = block->get_by_position(col_id).column;
+                        if (column->is_nullable()) {
+                            column = 
((ColumnNullable*)column.get())->get_nested_column_ptr();
+                        }
+
+                        size_t buffer_size =
+                                
_aggregate_evaluators[i]->function()->size_of_data() * rows;
+                        if (_deserialize_buffer.size() < buffer_size) {
+                            _deserialize_buffer.resize(buffer_size);
+                        }
+
+                        {
+                            SCOPED_TIMER(_deserialize_data_timer);
+                            
_aggregate_evaluators[i]->function()->deserialize_and_merge_vec(
+                                    _places.data(), 
_offsets_of_aggregate_states[i],
+                                    _deserialize_buffer.data(), column.get(), 
_agg_arena_pool.get(),
+                                    rows);
+                        }
                     } else {
-                        col_id = _get_slot_column_id(_aggregate_evaluators[i]);
-                    }
-                    auto column = block->get_by_position(col_id).column;
-                    if (column->is_nullable()) {
-                        column = 
((ColumnNullable*)column.get())->get_nested_column_ptr();
-                    }
-
-                    size_t buffer_size =
-                            
_aggregate_evaluators[i]->function()->size_of_data() * rows;
-                    if (_deserialize_buffer.size() < buffer_size) {
-                        _deserialize_buffer.resize(buffer_size);
-                    }
-
-                    {
-                        SCOPED_TIMER(_deserialize_data_timer);
-                        
_aggregate_evaluators[i]->function()->deserialize_and_merge_vec(
-                                _places.data(), 
_offsets_of_aggregate_states[i],
-                                _deserialize_buffer.data(), column.get(), 
_agg_arena_pool.get(),
-                                rows);
+                        
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
+                                block, _offsets_of_aggregate_states[i], 
_places.data(),
+                                _agg_arena_pool.get()));
                     }
-                } else {
-                    
RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add(
-                            block, _offsets_of_aggregate_states[i], 
_places.data(),
-                            _agg_arena_pool.get()));
                 }
             }
 
-            if (_should_limit_output) {
-                _reach_limit = _get_hash_table_size() >= _limit;
+            if (!limit && _should_limit_output) {
+                const size_t hash_table_size = _get_hash_table_size();
+                _reach_limit = hash_table_size >= _limit;
+                if (_do_sort_limit && _reach_limit) {
+                    _build_limit_heap(hash_table_size);
+                }
             }
         }
 
@@ -693,6 +724,10 @@ private:
     void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& 
key_columns,
                                   const size_t num_rows);
 
+    bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,
+                                        const std::vector<int>& key_locs,
+                                        ColumnRawPtrs& key_columns, size_t 
num_rows);
+
     size_t _memory_usage() const;
 
     Status _reset_hash_table();
@@ -736,6 +771,12 @@ private:
     };
 
     MemoryRecord _mem_usage_record;
+
+    MutableColumns _get_keys_hash_table();
+
+    bool _do_limit_filter(Block* block, int key_size, size_t num_rows);
+
+    void _build_limit_heap(size_t hash_table_size);
 };
 } // namespace vectorized
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to