This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 092a394782 [improvement](agg)limit the output of agg node (#11461)
092a394782 is described below

commit 092a3947822b316d3e84ee7681f53a9d03044b75
Author: starocean999 <40539150+starocean...@users.noreply.github.com>
AuthorDate: Fri Aug 5 07:53:55 2022 +0800

    [improvement](agg)limit the output of agg node (#11461)
    
    * [improvement](agg)limit the output of agg node
---
 .../vec/aggregate_functions/aggregate_function.h   |  31 +++++
 be/src/vec/common/columns_hashing.h                |  18 +++
 be/src/vec/common/columns_hashing_impl.h           |  39 ++++++
 be/src/vec/common/hash_table/hash_map.h            |   2 +
 be/src/vec/common/hash_table/ph_hash_map.h         |  14 ++
 be/src/vec/exec/vaggregation_node.cpp              | 144 +++++++++++----------
 be/src/vec/exec/vaggregation_node.h                | 138 ++++++++++++++++++++
 be/src/vec/exprs/vectorized_agg_fn.cpp             |   7 +
 be/src/vec/exprs/vectorized_agg_fn.h               |   3 +
 9 files changed, 326 insertions(+), 70 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function.h 
b/be/src/vec/aggregate_functions/aggregate_function.h
index 3e20a4db51..30b2e2c959 100644
--- a/be/src/vec/aggregate_functions/aggregate_function.h
+++ b/be/src/vec/aggregate_functions/aggregate_function.h
@@ -94,6 +94,11 @@ public:
     virtual void merge_vec(const AggregateDataPtr* places, size_t offset, 
ConstAggregateDataPtr rhs,
                            Arena* arena, const size_t num_rows) const = 0;
 
+    // same as merge_vec, but only call "merge" function when place is not 
nullptr
+    virtual void merge_vec_selected(const AggregateDataPtr* places, size_t 
offset,
+                                    ConstAggregateDataPtr rhs, Arena* arena,
+                                    const size_t num_rows) const = 0;
+
     /// Serializes state (to transmit it over the network, for example).
     virtual void serialize(ConstAggregateDataPtr __restrict place, 
BufferWritable& buf) const = 0;
 
@@ -132,6 +137,11 @@ public:
     virtual void add_batch(size_t batch_size, AggregateDataPtr* places, size_t 
place_offset,
                            const IColumn** columns, Arena* arena) const = 0;
 
+    // same as add_batch, but only call "add" function when place is not 
nullptr
+    virtual void add_batch_selected(size_t batch_size, AggregateDataPtr* 
places,
+                                    size_t place_offset, const IColumn** 
columns,
+                                    Arena* arena) const = 0;
+
     /** The same for single place.
       */
     virtual void add_batch_single_place(size_t batch_size, AggregateDataPtr 
place,
@@ -169,6 +179,15 @@ public:
         }
     }
 
+    void add_batch_selected(size_t batch_size, AggregateDataPtr* places, 
size_t place_offset,
+                            const IColumn** columns, Arena* arena) const 
override {
+        for (size_t i = 0; i < batch_size; ++i) {
+            if (places[i]) {
+                static_cast<const Derived*>(this)->add(places[i] + 
place_offset, columns, i, arena);
+            }
+        }
+    }
+
     void add_batch_single_place(size_t batch_size, AggregateDataPtr place, 
const IColumn** columns,
                                 Arena* arena) const override {
         for (size_t i = 0; i < batch_size; ++i) {
@@ -228,6 +247,18 @@ public:
                                                      arena);
         }
     }
+
+    void merge_vec_selected(const AggregateDataPtr* places, size_t offset,
+                            ConstAggregateDataPtr rhs, Arena* arena,
+                            const size_t num_rows) const override {
+        const auto size_of_data = static_cast<const 
Derived*>(this)->size_of_data();
+        for (size_t i = 0; i != num_rows; ++i) {
+            if (places[i]) {
+                static_cast<const Derived*>(this)->merge(places[i] + offset, 
rhs + size_of_data * i,
+                                                         arena);
+            }
+        }
+    }
 };
 
 /// Implements several methods for manipulation with data. T - type of 
