This is an automated email from the ASF dual-hosted git repository. gabriellee pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new d3f65aa746 [Improvement](join) remove unnecessary state for join (#13472) d3f65aa746 is described below commit d3f65aa746c7ab27c72a39b8d4869dd6f83979ba Author: Gabriel <gabrielleeb...@gmail.com> AuthorDate: Fri Oct 21 09:59:34 2022 +0800 [Improvement](join) remove unnecessary state for join (#13472) --- be/src/vec/exec/join/join_op.h | 243 +++++--- be/src/vec/exec/join/vhash_join_node.cpp | 947 +++++++++++++++++-------------- be/src/vec/exec/join/vhash_join_node.h | 116 ++-- be/src/vec/exec/vset_operation_node.cpp | 32 +- be/src/vec/exec/vset_operation_node.h | 133 +++-- 5 files changed, 858 insertions(+), 613 deletions(-) diff --git a/be/src/vec/exec/join/join_op.h b/be/src/vec/exec/join/join_op.h index 5c50238592..c760e6da9a 100644 --- a/be/src/vec/exec/join/join_op.h +++ b/be/src/vec/exec/join/join_op.h @@ -22,111 +22,210 @@ #include "vec/core/block.h" namespace doris::vectorized { -/// Reference to the row in block. +/** + * Now we have different kinds of RowRef for join operation. Overall, RowRef is the base class and + * the class inheritance is below: + * RowRef + * | + * --------------------------------------------------------- + * | | | + * RowRefListWithFlag RowRefList RowRefWithFlag + * | + * RowRefListWithFlags + * + * RowRef is a basic representation for a row which contains only row_num and block_offset. + * + * RowRefList is a list of many RowRefs. It used for join operations which doesn't need any flags to represent whether a row has already been visited. + * + * RowRefListWithFlag is a list of many RowRefs and an extra visited flag. It used for join operations which all RowRefs in a list has the same visited flag. + * + * RowRefWithFlag is a basic representation for a row with an extra visited flag. + * + * RowRefListWithFlags is a list of many RowRefWithFlags. This means each row will have different visited flags. It's used for join operation which has `other_conjuncts`. + */ struct RowRef { using SizeT = uint32_t; /// Do not use size_t cause of memory economy SizeT row_num = 0; uint8_t block_offset; - // Use in right join to mark row is visited - // TODO: opt the variable to use it only need - bool visited = false; - RowRef() {} - RowRef(size_t row_num_count, uint8_t block_offset_, bool is_visited = false) - : row_num(row_num_count), block_offset(block_offset_), visited(is_visited) {} + RowRef() = default; + RowRef(size_t row_num_count, uint8_t block_offset_) + : row_num(row_num_count), block_offset(block_offset_) {} }; -/// Single linked list of references to rows. Used for ALL JOINs (non-unique JOINs) -struct RowRefList : RowRef { - /// Portion of RowRefs, 16 * (MAX_SIZE + 1) bytes sized. - struct Batch { - static constexpr size_t MAX_SIZE = 7; /// Adequate values are 3, 7, 15, 31. +struct RowRefWithFlag : public RowRef { + bool visited; - SizeT size = 0; /// It's smaller than size_t but keeps align in Arena. - Batch* next; - RowRef row_refs[MAX_SIZE]; + RowRefWithFlag() = default; + RowRefWithFlag(size_t row_num_count, uint8_t block_offset_, bool is_visited = false) + : RowRef(row_num_count, block_offset_), visited(is_visited) {} +}; - Batch(Batch* parent) : next(parent) {} +/// Portion of RowRefs, 16 * (MAX_SIZE + 1) bytes sized. +template <typename RowRefType> +struct Batch { + static constexpr size_t MAX_SIZE = 7; /// Adequate values are 3, 7, 15, 31. - bool full() const { return size == MAX_SIZE; } + RowRef::SizeT size = 0; /// It's smaller than size_t but keeps align in Arena. + Batch<RowRefType>* next; + RowRefType row_refs[MAX_SIZE]; - Batch* insert(RowRef&& row_ref, Arena& pool) { - if (full()) { - auto batch = pool.alloc<Batch>(); - *batch = Batch(this); - batch->insert(std::move(row_ref), pool); - return batch; - } + Batch(Batch<RowRefType>* parent) : next(parent) {} + + bool full() const { return size == MAX_SIZE; } - row_refs[size++] = std::move(row_ref); - return this; + Batch<RowRefType>* insert(RowRefType&& row_ref, Arena& pool) { + if (full()) { + auto batch = pool.alloc<Batch<RowRefType>>(); + *batch = Batch<RowRefType>(this); + batch->insert(std::move(row_ref), pool); + return batch; } - }; - class ForwardIterator { - public: - ForwardIterator(RowRefList* begin) - : root(begin), first(true), batch(root->next), position(0) {} + row_refs[size++] = std::move(row_ref); + return this; + } +}; - RowRef& operator*() { - if (first) return *root; - return batch->row_refs[position]; - } - RowRef* operator->() { return &(**this); } +template <typename RowRefListType> +class ForwardIterator { +public: + using RowRefType = typename RowRefListType::RowRefType; + ForwardIterator(RowRefListType* begin) + : root(begin), first(true), batch(root->next), position(0) {} - bool operator==(const ForwardIterator& rhs) const { - if (ok() != rhs.ok()) { - return false; - } - if (first && rhs.first) { - return true; - } - return batch == rhs.batch && position == rhs.position; + RowRefType& operator*() { + if (first) return *root; + return batch->row_refs[position]; + } + RowRefType* operator->() { return &(**this); } + + bool operator==(const ForwardIterator<RowRefListType>& rhs) const { + if (ok() != rhs.ok()) { + return false; + } + if (first && rhs.first) { + return true; } - bool operator!=(const ForwardIterator& rhs) const { return !(*this == rhs); } + return batch == rhs.batch && position == rhs.position; + } + bool operator!=(const ForwardIterator<RowRefListType>& rhs) const { return !(*this == rhs); } - void operator++() { - if (first) { - first = false; - return; - } + void operator++() { + if (first) { + first = false; + return; + } - if (batch) { - ++position; - if (position >= batch->size) { - batch = batch->next; - position = 0; - } + if (batch) { + ++position; + if (position >= batch->size) { + batch = batch->next; + position = 0; } } + } - bool ok() const { return first || batch; } + bool ok() const { return first || batch; } - static ForwardIterator end() { return ForwardIterator(); } + static ForwardIterator<RowRefListType> end() { return ForwardIterator(); } - private: - RowRefList* root; - bool first; - Batch* batch; - size_t position; +private: + RowRefListType* root; + bool first; + Batch<RowRefType>* batch; + size_t position; + + ForwardIterator() : root(nullptr), first(false), batch(nullptr), position(0) {} +}; - ForwardIterator() : root(nullptr), first(false), batch(nullptr), position(0) {} - }; +struct RowRefList : RowRef { + using RowRefType = RowRef; - RowRefList() {} + RowRefList() = default; RowRefList(size_t row_num_, uint8_t block_offset_) : RowRef(row_num_, block_offset_) {} - ForwardIterator begin() { return ForwardIterator(this); } - static ForwardIterator end() { return ForwardIterator::end(); } + ForwardIterator<RowRefList> begin() { return ForwardIterator<RowRefList>(this); } + static ForwardIterator<RowRefList> end() { return ForwardIterator<RowRefList>::end(); } + + /// insert element after current one + void insert(RowRefType&& row_ref, Arena& pool) { + row_count++; + + if (!next) { + next = pool.alloc<Batch<RowRefType>>(); + *next = Batch<RowRefType>(nullptr); + } + next = next->insert(std::move(row_ref), pool); + } + + uint32_t get_row_count() { return row_count; } + +private: + friend class ForwardIterator<RowRefList>; + + Batch<RowRefType>* next = nullptr; + uint32_t row_count = 1; +}; + +struct RowRefListWithFlag : RowRef { + using RowRefType = RowRef; + + RowRefListWithFlag() = default; + RowRefListWithFlag(size_t row_num_, uint8_t block_offset_) : RowRef(row_num_, block_offset_) {} + + ForwardIterator<RowRefListWithFlag> begin() { + return ForwardIterator<RowRefListWithFlag>(this); + } + + static ForwardIterator<RowRefListWithFlag> end() { + return ForwardIterator<RowRefListWithFlag>::end(); + } /// insert element after current one void insert(RowRef&& row_ref, Arena& pool) { row_count++; if (!next) { - next = pool.alloc<Batch>(); - *next = Batch(nullptr); + next = pool.alloc<Batch<RowRefType>>(); + *next = Batch<RowRefType>(nullptr); + } + next = next->insert(std::move(row_ref), pool); + } + + uint32_t get_row_count() { return row_count; } + + bool visited = false; + +private: + friend class ForwardIterator<RowRefListWithFlag>; + + Batch<RowRefType>* next = nullptr; + uint32_t row_count = 1; +}; + +struct RowRefListWithFlags : RowRefWithFlag { + using RowRefType = RowRefWithFlag; + + RowRefListWithFlags() = default; + RowRefListWithFlags(size_t row_num_, uint8_t block_offset_) + : RowRefWithFlag(row_num_, block_offset_) {} + + ForwardIterator<RowRefListWithFlags> begin() { + return ForwardIterator<RowRefListWithFlags>(this); + } + static ForwardIterator<RowRefListWithFlags> end() { + return ForwardIterator<RowRefListWithFlags>::end(); + } + + /// insert element after current one + void insert(RowRefWithFlag&& row_ref, Arena& pool) { + row_count++; + + if (!next) { + next = pool.alloc<Batch<RowRefType>>(); + *next = Batch<RowRefType>(nullptr); } next = next->insert(std::move(row_ref), pool); } @@ -134,7 +233,9 @@ struct RowRefList : RowRef { uint32_t get_row_count() { return row_count; } private: - Batch* next = nullptr; + friend class ForwardIterator<RowRefListWithFlags>; + + Batch<RowRefType>* next = nullptr; uint32_t row_count = 1; }; diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp index 8126e7732d..fad0eb68cf 100644 --- a/be/src/vec/exec/join/vhash_join_node.cpp +++ b/be/src/vec/exec/join/vhash_join_node.cpp @@ -34,6 +34,16 @@ namespace doris::vectorized { // SQL hint to allow users to tune by hand. static constexpr int PREFETCH_STEP = 64; +template Status HashJoinNode::_extract_join_column<true>( + Block&, COW<IColumn>::mutable_ptr<ColumnVector<unsigned char>>&, + std::vector<IColumn const*, std::allocator<IColumn const*>>&, + std::vector<int, std::allocator<int>> const&); + +template Status HashJoinNode::_extract_join_column<false>( + Block&, COW<IColumn>::mutable_ptr<ColumnVector<unsigned char>>&, + std::vector<IColumn const*, std::allocator<IColumn const*>>&, + std::vector<int, std::allocator<int>> const&); + using ProfileCounter = RuntimeProfile::Counter; template <class HashTableContext> struct ProcessHashTableBuild { @@ -48,7 +58,8 @@ struct ProcessHashTableBuild { _offset(offset), _build_side_compute_hash_timer(join_node->_build_side_compute_hash_timer) {} - template <bool ignore_null, bool build_unique, bool has_runtime_filter> + template <bool need_null_map_for_build, bool ignore_null, bool build_unique, + bool has_runtime_filter> void run(HashTableContext& hash_table_ctx, ConstNullMapPtr null_map) { using KeyGetter = typename HashTableContext::State; using Mapped = typename HashTableContext::Mapped; @@ -81,7 +92,7 @@ struct ProcessHashTableBuild { { SCOPED_TIMER(_build_side_compute_hash_timer); for (size_t k = 0; k < _rows; ++k) { - if constexpr (ignore_null) { + if constexpr (ignore_null && need_null_map_for_build) { if ((*null_map)[k]) { continue; } @@ -97,7 +108,7 @@ struct ProcessHashTableBuild { } for (size_t k = 0; k < _rows; ++k) { - if constexpr (ignore_null) { + if constexpr (ignore_null && need_null_map_for_build) { if ((*null_map)[k]) { continue; } @@ -132,15 +143,6 @@ struct ProcessHashTableBuild { hash_table_ctx.hash_table.get_resize_timer_value()); } - template <bool ignore_null, bool build_unique, bool has_runtime_filter> - struct Reducer { - template <typename... TArgs> - static void run(ProcessHashTableBuild<HashTableContext>& build, TArgs&&... args) { - build.template run<ignore_null, build_unique, has_runtime_filter>( - std::forward<TArgs>(args)...); - } - }; - private: const int _rows; int _skip_rows; @@ -308,7 +310,7 @@ void ProcessHashTableProbe<JoinOpType, ignore_null>::probe_side_output_column( } template <class JoinOpType, bool ignore_null> -template <typename HashTableType> +template <bool need_null_map_for_probe, typename HashTableType> Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& hash_table_ctx, ConstNullMapPtr null_map, MutableBlock& mutable_block, @@ -335,11 +337,6 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& auto& mcol = mutable_block.mutable_columns(); int current_offset = 0; - constexpr auto need_to_set_visited = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN || - JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN || - JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN || - JoinOpType::value == TJoinOp::FULL_OUTER_JOIN; - constexpr auto is_right_semi_anti_join = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN || JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN; @@ -351,7 +348,7 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& { SCOPED_TIMER(_search_hashtable_timer); while (probe_index < probe_rows) { - if constexpr (ignore_null) { + if constexpr (ignore_null && need_null_map_for_probe) { if ((*null_map)[probe_index]) { _items_counts[probe_index++] = (uint32_t)0; all_match_one = false; @@ -360,7 +357,9 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& } int last_offset = current_offset; auto find_result = - (*null_map)[probe_index] + !need_null_map_for_probe + ? key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena) + : (*null_map)[probe_index] ? decltype(key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena)) {nullptr, false} : key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena); @@ -382,7 +381,9 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& // TODO: Iterators are currently considered to be a heavy operation and have a certain impact on performance. // We should rethink whether to use this iterator mode in the future. Now just opt the one row case if (mapped.get_row_count() == 1) { - if constexpr (need_to_set_visited) mapped.visited = true; + if constexpr (std::is_same_v<Mapped, RowRefListWithFlag>) { + mapped.visited = true; + } if constexpr (!is_right_semi_anti_join) { if (LIKELY(current_offset < _build_block_rows.size())) { @@ -406,7 +407,9 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& } ++current_offset; } - if constexpr (need_to_set_visited) it->visited = true; + } + if constexpr (std::is_same_v<Mapped, RowRefListWithFlag>) { + mapped.visited = true; } } } else { @@ -452,8 +455,8 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process(HashTableType& } template <class JoinOpType, bool ignore_null> -template <typename HashTableType> -Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process_with_other_join_conjunts( +template <bool need_null_map_for_probe, typename HashTableType> +Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process_with_other_join_conjuncts( HashTableType& hash_table_ctx, ConstNullMapPtr null_map, MutableBlock& mutable_block, Block* output_block, size_t probe_rows) { auto& probe_index = _join_node->_probe_index; @@ -468,232 +471,246 @@ Status ProcessHashTableProbe<JoinOpType, ignore_null>::do_process_with_other_joi using KeyGetter = typename HashTableType::State; using Mapped = typename HashTableType::Mapped; - KeyGetter key_getter(probe_raw_ptrs, _join_node->_probe_key_sz, nullptr); + if constexpr (std::is_same_v<Mapped, RowRefListWithFlags>) { + KeyGetter key_getter(probe_raw_ptrs, _join_node->_probe_key_sz, nullptr); - int right_col_idx = _join_node->_left_table_data_types.size(); - int right_col_len = _join_node->_right_table_data_types.size(); + int right_col_idx = _join_node->_left_table_data_types.size(); + int right_col_len = _join_node->_right_table_data_types.size(); - auto& mcol = mutable_block.mutable_columns(); - // use in right join to change visited state after - // exec the vother join conjunt - std::vector<bool*> visited_map; - visited_map.reserve(1.2 * _batch_size); + auto& mcol = mutable_block.mutable_columns(); + // use in right join to change visited state after + // exec the vother join conjunt + std::vector<bool*> visited_map; + visited_map.reserve(1.2 * _batch_size); - std::vector<bool> same_to_prev; - same_to_prev.reserve(1.2 * _batch_size); + std::vector<bool> same_to_prev; + same_to_prev.reserve(1.2 * _batch_size); - int current_offset = 0; + int current_offset = 0; - bool all_match_one = true; - int last_probe_index = probe_index; - while (probe_index < probe_rows) { - // ignore null rows - if constexpr (ignore_null) { - if ((*null_map)[probe_index]) { - _items_counts[probe_index++] = (uint32_t)0; - continue; + bool all_match_one = true; + int last_probe_index = probe_index; + while (probe_index < probe_rows) { + // ignore null rows + if constexpr (ignore_null && need_null_map_for_probe) { + if ((*null_map)[probe_index]) { + _items_counts[probe_index++] = (uint32_t)0; + continue; + } } - } - auto last_offset = current_offset; - auto find_result = - (*null_map)[probe_index] - ? decltype(key_getter.find_key(hash_table_ctx.hash_table, probe_index, - _arena)) {nullptr, false} - : key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena); - if (probe_index + PREFETCH_STEP < probe_rows) - key_getter.template prefetch<true>(hash_table_ctx.hash_table, - probe_index + PREFETCH_STEP, _arena); - if (find_result.is_found()) { - auto& mapped = find_result.get_mapped(); - auto origin_offset = current_offset; - // TODO: Iterators are currently considered to be a heavy operation and have a certain impact on performance. - // We should rethink whether to use this iterator mode in the future. Now just opt the one row case - if (mapped.get_row_count() == 1) { - if (LIKELY(current_offset < _build_block_rows.size())) { - _build_block_offsets[current_offset] = mapped.block_offset; - _build_block_rows[current_offset] = mapped.row_num; - } else { - _build_block_offsets.emplace_back(mapped.block_offset); - _build_block_rows.emplace_back(mapped.row_num); - } - ++current_offset; - visited_map.emplace_back(&mapped.visited); - } else { - for (auto it = mapped.begin(); it.ok(); ++it) { + auto last_offset = current_offset; + auto find_result = + !need_null_map_for_probe + ? key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena) + : (*null_map)[probe_index] + ? decltype(key_getter.find_key(hash_table_ctx.hash_table, probe_index, + _arena)) {nullptr, false} + : key_getter.find_key(hash_table_ctx.hash_table, probe_index, _arena); + if (probe_index + PREFETCH_STEP < probe_rows) + key_getter.template prefetch<true>(hash_table_ctx.hash_table, + probe_index + PREFETCH_STEP, _arena); + if (find_result.is_found()) { + auto& mapped = find_result.get_mapped(); + auto origin_offset = current_offset; + // TODO: Iterators are currently considered to be a heavy operation and have a certain impact on performance. + // We should rethink whether to use this iterator mode in the future. Now just opt the one row case + if (mapped.get_row_count() == 1) { if (LIKELY(current_offset < _build_block_rows.size())) { - _build_block_offsets[current_offset] = it->block_offset; - _build_block_rows[current_offset] = it->row_num; + _build_block_offsets[current_offset] = mapped.block_offset; + _build_block_rows[current_offset] = mapped.row_num; } else { - _build_block_offsets.emplace_back(it->block_offset); - _build_block_rows.emplace_back(it->row_num); + _build_block_offsets.emplace_back(mapped.block_offset); + _build_block_rows.emplace_back(mapped.row_num); } ++current_offset; - visited_map.emplace_back(&it->visited); + visited_map.emplace_back(&mapped.visited); + } else { + for (auto it = mapped.begin(); it.ok(); ++it) { + if (LIKELY(current_offset < _build_block_rows.size())) { + _build_block_offsets[current_offset] = it->block_offset; + _build_block_rows[current_offset] = it->row_num; + } else { + _build_block_offsets.emplace_back(it->block_offset); + _build_block_rows.emplace_back(it->row_num); + } + ++current_offset; + visited_map.emplace_back(&it->visited); + } } - } - same_to_prev.emplace_back(false); - for (int i = 0; i < current_offset - origin_offset - 1; ++i) { - same_to_prev.emplace_back(true); - } - } else if constexpr (JoinOpType::value == TJoinOp::LEFT_OUTER_JOIN || - JoinOpType::value == TJoinOp::FULL_OUTER_JOIN || - JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) { - same_to_prev.emplace_back(false); - visited_map.emplace_back(nullptr); - // only full outer / left outer need insert the data of right table - // left anti use -1 use a default value - if (LIKELY(current_offset < _build_block_rows.size())) { - _build_block_offsets[current_offset] = -1; - _build_block_rows[current_offset] = -1; + same_to_prev.emplace_back(false); + for (int i = 0; i < current_offset - origin_offset - 1; ++i) { + same_to_prev.emplace_back(true); + } + } else if constexpr (JoinOpType::value == TJoinOp::LEFT_OUTER_JOIN || + JoinOpType::value == TJoinOp::FULL_OUTER_JOIN || + JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) { + same_to_prev.emplace_back(false); + visited_map.emplace_back(nullptr); + // only full outer / left outer need insert the data of right table + // left anti use -1 use a default value + if (LIKELY(current_offset < _build_block_rows.size())) { + _build_block_offsets[current_offset] = -1; + _build_block_rows[current_offset] = -1; + } else { + _build_block_offsets.emplace_back(-1); + _build_block_rows.emplace_back(-1); + } + ++current_offset; } else { - _build_block_offsets.emplace_back(-1); - _build_block_rows.emplace_back(-1); + // other join, no nothing + } + uint32_t count = (uint32_t)(current_offset - last_offset); + _items_counts[probe_index++] = count; + all_match_one &= (count == 1); + if (current_offset >= _batch_size && !all_match_one) { + break; } - ++current_offset; - } else { - // other join, no nothing } - uint32_t count = (uint32_t)(current_offset - last_offset); - _items_counts[probe_index++] = count; - all_match_one &= (count == 1); - if (current_offset >= _batch_size && !all_match_one) { - break; + { + SCOPED_TIMER(_build_side_output_timer); + build_side_output_column<true>(mcol, right_col_idx, right_col_len, + _join_node->_right_output_slot_flags, current_offset); } - } - - { - SCOPED_TIMER(_build_side_output_timer); - build_side_output_column<true>(mcol, right_col_idx, right_col_len, - _join_node->_right_output_slot_flags, current_offset); - } - { - SCOPED_TIMER(_probe_side_output_timer); - probe_side_output_column<true>(mcol, _join_node->_left_output_slot_flags, current_offset, - last_probe_index, probe_index - last_probe_index, - all_match_one); - } - output_block->swap(mutable_block.to_block()); - - // dispose the other join conjunt exec - if (output_block->rows()) { - int result_column_id = -1; - int orig_columns = output_block->columns(); - (*_join_node->_vother_join_conjunct_ptr)->execute(output_block, &result_column_id); + { + SCOPED_TIMER(_probe_side_output_timer); + probe_side_output_column<true>(mcol, _join_node->_left_output_slot_flags, + current_offset, last_probe_index, + probe_index - last_probe_index, all_match_one); + } + output_block->swap(mutable_block.to_block()); + + // dispose the other join conjunt exec + if (output_block->rows()) { + int result_column_id = -1; + int orig_columns = output_block->columns(); + (*_join_node->_vother_join_conjunct_ptr)->execute(output_block, &result_column_id); + + auto column = output_block->get_by_position(result_column_id).column; + if constexpr (JoinOpType::value == TJoinOp::LEFT_OUTER_JOIN || + JoinOpType::value == TJoinOp::FULL_OUTER_JOIN) { + auto new_filter_column = ColumnVector<UInt8>::create(); + auto& filter_map = new_filter_column->get_data(); + + auto null_map_column = ColumnVector<UInt8>::create(column->size(), 0); + auto* __restrict null_map_data = null_map_column->get_data().data(); + + for (int i = 0; i < column->size(); ++i) { + auto join_hit = visited_map[i] != nullptr; + auto other_hit = column->get_bool(i); + + if (!other_hit) { + for (size_t j = 0; j < right_col_len; ++j) { + typeid_cast<ColumnNullable*>( + std::move(*output_block->get_by_position(j + right_col_idx) + .column) + .assume_mutable() + .get()) + ->get_null_map_data()[i] = true; + } + } + null_map_data[i] = !join_hit || !other_hit; + + if (join_hit) { + *visited_map[i] |= other_hit; + filter_map.push_back(other_hit || !same_to_prev[i] || + (!column->get_bool(i - 1) && filter_map.back())); + // Here to keep only hit join conjunt and other join conjunt is true need to be output. + // if not, only some key must keep one row will output will null right table column + if (same_to_prev[i] && filter_map.back() && !column->get_bool(i - 1)) + filter_map[i - 1] = false; + } else { + filter_map.push_back(true); + } + } - auto column = output_block->get_by_position(result_column_id).column; - if constexpr (JoinOpType::value == TJoinOp::LEFT_OUTER_JOIN || - JoinOpType::value == TJoinOp::FULL_OUTER_JOIN) { - auto new_filter_column = ColumnVector<UInt8>::create(); - auto& filter_map = new_filter_column->get_data(); - - auto null_map_column = ColumnVector<UInt8>::create(column->size(), 0); - auto* __restrict null_map_data = null_map_column->get_data().data(); - - for (int i = 0; i < column->size(); ++i) { - auto join_hit = visited_map[i] != nullptr; - auto other_hit = column->get_bool(i); - - if (!other_hit) { - for (size_t j = 0; j < right_col_len; ++j) { - typeid_cast<ColumnNullable*>( - std::move(*output_block->get_by_position(j + right_col_idx).column) - .assume_mutable() - .get()) - ->get_null_map_data()[i] = true; + for (int i = 0; i < column->size(); ++i) { + if (filter_map[i]) { + _tuple_is_null_right_flags->emplace_back(null_map_data[i]); } } - null_map_data[i] = !join_hit || !other_hit; - - if (join_hit) { - *visited_map[i] |= other_hit; - filter_map.push_back(other_hit || !same_to_prev[i] || - (!column->get_bool(i - 1) && filter_map.back())); - // Here to keep only hit join conjunt and other join conjunt is true need to be output. - // if not, only some key must keep one row will output will null right table column - if (same_to_prev[i] && filter_map.back() && !column->get_bool(i - 1)) - filter_map[i - 1] = false; - } else { - filter_map.push_back(true); + output_block->get_by_position(result_column_id).column = + std::move(new_filter_column); + } else if constexpr (JoinOpType::value == TJoinOp::LEFT_SEMI_JOIN) { + auto new_filter_column = ColumnVector<UInt8>::create(); + auto& filter_map = new_filter_column->get_data(); + + if (!column->empty()) { + filter_map.emplace_back(column->get_bool(0)); + } + for (int i = 1; i < column->size(); ++i) { + if (column->get_bool(i) || (same_to_prev[i] && filter_map[i - 1])) { + // Only last same element is true, output last one + filter_map.push_back(true); + filter_map[i - 1] = !same_to_prev[i] && filter_map[i - 1]; + } else { + filter_map.push_back(false); + } } - } - for (int i = 0; i < column->size(); ++i) { - if (filter_map[i]) { - _tuple_is_null_right_flags->emplace_back(null_map_data[i]); + output_block->get_by_position(result_column_id).column = + std::move(new_filter_column); + } else if constexpr (JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) { + auto new_filter_column = ColumnVector<UInt8>::create(); + auto& filter_map = new_filter_column->get_data(); + + if (!column->empty()) { + filter_map.emplace_back(column->get_bool(0) && visited_map[0]); } - } - output_block->get_by_position(result_column_id).column = std::move(new_filter_column); - } else if constexpr (JoinOpType::value == TJoinOp::LEFT_SEMI_JOIN) { - auto new_filter_column = ColumnVector<UInt8>::create(); - auto& filter_map = new_filter_column->get_data(); - - if (!column->empty()) filter_map.emplace_back(column->get_bool(0)); - for (int i = 1; i < column->size(); ++i) { - if (column->get_bool(i) || (same_to_prev[i] && filter_map[i - 1])) { - // Only last same element is true, output last one - filter_map.push_back(true); - filter_map[i - 1] = !same_to_prev[i] && filter_map[i - 1]; - } else { - filter_map.push_back(false); + for (int i = 1; i < column->size(); ++i) { + if ((visited_map[i] && column->get_bool(i)) || + (same_to_prev[i] && filter_map[i - 1])) { + filter_map.push_back(true); + filter_map[i - 1] = !same_to_prev[i] && filter_map[i - 1]; + } else { + filter_map.push_back(false); + } } - } - output_block->get_by_position(result_column_id).column = std::move(new_filter_column); - } else if constexpr (JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) { - auto new_filter_column = ColumnVector<UInt8>::create(); - auto& filter_map = new_filter_column->get_data(); - - if (!column->empty()) filter_map.emplace_back(column->get_bool(0) && visited_map[0]); - for (int i = 1; i < column->size(); ++i) { - if ((visited_map[i] && column->get_bool(i)) || - (same_to_prev[i] && filter_map[i - 1])) { - filter_map.push_back(true); - filter_map[i - 1] = !same_to_prev[i] && filter_map[i - 1]; - } else { - filter_map.push_back(false); + // Same to the semi join, but change the last value to opposite value + for (int i = 1; i < same_to_prev.size(); ++i) { + if (!same_to_prev[i]) filter_map[i - 1] = !filter_map[i - 1]; + } + filter_map[same_to_prev.size() - 1] = !filter_map[same_to_prev.size() - 1]; + + output_block->get_by_position(result_column_id).column = + std::move(new_filter_column); + } else if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN || + JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN) { + for (int i = 0; i < column->size(); ++i) { + DCHECK(visited_map[i]); + *visited_map[i] |= column->get_bool(i); + } + } else if constexpr (JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN) { + auto filter_size = 0; + for (int i = 0; i < column->size(); ++i) { + DCHECK(visited_map[i]); + auto result = column->get_bool(i); + *visited_map[i] |= result; + filter_size += result; } + _tuple_is_null_left_flags->resize_fill(filter_size, 0); + } else { + // inner join do nothing } - // Same to the semi join, but change the last value to opposite value - for (int i = 1; i < same_to_prev.size(); ++i) { - if (!same_to_prev[i]) filter_map[i - 1] = !filter_map[i - 1]; - } - filter_map[same_to_prev.size() - 1] = !filter_map[same_to_prev.size() - 1]; - - output_block->get_by_position(result_column_id).column = std::move(new_filter_column); - } else if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN || - JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN) { - for (int i = 0; i < column->size(); ++i) { - DCHECK(visited_map[i]); - *visited_map[i] |= column->get_bool(i); - } - } else if constexpr (JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN) { - auto filter_size = 0; - for (int i = 0; i < column->size(); ++i) { - DCHECK(visited_map[i]); - auto result = column->get_bool(i); - *visited_map[i] |= result; - filter_size += result; + if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN || + JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN) { + output_block->clear(); + } else { + if constexpr (JoinOpType::value == TJoinOp::LEFT_SEMI_JOIN || + JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) + orig_columns = right_col_idx; + Block::filter_block(output_block, result_column_id, orig_columns); } - _tuple_is_null_left_flags->resize_fill(filter_size, 0); - } else { - // inner join do nothing } - if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN || - JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN) { - output_block->clear(); - } else { - if constexpr (JoinOpType::value == TJoinOp::LEFT_SEMI_JOIN || - JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) - orig_columns = right_col_idx; - Block::filter_block(output_block, result_column_id, orig_columns); - } + return Status::OK(); + } else { + LOG(FATAL) << "Invalid RowRefList"; + return Status::InvalidArgument("Invalid RowRefList"); } - - return Status::OK(); } template <class JoinOpType, bool ignore_null> @@ -701,57 +718,80 @@ template <typename HashTableType> Status ProcessHashTableProbe<JoinOpType, ignore_null>::process_data_in_hashtable( HashTableType& hash_table_ctx, MutableBlock& mutable_block, Block* output_block, bool* eos) { - hash_table_ctx.init_once(); - auto& mcol = mutable_block.mutable_columns(); - - bool right_semi_anti_without_other = - _join_node->_is_right_semi_anti && !_join_node->_have_other_join_conjunct; - int right_col_idx = - right_semi_anti_without_other ? 0 : _join_node->_left_table_data_types.size(); - int right_col_len = _join_node->_right_table_data_types.size(); - - auto& iter = hash_table_ctx.iter; - auto block_size = 0; - - auto insert_from_hash_table = [&](uint8_t offset, uint32_t row_num) { - block_size++; - for (size_t j = 0; j < right_col_len; ++j) { - auto& column = *_build_blocks[offset].get_by_position(j).column; - mcol[j + right_col_idx]->insert_from(column, row_num); - } - }; - - for (; iter != hash_table_ctx.hash_table.end() && block_size < _batch_size; ++iter) { - auto& mapped = iter->get_second(); - for (auto it = mapped.begin(); it.ok(); ++it) { - if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN) { - if (it->visited) insert_from_hash_table(it->block_offset, it->row_num); + using Mapped = typename HashTableType::Mapped; + if constexpr (std::is_same_v<Mapped, RowRefListWithFlag> || + std::is_same_v<Mapped, RowRefListWithFlags>) { + hash_table_ctx.init_once(); + auto& mcol = mutable_block.mutable_columns(); + + bool right_semi_anti_without_other = + _join_node->_is_right_semi_anti && !_join_node->_have_other_join_conjunct; + int right_col_idx = + right_semi_anti_without_other ? 0 : _join_node->_left_table_data_types.size(); + int right_col_len = _join_node->_right_table_data_types.size(); + + auto& iter = hash_table_ctx.iter; + auto block_size = 0; + + auto insert_from_hash_table = [&](uint8_t offset, uint32_t row_num) { + block_size++; + for (size_t j = 0; j < right_col_len; ++j) { + auto& column = *_build_blocks[offset].get_by_position(j).column; + mcol[j + right_col_idx]->insert_from(column, row_num); + } + }; + + for (; iter != hash_table_ctx.hash_table.end() && block_size < _batch_size; ++iter) { + auto& mapped = iter->get_second(); + if constexpr (std::is_same_v<Mapped, RowRefListWithFlag>) { + if (mapped.visited) { + for (auto it = mapped.begin(); it.ok(); ++it) { + if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN) { + insert_from_hash_table(it->block_offset, it->row_num); + } + } + } else { + for (auto it = mapped.begin(); it.ok(); ++it) { + if constexpr (JoinOpType::value != TJoinOp::RIGHT_SEMI_JOIN) { + insert_from_hash_table(it->block_offset, it->row_num); + } + } + } } else { - if (!it->visited) insert_from_hash_table(it->block_offset, it->row_num); + for (auto it = mapped.begin(); it.ok(); ++it) { + if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN) { + if (it->visited) insert_from_hash_table(it->block_offset, it->row_num); + } else { + if (!it->visited) insert_from_hash_table(it->block_offset, it->row_num); + } + } } } - } - // just resize the left table column in case with other conjunct to make block size is not zero - if (_join_node->_is_right_semi_anti && _join_node->_have_other_join_conjunct) { - auto target_size = mcol[right_col_idx]->size(); - for (int i = 0; i < right_col_idx; ++i) { - mcol[i]->resize(target_size); + // just resize the left table column in case with other conjunct to make block size is not zero + if (_join_node->_is_right_semi_anti && _join_node->_have_other_join_conjunct) { + auto target_size = mcol[right_col_idx]->size(); + for (int i = 0; i < right_col_idx; ++i) { + mcol[i]->resize(target_size); + } } - } - // right outer join / full join need insert data of left table - if constexpr (JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN || - JoinOpType::value == TJoinOp::FULL_OUTER_JOIN) { - for (int i = 0; i < right_col_idx; ++i) { - assert_cast<ColumnNullable*>(mcol[i].get())->insert_many_defaults(block_size); + // right outer join / full join need insert data of left table + if constexpr (JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN || + JoinOpType::value == TJoinOp::FULL_OUTER_JOIN) { + for (int i = 0; i < right_col_idx; ++i) { + assert_cast<ColumnNullable*>(mcol[i].get())->insert_many_defaults(block_size); + } + _tuple_is_null_left_flags->resize_fill(block_size, 1); } - _tuple_is_null_left_flags->resize_fill(block_size, 1); + *eos = iter == hash_table_ctx.hash_table.end(); + output_block->swap( + mutable_block.to_block(right_semi_anti_without_other ? right_col_idx : 0)); + return Status::OK(); + } else { + LOG(FATAL) << "Invalid RowRefList"; + return Status::InvalidArgument("Invalid RowRefList"); } - *eos = iter == hash_table_ctx.hash_table.end(); - - output_block->swap(mutable_block.to_block(right_semi_anti_without_other ? right_col_idx : 0)); - return Status::OK(); } HashJoinNode::HashJoinNode(ObjectPool* pool, const TPlanNode& tnode, const DescriptorTbl& descs) @@ -832,6 +872,7 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { _probe_not_ignore_null.emplace_back( null_aware || (_probe_expr_ctxs.back()->root()->is_nullable() && probe_dispose_null)); + _build_side_ignore_null |= !_build_not_ignore_null.back(); } for (size_t i = 0; i < _probe_expr_ctxs.size(); ++i) { _probe_ignore_null |= !_probe_not_ignore_null[i]; @@ -840,9 +881,9 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { _probe_column_disguise_null.reserve(eq_join_conjuncts.size()); if (tnode.hash_join_node.__isset.vother_join_conjunct) { - _vother_join_conjunct_ptr.reset(new doris::vectorized::VExprContext*); - RETURN_IF_ERROR(doris::vectorized::VExpr::create_expr_tree( - _pool, tnode.hash_join_node.vother_join_conjunct, _vother_join_conjunct_ptr.get())); + _vother_join_conjunct_ptr.reset(new VExprContext*); + RETURN_IF_ERROR(VExpr::create_expr_tree(_pool, tnode.hash_join_node.vother_join_conjunct, + _vother_join_conjunct_ptr.get())); // If LEFT SEMI JOIN/LEFT ANTI JOIN with not equal predicate, // build table should not be deduplicated. @@ -1003,19 +1044,29 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo int probe_expr_ctxs_sz = _probe_expr_ctxs.size(); _probe_columns.resize(probe_expr_ctxs_sz); - if (_null_map_column == nullptr) { - _null_map_column = ColumnUInt8::create(); + + std::vector<int> res_col_ids(probe_expr_ctxs_sz); + RETURN_IF_ERROR(_do_evaluate(_probe_block, _probe_expr_ctxs, *_probe_expr_call_timer, + res_col_ids)); + // TODO: Now we are not sure whether a column is nullable only by ExecNode's `row_desc` + // so we have to initialize this flag by the first probe block. + if (!_has_set_need_null_map_for_probe) { + _has_set_need_null_map_for_probe = true; + _need_null_map_for_probe = _need_null_map<false>(_probe_block, res_col_ids); + } + if (_need_null_map_for_probe) { + if (_null_map_column == nullptr) { + _null_map_column = ColumnUInt8::create(); + } + _null_map_column->get_data().assign(probe_rows, (uint8_t)0); } - _null_map_column->get_data().assign(probe_rows, (uint8_t)0); Status st = std::visit( [&](auto&& arg) -> Status { using HashTableCtxType = std::decay_t<decltype(arg)>; if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) { - auto& null_map_val = _null_map_column->get_data(); - return _extract_probe_join_column(_probe_block, null_map_val, - _probe_columns, - *_probe_expr_call_timer); + return _extract_join_column<false>(_probe_block, _null_map_column, + _probe_columns, res_col_ids); } else { LOG(FATAL) << "FATAL: uninited hash table"; } @@ -1033,24 +1084,35 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo Block temp_block; if (_probe_index < _probe_block.rows()) { + DCHECK(_has_set_need_null_map_for_probe); std::visit( - [&](auto&& arg, auto&& process_hashtable_ctx, auto have_other_join_conjunct) { + [&](auto&& arg, auto&& process_hashtable_ctx, auto have_other_join_conjunct, + auto need_null_map_for_probe) { using HashTableProbeType = std::decay_t<decltype(process_hashtable_ctx)>; if constexpr (!std::is_same_v<HashTableProbeType, std::monostate>) { using HashTableCtxType = std::decay_t<decltype(arg)>; if constexpr (have_other_join_conjunct) { if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) { - st = process_hashtable_ctx.do_process_with_other_join_conjunts( - arg, &_null_map_column->get_data(), mutable_join_block, - &temp_block, probe_rows); + st = process_hashtable_ctx + .template do_process_with_other_join_conjuncts< + need_null_map_for_probe>( + arg, + need_null_map_for_probe + ? &_null_map_column->get_data() + : nullptr, + mutable_join_block, &temp_block, probe_rows); } else { LOG(FATAL) << "FATAL: uninited hash table"; } } else { if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) { - st = process_hashtable_ctx.do_process( - arg, &_null_map_column->get_data(), mutable_join_block, - &temp_block, probe_rows); + st = process_hashtable_ctx + .template do_process<need_null_map_for_probe>( + arg, + need_null_map_for_probe + ? &_null_map_column->get_data() + : nullptr, + mutable_join_block, &temp_block, probe_rows); } else { LOG(FATAL) << "FATAL: uninited hash table"; } @@ -1060,7 +1122,8 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo } }, _hash_table_variants, _process_hashtable_ctx_variants, - make_bool_variant(_have_other_join_conjunct)); + make_bool_variant(_have_other_join_conjunct), + make_bool_variant(_need_null_map_for_probe)); } else if (_probe_eos) { if (_is_right_semi_anti || (_is_outer_join && _join_op != TJoinOp::LEFT_OUTER_JOIN)) { std::visit( @@ -1237,35 +1300,31 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) { _hash_table_variants); } -// TODO:: unify the code of extract probe join column -Status HashJoinNode::_extract_build_join_column(Block& block, NullMap& null_map, - ColumnRawPtrs& raw_ptrs, bool& ignore_null, - RuntimeProfile::Counter& expr_call_timer) { +template <bool BuildSide> +Status HashJoinNode::_extract_join_column(Block& block, ColumnUInt8::MutablePtr& null_map, + ColumnRawPtrs& raw_ptrs, + const std::vector<int>& res_col_ids) { + DCHECK_EQ(_build_expr_ctxs.size(), _probe_expr_ctxs.size()); for (size_t i = 0; i < _build_expr_ctxs.size(); ++i) { - int result_col_id = -1; - // execute build column - { - SCOPED_TIMER(&expr_call_timer); - RETURN_IF_ERROR(_build_expr_ctxs[i]->execute(&block, &result_col_id)); - } - - // TODO: opt the column is const - block.get_by_position(result_col_id).column = - block.get_by_position(result_col_id).column->convert_to_full_column_if_const(); - if (_is_null_safe_eq_join[i]) { - raw_ptrs[i] = block.get_by_position(result_col_id).column.get(); + raw_ptrs[i] = block.get_by_position(res_col_ids[i]).column.get(); } else { - auto column = block.get_by_position(result_col_id).column.get(); + auto column = block.get_by_position(res_col_ids[i]).column.get(); if (auto* nullable = check_and_get_column<ColumnNullable>(*column)) { auto& col_nested = nullable->get_nested_column(); auto& col_nullmap = nullable->get_null_map_data(); - ignore_null |= !_build_not_ignore_null[i]; + if constexpr (!BuildSide) { + DCHECK(null_map != nullptr); + VectorizedUtils::update_null_map(null_map->get_data(), col_nullmap); + } if (_build_not_ignore_null[i]) { raw_ptrs[i] = nullable; } else { - VectorizedUtils::update_null_map(null_map, col_nullmap); + if constexpr (BuildSide) { + DCHECK(null_map != nullptr); + VectorizedUtils::update_null_map(null_map->get_data(), col_nullmap); + } raw_ptrs[i] = &col_nested; } } else { @@ -1276,49 +1335,45 @@ Status HashJoinNode::_extract_build_join_column(Block& block, NullMap& null_map, return Status::OK(); } -Status HashJoinNode::_extract_probe_join_column(Block& block, NullMap& null_map, - ColumnRawPtrs& raw_ptrs, - RuntimeProfile::Counter& expr_call_timer) { - for (size_t i = 0; i < _probe_expr_ctxs.size(); ++i) { +Status HashJoinNode::_do_evaluate(Block& block, std::vector<VExprContext*>& exprs, + RuntimeProfile::Counter& expr_call_timer, + std::vector<int>& res_col_ids) { + for (size_t i = 0; i < exprs.size(); ++i) { int result_col_id = -1; // execute build column { SCOPED_TIMER(&expr_call_timer); - RETURN_IF_ERROR(_probe_expr_ctxs[i]->execute(&block, &result_col_id)); + RETURN_IF_ERROR(exprs[i]->execute(&block, &result_col_id)); } // TODO: opt the column is const block.get_by_position(result_col_id).column = block.get_by_position(result_col_id).column->convert_to_full_column_if_const(); + res_col_ids[i] = result_col_id; + } + return Status::OK(); +} - if (_is_null_safe_eq_join[i]) { - raw_ptrs[i] = block.get_by_position(result_col_id).column.get(); - } else { - auto column = block.get_by_position(result_col_id).column.get(); - if (auto* nullable = check_and_get_column<ColumnNullable>(*column)) { - auto& col_nested = nullable->get_nested_column(); - auto& col_nullmap = nullable->get_null_map_data(); - - VectorizedUtils::update_null_map(null_map, col_nullmap); - if (_build_not_ignore_null[i]) { - raw_ptrs[i] = nullable; - } else { - raw_ptrs[i] = &col_nested; +template <bool BuildSide> +bool HashJoinNode::_need_null_map(Block& block, const std::vector<int>& res_col_ids) { + DCHECK_EQ(_build_expr_ctxs.size(), _probe_expr_ctxs.size()); + for (size_t i = 0; i < _build_expr_ctxs.size(); ++i) { + if (!_is_null_safe_eq_join[i]) { + auto column = block.get_by_position(res_col_ids[i]).column.get(); + if constexpr (BuildSide) { + if (check_and_get_column<ColumnNullable>(*column)) { + if (!_build_not_ignore_null[i]) { + return true; + } } } else { - if (_build_not_ignore_null[i]) { - auto column_ptr = - make_nullable(block.get_by_position(result_col_id).column, false); - _probe_column_disguise_null.emplace_back(block.columns()); - block.insert({column_ptr, - make_nullable(block.get_by_position(result_col_id).type), ""}); - column = column_ptr.get(); + if (check_and_get_column<ColumnNullable>(*column)) { + return true; } - raw_ptrs[i] = column; } } } - return Status::OK(); + return false; } Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block, uint8_t offset) { @@ -1334,17 +1389,26 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block, uin ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size()); - NullMap null_map_val(rows); - null_map_val.assign(rows, (uint8_t)0); - bool has_null = false; + ColumnUInt8::MutablePtr null_map_val; + std::vector<int> res_col_ids(_build_expr_ctxs.size()); + RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, *_build_expr_call_timer, res_col_ids)); + // TODO: Now we are not sure whether a column is nullable only by ExecNode's `row_desc` + // so we have to initialize this flag by the first build block. + if (!_has_set_need_null_map_for_build) { + _has_set_need_null_map_for_build = true; + _need_null_map_for_build = _need_null_map<true>(block, res_col_ids); + } + if (_need_null_map_for_build) { + null_map_val = ColumnUInt8::create(); + null_map_val->get_data().assign(rows, (uint8_t)0); + } // Get the key column that needs to be built Status st = std::visit( [&](auto&& arg) -> Status { using HashTableCtxType = std::decay_t<decltype(arg)>; if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) { - return _extract_build_join_column(block, null_map_val, raw_ptrs, has_null, - *_build_expr_call_timer); + return _extract_join_column<true>(block, null_map_val, raw_ptrs, res_col_ids); } else { LOG(FATAL) << "FATAL: uninited hash table"; } @@ -1355,126 +1419,147 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block, uin bool has_runtime_filter = !_runtime_filter_descs.empty(); std::visit( - [&](auto&& arg) { + [&](auto&& arg, auto has_null_value, auto build_unique, auto has_runtime_filter_value, + auto need_null_map_for_build) { using HashTableCtxType = std::decay_t<decltype(arg)>; if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) { ProcessHashTableBuild<HashTableCtxType> hash_table_build_process( rows, block, raw_ptrs, this, state->batch_size(), offset); - - constexpr_3_bool_match<ProcessHashTableBuild< - HashTableCtxType>::template Reducer>::run(has_null, _build_unique, - has_runtime_filter, - hash_table_build_process, arg, - &null_map_val); + hash_table_build_process.template run<need_null_map_for_build, has_null_value, + build_unique, has_runtime_filter_value>( + arg, need_null_map_for_build ? &null_map_val->get_data() : nullptr); } else { LOG(FATAL) << "FATAL: uninited hash table"; } }, - _hash_table_variants); + _hash_table_variants, make_bool_variant(_build_side_ignore_null), + make_bool_variant(_build_unique), make_bool_variant(has_runtime_filter), + make_bool_variant(_need_null_map_for_build)); return st; } void HashJoinNode::_hash_table_init() { - if (_build_expr_ctxs.size() == 1 && !_build_not_ignore_null[0]) { - // Single column optimization - switch (_build_expr_ctxs[0]->root()->result_type()) { - case TYPE_BOOLEAN: - case TYPE_TINYINT: - _hash_table_variants.emplace<I8HashTableContext>(); - break; - case TYPE_SMALLINT: - _hash_table_variants.emplace<I16HashTableContext>(); - break; - case TYPE_INT: - case TYPE_FLOAT: - case TYPE_DATEV2: - _hash_table_variants.emplace<I32HashTableContext>(); - break; - case TYPE_BIGINT: - case TYPE_DOUBLE: - case TYPE_DATETIME: - case TYPE_DATE: - case TYPE_DATETIMEV2: - _hash_table_variants.emplace<I64HashTableContext>(); - break; - case TYPE_LARGEINT: - case TYPE_DECIMALV2: - case TYPE_DECIMAL32: - case TYPE_DECIMAL64: - case TYPE_DECIMAL128: { - DataTypePtr& type_ptr = _build_expr_ctxs[0]->root()->data_type(); - TypeIndex idx = _build_expr_ctxs[0]->root()->is_nullable() - ? assert_cast<const DataTypeNullable&>(*type_ptr) - .get_nested_type() - ->get_type_id() - : type_ptr->get_type_id(); - WhichDataType which(idx); - if (which.is_decimal32()) { - _hash_table_variants.emplace<I32HashTableContext>(); - } else if (which.is_decimal64()) { - _hash_table_variants.emplace<I64HashTableContext>(); - } else { - _hash_table_variants.emplace<I128HashTableContext>(); - } - break; - } - default: - _hash_table_variants.emplace<SerializedHashTableContext>(); - } - return; - } + std::visit( + [&](auto&& join_op_variants, auto have_other_join_conjunct) { + using JoinOpType = std::decay_t<decltype(join_op_variants)>; + using RowRefListType = std::conditional_t< + have_other_join_conjunct, RowRefListWithFlags, + std::conditional_t<JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN || + JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN || + JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN || + JoinOpType::value == TJoinOp::FULL_OUTER_JOIN, + RowRefListWithFlag, RowRefList>>; + if (_build_expr_ctxs.size() == 1 && !_build_not_ignore_null[0]) { + // Single column optimization + switch (_build_expr_ctxs[0]->root()->result_type()) { + case TYPE_BOOLEAN: + case TYPE_TINYINT: + _hash_table_variants.emplace<I8HashTableContext<RowRefListType>>(); + break; + case TYPE_SMALLINT: + _hash_table_variants.emplace<I16HashTableContext<RowRefListType>>(); + break; + case TYPE_INT: + case TYPE_FLOAT: + case TYPE_DATEV2: + _hash_table_variants.emplace<I32HashTableContext<RowRefListType>>(); + break; + case TYPE_BIGINT: + case TYPE_DOUBLE: + case TYPE_DATETIME: + case TYPE_DATE: + case TYPE_DATETIMEV2: + _hash_table_variants.emplace<I64HashTableContext<RowRefListType>>(); + break; + case TYPE_LARGEINT: + case TYPE_DECIMALV2: + case TYPE_DECIMAL32: + case TYPE_DECIMAL64: + case TYPE_DECIMAL128: { + DataTypePtr& type_ptr = _build_expr_ctxs[0]->root()->data_type(); + TypeIndex idx = _build_expr_ctxs[0]->root()->is_nullable() + ? assert_cast<const DataTypeNullable&>(*type_ptr) + .get_nested_type() + ->get_type_id() + : type_ptr->get_type_id(); + WhichDataType which(idx); + if (which.is_decimal32()) { + _hash_table_variants.emplace<I32HashTableContext<RowRefListType>>(); + } else if (which.is_decimal64()) { + _hash_table_variants.emplace<I64HashTableContext<RowRefListType>>(); + } else { + _hash_table_variants.emplace<I128HashTableContext<RowRefListType>>(); + } + break; + } + default: + _hash_table_variants.emplace<SerializedHashTableContext<RowRefListType>>(); + } + return; + } - bool use_fixed_key = true; - bool has_null = false; - int key_byte_size = 0; + bool use_fixed_key = true; + bool has_null = false; + int key_byte_size = 0; - _probe_key_sz.resize(_probe_expr_ctxs.size()); - _build_key_sz.resize(_build_expr_ctxs.size()); + _probe_key_sz.resize(_probe_expr_ctxs.size()); + _build_key_sz.resize(_build_expr_ctxs.size()); - for (int i = 0; i < _build_expr_ctxs.size(); ++i) { - const auto vexpr = _build_expr_ctxs[i]->root(); - const auto& data_type = vexpr->data_type(); + for (int i = 0; i < _build_expr_ctxs.size(); ++i) { + const auto vexpr = _build_expr_ctxs[i]->root(); + const auto& data_type = vexpr->data_type(); - if (!data_type->have_maximum_size_of_value()) { - use_fixed_key = false; - break; - } + if (!data_type->have_maximum_size_of_value()) { + use_fixed_key = false; + break; + } - auto is_null = data_type->is_nullable(); - has_null |= is_null; - _build_key_sz[i] = data_type->get_maximum_size_of_value_in_memory() - (is_null ? 1 : 0); - _probe_key_sz[i] = _build_key_sz[i]; - key_byte_size += _probe_key_sz[i]; - } + auto is_null = data_type->is_nullable(); + has_null |= is_null; + _build_key_sz[i] = + data_type->get_maximum_size_of_value_in_memory() - (is_null ? 1 : 0); + _probe_key_sz[i] = _build_key_sz[i]; + key_byte_size += _probe_key_sz[i]; + } - if (std::tuple_size<KeysNullMap<UInt256>>::value + key_byte_size > sizeof(UInt256)) { - use_fixed_key = false; - } + if (std::tuple_size<KeysNullMap<UInt256>>::value + key_byte_size > + sizeof(UInt256)) { + use_fixed_key = false; + } - if (use_fixed_key) { - // TODO: may we should support uint256 in the future - if (has_null) { - if (std::tuple_size<KeysNullMap<UInt64>>::value + key_byte_size <= sizeof(UInt64)) { - _hash_table_variants.emplace<I64FixedKeyHashTableContext<true>>(); - } else if (std::tuple_size<KeysNullMap<UInt128>>::value + key_byte_size <= - sizeof(UInt128)) { - _hash_table_variants.emplace<I128FixedKeyHashTableContext<true>>(); - } else { - _hash_table_variants.emplace<I256FixedKeyHashTableContext<true>>(); - } - } else { - if (key_byte_size <= sizeof(UInt64)) { - _hash_table_variants.emplace<I64FixedKeyHashTableContext<false>>(); - } else if (key_byte_size <= sizeof(UInt128)) { - _hash_table_variants.emplace<I128FixedKeyHashTableContext<false>>(); - } else { - _hash_table_variants.emplace<I256FixedKeyHashTableContext<false>>(); - } - } - } else { - _hash_table_variants.emplace<SerializedHashTableContext>(); - } + if (use_fixed_key) { + // TODO: may we should support uint256 in the future + if (has_null) { + if (std::tuple_size<KeysNullMap<UInt64>>::value + key_byte_size <= + sizeof(UInt64)) { + _hash_table_variants + .emplace<I64FixedKeyHashTableContext<true, RowRefListType>>(); + } else if (std::tuple_size<KeysNullMap<UInt128>>::value + key_byte_size <= + sizeof(UInt128)) { + _hash_table_variants + .emplace<I128FixedKeyHashTableContext<true, RowRefListType>>(); + } else { + _hash_table_variants + .emplace<I256FixedKeyHashTableContext<true, RowRefListType>>(); + } + } else { + if (key_byte_size <= sizeof(UInt64)) { + _hash_table_variants + .emplace<I64FixedKeyHashTableContext<false, RowRefListType>>(); + } else if (key_byte_size <= sizeof(UInt128)) { + _hash_table_variants + .emplace<I128FixedKeyHashTableContext<false, RowRefListType>>(); + } else { + _hash_table_variants + .emplace<I256FixedKeyHashTableContext<false, RowRefListType>>(); + } + } + } else { + _hash_table_variants.emplace<SerializedHashTableContext<RowRefListType>>(); + } + }, + _join_op_variants, make_bool_variant(_have_other_join_conjunct)); } void HashJoinNode::_process_hashtable_ctx_variants_init(RuntimeState* state) { @@ -1548,7 +1633,7 @@ Status HashJoinNode::_build_output_block(Block* origin_block, Block* output_bloc return Status::OK(); } -void HashJoinNode::_add_tuple_is_null_column(doris::vectorized::Block* block) { +void HashJoinNode::_add_tuple_is_null_column(Block* block) { if (_is_outer_join) { auto p0 = _tuple_is_null_left_flag_column->assume_mutable(); auto p1 = _tuple_is_null_right_flag_column->assume_mutable(); diff --git a/be/src/vec/exec/join/vhash_join_node.h b/be/src/vec/exec/join/vhash_join_node.h index 2be992d5bd..5de6b0be9c 100644 --- a/be/src/vec/exec/join/vhash_join_node.h +++ b/be/src/vec/exec/join/vhash_join_node.h @@ -32,8 +32,9 @@ namespace doris { namespace vectorized { +template <typename RowRefListType> struct SerializedHashTableContext { - using Mapped = RowRefList; + using Mapped = RowRefListType; using HashTable = HashMap<StringRef, Mapped>; using State = ColumnsHashing::HashMethodSerialized<typename HashTable::value_type, Mapped>; using Iter = typename HashTable::iterator; @@ -61,9 +62,9 @@ struct IsSerializedHashTableContextTraits<ColumnsHashing::HashMethodSerialized<V }; // T should be UInt32 UInt64 UInt128 -template <class T> +template <class T, typename RowRefListType> struct PrimaryTypeHashTableContext { - using Mapped = RowRefList; + using Mapped = RowRefListType; using HashTable = HashMap<T, Mapped, HashCRC32<T>>; using State = ColumnsHashing::HashMethodOneNumber<typename HashTable::value_type, Mapped, T, false>; @@ -82,16 +83,22 @@ struct PrimaryTypeHashTableContext { }; // TODO: use FixedHashTable instead of HashTable -using I8HashTableContext = PrimaryTypeHashTableContext<UInt8>; -using I16HashTableContext = PrimaryTypeHashTableContext<UInt16>; -using I32HashTableContext = PrimaryTypeHashTableContext<UInt32>; -using I64HashTableContext = PrimaryTypeHashTableContext<UInt64>; -using I128HashTableContext = PrimaryTypeHashTableContext<UInt128>; -using I256HashTableContext = PrimaryTypeHashTableContext<UInt256>; - -template <class T, bool has_null> +template <typename RowRefListType> +using I8HashTableContext = PrimaryTypeHashTableContext<UInt8, RowRefListType>; +template <typename RowRefListType> +using I16HashTableContext = PrimaryTypeHashTableContext<UInt16, RowRefListType>; +template <typename RowRefListType> +using I32HashTableContext = PrimaryTypeHashTableContext<UInt32, RowRefListType>; +template <typename RowRefListType> +using I64HashTableContext = PrimaryTypeHashTableContext<UInt64, RowRefListType>; +template <typename RowRefListType> +using I128HashTableContext = PrimaryTypeHashTableContext<UInt128, RowRefListType>; +template <typename RowRefListType> +using I256HashTableContext = PrimaryTypeHashTableContext<UInt256, RowRefListType>; + +template <class T, bool has_null, typename RowRefListType> struct FixedKeyHashTableContext { - using Mapped = RowRefList; + using Mapped = RowRefListType; using HashTable = HashMap<T, Mapped, HashCRC32<T>>; using State = ColumnsHashing::HashMethodKeysFixed<typename HashTable::value_type, T, Mapped, has_null, false>; @@ -109,22 +116,45 @@ struct FixedKeyHashTableContext { } }; -template <bool has_null> -using I64FixedKeyHashTableContext = FixedKeyHashTableContext<UInt64, has_null>; - -template <bool has_null> -using I128FixedKeyHashTableContext = FixedKeyHashTableContext<UInt128, has_null>; - -template <bool has_null> -using I256FixedKeyHashTableContext = FixedKeyHashTableContext<UInt256, has_null>; - -using HashTableVariants = - std::variant<std::monostate, SerializedHashTableContext, I8HashTableContext, - I16HashTableContext, I32HashTableContext, I64HashTableContext, - I128HashTableContext, I256HashTableContext, I64FixedKeyHashTableContext<true>, - I64FixedKeyHashTableContext<false>, I128FixedKeyHashTableContext<true>, - I128FixedKeyHashTableContext<false>, I256FixedKeyHashTableContext<true>, - I256FixedKeyHashTableContext<false>>; +template <bool has_null, typename RowRefListType> +using I64FixedKeyHashTableContext = FixedKeyHashTableContext<UInt64, has_null, RowRefListType>; + +template <bool has_null, typename RowRefListType> +using I128FixedKeyHashTableContext = FixedKeyHashTableContext<UInt128, has_null, RowRefListType>; + +template <bool has_null, typename RowRefListType> +using I256FixedKeyHashTableContext = FixedKeyHashTableContext<UInt256, has_null, RowRefListType>; + +using HashTableVariants = std::variant< + std::monostate, SerializedHashTableContext<RowRefList>, I8HashTableContext<RowRefList>, + I16HashTableContext<RowRefList>, I32HashTableContext<RowRefList>, + I64HashTableContext<RowRefList>, I128HashTableContext<RowRefList>, + I256HashTableContext<RowRefList>, I64FixedKeyHashTableContext<true, RowRefList>, + I64FixedKeyHashTableContext<false, RowRefList>, + I128FixedKeyHashTableContext<true, RowRefList>, + I128FixedKeyHashTableContext<false, RowRefList>, + I256FixedKeyHashTableContext<true, RowRefList>, + I256FixedKeyHashTableContext<false, RowRefList>, + SerializedHashTableContext<RowRefListWithFlag>, I8HashTableContext<RowRefListWithFlag>, + I16HashTableContext<RowRefListWithFlag>, I32HashTableContext<RowRefListWithFlag>, + I64HashTableContext<RowRefListWithFlag>, I128HashTableContext<RowRefListWithFlag>, + I256HashTableContext<RowRefListWithFlag>, + I64FixedKeyHashTableContext<true, RowRefListWithFlag>, + I64FixedKeyHashTableContext<false, RowRefListWithFlag>, + I128FixedKeyHashTableContext<true, RowRefListWithFlag>, + I128FixedKeyHashTableContext<false, RowRefListWithFlag>, + I256FixedKeyHashTableContext<true, RowRefListWithFlag>, + I256FixedKeyHashTableContext<false, RowRefListWithFlag>, + SerializedHashTableContext<RowRefListWithFlags>, I8HashTableContext<RowRefListWithFlags>, + I16HashTableContext<RowRefListWithFlags>, I32HashTableContext<RowRefListWithFlags>, + I64HashTableContext<RowRefListWithFlags>, I128HashTableContext<RowRefListWithFlags>, + I256HashTableContext<RowRefListWithFlags>, + I64FixedKeyHashTableContext<true, RowRefListWithFlags>, + I64FixedKeyHashTableContext<false, RowRefListWithFlags>, + I128FixedKeyHashTableContext<true, RowRefListWithFlags>, + I128FixedKeyHashTableContext<false, RowRefListWithFlags>, + I256FixedKeyHashTableContext<true, RowRefListWithFlags>, + I256FixedKeyHashTableContext<false, RowRefListWithFlags>>; using JoinOpVariants = std::variant<std::integral_constant<TJoinOp::type, TJoinOp::INNER_JOIN>, @@ -170,18 +200,18 @@ struct ProcessHashTableProbe { // the output block struct is same with mutable block. we can do more opt on it and simplify // the logic of probe // TODO: opt the visited here to reduce the size of hash table - template <typename HashTableType> + template <bool need_null_map_for_probe, typename HashTableType> Status do_process(HashTableType& hash_table_ctx, ConstNullMapPtr null_map, MutableBlock& mutable_block, Block* output_block, size_t probe_rows); // In the presence of other join conjunt, the process of join become more complicated. // each matching join column need to be processed by other join conjunt. so the sturct of mutable block // and output block may be different // The output result is determined by the other join conjunt result and same_to_prev struct - template <typename HashTableType> - Status do_process_with_other_join_conjunts(HashTableType& hash_table_ctx, - ConstNullMapPtr null_map, - MutableBlock& mutable_block, Block* output_block, - size_t probe_rows); + template <bool need_null_map_for_probe, typename HashTableType> + Status do_process_with_other_join_conjuncts(HashTableType& hash_table_ctx, + ConstNullMapPtr null_map, + MutableBlock& mutable_block, Block* output_block, + size_t probe_rows); // Process full outer join/ right join / right semi/anti join to output the join result // in hash table @@ -325,10 +355,16 @@ private: Block _probe_block; ColumnRawPtrs _probe_columns; ColumnUInt8::MutablePtr _null_map_column; + bool _need_null_map_for_probe = false; + bool _has_set_need_null_map_for_probe = false; + bool _need_null_map_for_build = false; + bool _has_set_need_null_map_for_build = false; bool _probe_ignore_null = false; int _probe_index = -1; bool _probe_eos = false; + bool _build_side_ignore_null = false; + Sizes _probe_key_sz; Sizes _build_key_sz; @@ -359,11 +395,15 @@ private: Status _process_build_block(RuntimeState* state, Block& block, uint8_t offset); - Status _extract_build_join_column(Block& block, NullMap& null_map, ColumnRawPtrs& raw_ptrs, - bool& ignore_null, RuntimeProfile::Counter& expr_call_timer); + Status _do_evaluate(Block& block, std::vector<VExprContext*>& exprs, + RuntimeProfile::Counter& expr_call_timer, std::vector<int>& res_col_ids); + + template <bool BuildSide> + Status _extract_join_column(Block& block, ColumnUInt8::MutablePtr& null_map, + ColumnRawPtrs& raw_ptrs, const std::vector<int>& res_col_ids); - Status _extract_probe_join_column(Block& block, NullMap& null_map, ColumnRawPtrs& raw_ptrs, - RuntimeProfile::Counter& expr_call_timer); + template <bool BuildSide> + bool _need_null_map(Block& block, const std::vector<int>& res_col_ids); void _hash_table_init(); void _process_hashtable_ctx_variants_init(RuntimeState* state); diff --git a/be/src/vec/exec/vset_operation_node.cpp b/be/src/vec/exec/vset_operation_node.cpp index b95714513c..736c9786d9 100644 --- a/be/src/vec/exec/vset_operation_node.cpp +++ b/be/src/vec/exec/vset_operation_node.cpp @@ -150,16 +150,16 @@ void VSetOperationNode::hash_table_init() { switch (_child_expr_lists[0][0]->root()->result_type()) { case TYPE_BOOLEAN: case TYPE_TINYINT: - _hash_table_variants.emplace<I8HashTableContext>(); + _hash_table_variants.emplace<I8HashTableContext<RowRefListWithFlags>>(); break; case TYPE_SMALLINT: - _hash_table_variants.emplace<I16HashTableContext>(); + _hash_table_variants.emplace<I16HashTableContext<RowRefListWithFlags>>(); break; case TYPE_INT: case TYPE_FLOAT: case TYPE_DATEV2: case TYPE_DECIMAL32: - _hash_table_variants.emplace<I32HashTableContext>(); + _hash_table_variants.emplace<I32HashTableContext<RowRefListWithFlags>>(); break; case TYPE_BIGINT: case TYPE_DOUBLE: @@ -167,15 +167,15 @@ void VSetOperationNode::hash_table_init() { case TYPE_DATE: case TYPE_DECIMAL64: case TYPE_DATETIMEV2: - _hash_table_variants.emplace<I64HashTableContext>(); + _hash_table_variants.emplace<I64HashTableContext<RowRefListWithFlags>>(); break; case TYPE_LARGEINT: case TYPE_DECIMALV2: case TYPE_DECIMAL128: - _hash_table_variants.emplace<I128HashTableContext>(); + _hash_table_variants.emplace<I128HashTableContext<RowRefListWithFlags>>(); break; default: - _hash_table_variants.emplace<SerializedHashTableContext>(); + _hash_table_variants.emplace<SerializedHashTableContext<RowRefListWithFlags>>(); } return; } @@ -208,24 +208,30 @@ void VSetOperationNode::hash_table_init() { if (use_fixed_key) { if (has_null) { if (std::tuple_size<KeysNullMap<UInt64>>::value + key_byte_size <= sizeof(UInt64)) { - _hash_table_variants.emplace<I64FixedKeyHashTableContext<true>>(); + _hash_table_variants + .emplace<I64FixedKeyHashTableContext<true, RowRefListWithFlags>>(); } else if (std::tuple_size<KeysNullMap<UInt128>>::value + key_byte_size <= sizeof(UInt128)) { - _hash_table_variants.emplace<I128FixedKeyHashTableContext<true>>(); + _hash_table_variants + .emplace<I128FixedKeyHashTableContext<true, RowRefListWithFlags>>(); } else { - _hash_table_variants.emplace<I256FixedKeyHashTableContext<true>>(); + _hash_table_variants + .emplace<I256FixedKeyHashTableContext<true, RowRefListWithFlags>>(); } } else { if (key_byte_size <= sizeof(UInt64)) { - _hash_table_variants.emplace<I64FixedKeyHashTableContext<false>>(); + _hash_table_variants + .emplace<I64FixedKeyHashTableContext<false, RowRefListWithFlags>>(); } else if (key_byte_size <= sizeof(UInt128)) { - _hash_table_variants.emplace<I128FixedKeyHashTableContext<false>>(); + _hash_table_variants + .emplace<I128FixedKeyHashTableContext<false, RowRefListWithFlags>>(); } else { - _hash_table_variants.emplace<I256FixedKeyHashTableContext<false>>(); + _hash_table_variants + .emplace<I256FixedKeyHashTableContext<false, RowRefListWithFlags>>(); } } } else { - _hash_table_variants.emplace<SerializedHashTableContext>(); + _hash_table_variants.emplace<SerializedHashTableContext<RowRefListWithFlags>>(); } } diff --git a/be/src/vec/exec/vset_operation_node.h b/be/src/vec/exec/vset_operation_node.h index ba4eba3013..8c20c1447e 100644 --- a/be/src/vec/exec/vset_operation_node.h +++ b/be/src/vec/exec/vset_operation_node.h @@ -102,53 +102,57 @@ void VSetOperationNode::refresh_hash_table() { [&](auto&& arg) { using HashTableCtxType = std::decay_t<decltype(arg)>; if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) { - HashTableCtxType tmp_hash_table; - bool is_need_shrink = - arg.hash_table.should_be_shrink(_valid_element_in_hash_tbl); - if (is_need_shrink) { - tmp_hash_table.hash_table.init_buf_size( - _valid_element_in_hash_tbl / arg.hash_table.get_factor() + 1); - } + if constexpr (std::is_same_v<typename HashTableCtxType::Mapped, + RowRefListWithFlags>) { + HashTableCtxType tmp_hash_table; + bool is_need_shrink = + arg.hash_table.should_be_shrink(_valid_element_in_hash_tbl); + if (is_need_shrink) { + tmp_hash_table.hash_table.init_buf_size( + _valid_element_in_hash_tbl / arg.hash_table.get_factor() + 1); + } - arg.init_once(); - auto& iter = arg.iter; - auto iter_end = arg.hash_table.end(); - while (iter != iter_end) { - auto& mapped = iter->get_second(); - auto it = mapped.begin(); + arg.init_once(); + auto& iter = arg.iter; + auto iter_end = arg.hash_table.end(); + while (iter != iter_end) { + auto& mapped = iter->get_second(); + auto it = mapped.begin(); - if constexpr (keep_matched) { //intersected - if (it->visited) { - it->visited = false; - if (is_need_shrink) { - tmp_hash_table.hash_table.insert(iter->get_value()); - } - ++iter; - } else { - if (!is_need_shrink) { - arg.hash_table.delete_zero_key(iter->get_first()); - // the ++iter would check if the current key is zero. if it does, the iterator will be moved to the container's head. - // so we do ++iter before set_zero to make the iterator move to next valid key correctly. - auto iter_prev = iter; + if constexpr (keep_matched) { //intersected + if (it->visited) { + it->visited = false; + if (is_need_shrink) { + tmp_hash_table.hash_table.insert(iter->get_value()); + } ++iter; - iter_prev->set_zero(); } else { - ++iter; + if (!is_need_shrink) { + arg.hash_table.delete_zero_key(iter->get_first()); + // the ++iter would check if the current key is zero. if it does, the iterator will be moved to the container's head. + // so we do ++iter before set_zero to make the iterator move to next valid key correctly. + auto iter_prev = iter; + ++iter; + iter_prev->set_zero(); + } else { + ++iter; + } } + } else { //except + if (!it->visited && is_need_shrink) { + tmp_hash_table.hash_table.insert(iter->get_value()); + } + ++iter; } - } else { //except - if (!it->visited && is_need_shrink) { - tmp_hash_table.hash_table.insert(iter->get_value()); - } - ++iter; } - } - arg.inited = false; - if (is_need_shrink) { - arg.hash_table = std::move(tmp_hash_table.hash_table); + arg.inited = false; + if (is_need_shrink) { + arg.hash_table = std::move(tmp_hash_table.hash_table); + } + } else { + LOG(FATAL) << "FATAL: Invalid RowRefList"; } - } else { LOG(FATAL) << "FATAL: uninited hash table"; } @@ -179,24 +183,29 @@ struct HashTableProbe { KeyGetter key_getter(_probe_raw_ptrs, _operation_node->_probe_key_sz, nullptr); - for (; _probe_index < _probe_rows;) { - auto find_result = key_getter.find_key(hash_table_ctx.hash_table, _probe_index, _arena); - if (find_result.is_found()) { //if found, marked visited - auto it = find_result.get_mapped().begin(); - if (!(it->visited)) { - it->visited = true; - if constexpr (is_intersected) //intersected - _operation_node->_valid_element_in_hash_tbl++; - else - _operation_node->_valid_element_in_hash_tbl--; //except + if constexpr (std::is_same_v<typename HashTableContext::Mapped, RowRefListWithFlags>) { + for (; _probe_index < _probe_rows;) { + auto find_result = + key_getter.find_key(hash_table_ctx.hash_table, _probe_index, _arena); + if (find_result.is_found()) { //if found, marked visited + auto it = find_result.get_mapped().begin(); + if (!(it->visited)) { + it->visited = true; + if constexpr (is_intersected) //intersected + _operation_node->_valid_element_in_hash_tbl++; + else + _operation_node->_valid_element_in_hash_tbl--; //except + } } + _probe_index++; } - _probe_index++; + } else { + LOG(FATAL) << "Invalid RowRefListType!"; } return Status::OK(); } - void add_result_columns(RowRefList& value, int& block_size) { + void add_result_columns(RowRefListWithFlags& value, int& block_size) { auto it = value.begin(); for (auto idx = _build_col_idx.begin(); idx != _build_col_idx.end(); ++idx) { auto& column = *_build_blocks[it->block_offset].get_by_position(idx->first).column; @@ -213,18 +222,22 @@ struct HashTableProbe { auto& iter = hash_table_ctx.iter; auto block_size = 0; - for (; iter != hash_table_ctx.hash_table.end() && block_size < _batch_size; ++iter) { - auto& value = iter->get_second(); - auto it = value.begin(); - if constexpr (is_intersected) { - if (it->visited) { //intersected: have done probe, so visited values it's the result - add_result_columns(value, block_size); - } - } else { - if (!it->visited) { //except: haven't visited values it's the needed result - add_result_columns(value, block_size); + if constexpr (std::is_same_v<typename HashTableContext::Mapped, RowRefListWithFlags>) { + for (; iter != hash_table_ctx.hash_table.end() && block_size < _batch_size; ++iter) { + auto& value = iter->get_second(); + auto it = value.begin(); + if constexpr (is_intersected) { + if (it->visited) { //intersected: have done probe, so visited values it's the result + add_result_columns(value, block_size); + } + } else { + if (!it->visited) { //except: haven't visited values it's the needed result + add_result_columns(value, block_size); + } } } + } else { + LOG(FATAL) << "Invalid RowRefListType!"; } *eos = iter == hash_table_ctx.hash_table.end(); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org