This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 77b366fc4bc4ed1bee536ea46967079ffd2afd0b
Author: Jerry Hu <mrh...@gmail.com>
AuthorDate: Wed Jan 31 17:59:27 2024 +0800

    [fix](join) incorrect result of mark join (#30543)
    
     incorrect result of mark join
---
 be/src/pipeline/exec/hashjoin_build_sink.cpp       |   8 -
 be/src/pipeline/exec/hashjoin_build_sink.h         |   7 -
 be/src/pipeline/exec/hashjoin_probe_operator.cpp   |   1 -
 be/src/pipeline/exec/hashjoin_probe_operator.h     |   7 -
 be/src/vec/common/hash_table/join_hash_table.h     | 101 ++++++------
 be/src/vec/exec/join/process_hash_table_probe.h    |  19 ++-
 .../vec/exec/join/process_hash_table_probe_impl.h  | 172 +++++++++++++--------
 be/src/vec/exec/join/vhash_join_node.cpp           |   8 -
 be/src/vec/exec/join/vhash_join_node.h             |  16 +-
 .../data/nereids_p0/join/test_mark_join.out        |  43 ++++++
 .../suites/nereids_p0/join/test_mark_join.groovy   | 126 +++++++++++++++
 11 files changed, 352 insertions(+), 156 deletions(-)

diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index 5b6689e35a5..5c8995bebe9 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -76,12 +76,6 @@ Status HashJoinBuildSinkLocalState::init(RuntimeState* 
state, LocalSinkStateInfo
         _shared_hash_table_dependency->block();
         p._shared_hashtable_controller->append_dependency(p.node_id(),
                                                           
_shared_hash_table_dependency);
-    } else {
-        if ((p._join_op == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
-             p._join_op == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN) &&
-            p._have_other_join_conjunct) {
-            _build_indexes_null = std::make_shared<std::vector<uint32_t>>();
-        }
     }
 
     _build_blocks_memory_usage =
@@ -496,7 +490,6 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* 
state, vectorized::Block*
                 state, local_state._shared_state->build_block.get(), 
&local_state, use_global_rf));
         RETURN_IF_ERROR(
                 local_state.process_build_block(state, 
(*local_state._shared_state->build_block)));
-        local_state._shared_state->build_indexes_null = 
local_state._build_indexes_null;
         if (_shared_hashtable_controller) {
             _shared_hash_table_context->status = Status::OK();
             // arena will be shared with other instances.
@@ -542,7 +535,6 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* 
state, vectorized::Block*
                         _shared_hash_table_context->hash_table_variants));
 
         local_state._shared_state->build_block = 
_shared_hash_table_context->block;
-        local_state._build_indexes_null = 
_shared_hash_table_context->build_indexes_null;
         local_state._shared_state->build_indexes_null =
                 _shared_hash_table_context->build_indexes_null;
         const bool use_global_rf =
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index 2acc25151ab..efc0a46f3ba 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -120,13 +120,6 @@ protected:
     std::shared_ptr<SharedHashTableDependency> _shared_hash_table_dependency;
     std::vector<int> _build_col_ids;
 
