This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 255d80cf981 [Feature](exec) Support group by limit opt in BE code (#29641) 255d80cf981 is described below commit 255d80cf981a573fe3c6753c5f2da089f6ba2479 Author: HappenLee <happen...@hotmail.com> AuthorDate: Mon Jun 3 20:19:37 2024 +0800 [Feature](exec) Support group by limit opt in BE code (#29641) ## Proposed changes Do group by limit, do topn in opt in BE --- be/src/pipeline/dependency.cpp | 76 ++++++ be/src/pipeline/dependency.h | 68 +++++ be/src/pipeline/exec/aggregation_sink_operator.cpp | 288 ++++++++++++++++----- be/src/pipeline/exec/aggregation_sink_operator.h | 11 +- .../pipeline/exec/aggregation_source_operator.cpp | 12 +- be/src/pipeline/exec/aggregation_source_operator.h | 1 + be/src/pipeline/exec/operator.cpp | 3 +- be/src/vec/columns/column_nullable.cpp | 1 + be/src/vec/columns/column_string.cpp | 10 +- be/src/vec/columns/column_string.h | 1 + be/src/vec/columns/column_vector.cpp | 12 +- be/src/vec/core/block.cpp | 13 + be/src/vec/core/block.h | 3 +- be/src/vec/exec/vaggregation_node.cpp | 243 ++++++++++++++++- be/src/vec/exec/vaggregation_node.h | 205 +++++++++------ 15 files changed, 773 insertions(+), 174 deletions(-) diff --git a/be/src/pipeline/dependency.cpp b/be/src/pipeline/dependency.cpp index 8cf025274af..e7159b2df35 100644 --- a/be/src/pipeline/dependency.cpp +++ b/be/src/pipeline/dependency.cpp @@ -196,6 +196,82 @@ LocalExchangeSharedState::LocalExchangeSharedState(int num_instances) { mem_trackers.resize(num_instances, nullptr); } +vectorized::MutableColumns AggSharedState::_get_keys_hash_table() { + return std::visit( + vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return vectorized::MutableColumns(); + }, + [&](auto&& agg_method) -> vectorized::MutableColumns { + vectorized::MutableColumns key_columns; + for (int i = 0; i < probe_expr_ctxs.size(); ++i) { + key_columns.emplace_back( + probe_expr_ctxs[i]->root()->data_type()->create_column()); + } + auto& data = *agg_method.hash_table; + bool has_null_key = data.has_null_key_data(); + const auto size = data.size() - has_null_key; + using KeyType = std::decay_t<decltype(agg_method.iterator->get_first())>; + std::vector<KeyType> keys(size); + + size_t num_rows = 0; + auto iter = aggregate_data_container->begin(); + { + while (iter != aggregate_data_container->end()) { + keys[num_rows] = iter.get_key<KeyType>(); + ++iter; + ++num_rows; + } + } + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + if (has_null_key) { + key_columns[0]->insert_data(nullptr, 0); + } + return key_columns; + }}, + agg_data->method_variant); +} + +void AggSharedState::build_limit_heap(size_t hash_table_size) { + limit_columns = _get_keys_hash_table(); + for (size_t i = 0; i < hash_table_size; ++i) { + limit_heap.emplace(i, limit_columns, order_directions, null_directions); + } + while (hash_table_size > limit) { + limit_heap.pop(); + hash_table_size--; + } + limit_columns_min = limit_heap.top()._row_id; +} + +bool AggSharedState::do_limit_filter(vectorized::Block* block, size_t num_rows) { + if (num_rows) { + cmp_res.resize(num_rows); + need_computes.resize(num_rows); + memset(need_computes.data(), 0, need_computes.size()); + memset(cmp_res.data(), 0, cmp_res.size()); + + const auto key_size = null_directions.size(); + for (int i = 0; i < key_size; i++) { + block->get_by_position(i).column->compare_internal( + limit_columns_min, *limit_columns[i], null_directions[i], order_directions[i], + cmp_res, need_computes.data()); + } + + auto set_computes_arr = [](auto* __restrict res, auto* __restrict computes, int rows) { + for (int i = 0; i < rows; ++i) { + computes[i] = computes[i] == res[i]; + } + }; + set_computes_arr(cmp_res.data(), need_computes.data(), num_rows); + + return std::find(need_computes.begin(), need_computes.end(), 0) != need_computes.end(); + } + + return false; +} + Status AggSharedState::reset_hash_table() { return std::visit( vectorized::Overload { diff --git a/be/src/pipeline/dependency.h b/be/src/pipeline/dependency.h index d7084f85d5d..e32f5a1c0d6 100644 --- a/be/src/pipeline/dependency.h +++ b/be/src/pipeline/dependency.h @@ -311,6 +311,9 @@ public: Status reset_hash_table(); + bool do_limit_filter(vectorized::Block* block, size_t num_rows); + void build_limit_heap(size_t hash_table_size); + // We should call this function only at 1st phase. // 1st phase: is_merge=true, only have one SlotRef. // 2nd phase: is_merge=false, maybe have multiple exprs. @@ -346,8 +349,73 @@ public: MemoryRecord mem_usage_record; bool agg_data_created_without_key = false; bool enable_spill = false; + bool reach_limit = false; + + int64_t limit = -1; + bool do_sort_limit = false; + vectorized::MutableColumns limit_columns; + int limit_columns_min = -1; + vectorized::PaddedPODArray<uint8_t> need_computes; + std::vector<uint8_t> cmp_res; + std::vector<int> order_directions; + std::vector<int> null_directions; + + struct HeapLimitCursor { + HeapLimitCursor(int row_id, vectorized::MutableColumns& limit_columns, + std::vector<int>& order_directions, std::vector<int>& null_directions) + : _row_id(row_id), + _limit_columns(limit_columns), + _order_directions(order_directions), + _null_directions(null_directions) {} + + HeapLimitCursor(const HeapLimitCursor& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor(HeapLimitCursor&& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { + _row_id = other._row_id; + return *this; + } + + HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { + _row_id = other._row_id; + return *this; + } + + bool operator<(const HeapLimitCursor& rhs) const { + for (int i = 0; i < _limit_columns.size(); ++i) { + const auto& _limit_column = _limit_columns[i]; + auto res = _limit_column->compare_at(_row_id, rhs._row_id, *_limit_column, + _null_directions[i]) * + _order_directions[i]; + if (res < 0) { + return true; + } else if (res > 0) { + return false; + } + } + return false; + } + + int _row_id; + vectorized::MutableColumns& _limit_columns; + std::vector<int>& _order_directions; + std::vector<int>& _null_directions; + }; + + std::priority_queue<HeapLimitCursor> limit_heap; private: + vectorized::MutableColumns _get_keys_hash_table(); + void _close_with_serialized_key() { std::visit(vectorized::Overload {[&](std::monostate& arg) -> void { // Do nothing diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp b/be/src/pipeline/exec/aggregation_sink_operator.cpp index 79f5b5af083..a3ac73a5d85 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.cpp +++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp @@ -69,6 +69,7 @@ Status AggSinkLocalState::init(RuntimeState* state, LocalSinkStateInfo& info) { _serialize_data_timer = ADD_TIMER(Base::profile(), "SerializeDataTime"); _deserialize_data_timer = ADD_TIMER(Base::profile(), "DeserializeAndMergeTime"); _hash_table_compute_timer = ADD_TIMER(Base::profile(), "HashTableComputeTime"); + _hash_table_limit_compute_timer = ADD_TIMER(Base::profile(), "DoLimitComputeTime"); _hash_table_emplace_timer = ADD_TIMER(Base::profile(), "HashTableEmplaceTime"); _hash_table_input_counter = ADD_COUNTER(Base::profile(), "HashTableInputCount", TUnit::UNIT); _max_row_size_counter = ADD_COUNTER(Base::profile(), "MaxRowSizeInBytes", TUnit::UNIT); @@ -86,6 +87,11 @@ Status AggSinkLocalState::open(RuntimeState* state) { Base::_shared_state->offsets_of_aggregate_states = p._offsets_of_aggregate_states; Base::_shared_state->make_nullable_keys = p._make_nullable_keys; Base::_shared_state->probe_expr_ctxs.resize(p._probe_expr_ctxs.size()); + + Base::_shared_state->limit = p._limit; + Base::_shared_state->do_sort_limit = p._do_sort_limit; + Base::_shared_state->null_directions = p._null_directions; + Base::_shared_state->order_directions = p._order_directions; for (size_t i = 0; i < Base::_shared_state->probe_expr_ctxs.size(); i++) { RETURN_IF_ERROR( p._probe_expr_ctxs[i]->clone(state, Base::_shared_state->probe_expr_ctxs[i])); @@ -132,7 +138,6 @@ Status AggSinkLocalState::open(RuntimeState* state) { _should_limit_output = p._limit != -1 && // has limit (!p._have_conjuncts) && // no having conjunct - p._needs_finalize && // agg's finalize step !Base::_shared_state->enable_spill; } for (auto& evaluator : p._aggregate_evaluators) { @@ -183,7 +188,7 @@ Status AggSinkLocalState::_execute_without_key(vectorized::Block* block) { } Status AggSinkLocalState::_merge_with_serialized_key(vectorized::Block* block) { - if (_reach_limit) { + if (_shared_state->reach_limit) { return _merge_with_serialized_key_helper<true, false>(block); } else { return _merge_with_serialized_key_helper<false, false>(block); @@ -260,12 +265,14 @@ Status AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b size_t key_size = Base::_shared_state->probe_expr_ctxs.size(); vectorized::ColumnRawPtrs key_columns(key_size); + std::vector<int> key_locs(key_size); for (size_t i = 0; i < key_size; ++i) { if constexpr (for_spill) { key_columns[i] = block->get_by_position(i).column.get(); + key_locs[i] = i; } else { - int result_column_id = -1; + int& result_column_id = key_locs[i]; RETURN_IF_ERROR( Base::_shared_state->probe_expr_ctxs[i]->execute(block, &result_column_id)); block->replace_by_position_if_const(result_column_id); @@ -278,7 +285,7 @@ Status AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b _places.resize(rows); } - if constexpr (limit) { + if (limit && !_shared_state->do_sort_limit) { _find_in_hash_table(_places.data(), key_columns, rows); for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { @@ -318,52 +325,66 @@ Status AggSinkLocalState::_merge_with_serialized_key_helper(vectorized::Block* b } } } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); + bool need_do_agg = true; + if (limit) { + need_do_agg = _emplace_into_hash_table_limit(_places.data(), block, key_locs, + key_columns, rows); + } else { + _emplace_into_hash_table(_places.data(), key_columns, rows); + } - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - if (Base::_shared_state->aggregate_evaluators[i]->is_merge() || for_spill) { - int col_id = 0; - if constexpr (for_spill) { - col_id = Base::_shared_state->probe_expr_ctxs.size() + i; + if (need_do_agg) { + for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { + if (Base::_shared_state->aggregate_evaluators[i]->is_merge() || for_spill) { + int col_id = 0; + if constexpr (for_spill) { + col_id = Base::_shared_state->probe_expr_ctxs.size() + i; + } else { + col_id = AggSharedState::get_slot_column_id( + Base::_shared_state->aggregate_evaluators[i]); + } + auto column = block->get_by_position(col_id).column; + if (column->is_nullable()) { + column = ((vectorized::ColumnNullable*)column.get()) + ->get_nested_column_ptr(); + } + + size_t buffer_size = Base::_shared_state->aggregate_evaluators[i] + ->function() + ->size_of_data() * + rows; + if (_deserialize_buffer.size() < buffer_size) { + _deserialize_buffer.resize(buffer_size); + } + + { + SCOPED_TIMER(_deserialize_data_timer); + Base::_shared_state->aggregate_evaluators[i] + ->function() + ->deserialize_and_merge_vec( + _places.data(), + Base::_parent->template cast<AggSinkOperatorX>() + ._offsets_of_aggregate_states[i], + _deserialize_buffer.data(), column.get(), _agg_arena_pool, + rows); + } } else { - col_id = AggSharedState::get_slot_column_id( - Base::_shared_state->aggregate_evaluators[i]); - } - auto column = block->get_by_position(col_id).column; - if (column->is_nullable()) { - column = ((vectorized::ColumnNullable*)column.get())->get_nested_column_ptr(); - } - - size_t buffer_size = - Base::_shared_state->aggregate_evaluators[i]->function()->size_of_data() * - rows; - if (_deserialize_buffer.size() < buffer_size) { - _deserialize_buffer.resize(buffer_size); - } - - { - SCOPED_TIMER(_deserialize_data_timer); - Base::_shared_state->aggregate_evaluators[i] - ->function() - ->deserialize_and_merge_vec( - _places.data(), - Base::_parent->template cast<AggSinkOperatorX>() - ._offsets_of_aggregate_states[i], - _deserialize_buffer.data(), column.get(), _agg_arena_pool, - rows); + RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add( + block, + Base::_parent->template cast<AggSinkOperatorX>() + ._offsets_of_aggregate_states[i], + _places.data(), _agg_arena_pool)); } - } else { - RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add( - block, - Base::_parent->template cast<AggSinkOperatorX>() - ._offsets_of_aggregate_states[i], - _places.data(), _agg_arena_pool)); } } - if (_should_limit_output) { - _reach_limit = _get_hash_table_size() >= - Base::_parent->template cast<AggSinkOperatorX>()._limit; + if (!limit && _should_limit_output) { + const size_t hash_table_size = _get_hash_table_size(); + _shared_state->reach_limit = + hash_table_size >= Base::_parent->template cast<AggSinkOperatorX>()._limit; + if (_shared_state->do_sort_limit && _shared_state->reach_limit) { + _shared_state->build_limit_heap(hash_table_size); + } } } @@ -410,7 +431,7 @@ void AggSinkLocalState::_update_memusage_without_key() { } Status AggSinkLocalState::_execute_with_serialized_key(vectorized::Block* block) { - if (_reach_limit) { + if (_shared_state->reach_limit) { return _execute_with_serialized_key_helper<true>(block); } else { return _execute_with_serialized_key_helper<false>(block); @@ -424,10 +445,11 @@ Status AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block* size_t key_size = Base::_shared_state->probe_expr_ctxs.size(); vectorized::ColumnRawPtrs key_columns(key_size); + std::vector<int> key_locs(key_size); { SCOPED_TIMER(_expr_timer); for (size_t i = 0; i < key_size; ++i) { - int result_column_id = -1; + int& result_column_id = key_locs[i]; RETURN_IF_ERROR( Base::_shared_state->probe_expr_ctxs[i]->execute(block, &result_column_id)); block->get_by_position(result_column_id).column = @@ -442,7 +464,7 @@ Status AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block* _places.resize(rows); } - if constexpr (limit) { + if (limit && !_shared_state->do_sort_limit) { _find_in_hash_table(_places.data(), key_columns, rows); for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { @@ -454,27 +476,48 @@ Status AggSinkLocalState::_execute_with_serialized_key_helper(vectorized::Block* _places.data(), _agg_arena_pool)); } } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); - - for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add( - block, - Base::_parent->template cast<AggSinkOperatorX>() - ._offsets_of_aggregate_states[i], - _places.data(), _agg_arena_pool)); - } + auto do_aggregate_evaluators = [&] { + for (int i = 0; i < Base::_shared_state->aggregate_evaluators.size(); ++i) { + RETURN_IF_ERROR(Base::_shared_state->aggregate_evaluators[i]->execute_batch_add( + block, + Base::_parent->template cast<AggSinkOperatorX>() + ._offsets_of_aggregate_states[i], + _places.data(), _agg_arena_pool)); + } + return Status::OK(); + }; - if (_should_limit_output) { - _reach_limit = _get_hash_table_size() >= - Base::_parent->template cast<AggSinkOperatorX>()._limit; - if (_reach_limit && - Base::_parent->template cast<AggSinkOperatorX>()._can_short_circuit) { - Base::_dependency->set_ready_to_read(); - return Status::Error<ErrorCode::END_OF_FILE>(""); + if constexpr (limit) { + if (_emplace_into_hash_table_limit(_places.data(), block, key_locs, key_columns, + rows)) { + RETURN_IF_ERROR(do_aggregate_evaluators()); + } + } else { + _emplace_into_hash_table(_places.data(), key_columns, rows); + RETURN_IF_ERROR(do_aggregate_evaluators()); + + if (_should_limit_output && !Base::_shared_state->enable_spill) { + const size_t hash_table_size = _get_hash_table_size(); + if (Base::_parent->template cast<AggSinkOperatorX>()._can_short_circuit) { + _shared_state->reach_limit = + hash_table_size >= + Base::_parent->template cast<AggSinkOperatorX>()._limit; + if (_shared_state->reach_limit) { + Base::_dependency->set_ready_to_read(); + return Status::Error<ErrorCode::END_OF_FILE>(""); + } + } else { + _shared_state->reach_limit = + hash_table_size >= _shared_state->do_sort_limit + ? Base::_parent->template cast<AggSinkOperatorX>()._limit * 5 + : Base::_parent->template cast<AggSinkOperatorX>()._limit; + if (_shared_state->reach_limit && _shared_state->do_sort_limit) { + _shared_state->build_limit_heap(hash_table_size); + } + } } } } - return Status::OK(); } @@ -535,6 +578,108 @@ void AggSinkLocalState::_emplace_into_hash_table(vectorized::AggregateDataPtr* p _agg_data->method_variant); } +bool AggSinkLocalState::_emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places, + vectorized::Block* block, + const std::vector<int>& key_locs, + vectorized::ColumnRawPtrs& key_columns, + size_t num_rows) { + return std::visit( + vectorized::Overload { + [&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return true; + }, + [&](auto&& agg_method) -> bool { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t<decltype(agg_method)>; + using AggState = typename HashMethodType::State; + + bool need_filter = false; + { + SCOPED_TIMER(_hash_table_limit_compute_timer); + need_filter = _shared_state->do_limit_filter(block, num_rows); + } + + auto& need_computes = _shared_state->need_computes; + if (auto need_agg = + std::find(need_computes.begin(), need_computes.end(), 1); + need_agg != need_computes.end()) { + if (need_filter) { + vectorized::Block::filter_block_internal(block, need_computes); + for (int i = 0; i < key_locs.size(); ++i) { + key_columns[i] = + block->get_by_position(key_locs[i]).column.get(); + } + num_rows = block->rows(); + } + + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + size_t i = 0; + + auto refresh_top_limit = [&, this]() { + _shared_state->limit_heap.pop(); + for (int j = 0; j < key_columns.size(); ++j) { + _shared_state->limit_columns[j]->insert_from(*key_columns[j], + i); + } + _shared_state->limit_heap.emplace( + _shared_state->limit_columns[0]->size() - 1, + _shared_state->limit_columns, + _shared_state->order_directions, + _shared_state->null_directions); + _shared_state->limit_columns_min = + _shared_state->limit_heap.top()._row_id; + }; + + auto creator = [this, refresh_top_limit](const auto& ctor, auto& key, + auto& origin) { + try { + HashMethodType::try_presis_key_and_origin(key, origin, + *_agg_arena_pool); + auto mapped = + _shared_state->aggregate_data_container->append_data( + origin); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + ctor(key, mapped); + refresh_top_limit(); + } catch (...) { + // Exception-safety - if it can not allocate memory or create status, + // the destructors will not be called. + ctor(key, nullptr); + throw; + } + }; + + auto creator_for_null_key = [this, refresh_top_limit](auto& mapped) { + mapped = _agg_arena_pool->aligned_alloc( + Base::_parent->template cast<AggSinkOperatorX>() + ._total_size_of_aggregate_states, + Base::_parent->template cast<AggSinkOperatorX>() + ._align_aggregate_states); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + refresh_top_limit(); + }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (i = 0; i < num_rows; ++i) { + places[i] = agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + } + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + return true; + } + return false; + }}, + _agg_data->method_variant); +} + void AggSinkLocalState::_find_in_hash_table(vectorized::AggregateDataPtr* places, vectorized::ColumnRawPtrs& key_columns, size_t num_rows) { @@ -616,6 +761,21 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); + if (tnode.agg_node.__isset.agg_sort_info_by_group_key) { + _do_sort_limit = true; + const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key; + DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size()); + + const int order_by_key_size = agg_sort_info.is_asc_order.size(); + _order_directions.resize(order_by_key_size); + _null_directions.resize(order_by_key_size); + for (int i = 0; i < order_by_key_size; ++i) { + _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1; + _null_directions[i] = + agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i]; + } + } + return Status::OK(); } diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h b/be/src/pipeline/exec/aggregation_sink_operator.h index d48debc2d83..39fee1707e4 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.h +++ b/be/src/pipeline/exec/aggregation_sink_operator.h @@ -85,6 +85,9 @@ protected: vectorized::ColumnRawPtrs& key_columns, size_t num_rows); void _emplace_into_hash_table(vectorized::AggregateDataPtr* places, vectorized::ColumnRawPtrs& key_columns, size_t num_rows); + bool _emplace_into_hash_table_limit(vectorized::AggregateDataPtr* places, + vectorized::Block* block, const std::vector<int>& key_locs, + vectorized::ColumnRawPtrs& key_columns, size_t num_rows); size_t _get_hash_table_size() const; template <bool limit, bool for_spill = false> @@ -96,6 +99,7 @@ protected: RuntimeProfile::Counter* _hash_table_compute_timer = nullptr; RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr; + RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr; RuntimeProfile::Counter* _hash_table_input_counter = nullptr; RuntimeProfile::Counter* _build_timer = nullptr; RuntimeProfile::Counter* _expr_timer = nullptr; @@ -109,7 +113,6 @@ protected: RuntimeProfile::HighWaterMarkCounter* _serialize_key_arena_memory_usage = nullptr; bool _should_limit_output = false; - bool _reach_limit = false; vectorized::PODArray<vectorized::AggregateDataPtr> _places; std::vector<char> _deserialize_buffer; @@ -191,8 +194,12 @@ protected: ObjectPool* _pool = nullptr; std::vector<size_t> _make_nullable_keys; int64_t _limit; // -1: no limit - bool _have_conjuncts; + // do sort limit and directions + bool _do_sort_limit = false; + std::vector<int> _order_directions; + std::vector<int> _null_directions; + bool _have_conjuncts; const std::vector<TExpr> _partition_exprs; const bool _is_colocate; diff --git a/be/src/pipeline/exec/aggregation_source_operator.cpp b/be/src/pipeline/exec/aggregation_source_operator.cpp index b94f076bdbf..cca9fefbdb2 100644 --- a/be/src/pipeline/exec/aggregation_source_operator.cpp +++ b/be/src/pipeline/exec/aggregation_source_operator.cpp @@ -22,7 +22,6 @@ #include "common/exception.h" #include "pipeline/exec/operator.h" -#include "vec//utils/util.hpp" namespace doris::pipeline { @@ -444,10 +443,19 @@ Status AggSourceOperatorX::get_block(RuntimeState* state, vectorized::Block* blo local_state.make_nullable_output_key(block); // dispose the having clause, should not be execute in prestreaming agg RETURN_IF_ERROR(vectorized::VExprContext::filter_block(_conjuncts, block, block->columns())); - local_state.reached_limit(block, eos); + local_state.do_agg_limit(block, eos); return Status::OK(); } +void AggLocalState::do_agg_limit(vectorized::Block* block, bool* eos) { + if (_shared_state->reach_limit) { + if (_shared_state->do_sort_limit && _shared_state->do_limit_filter(block, block->rows())) { + vectorized::Block::filter_block_internal(block, _shared_state->need_computes); + } + reached_limit(block, eos); + } +} + void AggLocalState::make_nullable_output_key(vectorized::Block* block) { if (block->rows() != 0) { for (auto cid : _shared_state->make_nullable_keys) { diff --git a/be/src/pipeline/exec/aggregation_source_operator.h b/be/src/pipeline/exec/aggregation_source_operator.h index c4ea6c6ccde..a3824a381eb 100644 --- a/be/src/pipeline/exec/aggregation_source_operator.h +++ b/be/src/pipeline/exec/aggregation_source_operator.h @@ -41,6 +41,7 @@ public: void make_nullable_output_key(vectorized::Block* block); template <bool limit> Status merge_with_serialized_key_helper(vectorized::Block* block); + void do_agg_limit(vectorized::Block* block, bool* eos); protected: friend class AggSourceOperatorX; diff --git a/be/src/pipeline/exec/operator.cpp b/be/src/pipeline/exec/operator.cpp index 938eb22f253..455f11fa9f1 100644 --- a/be/src/pipeline/exec/operator.cpp +++ b/be/src/pipeline/exec/operator.cpp @@ -389,8 +389,7 @@ std::shared_ptr<BasicSharedState> DataSinkOperatorX<LocalStateType>::create_shar LOG(FATAL) << "should not reach here!"; return nullptr; } else { - std::shared_ptr<BasicSharedState> ss = nullptr; - ss = LocalStateType::SharedStateType::create_shared(); + auto ss = LocalStateType::SharedStateType::create_shared(); ss->id = operator_id(); for (auto& dest : dests_id()) { ss->related_op_ids.insert(dest); diff --git a/be/src/vec/columns/column_nullable.cpp b/be/src/vec/columns/column_nullable.cpp index 6efa690d7db..c516b96b72f 100644 --- a/be/src/vec/columns/column_nullable.cpp +++ b/be/src/vec/columns/column_nullable.cpp @@ -422,6 +422,7 @@ int ColumnNullable::compare_at(size_t n, size_t m, const IColumn& rhs_, return get_nested_column().compare_at(n, m, nullable_rhs.get_nested_column(), null_direction_hint); } + void ColumnNullable::compare_internal(size_t rhs_row_id, const IColumn& rhs, int nan_direction_hint, int direction, std::vector<uint8>& cmp_res, uint8* __restrict filter) const { diff --git a/be/src/vec/columns/column_string.cpp b/be/src/vec/columns/column_string.cpp index 446fd283b1c..919854a42d9 100644 --- a/be/src/vec/columns/column_string.cpp +++ b/be/src/vec/columns/column_string.cpp @@ -544,7 +544,7 @@ template <typename T> void ColumnStr<T>::compare_internal(size_t rhs_row_id, const IColumn& rhs, int nan_direction_hint, int direction, std::vector<uint8>& cmp_res, uint8* __restrict filter) const { - auto sz = this->size(); + auto sz = offsets.size(); DCHECK(cmp_res.size() == sz); const auto& cmp_base = assert_cast<const ColumnStr<T>&>(rhs).get_data_at(rhs_row_id); size_t begin = simd::find_zero(cmp_res, 0); @@ -554,12 +554,8 @@ void ColumnStr<T>::compare_internal(size_t rhs_row_id, const IColumn& rhs, int n auto value_a = get_data_at(row_id); int res = memcmp_small_allow_overflow15(value_a.data, value_a.size, cmp_base.data, cmp_base.size); - if (res * direction < 0) { - filter[row_id] = 1; - cmp_res[row_id] = 1; - } else if (res * direction > 0) { - cmp_res[row_id] = 1; - } + cmp_res[row_id] = res != 0; + filter[row_id] = res * direction < 0; } begin = simd::find_zero(cmp_res, end + 1); } diff --git a/be/src/vec/columns/column_string.h b/be/src/vec/columns/column_string.h index d0994607a46..22dcd612d3a 100644 --- a/be/src/vec/columns/column_string.h +++ b/be/src/vec/columns/column_string.h @@ -549,6 +549,7 @@ public: void compare_internal(size_t rhs_row_id, const IColumn& rhs, int nan_direction_hint, int direction, std::vector<uint8>& cmp_res, uint8* __restrict filter) const override; + MutableColumnPtr get_shinked_column() const { auto shrinked_column = ColumnStr<T>::create(); for (int i = 0; i < size(); i++) { diff --git a/be/src/vec/columns/column_vector.cpp b/be/src/vec/columns/column_vector.cpp index 60a75420405..14d52045943 100644 --- a/be/src/vec/columns/column_vector.cpp +++ b/be/src/vec/columns/column_vector.cpp @@ -141,21 +141,17 @@ void ColumnVector<T>::compare_internal(size_t rhs_row_id, const IColumn& rhs, int nan_direction_hint, int direction, std::vector<uint8>& cmp_res, uint8* __restrict filter) const { - auto sz = this->size(); + const auto sz = data.size(); DCHECK(cmp_res.size() == sz); const auto& cmp_base = assert_cast<const ColumnVector<T>&>(rhs).get_data()[rhs_row_id]; size_t begin = simd::find_zero(cmp_res, 0); while (begin < sz) { size_t end = simd::find_one(cmp_res, begin + 1); for (size_t row_id = begin; row_id < end; row_id++) { - auto value_a = get_data()[row_id]; + auto value_a = data[row_id]; int res = value_a > cmp_base ? 1 : (value_a < cmp_base ? -1 : 0); - if (res * direction < 0) { - filter[row_id] = 1; - cmp_res[row_id] = 1; - } else if (res * direction > 0) { - cmp_res[row_id] = 1; - } + cmp_res[row_id] = (res != 0); + filter[row_id] = (res * direction < 0); } begin = simd::find_zero(cmp_res, end + 1); } diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index 7595ffb6620..95af060dfc7 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -797,6 +797,19 @@ void Block::filter_block_internal(Block* block, const IColumn::Filter& filter, filter_block_internal(block, columns_to_filter, filter); } +void Block::filter_block_internal(Block* block, const IColumn::Filter& filter) { + const size_t count = + filter.size() - simd::count_zero_num((int8_t*)filter.data(), filter.size()); + for (int i = 0; i < block->columns(); ++i) { + auto& column = block->get_by_position(i).column; + if (column->is_exclusive()) { + column->assume_mutable()->filter(filter); + } else { + column = column->filter(filter, count); + } + } +} + Block Block::copy_block(const std::vector<int>& column_offset) const { ColumnsWithTypeAndName columns_with_type_and_name; for (auto offset : column_offset) { diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h index 593d37f7ff2..3611252ea59 100644 --- a/be/src/vec/core/block.h +++ b/be/src/vec/core/block.h @@ -281,10 +281,11 @@ public: // need exception safety static void filter_block_internal(Block* block, const std::vector<uint32_t>& columns_to_filter, const IColumn::Filter& filter); - // need exception safety static void filter_block_internal(Block* block, const IColumn::Filter& filter, uint32_t column_to_keep); + // need exception safety + static void filter_block_internal(Block* block, const IColumn::Filter& filter); static Status filter_block(Block* block, const std::vector<uint32_t>& columns_to_filter, int filter_column_id, int column_to_keep); diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index 1845382a2b4..f009802d5dd 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -158,6 +158,21 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) { _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(), [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; }); + + if (tnode.agg_node.__isset.agg_sort_info_by_group_key) { + _do_sort_limit = true; + const auto& agg_sort_info = tnode.agg_node.agg_sort_info_by_group_key; + DCHECK_EQ(agg_sort_info.nulls_first.size(), agg_sort_info.is_asc_order.size()); + + const int order_by_key_size = agg_sort_info.is_asc_order.size(); + _order_directions.resize(order_by_key_size); + _null_directions.resize(order_by_key_size); + for (int i = 0; i < order_by_key_size; ++i) { + _order_directions[i] = agg_sort_info.is_asc_order[i] ? 1 : -1; + _null_directions[i] = + agg_sort_info.nulls_first[i] ? -_order_directions[i] : _order_directions[i]; + } + } return Status::OK(); } @@ -183,6 +198,7 @@ Status AggregationNode::prepare_profile(RuntimeState* state) { _deserialize_data_timer = ADD_TIMER(runtime_profile(), "DeserializeAndMergeTime"); _hash_table_compute_timer = ADD_TIMER(runtime_profile(), "HashTableComputeTime"); _hash_table_emplace_timer = ADD_TIMER(runtime_profile(), "HashTableEmplaceTime"); + _hash_table_limit_compute_timer = ADD_TIMER(runtime_profile(), "DoLimitComputeTime"); _hash_table_iterate_timer = ADD_TIMER(runtime_profile(), "HashTableIterateTime"); _insert_keys_to_column_timer = ADD_TIMER(runtime_profile(), "InsertKeysToColumnTime"); _streaming_agg_timer = ADD_TIMER(runtime_profile(), "StreamingAggTime"); @@ -315,9 +331,8 @@ Status AggregationNode::prepare_profile(RuntimeState* state) { std::bind<void>(&AggregationNode::_update_memusage_with_serialized_key, this); _executor.close = std::bind<void>(&AggregationNode::_close_with_serialized_key, this); - _should_limit_output = _limit != -1 && // has limit - _conjuncts.empty() && // no having conjunct - _needs_finalize; // agg's finalize step + _should_limit_output = _limit != -1 && // has limit + _conjuncts.empty(); } fmt::memory_buffer msg; @@ -436,8 +451,12 @@ Status AggregationNode::pull(doris::RuntimeState* state, vectorized::Block* bloc _make_nullable_output_key(block); // dispose the having clause, should not be execute in prestreaming agg RETURN_IF_ERROR(VExprContext::filter_block(_conjuncts, block, block->columns())); - reached_limit(block, eos); - + if (_reach_limit) { + if (_do_sort_limit && _do_limit_filter(block, _order_directions.size(), block->rows())) { + Block::filter_block_internal(block, _need_computes); + } + reached_limit(block, eos); + } return Status::OK(); } @@ -775,6 +794,158 @@ size_t AggregationNode::_get_hash_table_size() { _agg_data->method_variant); } +template <bool limit> +Status AggregationNode::_execute_with_serialized_key_helper(Block* block) { + DCHECK(!_probe_expr_ctxs.empty()); + + size_t key_size = _probe_expr_ctxs.size(); + ColumnRawPtrs key_columns(key_size); + std::vector<int> key_locs(key_size); + { + SCOPED_TIMER(_expr_timer); + for (size_t i = 0; i < key_size; ++i) { + auto& result_column_id = key_locs[i]; + RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, &result_column_id)); + block->get_by_position(result_column_id).column = + block->get_by_position(result_column_id) + .column->convert_to_full_column_if_const(); + key_columns[i] = block->get_by_position(result_column_id).column.get(); + } + } + + int rows = block->rows(); + if (_places.size() < rows) { + _places.resize(rows); + } + + if constexpr (limit) { + if (_emplace_into_hash_table_limit(_places.data(), block, key_locs, key_columns, rows)) { + for (int i = 0; i < _aggregate_evaluators.size(); ++i) { + RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( + block, _offsets_of_aggregate_states[i], _places.data(), + _agg_arena_pool.get())); + } + } + } else { + _emplace_into_hash_table(_places.data(), key_columns, rows); + + for (int i = 0; i < _aggregate_evaluators.size(); ++i) { + RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( + block, _offsets_of_aggregate_states[i], _places.data(), _agg_arena_pool.get())); + } + + if (_should_limit_output && !_reach_limit) { + auto size = _get_hash_table_size(); + _reach_limit = size >= _limit * 5; + if (_reach_limit) { + _build_limit_heap(size); + } + } + } + + return Status::OK(); +} + +void AggregationNode::_build_limit_heap(size_t hash_table_size) { + _limit_columns = _get_keys_hash_table(); + for (size_t i = 0; i < hash_table_size; ++i) { + _limit_heap.emplace(i, _limit_columns, _order_directions, _null_directions); + } + while (hash_table_size > _limit) { + _limit_heap.pop(); + hash_table_size--; + } + _limit_columns_min = _limit_heap.top()._row_id; +} + +bool AggregationNode::_emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, + const std::vector<int>& key_locs, + ColumnRawPtrs& key_columns, size_t num_rows) { + return std::visit( + Overload {[&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return true; + }, + [&](auto&& agg_method) -> bool { + SCOPED_TIMER(_hash_table_compute_timer); + using HashMethodType = std::decay_t<decltype(agg_method)>; + using AggState = typename HashMethodType::State; + + bool need_filter = false; + { + SCOPED_TIMER(_hash_table_limit_compute_timer); + need_filter = _do_limit_filter(block, key_columns.size(), num_rows); + } + + if (auto need_agg = + std::find(_need_computes.begin(), _need_computes.end(), 1); + need_agg != _need_computes.end()) { + if (need_filter) { + Block::filter_block_internal(block, _need_computes); + for (int i = 0; i < key_locs.size(); ++i) { + key_columns[i] = + block->get_by_position(key_locs[i]).column.get(); + } + num_rows = block->rows(); + } + + AggState state(key_columns); + agg_method.init_serialized_keys(key_columns, num_rows); + size_t i = 0; + + auto refresh_top_limit = [&, this] { + _limit_heap.pop(); + for (int j = 0; j < key_columns.size(); ++j) { + _limit_columns[j]->insert_from(*key_columns[j], i); + } + _limit_heap.emplace(_limit_columns[0]->size() - 1, _limit_columns, + _order_directions, _null_directions); + _limit_columns_min = _limit_heap.top()._row_id; + }; + + auto creator = [this, refresh_top_limit](const auto& ctor, auto& key, + auto& origin) { + try { + HashMethodType::try_presis_key_and_origin(key, origin, + *_agg_arena_pool); + auto mapped = _aggregate_data_container->append_data(origin); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + ctor(key, mapped); + refresh_top_limit(); + } catch (...) { + // Exception-safety - if it can not allocate memory or create status, + // the destructors will not be called. + ctor(key, nullptr); + throw; + } + }; + + auto creator_for_null_key = [this, refresh_top_limit](auto& mapped) { + mapped = _agg_arena_pool->aligned_alloc( + _total_size_of_aggregate_states, _align_aggregate_states); + auto st = _create_agg_status(mapped); + if (!st) { + throw Exception(st.code(), st.to_string()); + } + refresh_top_limit(); + }; + + SCOPED_TIMER(_hash_table_emplace_timer); + for (i = 0; i < num_rows; ++i) { + places[i] = agg_method.lazy_emplace(state, i, creator, + creator_for_null_key); + } + COUNTER_UPDATE(_hash_table_input_counter, num_rows); + return true; + } + return false; + }}, + _agg_data->method_variant); +} + void AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, const size_t num_rows) { std::visit(Overload {[&](std::monostate& arg) { @@ -785,6 +956,7 @@ void AggregationNode::_emplace_into_hash_table(AggregateDataPtr* places, ColumnR SCOPED_TIMER(_hash_table_compute_timer); using HashMethodType = std::decay_t<decltype(agg_method)>; using AggState = typename HashMethodType::State; + AggState state(key_columns); agg_method.init_serialized_keys(key_columns, num_rows); @@ -850,6 +1022,42 @@ void AggregationNode::_find_in_hash_table(AggregateDataPtr* places, ColumnRawPtr _agg_data->method_variant); } +MutableColumns AggregationNode::_get_keys_hash_table() { + return std::visit( + Overload {[&](std::monostate& arg) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "uninited hash table"); + return MutableColumns(); + }, + [&](auto&& agg_method) -> MutableColumns { + MutableColumns key_columns; + for (int i = 0; i < _probe_expr_ctxs.size(); ++i) { + key_columns.emplace_back( + _probe_expr_ctxs[i]->root()->data_type()->create_column()); + } + auto& data = *agg_method.hash_table; + bool has_null_key = data.has_null_key_data(); + const auto size = data.size() - has_null_key; + using KeyType = std::decay_t<decltype(agg_method.iterator->get_first())>; + std::vector<KeyType> keys(size); + + size_t num_rows = 0; + auto iter = _aggregate_data_container->begin(); + { + while (iter != _aggregate_data_container->end()) { + keys[num_rows] = iter.get_key<KeyType>(); + ++iter; + ++num_rows; + } + } + agg_method.insert_keys_into_columns(keys, key_columns, num_rows); + if (has_null_key) { + key_columns[0]->insert_data(nullptr, 0); + } + return key_columns; + }}, + _agg_data->method_variant); +} + Status AggregationNode::_pre_agg_with_serialized_key(doris::vectorized::Block* in_block, doris::vectorized::Block* out_block) { DCHECK(!_probe_expr_ctxs.empty()); @@ -1455,14 +1663,37 @@ Status AggregationNode::_serialize_with_serialized_key_result_non_spill(RuntimeS _probe_expr_ctxs[i]->root()->data_type(), _probe_expr_ctxs[i]->root()->expr_name()); } + for (int i = 0; i < agg_size; ++i) { columns_with_schema.emplace_back(std::move(value_columns[i]), value_data_types[i], ""); + *block = Block(columns_with_schema); } - *block = Block(columns_with_schema); } + return Status::OK(); } +bool AggregationNode::_do_limit_filter(Block* block, int key_size, size_t num_rows) { + if (num_rows) { + _cmp_res.resize(num_rows); + _need_computes.resize(num_rows); + memset(_need_computes.data(), 0, _need_computes.size()); + memset(_cmp_res.data(), 0, _cmp_res.size()); + + for (int i = 0; i < key_size; i++) { + block->get_by_position(i).column->compare_internal( + _limit_columns_min, *_limit_columns[i], _null_directions[i], + _order_directions[i], _cmp_res, _need_computes.data()); + } + + for (int i = 0; i < num_rows; ++i) { + _need_computes[i] = _need_computes[i] == _cmp_res[i]; + } + return std::find(_need_computes.begin(), _need_computes.end(), 0) != _need_computes.end(); + } + return false; +} + Status AggregationNode::_merge_with_serialized_key(Block* block) { if (_reach_limit) { return _merge_with_serialized_key_helper<true, false>(block); diff --git a/be/src/vec/exec/vaggregation_node.h b/be/src/vec/exec/vaggregation_node.h index 70222f9ccf7..cd7aedebead 100644 --- a/be/src/vec/exec/vaggregation_node.h +++ b/be/src/vec/exec/vaggregation_node.h @@ -438,6 +438,7 @@ protected: // nullable diff. so we need make nullable of it. std::vector<size_t> _make_nullable_keys; RuntimeProfile::Counter* _hash_table_compute_timer = nullptr; + RuntimeProfile::Counter* _hash_table_limit_compute_timer = nullptr; RuntimeProfile::Counter* _hash_table_emplace_timer = nullptr; RuntimeProfile::Counter* _hash_table_input_counter = nullptr; RuntimeProfile::Counter* _expr_timer = nullptr; @@ -491,8 +492,70 @@ private: RuntimeProfile::HighWaterMarkCounter* _serialize_key_arena_memory_usage = nullptr; bool _should_expand_hash_table = true; + bool _should_limit_output = false; bool _reach_limit = false; + bool _do_sort_limit = false; + MutableColumns _limit_columns; + int _limit_columns_min = -1; + PaddedPODArray<uint8_t> _need_computes; + std::vector<uint8_t> _cmp_res; + std::vector<int> _order_directions; + std::vector<int> _null_directions; + + struct HeapLimitCursor { + HeapLimitCursor(int row_id, MutableColumns& limit_columns, + std::vector<int>& order_directions, std::vector<int>& null_directions) + : _row_id(row_id), + _limit_columns(limit_columns), + _order_directions(order_directions), + _null_directions(null_directions) {} + + HeapLimitCursor(const HeapLimitCursor& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor(HeapLimitCursor&& other) noexcept + : _row_id(other._row_id), + _limit_columns(other._limit_columns), + _order_directions(other._order_directions), + _null_directions(other._null_directions) {} + + HeapLimitCursor& operator=(const HeapLimitCursor& other) noexcept { + _row_id = other._row_id; + return *this; + } + + HeapLimitCursor& operator=(HeapLimitCursor&& other) noexcept { + _row_id = other._row_id; + return *this; + } + + bool operator<(const HeapLimitCursor& rhs) const { + for (int i = 0; i < _limit_columns.size(); ++i) { + const auto& _limit_column = _limit_columns[i]; + auto res = _limit_column->compare_at(_row_id, rhs._row_id, *_limit_column, + _null_directions[i]) * + _order_directions[i]; + if (res < 0) { + return true; + } else if (res > 0) { + return false; + } + } + return false; + } + + int _row_id; + MutableColumns& _limit_columns; + std::vector<int>& _order_directions; + std::vector<int>& _null_directions; + }; + + std::priority_queue<HeapLimitCursor> _limit_heap; + bool _agg_data_created_without_key = false; PODArray<AggregateDataPtr> _places; @@ -535,52 +598,7 @@ private: Status _init_hash_method(const VExprContextSPtrs& probe_exprs); template <bool limit> - Status _execute_with_serialized_key_helper(Block* block) { - DCHECK(!_probe_expr_ctxs.empty()); - - size_t key_size = _probe_expr_ctxs.size(); - ColumnRawPtrs key_columns(key_size); - { - SCOPED_TIMER(_expr_timer); - for (size_t i = 0; i < key_size; ++i) { - int result_column_id = -1; - RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, &result_column_id)); - block->get_by_position(result_column_id).column = - block->get_by_position(result_column_id) - .column->convert_to_full_column_if_const(); - key_columns[i] = block->get_by_position(result_column_id).column.get(); - } - } - - int rows = block->rows(); - if (_places.size() < rows) { - _places.resize(rows); - } - - if constexpr (limit) { - _find_in_hash_table(_places.data(), key_columns, rows); - - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add_selected( - block, _offsets_of_aggregate_states[i], _places.data(), - _agg_arena_pool.get())); - } - } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); - - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( - block, _offsets_of_aggregate_states[i], _places.data(), - _agg_arena_pool.get())); - } - - if (_should_limit_output) { - _reach_limit = _get_hash_table_size() >= _limit; - } - } - - return Status::OK(); - } + Status _execute_with_serialized_key_helper(Block* block); // We should call this function only at 1st phase. // 1st phase: is_merge=true, only have one SlotRef. @@ -599,15 +617,16 @@ private: size_t key_size = _probe_expr_ctxs.size(); ColumnRawPtrs key_columns(key_size); + std::vector<int> key_locs(key_size); for (size_t i = 0; i < key_size; ++i) { if constexpr (for_spill) { key_columns[i] = block->get_by_position(i).column.get(); + key_locs[i] = i; } else { - int result_column_id = -1; - RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, &result_column_id)); - block->replace_by_position_if_const(result_column_id); - key_columns[i] = block->get_by_position(result_column_id).column.get(); + RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(block, &key_locs[i])); + block->replace_by_position_if_const(key_locs[i]); + key_columns[i] = block->get_by_position(key_locs[i]).column.get(); } } @@ -616,7 +635,7 @@ private: _places.resize(rows); } - if constexpr (limit) { + if (limit && !_do_sort_limit) { _find_in_hash_table(_places.data(), key_columns, rows); for (int i = 0; i < _aggregate_evaluators.size(); ++i) { @@ -647,43 +666,55 @@ private: } } } else { - _emplace_into_hash_table(_places.data(), key_columns, rows); + bool need_do_agg = true; + if (limit) { + need_do_agg = _emplace_into_hash_table_limit(_places.data(), block, key_locs, + key_columns, rows); + } else { + _emplace_into_hash_table(_places.data(), key_columns, rows); + } - for (int i = 0; i < _aggregate_evaluators.size(); ++i) { - if (_aggregate_evaluators[i]->is_merge() || for_spill) { - int col_id; - if constexpr (for_spill) { - col_id = _probe_expr_ctxs.size() + i; + if (need_do_agg) { + for (int i = 0; i < _aggregate_evaluators.size(); ++i) { + if (_aggregate_evaluators[i]->is_merge() || for_spill) { + int col_id; + if constexpr (for_spill) { + col_id = _probe_expr_ctxs.size() + i; + } else { + col_id = _get_slot_column_id(_aggregate_evaluators[i]); + } + auto column = block->get_by_position(col_id).column; + if (column->is_nullable()) { + column = ((ColumnNullable*)column.get())->get_nested_column_ptr(); + } + + size_t buffer_size = + _aggregate_evaluators[i]->function()->size_of_data() * rows; + if (_deserialize_buffer.size() < buffer_size) { + _deserialize_buffer.resize(buffer_size); + } + + { + SCOPED_TIMER(_deserialize_data_timer); + _aggregate_evaluators[i]->function()->deserialize_and_merge_vec( + _places.data(), _offsets_of_aggregate_states[i], + _deserialize_buffer.data(), column.get(), _agg_arena_pool.get(), + rows); + } } else { - col_id = _get_slot_column_id(_aggregate_evaluators[i]); - } - auto column = block->get_by_position(col_id).column; - if (column->is_nullable()) { - column = ((ColumnNullable*)column.get())->get_nested_column_ptr(); - } - - size_t buffer_size = - _aggregate_evaluators[i]->function()->size_of_data() * rows; - if (_deserialize_buffer.size() < buffer_size) { - _deserialize_buffer.resize(buffer_size); - } - - { - SCOPED_TIMER(_deserialize_data_timer); - _aggregate_evaluators[i]->function()->deserialize_and_merge_vec( - _places.data(), _offsets_of_aggregate_states[i], - _deserialize_buffer.data(), column.get(), _agg_arena_pool.get(), - rows); + RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( + block, _offsets_of_aggregate_states[i], _places.data(), + _agg_arena_pool.get())); } - } else { - RETURN_IF_ERROR(_aggregate_evaluators[i]->execute_batch_add( - block, _offsets_of_aggregate_states[i], _places.data(), - _agg_arena_pool.get())); } } - if (_should_limit_output) { - _reach_limit = _get_hash_table_size() >= _limit; + if (!limit && _should_limit_output) { + const size_t hash_table_size = _get_hash_table_size(); + _reach_limit = hash_table_size >= _limit; + if (_do_sort_limit && _reach_limit) { + _build_limit_heap(hash_table_size); + } } } @@ -693,6 +724,10 @@ private: void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs& key_columns, const size_t num_rows); + bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block, + const std::vector<int>& key_locs, + ColumnRawPtrs& key_columns, size_t num_rows); + size_t _memory_usage() const; Status _reset_hash_table(); @@ -736,6 +771,12 @@ private: }; MemoryRecord _mem_usage_record; + + MutableColumns _get_keys_hash_table(); + + bool _do_limit_filter(Block* block, int key_size, size_t num_rows); + + void _build_limit_heap(size_t hash_table_size); }; } // namespace vectorized --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org