This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch new_join2
in repository https://gitbox.apache.org/repos/asf/doris.git

commit c8d72228e63422076e4b49bafe13ca7a5a128be9
Author: BiteTheDDDDt <pxl...@qq.com>
AuthorDate: Wed Nov 22 18:56:10 2023 +0800

    update fix
---
 be/src/pipeline/exec/hashjoin_build_sink.cpp       | 21 ++++++--
 be/src/vec/common/hash_table/hash_map.h            | 38 +++------------
 .../vec/exec/join/process_hash_table_probe_impl.h  | 57 ++++++++--------------
 be/src/vec/exec/join/vhash_join_node.cpp           |  4 +-
 be/test/exprs/bloom_filter_predicate_test.cpp      |  3 --
 5 files changed, 47 insertions(+), 76 deletions(-)

diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index ba7b3c0e1a0..8539274061e 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -232,6 +232,14 @@ Status 
HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state,
     RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, 
*_build_expr_call_timer, res_col_ids));
     if (p._join_op == TJoinOp::LEFT_OUTER_JOIN || p._join_op == 
TJoinOp::FULL_OUTER_JOIN) {
         _convert_block_to_null(block);
+        // first row is mocked
+        for (int i = 0; i < block.columns(); i++) {
+            assert_cast<vectorized::ColumnNullable*>(
+                    
(*std::move(block.safe_get_by_position(i).column)).mutate().get())
+                    ->get_null_map_column()
+                    .get_data()
+                    .data()[0] = 1;
+        }
     }
     // TODO: Now we are not sure whether a column is nullable only by 
ExecNode's `row_desc`
     //  so we have to initialize this flag by the first build block.
@@ -445,12 +453,17 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* 
state, vectorized::Block*
         // data from probe side.
         local_state._build_side_mem_used += in_block->allocated_bytes();
 
+        if (local_state._build_side_mutable_block.empty()) {
+            auto tmp_build_block = 
vectorized::VectorizedUtils::create_empty_columnswithtypename(
+                    _child_x->row_desc());
+            local_state._build_side_mutable_block =
+                    
vectorized::MutableBlock::build_mutable_block(&tmp_build_block);
+            RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(
+                    *(tmp_build_block.create_same_struct_block(1, false))));
+        }
+
         if (in_block->rows() != 0) {
             SCOPED_TIMER(local_state._build_side_merge_block_timer);
-            if (local_state._build_side_mutable_block.empty()) {
-                RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(
-                        *(in_block->create_same_struct_block(1, false))));
-            }
             
RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(*in_block));
             if (local_state._build_side_mutable_block.rows() >
                 std::numeric_limits<uint32_t>::max()) {
diff --git a/be/src/vec/common/hash_table/hash_map.h 
b/be/src/vec/common/hash_table/hash_map.h
index 80ff3481544..e6b5a527d74 100644
--- a/be/src/vec/common/hash_table/hash_map.h
+++ b/be/src/vec/common/hash_table/hash_map.h
@@ -22,8 +22,6 @@
 
 #include <gen_cpp/PlanNodes_types.h>
 
-#include <span>
-
 #include "common/compiler_util.h"
 #include "vec/columns/column_filter_helper.h"
 #include "vec/common/hash_table/hash.h"
@@ -260,7 +258,8 @@ public:
     template <int JoinOpType, bool with_other_conjuncts, bool is_mark_join, 
bool need_judge_null>
     auto find_batch(const Key* __restrict keys, const uint32_t* __restrict 
bucket_nums,
                     int probe_idx, uint32_t build_idx, int probe_rows,
-                    uint32_t* __restrict probe_idxs, bool& probe_visited, 
uint32_t* __restrict build_idxs,
+                    uint32_t* __restrict probe_idxs, bool& probe_visited,
+                    uint32_t* __restrict build_idxs,
                     doris::vectorized::ColumnFilterHelper* mark_column) {
         if constexpr (is_mark_join) {
             return _find_batch_mark<JoinOpType>(keys, bucket_nums, probe_idx, 
probe_rows,
@@ -277,7 +276,8 @@ public:
                       JoinOpType == doris::TJoinOp::LEFT_OUTER_JOIN ||
                       JoinOpType == doris::TJoinOp::RIGHT_OUTER_JOIN) {
             return _find_batch_inner_outer_join<JoinOpType>(keys, bucket_nums, 
probe_idx, build_idx,
-                                                            probe_rows, 
probe_idxs, probe_visited, build_idxs);
+                                                            probe_rows, 
probe_idxs, probe_visited,
+                                                            build_idxs);
         }
         if constexpr (JoinOpType == doris::TJoinOp::LEFT_ANTI_JOIN ||
                       JoinOpType == doris::TJoinOp::LEFT_SEMI_JOIN ||
@@ -392,29 +392,6 @@ private:
         return std::tuple {probe_idx, 0U, matched_cnt};
     }
 
-    auto _find_batch_left_semi_anti_conjunct(const Key* __restrict keys,
-                                             const uint32_t* __restrict 
bucket_nums, int probe_idx,
-                                             int probe_rows, uint32_t* 
__restrict probe_idxs,
-                                             uint32_t* __restrict build_idxs) {
-        auto matched_cnt = 0;
-        const auto batch_size = max_batch_size;
-
-        while (probe_idx < probe_rows && matched_cnt < batch_size) {
-            auto build_idx = first[bucket_nums[probe_idx]];
-
-            while (build_idx) {
-                if (keys[probe_idx] == build_keys[build_idx]) {
-                    probe_idxs[matched_cnt] = probe_idx;
-                    build_idxs[matched_cnt] = build_idx;
-                    matched_cnt++;
-                }
-                build_idx = next[build_idx];
-            }
-            probe_idx++;
-        }
-        return std::tuple {probe_idx, 0U, matched_cnt};
-    }
-
     template <int JoinOpType>
     auto _find_batch_conjunct(const Key* __restrict keys, const uint32_t* 
__restrict bucket_nums,
                               int probe_idx, uint32_t build_idx, int 
probe_rows,
@@ -442,7 +419,9 @@ private:
             }
 
             if constexpr (JoinOpType == doris::TJoinOp::LEFT_OUTER_JOIN ||
-                          JoinOpType == doris::TJoinOp::FULL_OUTER_JOIN) {
+                          JoinOpType == doris::TJoinOp::FULL_OUTER_JOIN ||
+                          JoinOpType == doris::TJoinOp::LEFT_ANTI_JOIN ||
+                          JoinOpType == 
doris::TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
                 // may over batch_size when emplace 0 into build_idxs
                 if (!build_idx) {
                     probe_idxs[matched_cnt] = probe_idx;
@@ -471,8 +450,7 @@ private:
     auto _find_batch_inner_outer_join(const Key* __restrict keys,
                                       const uint32_t* __restrict bucket_nums, 
int probe_idx,
                                       uint32_t build_idx, int probe_rows,
-                                      uint32_t* __restrict probe_idxs,
-                                      bool& probe_visited,
+                                      uint32_t* __restrict probe_idxs, bool& 
probe_visited,
                                       uint32_t* __restrict build_idxs) {
         auto matched_cnt = 0;
         const auto batch_size = max_batch_size;
diff --git a/be/src/vec/exec/join/process_hash_table_probe_impl.h 
b/be/src/vec/exec/join/process_hash_table_probe_impl.h
index 4cd79510d4b..937c8db453d 100644
--- a/be/src/vec/exec/join/process_hash_table_probe_impl.h
+++ b/be/src/vec/exec/join/process_hash_table_probe_impl.h
@@ -179,8 +179,8 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_process(HashTableType& hash
               with_other_conjuncts, is_mark_join,
               need_null_map_for_probe &&
                       ignore_null > (hash_table_ctx.keys, 
hash_table_ctx.bucket_nums.data(),
-                                     probe_index, build_index, probe_rows, 
_probe_indexs.data(), _probe_visited,
-                                     _build_indexs.data(), mark_column.get());
+                                     probe_index, build_index, probe_rows, 
_probe_indexs.data(),
+                                     _probe_visited, _build_indexs.data(), 
mark_column.get());
         probe_index = new_probe_idx;
         build_index = new_build_idx;
         current_offset = new_current_offset;
@@ -269,45 +269,28 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_other_join_conjuncts(
                 }
             }
         }
-        output_block->get_by_position(result_column_id).column = 
std::move(new_filter_column);
-    } else if constexpr (JoinOpType == TJoinOp::LEFT_SEMI_JOIN) {
-        auto new_filter_column = ColumnVector<UInt8>::create(row_count);
-        auto& filter_map = new_filter_column->get_data();
-
-        for (size_t i = 0; i < row_count; ++i) {
-            filter_map[i] = filter_column_ptr[i];
-        }
-
-        /// FIXME: incorrect result of semi mark join with other 
conjuncts(null value missed).
-        if (is_mark_join) {
-            auto mark_column =
-                    output_block->get_by_position(orig_columns - 
1).column->assume_mutable();
-            ColumnFilterHelper helper(*mark_column);
-
-            // For mark join, we only filter rows which have duplicate join 
keys.
-            // And then, we set matched_map to the join result to do the mark 
join's filtering.
-            for (size_t i = 0; i < row_count; ++i) {
-                helper.insert_value(filter_map[i]);
-            }
-        }
-
         output_block->get_by_position(result_column_id).column = 
std::move(new_filter_column);
     } else if constexpr (JoinOpType == TJoinOp::LEFT_ANTI_JOIN ||
-                         JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
+                         JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
+                         JoinOpType == TJoinOp::LEFT_SEMI_JOIN) {
         auto new_filter_column = ColumnVector<UInt8>::create(row_count);
         auto* __restrict filter_map = new_filter_column->get_data().data();
-
-        // for left anti join, the probe side is output only when
-        // there are no matched tuples for the probe row.
-
-        // If multiple equal-conjuncts-matched tuples is splitted into several
-        // sub blocks, just filter out all the other-conjuncts-NOT-matched 
tuples at first,
-        // and when processing the last sub block, check whether there are any
-        // equal-conjuncts-matched tuple is output in all sub blocks,
-        // if there are none, just pick a tuple and output.
-
         for (size_t i = 0; i < row_count; ++i) {
-            filter_map[i] = _build_indexs[i] && filter_column_ptr[i];
+            if (filter_column_ptr[i]) {
+                if constexpr (JoinOpType == TJoinOp::LEFT_SEMI_JOIN) {
+                    filter_map[i] = _parent->_last_probe_match != 
_probe_indexs[i];
+                    _parent->_last_probe_match = _probe_indexs[i];
+                } else {
+                    if (_build_indexs[i]) {
+                        filter_map[i] = false;
+                        _parent->_last_probe_match = _probe_indexs[i];
+                    } else {
+                        filter_map[i] = _parent->_last_probe_match != 
_probe_indexs[i];
+                    }
+                }
+            } else {
+                filter_map[i] = false;
+            }
         }
 
         if (is_mark_join) {
@@ -316,7 +299,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_other_join_conjuncts(
                                                   .column->assume_mutable()))
                                         .get_data();
             for (int i = 0; i < row_count; ++i) {
-                matched_map.push_back(!filter_map[i]);
+                matched_map.push_back(filter_map[i] ^ (JoinOpType != 
TJoinOp::LEFT_SEMI_JOIN));
             }
         }
 
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index 417dbfb4a4d..3111dacd830 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -953,8 +953,8 @@ Status HashJoinNode::_process_build_block(RuntimeState* 
state, Block& block) {
         _convert_block_to_null(block);
         // first row is mocked
         for (int i = 0; i < block.columns(); i++) {
-            assert_cast<ColumnNullable*>(
-                    
(*std::move(block.safe_get_by_position(i).column)).mutate().get())
+            auto [column, is_const] = 
unpack_if_const(block.safe_get_by_position(i).column);
+            assert_cast<ColumnNullable*>(column->assume_mutable().get())
                     ->get_null_map_column()
                     .get_data()
                     .data()[0] = 1;
diff --git a/be/test/exprs/bloom_filter_predicate_test.cpp 
b/be/test/exprs/bloom_filter_predicate_test.cpp
index 4f4ecd7c876..8c33ed13a6d 100644
--- a/be/test/exprs/bloom_filter_predicate_test.cpp
+++ b/be/test/exprs/bloom_filter_predicate_test.cpp
@@ -53,9 +53,6 @@ TEST_F(BloomFilterPredicateTest, bloom_filter_func_int_test) {
     // test not exist val
     int not_exist_val = 0x3355ff;
     EXPECT_FALSE(func->find((const void*)&not_exist_val));
-    // TEST null value
-    func->insert(nullptr);
-    func->find(nullptr);
 }
 
 TEST_F(BloomFilterPredicateTest, bloom_filter_func_stringval_test) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to