-    /*
-     * For null aware anti/semi join with other join conjuncts, we do need to 
care about the rows in
-     * build side with null keys,
-     * because the other join conjuncts' result may be changed from null to 
false(null & false == false).
-     */
-    std::shared_ptr<std::vector<uint32_t>> _build_indexes_null;
-
     RuntimeProfile::Counter* _build_table_timer = nullptr;
     RuntimeProfile::Counter* _build_expr_call_timer = nullptr;
     RuntimeProfile::Counter* _build_table_insert_timer = nullptr;
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.cpp 
b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
index f7a06655b19..a2fb0012ffa 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.cpp
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
@@ -301,7 +301,6 @@ Status HashJoinProbeOperatorX::pull(doris::RuntimeState* 
state, vectorized::Bloc
 
     Status st;
     if (local_state._probe_index < local_state._probe_block.rows()) {
-        local_state._build_indexes_null = 
local_state._shared_state->build_indexes_null;
         DCHECK(local_state._has_set_need_null_map_for_probe);
         RETURN_IF_CATCH_EXCEPTION({
             std::visit(
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.h 
b/be/src/pipeline/exec/hashjoin_probe_operator.h
index 1bdb9864c40..4b7f9271920 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.h
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.h
@@ -125,13 +125,6 @@ private:
     // For mark join, last probe index of null mark
     int _last_probe_null_mark;
 
-    /*
-     * For null aware anti/semi join with other join conjuncts, we do need to 
care about the rows in
-     * build side with null keys,
-     * because the other join conjuncts' result may be changed from null to 
false(null & false == false).
-     */
-    std::shared_ptr<std::vector<uint32_t>> _build_indexes_null;
-
     vectorized::Block _probe_block;
     vectorized::ColumnRawPtrs _probe_columns;
     // other expr
diff --git a/be/src/vec/common/hash_table/join_hash_table.h 
b/be/src/vec/common/hash_table/join_hash_table.h
index 08311989b5d..85665e76853 100644
--- a/be/src/vec/common/hash_table/join_hash_table.h
+++ b/be/src/vec/common/hash_table/join_hash_table.h
@@ -68,6 +68,7 @@ public:
 
     std::vector<uint8_t>& get_visited() { return visited; }
 
+    template <int JoinOpType, bool with_other_conjuncts>
     void build(const Key* __restrict keys, const uint32_t* __restrict 
bucket_nums,
                size_t num_elem) {
         build_keys = keys;
@@ -76,19 +77,24 @@ public:
             next[i] = first[bucket_num];
             first[bucket_num] = i;
         }
-        first[bucket_size] = 0; // index = bucket_num means null
+        if constexpr ((JoinOpType != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN &&
+                       JoinOpType != TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN) ||
+                      !with_other_conjuncts) {
+            /// Only null aware join with other conjuncts need to access the 
null value in hash table
+            first[bucket_size] = 0; // index = bucket_num means null
+        }
     }
 
     template <int JoinOpType, bool with_other_conjuncts, bool is_mark_join, 
bool need_judge_null>
     auto find_batch(const Key* __restrict keys, const uint32_t* __restrict 
build_idx_map,
                     int probe_idx, uint32_t build_idx, int probe_rows,
                     uint32_t* __restrict probe_idxs, bool& probe_visited,
-                    uint32_t* __restrict build_idxs, 
vectorized::ColumnFilterHelper* mark_column) {
+                    uint32_t* __restrict build_idxs) {
         if constexpr (JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
             if (_empty_build_side) {
                 return _process_null_aware_left_anti_join_for_empty_build_side<
-                        JoinOpType, with_other_conjuncts, is_mark_join>(
-                        probe_idx, probe_rows, probe_idxs, build_idxs, 
mark_column);
+                        JoinOpType, with_other_conjuncts, 
is_mark_join>(probe_idx, probe_rows,
+                                                                        
probe_idxs, build_idxs);
             }
         }
 
@@ -128,51 +134,48 @@ public:
      * select 'a' not in ('b', null) => null => 'a' != 'b' and 'a' != null => 
true and null => null
      * select 'a' not in ('a', 'b', null) => false
      */
-    auto find_null_aware_with_other_conjuncts(
-            const Key* __restrict keys, const uint32_t* __restrict 
build_idx_map, int probe_idx,
-            uint32_t build_idx, int probe_rows, uint32_t* __restrict 
probe_idxs,
-            uint32_t* __restrict build_idxs, std::set<uint32_t>& null_result,
-            const std::vector<uint32_t>& build_indexes_null, const size_t 
build_block_count) {
+    auto find_null_aware_with_other_conjuncts(const Key* __restrict keys,
+                                              const uint32_t* __restrict 
build_idx_map,
+                                              int probe_idx, uint32_t 
build_idx, int probe_rows,
+                                              uint32_t* __restrict probe_idxs,
+                                              uint32_t* __restrict build_idxs,
+                                              uint8_t* __restrict null_flags,
+                                              bool picking_null_keys) {
         auto matched_cnt = 0;
         const auto batch_size = max_batch_size;
 
-        bool has_matched = false;
         auto do_the_probe = [&]() {
+            /// If no any rows match the probe key, here start to handle null 
keys in build side.
+            /// The result of "Any = null" is null.
+            if (build_idx == 0 && !picking_null_keys) {
+                build_idx = first[bucket_size];
+                picking_null_keys = true; // now pick null from build side
+            }
+
             while (build_idx && matched_cnt < batch_size) {
-                if (build_idx == bucket_size) {
-                    /// All rows in build side should be executed with other 
join conjuncts.
-                    for (size_t i = 1; i != build_block_count; ++i) {
-                        build_idxs[matched_cnt] = i;
-                        probe_idxs[matched_cnt] = probe_idx;
-                        matched_cnt++;
-                    }
-                    null_result.emplace(probe_idx);
-                    build_idx = 0;
-                    has_matched = true;
-                    break;
-                } else if (keys[probe_idx] == build_keys[build_idx]) {
+                if (picking_null_keys || keys[probe_idx] == 
build_keys[build_idx]) {
                     build_idxs[matched_cnt] = build_idx;
                     probe_idxs[matched_cnt] = probe_idx;
+                    null_flags[matched_cnt] = picking_null_keys;
                     matched_cnt++;
-                    has_matched = true;
                 }
 
                 build_idx = next[build_idx];
+
+                // If `build_idx` is 0, all matched keys are handled,
+                // now need to handle null keys in build side.
+                if (!build_idx && !picking_null_keys) {
+                    build_idx = first[bucket_size];
+                    picking_null_keys = true; // now pick null keys from build 
side
+                }
             }
 
             // may over batch_size when emplace 0 into build_idxs
             if (!build_idx) {
-                if (!has_matched) { // has no any row matched
-                    for (auto index : build_indexes_null) {
-                        build_idxs[matched_cnt] = index;
-                        probe_idxs[matched_cnt] = probe_idx;
-                        matched_cnt++;
-                    }
-                }
                 probe_idxs[matched_cnt] = probe_idx;
                 build_idxs[matched_cnt] = 0;
+                picking_null_keys = false;
                 matched_cnt++;
-                has_matched = false;
             }
 
             probe_idx++;
@@ -184,11 +187,20 @@ public:
 
         while (probe_idx < probe_rows && matched_cnt < batch_size) {
             build_idx = build_idx_map[probe_idx];
+
+            /// If the probe key is null
+            if (build_idx == bucket_size) {
+                probe_idx++;
+                break;
+            }
             do_the_probe();
+            if (picking_null_keys) {
+                break;
+            }
         }
 
         probe_idx -= (build_idx != 0);
-        return std::tuple {probe_idx, build_idx, matched_cnt};
+        return std::tuple {probe_idx, build_idx, matched_cnt, 
picking_null_keys};
     }
 
     template <int JoinOpType>
@@ -215,21 +227,23 @@ public:
 
     bool has_null_key() { return _has_null_key; }
 
-    void pre_build_idxs(std::vector<uint32>& bucksets, const uint8_t* 
null_map) {
+    void pre_build_idxs(std::vector<uint32>& buckets, const uint8_t* null_map) 
{
         if (null_map) {
-            first[bucket_size] = bucket_size; // distinguish between not 
matched and null
-        }
-
-        for (uint32_t i = 0; i < bucksets.size(); i++) {
-            bucksets[i] = first[bucksets[i]];
+            for (unsigned int& bucket : buckets) {
+                bucket = bucket == bucket_size ? bucket_size : first[bucket];
+            }
+        } else {
+            for (unsigned int& bucket : buckets) {
+                bucket = first[bucket];
+            }
         }
     }
 
 private:
     template <int JoinOpType, bool with_other_conjuncts, bool is_mark_join>
-    auto _process_null_aware_left_anti_join_for_empty_build_side(
-            int probe_idx, int probe_rows, uint32_t* __restrict probe_idxs,
-            uint32_t* __restrict build_idxs, vectorized::ColumnFilterHelper* 
mark_column) {
+    auto _process_null_aware_left_anti_join_for_empty_build_side(int 
probe_idx, int probe_rows,
+                                                                 uint32_t* 
__restrict probe_idxs,
+                                                                 uint32_t* 
__restrict build_idxs) {
         static_assert(JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN);
         auto matched_cnt = 0;
         const auto batch_size = max_batch_size;
@@ -240,11 +254,6 @@ private:
             ++matched_cnt;
         }
 
-        if constexpr (is_mark_join && !with_other_conjuncts) {
-            // we will flip the mark column later for anti join, so here set 0 
into mark column.
-            mark_column->resize_fill(matched_cnt, 0);
-        }
-
         return std::tuple {probe_idx, 0U, matched_cnt};
     }
 
diff --git a/be/src/vec/exec/join/process_hash_table_probe.h 
b/be/src/vec/exec/join/process_hash_table_probe.h
index 02bf242e55a..9f4ddbabdcb 100644
--- a/be/src/vec/exec/join/process_hash_table_probe.h
+++ b/be/src/vec/exec/join/process_hash_table_probe.h
@@ -67,12 +67,11 @@ struct ProcessHashTableProbe {
     // each matching join column need to be processed by other join conjunct. 
so the struct of mutable block
     // and output block may be different
     // The output result is determined by the other join conjunct result and 
same_to_prev struct
-    Status do_other_join_conjuncts(Block* output_block, bool is_mark_join,
-                                   std::vector<uint8_t>& visited, bool 
has_null_in_build_side);
+    Status do_other_join_conjuncts(Block* output_block, std::vector<uint8_t>& 
visited,
+                                   bool has_null_in_build_side);
 
     template <bool with_other_conjuncts>
-    Status do_mark_join_conjuncts(Block* output_block, size_t 
hash_table_bucket_size,
-                                  const std::set<uint32_t>& null_result);
+    Status do_mark_join_conjuncts(Block* output_block, size_t 
hash_table_bucket_size);
 
     template <typename HashTableType>
     typename HashTableType::State _init_probe_side(HashTableType& 
hash_table_ctx, size_t probe_rows,
@@ -85,6 +84,10 @@ struct ProcessHashTableProbe {
     Status process_data_in_hashtable(HashTableType& hash_table_ctx, 
MutableBlock& mutable_block,
                                      Block* output_block, bool* eos);
 
+    /// For null aware join with other conjuncts, if the probe key of one row 
on left side is null,
+    /// we should make this row match with all rows in build side.
+    size_t _process_probe_null_key(uint32_t probe_idx);
+
     Parent* _parent = nullptr;
     const int _batch_size;
     const std::shared_ptr<Block>& _build_block;
@@ -93,7 +96,15 @@ struct ProcessHashTableProbe {
 
     std::vector<uint32_t> _probe_indexs;
     bool _probe_visited = false;
+    bool _picking_null_keys = false;
     std::vector<uint32_t> _build_indexs;
+    std::vector<uint8_t> _null_flags;
+
+    /// If the probe key of one row on left side is null,
+    /// we will make all rows in build side match with this row,
+    /// `_build_index_for_null_probe_key` is used to record the progress if 
the build block is too big.
+    uint32_t _build_index_for_null_probe_key {0};
+
     std::vector<int> _build_blocks_locs;
     // only need set the tuple is null in RIGHT_OUTER_JOIN and FULL_OUTER_JOIN
     ColumnUInt8::Container* _tuple_is_null_left_flags = nullptr;
diff --git a/be/src/vec/exec/join/process_hash_table_probe_impl.h 
b/be/src/vec/exec/join/process_hash_table_probe_impl.h
index 1939b702c69..06dfdac9074 100644
--- a/be/src/vec/exec/join/process_hash_table_probe_impl.h
+++ b/be/src/vec/exec/join/process_hash_table_probe_impl.h
@@ -131,6 +131,11 @@ typename HashTableType::State 
ProcessHashTableProbe<JoinOpType, Parent>::_init_p
     // may over batch size 1 for some outer join case
     _probe_indexs.resize(_batch_size + 1);
     _build_indexs.resize(_batch_size + 1);
+    if constexpr (JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
+                  JoinOpType == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN) {
+        _null_flags.resize(_batch_size + 1);
+        memset(_null_flags.data(), 0, _batch_size + 1);
+    }
 
     if (!_parent->_ready_probe) {
         _parent->_ready_probe = true;
@@ -175,26 +180,41 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_process(HashTableType& hash
     auto& mcol = mutable_block.mutable_columns();
 
     int current_offset = 0;
-
-    std::unique_ptr<ColumnFilterHelper> mark_column;
-    if (is_mark_join) {
-        mark_column = std::make_unique<ColumnFilterHelper>(*mcol[mcol.size() - 
1]);
-    }
-
-    /// `null_result` set which contains the probe indexes of null results.
-    std::set<uint32_t> null_result;
     if constexpr ((JoinOpType == doris::TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
                    JoinOpType == doris::TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN) &&
                   with_other_conjuncts) {
         SCOPED_TIMER(_search_hashtable_timer);
-        auto [new_probe_idx, new_build_idx, new_current_offset] =
-                
hash_table_ctx.hash_table->find_null_aware_with_other_conjuncts(
-                        hash_table_ctx.keys, 
hash_table_ctx.bucket_nums.data(), probe_index,
-                        build_index, probe_rows, _probe_indexs.data(), 
_build_indexs.data(),
-                        null_result, *(_parent->_build_indexes_null), 
_build_block->rows());
-        probe_index = new_probe_idx;
-        build_index = new_build_idx;
-        current_offset = new_current_offset;
+
+        /// If `_build_index_for_null_probe_key` is not zero, it means we are 
in progress of handling probe null key.
+        if (_build_index_for_null_probe_key) {
+            DCHECK_EQ(build_index, 
hash_table_ctx.hash_table->get_bucket_size());
+            current_offset = _process_probe_null_key(probe_index);
+            if (!_build_index_for_null_probe_key) {
+                probe_index++;
+                build_index = 0;
+            }
+        } else {
+            auto [new_probe_idx, new_build_idx, new_current_offset, 
picking_null_keys] =
+                    
hash_table_ctx.hash_table->find_null_aware_with_other_conjuncts(
+                            hash_table_ctx.keys, 
hash_table_ctx.bucket_nums.data(), probe_index,
+                            build_index, probe_rows, _probe_indexs.data(), 
_build_indexs.data(),
+                            _null_flags.data(), _picking_null_keys);
+            probe_index = new_probe_idx;
+            build_index = new_build_idx;
+            current_offset = new_current_offset;
+            _picking_null_keys = picking_null_keys;
+
+            if (build_index == hash_table_ctx.hash_table->get_bucket_size()) {
+                _build_index_for_null_probe_key = 1;
+                if (current_offset == 0) {
+                    current_offset = _process_probe_null_key(probe_index);
+                    if (!_build_index_for_null_probe_key) {
+                        probe_index++;
+                        build_index = 0;
+                    }
+                }
+            }
+        }
     } else {
         SCOPED_TIMER(_search_hashtable_timer);
         auto [new_probe_idx, new_build_idx,
@@ -203,7 +223,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_process(HashTableType& hash
               need_null_map_for_probe &&
                       ignore_null > (hash_table_ctx.keys, 
hash_table_ctx.bucket_nums.data(),
                                      probe_index, build_index, probe_rows, 
_probe_indexs.data(),
-                                     _probe_visited, _build_indexs.data(), 
mark_column.get());
+                                     _probe_visited, _build_indexs.data());
         probe_index = new_probe_idx;
         build_index = new_build_idx;
         current_offset = new_current_offset;
@@ -235,20 +255,76 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_process(HashTableType& hash
 
     if constexpr (is_mark_join) {
         return do_mark_join_conjuncts<with_other_conjuncts>(
-                output_block, hash_table_ctx.hash_table->get_bucket_size(), 
null_result);
+                output_block, hash_table_ctx.hash_table->get_bucket_size());
     } else if constexpr (with_other_conjuncts) {
-        return do_other_join_conjuncts(output_block, is_mark_join,
-                                       
hash_table_ctx.hash_table->get_visited(),
+        return do_other_join_conjuncts(output_block, 
hash_table_ctx.hash_table->get_visited(),
                                        
hash_table_ctx.hash_table->has_null_key());
     }
 
     return Status::OK();
 }
 
+template <int JoinOpType, typename Parent>
+size_t ProcessHashTableProbe<JoinOpType, 
Parent>::_process_probe_null_key(uint32_t probe_index) {
+    const auto rows = _build_block->rows();
+
+    DCHECK_LT(_build_index_for_null_probe_key, rows);
+    DCHECK_LT(0, _build_index_for_null_probe_key);
+    size_t matched_cnt = 0;
+    for (; _build_index_for_null_probe_key < rows && matched_cnt < 
_batch_size; ++matched_cnt) {
+        _probe_indexs[matched_cnt] = probe_index;
+        _build_indexs[matched_cnt] = _build_index_for_null_probe_key++;
+        _null_flags[matched_cnt] = 1;
+    }
+
+    if (_build_index_for_null_probe_key == rows) {
+        _build_index_for_null_probe_key = 0;
+        _probe_indexs[matched_cnt] = probe_index;
+        _build_indexs[matched_cnt] = 0;
+        _null_flags[matched_cnt] = 0;
+        matched_cnt++;
+    }
+
+    return matched_cnt;
+}
+
+/**
+     * Mark join: there is a column named mark column which stores the result 
of mark join conjunct.
+     * For example:
+     * ```sql
+     *  select * from t1 where t1.k1 not in (select t2.k1 from t2 where t2.k2 
= t1.k2 and t2.k3 > t1.k3) or t1.k4 < 10;
+     * ```
+     * equal join conjuncts: t2.k2 = t1.k2
+     * mark join conjunct: t1.k1 = t2.k1
+     * other join conjuncts: t2.k3 > t1.k3
+     * other predicates: $c$1 or t1.k4 < 10   # `$c$1` means the result of 
mark join conjunct(mark column)
+     *
+     * Executing flow:
+     *
+     * Equal join conjuncts (probe hash table)
+     *                  ↓↓
+     * Mark join conjuncts (result is nullable, stored in mark column)
+     *                  ↓↓
+     * Other join conjuncts (update the mark column)
+     *                  ↓↓
+     * Other predicates (filter rows)
+     *
+     * ```sql
+     *   select * from t1 where t1.k1 not in (select t2.k1 from t2 where t2.k3 
> t1.k3) or t1.k4 < 10;
+     * ```
+     * This sql has no equal join conjuncts:
+     * equal join conjuncts: NAN
+     * mark join conjunct: t1.k1 = t2.k1
+     * other join conjuncts: t2.k3 > t1.k3
+     * other predicates: $c$1 or t1.k4 < 10   # `$c$1` means the result of 
mark join conjunct(mark column)
+     *
+     * To avoid using nested loop join, we use the mark join conjunct(`t1.k1 = 
t2.k1`) as the equal join conjunct.
+     * So this query will be a "null aware left anti join", which means the 
equal conjunct's result should be nullable.
+     */
 template <int JoinOpType, typename Parent>
 template <bool with_other_conjuncts>
 Status ProcessHashTableProbe<JoinOpType, Parent>::do_mark_join_conjuncts(
-        Block* output_block, size_t hash_table_bucket_size, const 
std::set<uint32_t>& null_result) {
+        Block* output_block, size_t hash_table_bucket_size) {
     DCHECK(JoinOpType == TJoinOp::LEFT_ANTI_JOIN ||
            JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
            JoinOpType == TJoinOp::LEFT_SEMI_JOIN ||
@@ -260,6 +336,10 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_mark_join_conjuncts(
                                         JoinOpType == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN;
 
     const auto row_count = output_block->rows();
+    if (!row_count) {
+        return Status::OK();
+    }
+
     auto mark_column_mutable =
             
output_block->get_by_position(_parent->_mark_column_id).column->assume_mutable();
     auto& mark_column = assert_cast<ColumnNullable&>(*mark_column_mutable);
@@ -281,8 +361,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_mark_join_conjuncts(
             filter_data[i] = _build_indexs[i] != 0 && _build_indexs[i] != 
hash_table_bucket_size;
             if constexpr (is_null_aware_join) {
                 if constexpr (with_other_conjuncts) {
-                    mark_null_map[i] =
-                            null_result.contains(_probe_indexs[i]) && 
_build_indexs[i] != 0;
+                    mark_null_map[i] = _null_flags[i];
                 } else {
                     if (filter_data[i]) {
                         last_probe_matched = _probe_indexs[i];
@@ -361,7 +440,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_mark_join_conjuncts(
     if constexpr (is_anti_join) {
         // flip the mark column
         for (size_t i = 0; i != row_count; ++i) {
-            mark_filter_data[i] ^= 1;
+            mark_filter_data[i] ^= 1; // not null/ null
         }
     }
 
@@ -372,8 +451,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_mark_join_conjuncts(
 
 template <int JoinOpType, typename Parent>
 Status ProcessHashTableProbe<JoinOpType, Parent>::do_other_join_conjuncts(
-        Block* output_block, bool is_mark_join, std::vector<uint8_t>& visited,
-        bool has_null_in_build_side) {
+        Block* output_block, std::vector<uint8_t>& visited, bool 
has_null_in_build_side) {
     // dispose the other join conjunct exec
     auto row_count = output_block->rows();
     if (!row_count) {
@@ -440,30 +518,18 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_other_join_conjuncts(
         for (size_t i = 0; i < row_count; ++i) {
             bool not_matched_before = _parent->_last_probe_match != 
_probe_indexs[i];
 
-            // _build_indexs[i] == 0 means the end of this probe index
-            // if a probe row not matched with any build row, we need output a 
false value into mark column
             if constexpr (JoinOpType == TJoinOp::LEFT_SEMI_JOIN) {
                 if (_build_indexs[i] == 0) {
-                    filter_map[i] = is_mark_join && not_matched_before;
-                    filter_column_ptr[i] = false;
+                    filter_map[i] = false;
+                } else if (filter_column_ptr[i]) {
+                    filter_map[i] = not_matched_before;
+                    _parent->_last_probe_match = _probe_indexs[i];
                 } else {
-                    if (filter_column_ptr[i]) {
-                        filter_map[i] = not_matched_before;
-                        _parent->_last_probe_match = _probe_indexs[i];
-                    } else {
-                        filter_map[i] = false;
-                    }
+                    filter_map[i] = false;
                 }
             } else {
                 if (_build_indexs[i] == 0) {
-                    if (not_matched_before) {
-                        filter_map[i] = true;
-                    } else if (is_mark_join) {
-                        filter_map[i] = true;
-                        filter_column_ptr[i] = false;
-                    } else {
-                        filter_map[i] = false;
-                    }
+                    filter_map[i] = not_matched_before;
                 } else {
                     filter_map[i] = false;
                     if (filter_column_ptr[i]) {
@@ -473,21 +539,6 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_other_join_conjuncts(
             }
         }
 
-        if (is_mark_join) {
-            auto mark_column =
-                    output_block->get_by_position(orig_columns - 
1).column->assume_mutable();
-            ColumnFilterHelper helper(*mark_column);
-            for (size_t i = 0; i < row_count; ++i) {
-                bool mathced = filter_column_ptr[i] &&
-                               (_build_indexs[i] != 0) == (JoinOpType == 
TJoinOp::LEFT_SEMI_JOIN);
-                if (has_null_in_build_side && !mathced) {
-                    helper.insert_null();
-                } else {
-                    helper.insert_value(mathced);
-                }
-            }
-        }
-
         output_block->get_by_position(result_column_id).column = 
std::move(new_filter_column);
     } else if constexpr (JoinOpType == TJoinOp::RIGHT_SEMI_JOIN ||
                          JoinOpType == TJoinOp::RIGHT_ANTI_JOIN) {
@@ -512,8 +563,7 @@ Status ProcessHashTableProbe<JoinOpType, 
Parent>::do_other_join_conjuncts(
                       JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
             orig_columns = _right_col_idx;
         }
-        RETURN_IF_ERROR(Block::filter_block(output_block, result_column_id,
-                                            is_mark_join ? 
output_block->columns() : orig_columns));
+        RETURN_IF_ERROR(Block::filter_block(output_block, result_column_id, 
orig_columns));
     }
 
     return Status::OK();
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index a813ec565a4..ec630f3fe32 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -180,12 +180,6 @@ Status HashJoinNode::init(const TPlanNode& tnode, 
RuntimeState* state) {
     }
 #endif
 
-    if ((_join_op == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
-         _join_op == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN) &&
-        _have_other_join_conjunct) {
-        _build_indexes_null = std::make_shared<std::vector<uint32_t>>();
-    }
-
     _runtime_filters.resize(_runtime_filter_descs.size());
     for (size_t i = 0; i < _runtime_filter_descs.size(); i++) {
         RETURN_IF_ERROR(state->runtime_filter_mgr()->register_producer_filter(
@@ -761,7 +755,6 @@ Status HashJoinNode::sink(doris::RuntimeState* state, 
vectorized::Block* in_bloc
             // arena will be shared with other instances.
             _shared_hash_table_context->arena = _arena;
             _shared_hash_table_context->block = _build_block;
-            _shared_hash_table_context->build_indexes_null = 
_build_indexes_null;
             _shared_hash_table_context->hash_table_variants = 
_hash_table_variants;
             _shared_hash_table_context->short_circuit_for_null_in_probe_side =
                     _has_null_in_build_side;
@@ -794,7 +787,6 @@ Status HashJoinNode::sink(doris::RuntimeState* state, 
vectorized::Block* in_bloc
                 *std::static_pointer_cast<HashTableVariants>(
                         _shared_hash_table_context->hash_table_variants));
         _build_block = _shared_hash_table_context->block;
-        _build_indexes_null = _shared_hash_table_context->build_indexes_null;
 
         if (!_shared_hash_table_context->runtime_filters.empty()) {
             auto ret = std::visit(
diff --git a/be/src/vec/exec/join/vhash_join_node.h 
b/be/src/vec/exec/join/vhash_join_node.h
index 95c59094ba6..c38f8f563ea 100644
--- a/be/src/vec/exec/join/vhash_join_node.h
+++ b/be/src/vec/exec/join/vhash_join_node.h
@@ -117,11 +117,6 @@ struct ProcessHashTableBuild {
             for (uint32_t i = 1; i < _rows; i++) {
                 if ((*null_map)[i]) {
                     *has_null_key = true;
-                    if constexpr (with_other_conjuncts &&
-                                  (JoinOpType == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN ||
-                                   JoinOpType == 
TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN)) {
-                        _parent->_build_indexes_null->emplace_back(i);
-                    }
                 }
             }
             if (short_circuit_for_null && *has_null_key) {
@@ -136,8 +131,8 @@ struct ProcessHashTableBuild {
         hash_table_ctx.init_serialized_keys(_build_raw_ptrs, _rows,
                                             null_map ? null_map->data() : 
nullptr, true, true,
                                             
hash_table_ctx.hash_table->get_bucket_size());
-        hash_table_ctx.hash_table->build(hash_table_ctx.keys, 
hash_table_ctx.bucket_nums.data(),
-                                         _rows);
+        hash_table_ctx.hash_table->template build<JoinOpType, 
with_other_conjuncts>(
+                hash_table_ctx.keys, hash_table_ctx.bucket_nums.data(), _rows);
         hash_table_ctx.bucket_nums.resize(_batch_size);
         hash_table_ctx.bucket_nums.shrink_to_fit();
 
@@ -303,13 +298,6 @@ private:
     std::vector<uint16_t> _probe_column_disguise_null;
     std::vector<uint16_t> _probe_column_convert_to_null;
 
-    /*
-     * For null aware anti/semi join with other join conjuncts, we do need to 
care about the rows in
-     * build side with null keys,
-     * because the other join conjuncts' result maybe change null to 
false(null & false == false).
-     */
-    std::shared_ptr<std::vector<uint32_t>> _build_indexes_null;
-
     DataTypes _right_table_data_types;
     DataTypes _left_table_data_types;
     std::vector<std::string> _right_table_column_names;
diff --git a/regression-test/data/nereids_p0/join/test_mark_join.out 
b/regression-test/data/nereids_p0/join/test_mark_join.out
new file mode 100644
index 00000000000..4098502b75d
--- /dev/null
+++ b/regression-test/data/nereids_p0/join/test_mark_join.out
@@ -0,0 +1,43 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !mark_join1 --
+1      1       true
+2      2       true
+3      \N      true
+3      \N      true
+4      \N      \N
+
+-- !mark_join2 --
+1      1       \N
+2      2       \N
+3      \N      \N
+3      \N      true
+4      \N      true
+
+-- !mark_join3 --
+1      1       false
+2      2       false
+3      \N      false
+3      \N      false
+4      \N      false
+
+-- !mark_join4 --
+1      1       false
+2      2       false
+3      \N      \N
+3      \N      true
+4      \N      true
+
+-- !mark_join5 --
+1      1       false
+2      2       false
+3      \N      true
+3      \N      true
+4      \N      \N
+
+-- !mark_join6 --
+1      1       true
+2      2       true
+3      \N      false
+3      \N      true
+4      \N      false
+
diff --git a/regression-test/suites/nereids_p0/join/test_mark_join.groovy 
b/regression-test/suites/nereids_p0/join/test_mark_join.groovy
new file mode 100644
index 00000000000..6008919d831
--- /dev/null
+++ b/regression-test/suites/nereids_p0/join/test_mark_join.groovy
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_mark_join", "nereids_p0") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+
+    sql "drop table if exists `test_mark_join_t1`;"
+    sql "drop table if exists `test_mark_join_t2`;"
+
+    sql """
+        CREATE TABLE IF NOT EXISTS `test_mark_join_t1` (
+          k1 int not null,
+          k2 int,
+          k3 bigint,
+          v1 varchar(255) not null,
+          v2 varchar(255),
+          v3 varchar(255)
+        ) ENGINE=OLAP
+        DUPLICATE KEY(`k1`, `k2`)
+        COMMENT "OLAP"
+        DISTRIBUTED BY HASH(`k1`) BUCKETS 3
+          PROPERTIES (
+          "replication_allocation" = "tag.location.default: 1",
+          "in_memory" = "false",
+          "storage_format" = "V2"
+        );
+    """
+
+    sql """
+        CREATE TABLE IF NOT EXISTS `test_mark_join_t2` (
+          k1 int not null,
+          k2 int,
+          k3 bigint,
+          v1 varchar(255) not null,
+          v2 varchar(255),
+          v3 varchar(255)
+        ) ENGINE=OLAP
+        DUPLICATE KEY(`k1`, `k2`)
+        COMMENT "OLAP"
+        DISTRIBUTED BY HASH(`k1`) BUCKETS 3
+          PROPERTIES (
+          "replication_allocation" = "tag.location.default: 1",
+          "in_memory" = "false",
+          "storage_format" = "V2"
+        );
+    """
+
+    sql """
+        insert into `test_mark_join_t1` values
+            (1,     1,      1,      'abc',      'efg',      'hjk'),
+            (2,     2,      2,      'aabb',     'eeff',     'ccdd'),
+            (3,     null,   3,      'iii',      null,       null),
+            (3,     null,   null,   'hhhh',     null,       null),
+            (4,     null,   4,      'dddd',     'ooooo',    'kkkkk'
+        );
+    """
+
+    sql """
+        insert into `test_mark_join_t2` values
+            (1,     1,      1,      'abc',      'efg',      'hjk'),
+            (2,     2,      2,      'aabb',     'eeff',     'ccdd'),
+            (3,     null,   null,   'diid',     null,       null),
+            (3,     null,   3,      'ooekd',    null,       null),
+            (4,     4,   null,   'oepeld',   null,       'kkkkk'
+        );
+    """
+
+    qt_mark_join1 """
+        select
+            k1, k2
+            , k1 not in (select test_mark_join_t2.k2 from test_mark_join_t2 
where test_mark_join_t2.k3 < test_mark_join_t1.k3) vv
+        from test_mark_join_t1 order by 1, 2, 3;
+    """
+
+    qt_mark_join2 """
+        select
+            k1, k2
+            , k2 not in (select test_mark_join_t2.k3 from test_mark_join_t2 
where test_mark_join_t2.k2 > test_mark_join_t1.k3) vv
+        from test_mark_join_t1 order by 1, 2, 3;
+    """
+
+    qt_mark_join3 """
+        select
+            k1, k2
+            , k1 in (select test_mark_join_t2.k1 from test_mark_join_t2 where 
test_mark_join_t2.k3 < test_mark_join_t1.k3) vv
+        from test_mark_join_t1 order by 1, 2, 3;
+    """
+
+    qt_mark_join4 """
+        select
+            k1, k2
+            , k1 not in (select test_mark_join_t2.k2 from test_mark_join_t2 
where test_mark_join_t2.k3 = test_mark_join_t1.k3) vv
+        from test_mark_join_t1 order by 1, 2, 3;
+    """
+
+    qt_mark_join5 """
+        select
+            k1, k2
+            , k2 not in (select test_mark_join_t2.k3 from test_mark_join_t2 
where test_mark_join_t2.k2 = test_mark_join_t1.k3) vv
+        from test_mark_join_t1 order by 1, 2, 3;
+    """
+
+    qt_mark_join6 """
+        select
+            k1, k2
+            , k1 in (select test_mark_join_t2.k1 from test_mark_join_t2 where 
test_mark_join_t2.k3 = test_mark_join_t1.k3) vv
+        from test_mark_join_t1 order by 1, 2, 3;
+    """
+
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to