structure with data for aggregation.
diff --git a/be/src/vec/common/columns_hashing.h 
b/be/src/vec/common/columns_hashing.h
index 580588c480..844fc26ca7 100644
--- a/be/src/vec/common/columns_hashing.h
+++ b/be/src/vec/common/columns_hashing.h
@@ -257,6 +257,24 @@ struct HashMethodSingleLowNullableColumn : public 
SingleColumnMethod {
         } else
             return EmplaceResult(inserted);
     }
+
+    template <typename Data>
+    ALWAYS_INLINE FindResult find_key(Data& data, size_t row, Arena& pool) {
+        if (key_columns[0]->is_null_at(row)) {
+            bool has_null_key = data.has_null_key_data();
+            if constexpr (has_mapped)
+                return FindResult(&data.get_null_key_data(), has_null_key);
+            else
+                return FindResult(has_null_key);
+        }
+        auto key_holder = Base::get_key_holder(row, pool);
+        auto key = key_holder_get_key(key_holder);
+        auto it = data.find(key);
+        if constexpr (has_mapped)
+            return FindResult(it ? lookup_result_get_mapped(it) : nullptr, it 
!= nullptr);
+        else
+            return FindResult(it != nullptr);
+    }
 };
 
 } // namespace ColumnsHashing
diff --git a/be/src/vec/common/columns_hashing_impl.h 
b/be/src/vec/common/columns_hashing_impl.h
index 2abfa3b8e8..5b5f13b86a 100644
--- a/be/src/vec/common/columns_hashing_impl.h
+++ b/be/src/vec/common/columns_hashing_impl.h
@@ -146,6 +146,12 @@ public:
         return find_key_impl(key_holder_get_key(key_holder), data);
     }
 
+    template <typename Data>
+    ALWAYS_INLINE FindResult find_key(Data& data, size_t hash_value, size_t 
row, Arena& pool) {
+        auto key_holder = static_cast<Derived&>(*this).get_key_holder(row, 
pool);
+        return find_key_impl(key_holder_get_key(key_holder), hash_value, data);
+    }
+
     template <typename Data>
     ALWAYS_INLINE size_t get_hash(const Data& data, size_t row, Arena& pool) {
         auto key_holder = static_cast<Derived&>(*this).get_key_holder(row, 
pool);
@@ -290,6 +296,39 @@ protected:
         else
             return FindResult(it != nullptr);
     }
+
+    template <typename Data, typename Key>
+    ALWAYS_INLINE FindResult find_key_impl(Key key, size_t hash_value, Data& 
data) {
+        if constexpr (Cache::consecutive_keys_optimization) {
+            if (cache.check(key)) {
+                if constexpr (has_mapped)
+                    return FindResult(&cache.value.second, cache.found);
+                else
+                    return FindResult(cache.found);
+            }
+        }
+
+        auto it = data.find(key, hash_value);
+
+        if constexpr (consecutive_keys_optimization) {
+            cache.found = it != nullptr;
+            cache.empty = false;
+
+            if constexpr (has_mapped) {
+                cache.value.first = key;
+                if (it) {
+                    cache.value.second = *lookup_result_get_mapped(it);
+                }
+            } else {
+                cache.value = key;
+            }
+        }
+
+        if constexpr (has_mapped)
+            return FindResult(it ? lookup_result_get_mapped(it) : nullptr, it 
!= nullptr);
+        else
+            return FindResult(it != nullptr);
+    }
 };
 
 template <typename T>
diff --git a/be/src/vec/common/hash_table/hash_map.h 
b/be/src/vec/common/hash_table/hash_map.h
index 0b849c13fd..25e1b74c51 100644
--- a/be/src/vec/common/hash_table/hash_map.h
+++ b/be/src/vec/common/hash_table/hash_map.h
@@ -68,6 +68,8 @@ struct HashMapCell {
     const value_type& get_value() const { return value; }
 
     static const Key& get_key(const value_type& value) { return value.first; }
+    Mapped& get_mapped() { return value.second; }
+    const Mapped& get_mapped() const { return value.second; }
 
     bool key_equals(const Key& key_) const { return value.first == key_; }
     bool key_equals(const Key& key_, size_t /*hash_*/) const { return 
value.first == key_; }
diff --git a/be/src/vec/common/hash_table/ph_hash_map.h 
b/be/src/vec/common/hash_table/ph_hash_map.h
index 5f06f6a09f..3a56ce7988 100644
--- a/be/src/vec/common/hash_table/ph_hash_map.h
+++ b/be/src/vec/common/hash_table/ph_hash_map.h
@@ -140,6 +140,20 @@ public:
         }
     }
 
