This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new c699270abad [fix](mark join) mark join column should be nullable 
(#24910) (#26237)
c699270abad is described below

commit c699270abadeac590774e72060aeb9954bf4da39
Author: Jerry Hu <mrh...@gmail.com>
AuthorDate: Wed Nov 1 21:15:31 2023 +0800

    [fix](mark join) mark join column should be nullable (#24910) (#26237)
---
 be/src/pipeline/exec/hashjoin_build_sink.cpp       |  2 +-
 be/src/vec/columns/column_filter_helper.cpp        | 46 ++++++++++++
 .../columns/column_filter_helper.h}                | 23 ++++--
 .../vec/exec/join/process_hash_table_probe_impl.h  | 57 ++++++++++-----
 be/src/vec/exec/join/vhash_join_node.cpp           | 31 +++++----
 be/src/vec/exec/join/vhash_join_node.h             |  4 +-
 be/src/vec/exec/join/vjoin_node_base.cpp           |  8 +--
 be/src/vec/exec/join/vjoin_node_base.h             |  8 +--
 be/src/vec/exec/join/vnested_loop_join_node.cpp    | 25 +++----
 .../trees/expressions/MarkJoinSlotReference.java   |  6 +-
 .../org/apache/doris/nereids/util/JoinUtils.java   |  3 +-
 .../nereids_syntax_p0/sub_query_correlated.out     | 15 ++++
 .../nereids_tpcds_shape_sf100_p0/shape/query10.out | 81 +++++++++++-----------
 .../nereids_tpcds_shape_sf100_p0/shape/query35.out | 49 +++++++------
 .../nereids_syntax_p0/sub_query_correlated.groovy  | 38 +++++++++-
 15 files changed, 258 insertions(+), 138 deletions(-)

diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index 55aa25397c9..f4ab399a555 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -25,4 +25,4 @@ namespace doris::pipeline {
 
 OPERATOR_CODE_GENERATOR(HashJoinBuildSink, StreamingOperator)
 
-} // namespace doris::pipeline
\ No newline at end of file
+} // namespace doris::pipeline
diff --git a/be/src/vec/columns/column_filter_helper.cpp 
b/be/src/vec/columns/column_filter_helper.cpp
new file mode 100644
index 00000000000..f65bd8d8649
--- /dev/null
+++ b/be/src/vec/columns/column_filter_helper.cpp
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/columns/column_filter_helper.h"
+
+namespace doris::vectorized {
+ColumnFilterHelper::ColumnFilterHelper(IColumn& column_)
+        : _column(assert_cast<ColumnNullable&>(column_)),
+          
_value_column(assert_cast<ColumnUInt8&>(_column.get_nested_column())),
+          _null_map_column(_column.get_null_map_column()) {}
+
+void ColumnFilterHelper::resize_fill(size_t size, doris::vectorized::UInt8 
value) {
+    _value_column.get_data().resize_fill(size, value);
+    _null_map_column.get_data().resize_fill(size, 0);
+}
+
+void ColumnFilterHelper::insert_value(doris::vectorized::UInt8 value) {
+    _value_column.get_data().push_back(value);
+    _null_map_column.get_data().push_back(0);
+}
+
+void ColumnFilterHelper::insert_null() {
+    _value_column.insert_default();
+    _null_map_column.get_data().push_back(1);
+}
+
+void ColumnFilterHelper::reserve(size_t size) {
+    _value_column.reserve(size);
+    _null_map_column.reserve(size);
+}
+
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/vec/columns/column_filter_helper.h
similarity index 62%
copy from be/src/pipeline/exec/hashjoin_build_sink.cpp
copy to be/src/vec/columns/column_filter_helper.h
index 55aa25397c9..2dc529ef3b4 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/vec/columns/column_filter_helper.h
@@ -15,14 +15,25 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#include "hashjoin_build_sink.h"
+#pragma once
 
-#include <string>
+#include "column_nullable.h"
 
-#include "pipeline/exec/operator.h"
+namespace doris::vectorized {
+class ColumnFilterHelper {
+public:
+    ColumnFilterHelper(IColumn&);
 
-namespace doris::pipeline {
+    void resize_fill(size_t size, UInt8 value);
+    void insert_null();
+    void insert_value(UInt8 value);
+    void reserve(size_t size);
 
-OPERATOR_CODE_GENERATOR(HashJoinBuildSink, StreamingOperator)
+    [[nodiscard]] size_t size() const { return _column.size(); }
 
-} // namespace doris::pipeline
\ No newline at end of file
+private:
+    ColumnNullable& _column;
+    ColumnUInt8& _value_column;
+    ColumnUInt8& _null_map_column;
+};
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/exec/join/process_hash_table_probe_impl.h 
b/be/src/vec/exec/join/process_hash_table_probe_impl.h
index d32bfdbe37b..5e54d826629 100644
--- a/be/src/vec/exec/join/process_hash_table_probe_impl.h
+++ b/be/src/vec/exec/join/process_hash_table_probe_impl.h
@@ -21,6 +21,7 @@
 #include "process_hash_table_probe.h"
 #include "runtime/thread_context.h" // IWYU pragma: keep
 #include "util/simd/bits.h"
+#include "vec/columns/column_filter_helper.h"
 #include "vec/exprs/vexpr_context.h"
 #include "vhash_join_node.h"
 
@@ -212,6 +213,11 @@ Status 
ProcessHashTableProbe<JoinOpType>::do_process(HashTableType& hash_table_c
     size_t probe_size = 0;
     auto& probe_row_match_iter =
             
std::get<ForwardIterator<Mapped>>(_join_node->_probe_row_match_iter);
+
+    std::unique_ptr<ColumnFilterHelper> mark_column;
+    if (is_mark_join) {
+        mark_column = std::make_unique<ColumnFilterHelper>(*mcol[mcol.size() - 
1]);
+    }
     {
         SCOPED_TIMER(_search_hashtable_timer);
         if constexpr (!is_right_semi_anti_join) {
@@ -285,9 +291,14 @@ Status 
ProcessHashTableProbe<JoinOpType>::do_process(HashTableType& hash_table_c
                               JoinOpType == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
                     if (is_mark_join) {
                         ++current_offset;
-                        
assert_cast<doris::vectorized::ColumnVector<UInt8>&>(*mcol[mcol.size() - 1])
-                                .get_data()
-                                .template push_back(!find_result.is_found());
+                        bool null_result =
+                                (*null_map)[probe_index] ||
+                                (!find_result.is_found() && 
_join_node->_has_null_in_build_side);
+                        if (null_result) {
+                            mark_column->insert_null();
+                        } else {
+                            mark_column->insert_value(!find_result.is_found());
+                        }
                     } else {
                         if (!find_result.is_found()) {
                             ++current_offset;
@@ -297,9 +308,14 @@ Status 
ProcessHashTableProbe<JoinOpType>::do_process(HashTableType& hash_table_c
                 } else if constexpr (JoinOpType == TJoinOp::LEFT_SEMI_JOIN) {
                     if (is_mark_join) {
                         ++current_offset;
-                        
assert_cast<doris::vectorized::ColumnVector<UInt8>&>(*mcol[mcol.size() - 1])
-                                .get_data()
-                                .template push_back(find_result.is_found());
+                        bool null_result =
+                                (*null_map)[probe_index] ||
+                                (!find_result.is_found() && 
_join_node->_has_null_in_build_side);
+                        if (null_result) {
+                            mark_column->insert_null();
+                        } else {
+                            mark_column->insert_value(find_result.is_found());
+                        }
                     } else {
                         if (find_result.is_found()) {
                             ++current_offset;
@@ -840,20 +856,20 @@ Status 
ProcessHashTableProbe<JoinOpType>::do_process_with_other_join_conjuncts(
                 }
 
                 if (is_mark_join) {
-                    auto& matched_map = 
assert_cast<doris::vectorized::ColumnVector<UInt8>&>(
-                                                
*(output_block->get_by_position(num_cols - 1)
-                                                          
.column->assume_mutable()))
-                                                .get_data();
+                    /// FIXME: incorrect result of semi mark join with other 
conjuncts(null value missed).
+                    auto mark_column =
+                            output_block->get_by_position(num_cols - 
1).column->assume_mutable();
+                    ColumnFilterHelper helper(*mark_column);
 
                     // For mark join, we only filter rows which have duplicate 
join keys.
                     // And then, we set matched_map to the join result to do 
the mark join's filtering.
                     for (size_t i = 1; i < row_count; ++i) {
                         if (!same_to_prev[i]) {
-                            matched_map.push_back(filter_map[i - 1]);
+                            helper.insert_value(filter_map[i - 1]);
                             filter_map[i - 1] = true;
                         }
                     }
-                    matched_map.push_back(filter_map[filter_map.size() - 1]);
+                    helper.insert_value(filter_map[filter_map.size() - 1]);
                     filter_map[filter_map.size() - 1] = true;
                 }
 
@@ -913,17 +929,20 @@ Status 
ProcessHashTableProbe<JoinOpType>::do_process_with_other_join_conjuncts(
                 }
 
                 if (is_mark_join) {
-                    auto& matched_map = 
assert_cast<doris::vectorized::ColumnVector<UInt8>&>(
-                                                
*(output_block->get_by_position(num_cols - 1)
-                                                          
.column->assume_mutable()))
-                                                .get_data();
-                    for (int i = 1; i < row_count; ++i) {
+                    /// FIXME: incorrect result of semi mark join with other 
conjuncts(null value missed).
+                    auto mark_column =
+                            output_block->get_by_position(num_cols - 
1).column->assume_mutable();
+                    ColumnFilterHelper helper(*mark_column);
+
+                    // For mark join, we only filter rows which have duplicate 
join keys.
+                    // And then, we set matched_map to the join result to do 
the mark join's filtering.
+                    for (size_t i = 1; i < row_count; ++i) {
                         if (!same_to_prev[i]) {
-                            matched_map.push_back(!filter_map[i - 1]);
+                            helper.insert_value(filter_map[i - 1]);
                             filter_map[i - 1] = true;
                         }
                     }
-                    matched_map.push_back(!filter_map[row_count - 1]);
+                    helper.insert_value(filter_map[row_count - 1]);
                     filter_map[row_count - 1] = true;
                 } else {
                     int end_row_idx;
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index f6dafef06a7..c8bf5c9b36d 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -180,6 +180,9 @@ struct ProcessHashTableBuild {
                 }
                 if constexpr (ignore_null) {
                     if ((*null_map)[k]) {
+                        if (has_null_key != nullptr) {
+                            *has_null_key = true;
+                        }
                         continue;
                     }
                 }
@@ -535,11 +538,10 @@ Status HashJoinNode::pull(doris::RuntimeState* state, 
vectorized::Block* output_
         return Status::OK();
     }
 
-    if (_short_circuit_for_null_in_probe_side && _join_op == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN &&
-        _is_mark_join) {
-        /// If `_short_circuit_for_null_in_probe_side` is true, this indicates 
no rows
-        /// match the join condition, and this is 'mark join', so we need to 
create a column as mark
-        /// with all rows set to 0.
+    /// `_has_null_in_build_side` means have null value in build side.
+    /// `_short_circuit_for_null_in_build_side` means short circuit if has 
null in build side(e.g. null aware left anti join).
+    if (_has_null_in_build_side && _short_circuit_for_null_in_build_side && 
_is_mark_join) {
+        /// We need to create a column as mark with all rows set to NULL.
         auto block_rows = _probe_block.rows();
         if (block_rows == 0) {
             *eos = _probe_eos;
@@ -553,8 +555,10 @@ Status HashJoinNode::pull(doris::RuntimeState* state, 
vectorized::Block* output_
                 temp_block.insert(_probe_block.get_by_position(i));
             }
         }
-        auto mark_column = ColumnUInt8::create(block_rows, 0);
-        temp_block.insert({std::move(mark_column), 
std::make_shared<DataTypeUInt8>(), ""});
+        auto mark_column = 
ColumnNullable::create(ColumnUInt8::create(block_rows, 0),
+                                                  
ColumnUInt8::create(block_rows, 1));
+        temp_block.insert(
+                {std::move(mark_column), 
make_nullable(std::make_shared<DataTypeUInt8>()), ""});
 
         {
             SCOPED_TIMER(_join_filter_timer);
@@ -878,7 +882,7 @@ Status HashJoinNode::_materialize_build_side(RuntimeState* 
state) {
         Block block;
         // If eos or have already met a null value using short-circuit 
strategy, we do not need to pull
         // data from data.
-        while (!eos && !_short_circuit_for_null_in_probe_side) {
+        while (!eos && (!_short_circuit_for_null_in_build_side || 
!_has_null_in_build_side)) {
             block.clear_column_data();
             RETURN_IF_CANCELLED(state);
             {
@@ -906,8 +910,8 @@ Status HashJoinNode::sink(doris::RuntimeState* state, 
vectorized::Block* in_bloc
     // make one block for each 4 gigabytes
     constexpr static auto BUILD_BLOCK_MAX_SIZE = 4 * 1024UL * 1024UL * 1024UL;
 
-    if (_short_circuit_for_null_in_probe_side) {
-        // TODO: if _short_circuit_for_null_in_probe_side is true we should 
finish current pipeline task.
+    if (_has_null_in_build_side) {
+        // TODO: if _has_null_in_build_side is true we should finish current 
pipeline task.
         DCHECK(state->enable_pipeline_exec());
         return Status::OK();
     }
@@ -979,7 +983,7 @@ Status HashJoinNode::sink(doris::RuntimeState* state, 
vectorized::Block* in_bloc
             _shared_hash_table_context->blocks = _build_blocks;
             _shared_hash_table_context->hash_table_variants = 
_hash_table_variants;
             _shared_hash_table_context->short_circuit_for_null_in_probe_side =
-                    _short_circuit_for_null_in_probe_side;
+                    _has_null_in_build_side;
             if (_runtime_filter_slots) {
                 
_runtime_filter_slots->copy_to_shared_context(_shared_hash_table_context);
             }
@@ -997,8 +1001,7 @@ Status HashJoinNode::sink(doris::RuntimeState* state, 
vectorized::Block* in_bloc
         _build_phase_profile->add_info_string(
                 "SharedHashTableFrom",
                 
print_id(_shared_hashtable_controller->get_builder_fragment_instance_id(id())));
-        _short_circuit_for_null_in_probe_side =
-                
_shared_hash_table_context->short_circuit_for_null_in_probe_side;
+        _has_null_in_build_side = 
_shared_hash_table_context->short_circuit_for_null_in_probe_side;
         _hash_table_variants = std::static_pointer_cast<HashTableVariants>(
                 _shared_hash_table_context->hash_table_variants);
         _build_blocks = _shared_hash_table_context->blocks;
@@ -1183,7 +1186,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* 
state, Block& block, uin
                                         has_null_value || 
short_circuit_for_null_in_build_side
                                                 ? &null_map_val->get_data()
                                                 : nullptr,
-                                        
&_short_circuit_for_null_in_probe_side);
+                                        &_has_null_in_build_side);
                     }},
             *_hash_table_variants, make_bool_variant(_build_side_ignore_null),
             make_bool_variant(_short_circuit_for_null_in_build_side));
diff --git a/be/src/vec/exec/join/vhash_join_node.h 
b/be/src/vec/exec/join/vhash_join_node.h
index bce0e34828c..a9e7b65aa01 100644
--- a/be/src/vec/exec/join/vhash_join_node.h
+++ b/be/src/vec/exec/join/vhash_join_node.h
@@ -264,8 +264,8 @@ public:
 private:
     void _init_short_circuit_for_probe() override {
         _short_circuit_for_probe =
-                (_short_circuit_for_null_in_probe_side &&
-                 _join_op == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN && 
!_is_mark_join) ||
+                (_has_null_in_build_side && _join_op == 
TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN &&
+                 !_is_mark_join) ||
                 (_build_blocks->empty() && _join_op == TJoinOp::INNER_JOIN && 
!_is_mark_join) ||
                 (_build_blocks->empty() && _join_op == TJoinOp::LEFT_SEMI_JOIN 
&& !_is_mark_join) ||
                 (_build_blocks->empty() && _join_op == 
TJoinOp::RIGHT_OUTER_JOIN) ||
diff --git a/be/src/vec/exec/join/vjoin_node_base.cpp 
b/be/src/vec/exec/join/vjoin_node_base.cpp
index d9cb9e81fd4..912f7826fe2 100644
--- a/be/src/vec/exec/join/vjoin_node_base.cpp
+++ b/be/src/vec/exec/join/vjoin_node_base.cpp
@@ -144,11 +144,9 @@ void VJoinNodeBase::_construct_mutable_join_block() {
             _join_block.insert({type_ptr->create_column(), type_ptr, 
slot_desc->col_name()});
         }
     }
-    if (_is_mark_join) {
-        _join_block.replace_by_position(
-                _join_block.columns() - 1,
-                
remove_nullable(_join_block.get_by_position(_join_block.columns() - 1).column));
-    }
+
+    DCHECK(!_is_mark_join ||
+           _join_block.get_by_position(_join_block.columns() - 
1).column->is_nullable());
 }
 
 Status VJoinNodeBase::_build_output_block(Block* origin_block, Block* 
output_block,
diff --git a/be/src/vec/exec/join/vjoin_node_base.h 
b/be/src/vec/exec/join/vjoin_node_base.h
index 22eafdbb346..2a108e7cea7 100644
--- a/be/src/vec/exec/join/vjoin_node_base.h
+++ b/be/src/vec/exec/join/vjoin_node_base.h
@@ -117,13 +117,13 @@ protected:
 
     // For null aware left anti join, we apply a short circuit strategy.
     // 1. Set _short_circuit_for_null_in_build_side to true if join operator 
is null aware left anti join.
-    // 2. In build phase, we stop materialize build side when we meet the 
first null value and set _short_circuit_for_null_in_probe_side to true.
-    // 3. In probe phase, if _short_circuit_for_null_in_probe_side is true, 
join node returns empty block directly. Otherwise, probing will continue as the 
same as generic left anti join.
+    // 2. In build phase, we stop materialize build side when we meet the 
first null value and set _has_null_in_build_side to true.
+    // 3. In probe phase, if _has_null_in_build_side is true, join node 
returns empty block directly. Otherwise, probing will continue as the same as 
generic left anti join.
     const bool _short_circuit_for_null_in_build_side = false;
-    bool _short_circuit_for_null_in_probe_side = false;
+    bool _has_null_in_build_side = false;
 
     // For some join case, we can apply a short circuit strategy
-    // 1. _short_circuit_for_null_in_probe_side = true
+    // 1. _has_null_in_build_side = true
     // 2. build side rows is empty, Join op is: inner join/right outer 
join/left semi/right semi/right anti
     bool _short_circuit_for_probe = false;
 
diff --git a/be/src/vec/exec/join/vnested_loop_join_node.cpp 
b/be/src/vec/exec/join/vnested_loop_join_node.cpp
index eee179f7837..4463f7a3955 100644
--- a/be/src/vec/exec/join/vnested_loop_join_node.cpp
+++ b/be/src/vec/exec/join/vnested_loop_join_node.cpp
@@ -46,6 +46,7 @@
 #include "util/simd/bits.h"
 #include "util/telemetry/telemetry.h"
 #include "vec/columns/column_const.h"
+#include "vec/columns/column_filter_helper.h"
 #include "vec/columns/column_nullable.h"
 #include "vec/columns/column_vector.h"
 #include "vec/columns/columns_number.h"
@@ -298,10 +299,9 @@ void 
VNestedLoopJoinNode::_append_left_data_with_null(MutableBlock& mutable_bloc
     for (size_t i = 0; i < _num_build_side_columns; ++i) {
         dst_columns[_num_probe_side_columns + 
i]->insert_many_defaults(_left_side_process_count);
     }
-    IColumn::Filter& mark_data = 
assert_cast<doris::vectorized::ColumnVector<UInt8>&>(
-                                         *dst_columns[dst_columns.size() - 1])
-                                         .get_data();
-    mark_data.resize_fill(mark_data.size() + _left_side_process_count, 0);
+
+    auto& mark_column = *dst_columns[dst_columns.size() - 1];
+    ColumnFilterHelper(mark_column).resize_fill(mark_column.size() + 
_left_side_process_count, 0);
 }
 
 void VNestedLoopJoinNode::_process_left_child_block(MutableBlock& 
mutable_block,
@@ -361,12 +361,9 @@ void VNestedLoopJoinNode::_update_additional_flags(Block* 
block) {
         }
     }
     if (_is_mark_join) {
-        IColumn::Filter& mark_data =
-                assert_cast<doris::vectorized::ColumnVector<UInt8>&>(
-                        *block->get_by_position(block->columns() - 
1).column->assume_mutable())
-                        .get_data();
-        if (mark_data.size() < block->rows()) {
-            mark_data.resize_fill(block->rows(), 1);
+        auto mark_column = block->get_by_position(block->columns() - 
1).column->assume_mutable();
+        if (mark_column->size() < block->rows()) {
+            ColumnFilterHelper(*mark_column).resize_fill(block->rows(), 1);
         }
     }
 }
@@ -488,14 +485,12 @@ void 
VNestedLoopJoinNode::_finalize_current_phase(MutableBlock& mutable_block, s
                 _resize_fill_tuple_is_null_column(new_size, 0, 1);
             }
         } else {
-            IColumn::Filter& mark_data = 
assert_cast<doris::vectorized::ColumnVector<UInt8>&>(
-                                                 
*dst_columns[dst_columns.size() - 1])
-                                                 .get_data();
-            mark_data.reserve(mark_data.size() + _left_side_process_count);
+            ColumnFilterHelper mark_column(*dst_columns[dst_columns.size() - 
1]);
+            mark_column.reserve(mark_column.size() + _left_side_process_count);
             DCHECK_LE(_left_block_start_pos + _left_side_process_count, 
_left_block.rows());
             for (int j = _left_block_start_pos;
                  j < _left_block_start_pos + _left_side_process_count; ++j) {
-                mark_data.emplace_back(IsSemi == 
_cur_probe_row_visited_flags[j]);
+                mark_column.insert_value(IsSemi == 
_cur_probe_row_visited_flags[j]);
             }
             for (size_t i = 0; i < _num_probe_side_columns; ++i) {
                 const ColumnWithTypeAndName src_column = 
_left_block.get_by_position(i);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java
index 099e64eb5d2..021fcea1a3a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java
@@ -29,17 +29,17 @@ public class MarkJoinSlotReference extends SlotReference 
implements SlotNotFromC
     final boolean existsHasAgg;
 
     public MarkJoinSlotReference(String name) {
-        super(name, BooleanType.INSTANCE, false);
+        super(name, BooleanType.INSTANCE, true);
         this.existsHasAgg = false;
     }
 
     public MarkJoinSlotReference(String name, boolean existsHasAgg) {
-        super(name, BooleanType.INSTANCE, false);
+        super(name, BooleanType.INSTANCE, true);
         this.existsHasAgg = existsHasAgg;
     }
 
     public MarkJoinSlotReference(ExprId exprId, String name, boolean 
existsHasAgg) {
-        super(exprId, name, BooleanType.INSTANCE, false, ImmutableList.of());
+        super(exprId, name, BooleanType.INSTANCE, true, ImmutableList.of());
         this.existsHasAgg = existsHasAgg;
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
index 3e40db15fd8..eda7d2e6ad1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
@@ -53,7 +53,8 @@ import java.util.stream.Collectors;
 public class JoinUtils {
     public static boolean couldShuffle(Join join) {
         // Cross-join and Null-Aware-Left-Anti-Join only can be broadcast join.
-        return !(join.getJoinType().isCrossJoin()) && 
!(join.getJoinType().isNullAwareLeftAntiJoin());
+        // Because mark join would consider null value from both build and 
probe side, so must use broadcast join too.
+        return !(join.getJoinType().isCrossJoin() || 
join.getJoinType().isNullAwareLeftAntiJoin() || join.isMarkJoin());
     }
 
     public static boolean couldBroadcast(Join join) {
diff --git a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out 
b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
index 732a72a3907..647babc200d 100644
--- a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
+++ b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
@@ -450,3 +450,18 @@
 22     3
 24     4
 
+-- !mark_join_nullable --
+\N
+\N
+\N
+\N
+\N
+\N
+true
+true
+true
+true
+\N
+\N
+\N
+
diff --git 
a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query10.out 
b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query10.out
index 485a11ae7d4..aee9e1657c3 100644
--- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query10.out
+++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query10.out
@@ -5,54 +5,51 @@ PhysicalResultSink
 ----PhysicalDistribute
 ------PhysicalTopN
 --------PhysicalProject
-----------hashAgg[GLOBAL]
-------------PhysicalDistribute
---------------hashAgg[LOCAL]
-----------------PhysicalProject
-------------------filter(($c$1 OR $c$2))
---------------------hashJoin[LEFT_SEMI_JOIN](c.c_customer_sk = 
catalog_sales.cs_ship_customer_sk)
-----------------------hashJoin[LEFT_SEMI_JOIN](c.c_customer_sk = 
web_sales.ws_bill_customer_sk)
+----------hashAgg[LOCAL]
+------------PhysicalProject
+--------------filter(($c$1 OR $c$2))
+----------------hashJoin[LEFT_SEMI_JOIN](c.c_customer_sk = 
catalog_sales.cs_ship_customer_sk)
+------------------hashJoin[LEFT_SEMI_JOIN](c.c_customer_sk = 
web_sales.ws_bill_customer_sk)
+--------------------PhysicalProject
+----------------------hashJoin[INNER_JOIN](customer_demographics.cd_demo_sk = 
c.c_current_cdemo_sk)
+------------------------PhysicalOlapScan[customer_demographics]
 ------------------------PhysicalDistribute
 --------------------------PhysicalProject
-----------------------------hashJoin[INNER_JOIN](customer_demographics.cd_demo_sk
 = c.c_current_cdemo_sk)
-------------------------------PhysicalOlapScan[customer_demographics]
+----------------------------hashJoin[RIGHT_SEMI_JOIN](c.c_customer_sk = 
store_sales.ss_customer_sk)
 ------------------------------PhysicalDistribute
 --------------------------------PhysicalProject
-----------------------------------hashJoin[RIGHT_SEMI_JOIN](c.c_customer_sk = 
store_sales.ss_customer_sk)
+----------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk
 = date_dim.d_date_sk)
+------------------------------------PhysicalProject
+--------------------------------------PhysicalOlapScan[store_sales]
 ------------------------------------PhysicalDistribute
 --------------------------------------PhysicalProject
-----------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk
 = date_dim.d_date_sk)
-------------------------------------------PhysicalProject
---------------------------------------------PhysicalOlapScan[store_sales]
-------------------------------------------PhysicalDistribute
---------------------------------------------PhysicalProject
-----------------------------------------------filter((date_dim.d_moy <= 
4)(date_dim.d_moy >= 1)(date_dim.d_year = 2001))
-------------------------------------------------PhysicalOlapScan[date_dim]
-------------------------------------PhysicalDistribute
---------------------------------------hashJoin[INNER_JOIN](c.c_current_addr_sk 
= ca.ca_address_sk)
-----------------------------------------PhysicalDistribute
-------------------------------------------PhysicalProject
---------------------------------------------PhysicalOlapScan[customer]
-----------------------------------------PhysicalDistribute
-------------------------------------------PhysicalProject
---------------------------------------------filter(ca_county IN ('Storey 
County', 'Marquette County', 'Warren County', 'Cochran County', 'Kandiyohi 
County'))
-----------------------------------------------PhysicalOlapScan[customer_address]
-------------------------PhysicalDistribute
---------------------------PhysicalProject
-----------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = 
date_dim.d_date_sk)
-------------------------------PhysicalProject
---------------------------------PhysicalOlapScan[web_sales]
+----------------------------------------filter((date_dim.d_moy <= 
4)(date_dim.d_moy >= 1)(date_dim.d_year = 2001))
+------------------------------------------PhysicalOlapScan[date_dim]
 ------------------------------PhysicalDistribute
---------------------------------PhysicalProject
-----------------------------------filter((date_dim.d_moy <= 4)(date_dim.d_moy 
>= 1)(date_dim.d_year = 2001))
-------------------------------------PhysicalOlapScan[date_dim]
-----------------------PhysicalDistribute
-------------------------PhysicalProject
---------------------------hashJoin[INNER_JOIN](catalog_sales.cs_sold_date_sk = 
date_dim.d_date_sk)
+--------------------------------hashJoin[INNER_JOIN](c.c_current_addr_sk = 
ca.ca_address_sk)
+----------------------------------PhysicalDistribute
+------------------------------------PhysicalProject
+--------------------------------------PhysicalOlapScan[customer]
+----------------------------------PhysicalDistribute
+------------------------------------PhysicalProject
+--------------------------------------filter(ca_county IN ('Storey County', 
'Marquette County', 'Warren County', 'Cochran County', 'Kandiyohi County'))
+----------------------------------------PhysicalOlapScan[customer_address]
+--------------------PhysicalDistribute
+----------------------PhysicalProject
+------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = 
date_dim.d_date_sk)
+--------------------------PhysicalProject
+----------------------------PhysicalOlapScan[web_sales]
+--------------------------PhysicalDistribute
 ----------------------------PhysicalProject
-------------------------------PhysicalOlapScan[catalog_sales]
-----------------------------PhysicalDistribute
-------------------------------PhysicalProject
---------------------------------filter((date_dim.d_moy >= 1)(date_dim.d_moy <= 
4)(date_dim.d_year = 2001))
-----------------------------------PhysicalOlapScan[date_dim]
+------------------------------filter((date_dim.d_moy <= 4)(date_dim.d_moy >= 
1)(date_dim.d_year = 2001))
+--------------------------------PhysicalOlapScan[date_dim]
+------------------PhysicalDistribute
+--------------------PhysicalProject
+----------------------hashJoin[INNER_JOIN](catalog_sales.cs_sold_date_sk = 
date_dim.d_date_sk)
+------------------------PhysicalProject
+--------------------------PhysicalOlapScan[catalog_sales]
+------------------------PhysicalDistribute
+--------------------------PhysicalProject
+----------------------------filter((date_dim.d_moy >= 1)(date_dim.d_moy <= 
4)(date_dim.d_year = 2001))
+------------------------------PhysicalOlapScan[date_dim]
 
diff --git 
a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query35.out 
b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query35.out
index 1a1d022d75b..b8294ca5c37 100644
--- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query35.out
+++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query35.out
@@ -12,32 +12,31 @@ PhysicalResultSink
 ------------------filter(($c$1 OR $c$2))
 --------------------hashJoin[LEFT_SEMI_JOIN](c.c_customer_sk = 
catalog_sales.cs_ship_customer_sk)
 ----------------------hashJoin[LEFT_SEMI_JOIN](c.c_customer_sk = 
web_sales.ws_bill_customer_sk)
-------------------------PhysicalDistribute
---------------------------PhysicalProject
-----------------------------hashJoin[INNER_JOIN](customer_demographics.cd_demo_sk
 = c.c_current_cdemo_sk)
-------------------------------PhysicalDistribute
---------------------------------PhysicalProject
-----------------------------------hashJoin[INNER_JOIN](c.c_current_addr_sk = 
ca.ca_address_sk)
-------------------------------------PhysicalDistribute
---------------------------------------hashJoin[RIGHT_SEMI_JOIN](c.c_customer_sk
 = store_sales.ss_customer_sk)
-----------------------------------------PhysicalDistribute
-------------------------------------------PhysicalProject
---------------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk
 = date_dim.d_date_sk)
+------------------------PhysicalProject
+--------------------------hashJoin[INNER_JOIN](customer_demographics.cd_demo_sk
 = c.c_current_cdemo_sk)
+----------------------------PhysicalDistribute
+------------------------------PhysicalProject
+--------------------------------hashJoin[INNER_JOIN](c.c_current_addr_sk = 
ca.ca_address_sk)
+----------------------------------PhysicalDistribute
+------------------------------------hashJoin[RIGHT_SEMI_JOIN](c.c_customer_sk 
= store_sales.ss_customer_sk)
+--------------------------------------PhysicalDistribute
+----------------------------------------PhysicalProject
+------------------------------------------hashJoin[INNER_JOIN](store_sales.ss_sold_date_sk
 = date_dim.d_date_sk)
+--------------------------------------------PhysicalProject
+----------------------------------------------PhysicalOlapScan[store_sales]
+--------------------------------------------PhysicalDistribute
 ----------------------------------------------PhysicalProject
-------------------------------------------------PhysicalOlapScan[store_sales]
-----------------------------------------------PhysicalDistribute
-------------------------------------------------PhysicalProject
---------------------------------------------------filter((date_dim.d_qoy < 
4)(date_dim.d_year = 2001))
-----------------------------------------------------PhysicalOlapScan[date_dim]
-----------------------------------------PhysicalDistribute
-------------------------------------------PhysicalProject
---------------------------------------------PhysicalOlapScan[customer]
-------------------------------------PhysicalDistribute
---------------------------------------PhysicalProject
-----------------------------------------PhysicalOlapScan[customer_address]
-------------------------------PhysicalDistribute
---------------------------------PhysicalProject
-----------------------------------PhysicalOlapScan[customer_demographics]
+------------------------------------------------filter((date_dim.d_qoy < 
4)(date_dim.d_year = 2001))
+--------------------------------------------------PhysicalOlapScan[date_dim]
+--------------------------------------PhysicalDistribute
+----------------------------------------PhysicalProject
+------------------------------------------PhysicalOlapScan[customer]
+----------------------------------PhysicalDistribute
+------------------------------------PhysicalProject
+--------------------------------------PhysicalOlapScan[customer_address]
+----------------------------PhysicalDistribute
+------------------------------PhysicalProject
+--------------------------------PhysicalOlapScan[customer_demographics]
 ------------------------PhysicalDistribute
 --------------------------PhysicalProject
 ----------------------------hashJoin[INNER_JOIN](web_sales.ws_sold_date_sk = 
date_dim.d_date_sk)
diff --git 
a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy 
b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
index 482eab7a6aa..6664ad0c6c7 100644
--- a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
+++ b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
@@ -50,6 +50,14 @@ suite ("sub_query_correlated") {
         DROP TABLE IF EXISTS `sub_query_correlated_subquery7`
     """
 
+    sql """
+        DROP TABLE IF EXISTS `sub_query_correlated_subquery8`
+    """
+
+    sql """
+        DROP TABLE IF EXISTS `sub_query_correlated_subquery9`
+    """
+
     sql """
         create table if not exists sub_query_correlated_subquery1
         (k1 bigint, k2 bigint)
@@ -105,6 +113,21 @@ suite ("sub_query_correlated") {
             properties('replication_num' = '1');
     """
 
+    sql """
+        create table if not exists sub_query_correlated_subquery8
+        (k1 bigint, k2 bigint)
+        duplicate key(k1)
+        distributed by hash(k2) buckets 1
+        properties('replication_num' = '1')
+    """
+
+    sql """
+        create table if not exists sub_query_correlated_subquery9
+            (k1 int, k2 varchar(128), k3 bigint, v1 bigint, v2 bigint)
+            distributed by hash(k2) buckets 1
+            properties('replication_num' = '1');
+    """
+
     sql """
         insert into sub_query_correlated_subquery1 values (1,2), (1,3), (2,4), 
(2,5), (3,3), (3,4), (20,2), (22,3), (24,4)
     """
@@ -126,7 +149,7 @@ suite ("sub_query_correlated") {
         insert into sub_query_correlated_subquery5 values (5,4), (5,2), (8,3), 
(5,4), (6,7), (8,9)
     """
 
-     sql """
+    sql """
         insert into sub_query_correlated_subquery6 values 
(1,null),(null,1),(1,2), (null,2),(1,3), (2,4), (2,5), (3,3), (3,4), (20,2), 
(22,3), (24,4),(null,null);
     """
 
@@ -135,6 +158,15 @@ suite ("sub_query_correlated") {
             (2,"uvw",3,4,2), (2,"uvw",3,4,2), (3,"abc",4,5,3), 
(3,"abc",4,5,3), (null,null,null,null,null);
     """
 
+    sql """
+        insert into sub_query_correlated_subquery8 values 
(1,null),(null,1),(1,2), (null,2),(1,3), (2,4), (2,5), (3,3), (3,4), (20,2), 
(22,3), (24,4),(null,null);
+    """
+
+    sql """
+        insert into sub_query_correlated_subquery9 values (1,"abc",2,3,4), 
(1,"abcd",3,3,4),
+            (2,"xyz",2,4,2),(2,"uvw",3,4,2), (2,"uvw",3,4,2), (3,"abc",4,5,3), 
(3,"abc",4,5,3), (null,null,null,null,null);
+    """
+
     sql "SET enable_fallback_to_original_planner=false"
 
     //------------------Correlated-----------------
@@ -496,6 +528,10 @@ suite ("sub_query_correlated") {
         order by k1, k2;
     """
 
+    qt_mark_join_nullable """
+        select sub_query_correlated_subquery8.k1 in (select 
sub_query_correlated_subquery9.k3 from sub_query_correlated_subquery9) from 
sub_query_correlated_subquery8 order by k1, k2;
+    """
+
     // order_qt_doris_6937_2 """
     //     select * from sub_query_correlated_subquery1 where 
sub_query_correlated_subquery1.k1 not in (select 
sub_query_correlated_subquery3.k3 from sub_query_correlated_subquery3 where 
sub_query_correlated_subquery3.v2 > sub_query_correlated_subquery1.k2) or k1 < 
10 order by k1, k2;
     // """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to