+    template <typename KeyHolder>
+    LookupResult ALWAYS_INLINE find(KeyHolder&& key_holder) {
+        const auto& key = key_holder_get_key(key_holder);
+        auto it = _hash_map.find(key);
+        return it != _hash_map.end() ? &*it : nullptr;
+    }
+
+    template <typename KeyHolder>
+    LookupResult ALWAYS_INLINE find(KeyHolder&& key_holder, size_t hash_value) 
{
+        const auto& key = key_holder_get_key(key_holder);
+        auto it = _hash_map.find(key, hash_value);
+        return it != _hash_map.end() ? &*it : nullptr;
+    }
+
     size_t hash(const Key& x) const { return _hash_map.hash(x); }
 
     void ALWAYS_INLINE prefetch_by_hash(size_t hash_value) {
diff --git a/be/src/vec/exec/vaggregation_node.cpp 
b/be/src/vec/exec/vaggregation_node.cpp
index 80817a0095..1bdccbd614 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -27,7 +27,6 @@
 #include "vec/data_types/data_type_string.h"
 #include "vec/exprs/vexpr.h"
 #include "vec/exprs/vexpr_context.h"
-#include "vec/exprs/vslot_ref.h"
 #include "vec/utils/util.hpp"
 
 namespace doris::vectorized {
@@ -393,6 +392,10 @@ Status AggregationNode::prepare(RuntimeState* state) {
         _executor.update_memusage =
                 
std::bind<void>(&AggregationNode::_update_memusage_with_serialized_key, this);
         _executor.close = 
std::bind<void>(&AggregationNode::_close_with_serialized_key, this);
+
+        _should_limit_output = _limit != -1 &&        // has limit
+                               !_vconjunct_ctx_ptr && // no having conjunct
+                               _needs_finalize;       // agg's finalize step
     }
 
     return Status::OK();
@@ -704,6 +707,11 @@ bool AggregationNode::_should_expand_preagg_hash_tables() {
             _agg_data._aggregated_method_variant);
 }
 
+size_t AggregationNode::_get_hash_table_size() {
+    return std::visit([&](auto&& agg_method) { return agg_method.data.size(); 
},
+                      _agg_data._aggregated_method_variant);
+}
+
 void AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, 
ColumnRawPtrs& key_columns,
                                                const size_t num_rows) {
     std::visit(
@@ -774,6 +782,63 @@ void 
AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, ColumnR
             _agg_data._aggregated_method_variant);
 }
 
+void AggregationNode::_find_in_hash_table(AggregateDataPtr* places, 
ColumnRawPtrs& key_columns,
+                                          size_t rows) {
+    std::visit(
+            [&](auto&& agg_method) -> void {
+                using HashMethodType = std::decay_t<decltype(agg_method)>;
+                using HashTableType = std::decay_t<decltype(agg_method.data)>;
+                using AggState = typename HashMethodType::State;
+                AggState state(key_columns, _probe_key_sz, nullptr);
+
+                _pre_serialize_key_if_need(state, agg_method, key_columns, 
rows);
+
+                std::vector<size_t> hash_values;
+
+                if constexpr (HashTableTraits<HashTableType>::is_phmap) {
+                    if (hash_values.size() < rows) hash_values.resize(rows);
+                    if constexpr 
(ColumnsHashing::IsPreSerializedKeysHashMethodTraits<
+                                          AggState>::value) {
+                        for (size_t i = 0; i < rows; ++i) {
+                            hash_values[i] = 
agg_method.data.hash(agg_method.keys[i]);
+                        }
+                    } else {
+                        for (size_t i = 0; i < rows; ++i) {
+                            hash_values[i] =
+                                    
agg_method.data.hash(state.get_key_holder(i, _agg_arena_pool));
+                        }
+                    }
+                }
+
+                /// For all rows.
+                for (size_t i = 0; i < rows; ++i) {
+                    auto find_result = [&]() {
+                        if constexpr 
(HashTableTraits<HashTableType>::is_phmap) {
+                            if (LIKELY(i + HASH_MAP_PREFETCH_DIST < rows)) {
+                                if constexpr 
(HashTableTraits<HashTableType>::is_parallel_phmap) {
+                                    
agg_method.data.prefetch_by_key(state.get_key_holder(
+                                            i + HASH_MAP_PREFETCH_DIST, 
_agg_arena_pool));
+                                } else
+                                    agg_method.data.prefetch_by_hash(
+                                            hash_values[i + 
HASH_MAP_PREFETCH_DIST]);
+                            }
+
+                            return state.find_key(agg_method.data, 
hash_values[i], i,
+                                                  _agg_arena_pool);
+                        } else {
+                            return state.find_key(agg_method.data, i, 
_agg_arena_pool);
+                        }
+                    }();
+
+                    if (find_result.is_found()) {
+                        places[i] = find_result.get_mapped();
+                    } else
+                        places[i] = nullptr;
+                }
+            },
+            _agg_data._aggregated_method_variant);
+}
+
 Status AggregationNode::_pre_agg_with_serialized_key(doris::vectorized::Block* 
in_block,
                                                      doris::vectorized::Block* 
out_block) {
     SCOPED_TIMER(_build_timer);
@@ -896,34 +961,11 @@ Status 
AggregationNode::_pre_agg_with_serialized_key(doris::vectorized::Block* i
 }
 
 Status AggregationNode::_execute_with_serialized_key(Block* block) {
-    SCOPED_TIMER(_build_timer);
-    DCHECK(!_probe_expr_ctxs.empty());
-
-    size_t key_size = _probe_expr_ctxs.size();
-    ColumnRawPtrs key_columns(key_size);
-    {
-        SCOPED_TIMER(_expr_timer);
-        for (size_t i = 0; i < key_size; ++i) {
-            int result_column_id = -1;
-            RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
-            block->get_by_position(result_column_id).column =
-                    block->get_by_position(result_column_id)
-                            .column->convert_to_full_column_if_const();
-            key_columns[i] = 
block->get_by_position(result_column_id).column.get();
-        }
-    }
-
-    int rows = block->rows();
-    PODArray<AggregateDataPtr> places(rows);
-
-    _emplace_into_hash_table(places.data(), key_columns, rows);
-
-    for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-        _aggregate_evaluators[i]->execute_batch_add(block, 
_offsets_of_aggregate_states[i],
-                                                    places.data(), 
&_agg_arena_pool);
+    if (_reach_limit) {
+        return _execute_with_serialized_key_helper<true>(block);
+    } else {
+        return _execute_with_serialized_key_helper<false>(block);
     }
-
-    return Status::OK();
 }
 
 Status AggregationNode::_get_with_serialized_key_result(RuntimeState* state, 
Block* block,
@@ -1111,49 +1153,11 @@ Status 
AggregationNode::_serialize_with_serialized_key_result(RuntimeState* stat
 }
 
 Status AggregationNode::_merge_with_serialized_key(Block* block) {
-    SCOPED_TIMER(_merge_timer);
-
-    size_t key_size = _probe_expr_ctxs.size();
-    ColumnRawPtrs key_columns(key_size);
-
-    for (size_t i = 0; i < key_size; ++i) {
-        int result_column_id = -1;
-        RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
-        key_columns[i] = block->get_by_position(result_column_id).column.get();
-    }
-
-    int rows = block->rows();
-    PODArray<AggregateDataPtr> places(rows);
-
-    _emplace_into_hash_table(places.data(), key_columns, rows);
-
-    for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-        DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 1 &&
-               
_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
-        int col_id =
-                
((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())->column_id();
-        if (_aggregate_evaluators[i]->is_merge()) {
-            auto column = block->get_by_position(col_id).column;
-            if (column->is_nullable()) {
-                column = 
((ColumnNullable*)column.get())->get_nested_column_ptr();
-            }
-
-            std::unique_ptr<char[]> deserialize_buffer(
-                    new 
char[_aggregate_evaluators[i]->function()->size_of_data() * rows]);
-
-            
_aggregate_evaluators[i]->function()->deserialize_vec(deserialize_buffer.get(),
-                                                                  
(ColumnString*)(column.get()),
-                                                                  
&_agg_arena_pool, rows);
-            _aggregate_evaluators[i]->function()->merge_vec(
-                    places.data(), _offsets_of_aggregate_states[i], 
deserialize_buffer.get(),
-                    &_agg_arena_pool, rows);
-
-        } else {
-            _aggregate_evaluators[i]->execute_batch_add(block, 
_offsets_of_aggregate_states[i],
-                                                        places.data(), 
&_agg_arena_pool);
-        }
+    if (_reach_limit) {
+        return _merge_with_serialized_key_helper<true>(block);
+    } else {
+        return _merge_with_serialized_key_helper<false>(block);
     }
-    return Status::OK();
 }
 
 void AggregationNode::_update_memusage_with_serialized_key() {
diff --git a/be/src/vec/exec/vaggregation_node.h 
b/be/src/vec/exec/vaggregation_node.h
index ba7cf24b70..7702102e6a 100644
--- a/be/src/vec/exec/vaggregation_node.h
+++ b/be/src/vec/exec/vaggregation_node.h
@@ -26,6 +26,7 @@
 #include "vec/common/columns_hashing.h"
 #include "vec/common/hash_table/fixed_hash_map.h"
 #include "vec/exprs/vectorized_agg_fn.h"
+#include "vec/exprs/vslot_ref.h"
 
 namespace doris {
 class TPlanNode;
@@ -674,11 +675,16 @@ private:
     bool _should_expand_hash_table = true;
     std::vector<char*> _streaming_pre_places;
 
+    bool _should_limit_output = false;
+    bool _reach_limit = false;
+
 private:
     /// Return true if we should keep expanding hash tables in the preagg. If 
false,
     /// the preagg should pass through any rows it can't fit in its tables.
     bool _should_expand_preagg_hash_tables();
 
+    size_t _get_hash_table_size();
+
     void _make_nullable_output_key(Block* block);
 
     Status _create_agg_status(AggregateDataPtr data);
@@ -710,9 +716,141 @@ private:
         }
     }
 
+    template <bool limit>
+    Status _execute_with_serialized_key_helper(Block* block) {
+        SCOPED_TIMER(_build_timer);
+        DCHECK(!_probe_expr_ctxs.empty());
+
+        size_t key_size = _probe_expr_ctxs.size();
+        ColumnRawPtrs key_columns(key_size);
+        {
+            SCOPED_TIMER(_expr_timer);
+            for (size_t i = 0; i < key_size; ++i) {
+                int result_column_id = -1;
+                RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
+                block->get_by_position(result_column_id).column =
+                        block->get_by_position(result_column_id)
+                                .column->convert_to_full_column_if_const();
+                key_columns[i] = 
block->get_by_position(result_column_id).column.get();
+            }
+        }
+
+        int rows = block->rows();
+        PODArray<AggregateDataPtr> places(rows);
+
+        if constexpr (limit) {
+            _find_in_hash_table(places.data(), key_columns, rows);
+
+            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+                _aggregate_evaluators[i]->execute_batch_add_selected(
+                        block, _offsets_of_aggregate_states[i], places.data(), 
&_agg_arena_pool);
+            }
+        } else {
+            _emplace_into_hash_table(places.data(), key_columns, rows);
+
+            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+                _aggregate_evaluators[i]->execute_batch_add(block, 
_offsets_of_aggregate_states[i],
+                                                            places.data(), 
&_agg_arena_pool);
+            }
+
+            if (_should_limit_output) {
+                _reach_limit = _get_hash_table_size() >= _limit;
+            }
+        }
+
+        return Status::OK();
+    }
+
+    template <bool limit>
+    Status _merge_with_serialized_key_helper(Block* block) {
+        SCOPED_TIMER(_merge_timer);
+
+        size_t key_size = _probe_expr_ctxs.size();
+        ColumnRawPtrs key_columns(key_size);
+
+        for (size_t i = 0; i < key_size; ++i) {
+            int result_column_id = -1;
+            RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, 
&result_column_id));
+            key_columns[i] = 
block->get_by_position(result_column_id).column.get();
+        }
+
+        int rows = block->rows();
+        PODArray<AggregateDataPtr> places(rows);
+
+        if constexpr (limit) {
+            _find_in_hash_table(places.data(), key_columns, rows);
+
+            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+                DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 
1 &&
+                       
_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
+                int col_id = 
((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())
+                                     ->column_id();
+                if (_aggregate_evaluators[i]->is_merge()) {
+                    auto column = block->get_by_position(col_id).column;
+                    if (column->is_nullable()) {
+                        column = 
((ColumnNullable*)column.get())->get_nested_column_ptr();
+                    }
+
+                    std::unique_ptr<char[]> deserialize_buffer(
+                            new 
char[_aggregate_evaluators[i]->function()->size_of_data() * rows]);
+
+                    _aggregate_evaluators[i]->function()->deserialize_vec(
+                            deserialize_buffer.get(), 
(ColumnString*)(column.get()),
+                            &_agg_arena_pool, rows);
+                    _aggregate_evaluators[i]->function()->merge_vec_selected(
+                            places.data(), _offsets_of_aggregate_states[i],
+                            deserialize_buffer.get(), &_agg_arena_pool, rows);
+
+                } else {
+                    _aggregate_evaluators[i]->execute_batch_add_selected(
+                            block, _offsets_of_aggregate_states[i], 
places.data(),
+                            &_agg_arena_pool);
+                }
+            }
+        } else {
+            _emplace_into_hash_table(places.data(), key_columns, rows);
+
+            for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
+                DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 
1 &&
+                       
_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
+                int col_id = 
((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())
+                                     ->column_id();
+                if (_aggregate_evaluators[i]->is_merge()) {
+                    auto column = block->get_by_position(col_id).column;
+                    if (column->is_nullable()) {
+                        column = 
((ColumnNullable*)column.get())->get_nested_column_ptr();
+                    }
+
+                    std::unique_ptr<char[]> deserialize_buffer(
+                            new 
char[_aggregate_evaluators[i]->function()->size_of_data() * rows]);
+
+                    _aggregate_evaluators[i]->function()->deserialize_vec(
+                            deserialize_buffer.get(), 
(ColumnString*)(column.get()),
+                            &_agg_arena_pool, rows);
+                    _aggregate_evaluators[i]->function()->merge_vec(
+                            places.data(), _offsets_of_aggregate_states[i],
+                            deserialize_buffer.get(), &_agg_arena_pool, rows);
+
+                } else {
+                    _aggregate_evaluators[i]->execute_batch_add(block,
+                                                                
_offsets_of_aggregate_states[i],
+                                                                places.data(), 
&_agg_arena_pool);
+                }
+            }
+
+            if (_should_limit_output) {
+                _reach_limit = _get_hash_table_size() >= _limit;
+            }
+        }
+
+        return Status::OK();
+    }
+
     void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& 
key_columns,
                                   const size_t num_rows);
 
+    void _find_in_hash_table(AggregateDataPtr* places, ColumnRawPtrs& 
key_columns, size_t num_rows);
+
     void release_tracker();
 
     using vectorized_execute = std::function<Status(Block* block)>;
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp 
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 527cb40e18..1cd7198ec0 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -156,6 +156,13 @@ void AggFnEvaluator::execute_batch_add(Block* block, 
size_t offset, AggregateDat
     _function->add_batch(block->rows(), places, offset, _agg_columns.data(), 
arena);
 }
 
+void AggFnEvaluator::execute_batch_add_selected(Block* block, size_t offset,
+                                                AggregateDataPtr* places, 
Arena* arena) {
+    _calc_argment_columns(block);
+    SCOPED_TIMER(_exec_timer);
+    _function->add_batch_selected(block->rows(), places, offset, 
_agg_columns.data(), arena);
+}
+
 void AggFnEvaluator::insert_result_info(AggregateDataPtr place, IColumn* 
column) {
     _function->insert_result_into(place, *column);
 }
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h 
b/be/src/vec/exprs/vectorized_agg_fn.h
index 52098f0d8f..e9d86b4e28 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -58,6 +58,9 @@ public:
     void execute_batch_add(Block* block, size_t offset, AggregateDataPtr* 
places,
                            Arena* arena = nullptr);
 
+    void execute_batch_add_selected(Block* block, size_t offset, 
AggregateDataPtr* places,
+                                    Arena* arena = nullptr);
+
     void insert_result_info(AggregateDataPtr place, IColumn* column);
 
     void insert_result_info_vec(const std::vector<AggregateDataPtr>& place, 
size_t offset,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to