This is an automated email from the ASF dual-hosted git repository.

gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 59d7f64360b [Fix](Nereids) fix pipelineX distribute expr list with 
child output expr ids (#29621)
59d7f64360b is described below

commit 59d7f64360bd4839594675c08a5d4295d0588a0c
Author: Gabriel <gabrielleeb...@gmail.com>
AuthorDate: Mon Jan 8 10:46:27 2024 +0800

    [Fix](Nereids) fix pipelineX distribute expr list with child output expr 
ids (#29621)
---
 be/src/pipeline/exec/aggregation_sink_operator.cpp |  3 +-
 be/src/pipeline/exec/analytic_sink_operator.cpp    |  3 +-
 be/src/pipeline/exec/hashjoin_build_sink.cpp       |  8 +--
 be/src/pipeline/exec/hashjoin_build_sink.h         |  2 +-
 be/src/pipeline/exec/hashjoin_probe_operator.cpp   |  6 +-
 be/src/pipeline/exec/hashjoin_probe_operator.h     |  2 +-
 be/src/pipeline/exec/sort_sink_operator.cpp        |  8 ++-
 be/src/pipeline/exec/sort_sink_operator.h          |  9 ++-
 be/src/pipeline/pipeline.h                         |  5 +-
 .../glue/translator/PhysicalPlanTranslator.java    | 60 ++++++++++++++++-
 .../java/org/apache/doris/planner/PlanNode.java    | 20 ++++++
 .../java/org/apache/doris/planner/SortNode.java    |  7 ++
 gensrc/thrift/PlanNodes.thrift                     |  4 ++
 .../test_bucket_hash_local_shuffle.out             | 14 ++++
 .../test_bucket_hash_local_shuffle.groovy          | 76 ++++++++++++++++++++++
 15 files changed, 212 insertions(+), 15 deletions(-)

diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp 
b/be/src/pipeline/exec/aggregation_sink_operator.cpp
index cc4328b8c08..3ba4dd05dc7 100644
--- a/be/src/pipeline/exec/aggregation_sink_operator.cpp
+++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp
@@ -734,7 +734,8 @@ 
AggSinkOperatorX<LocalStateType>::AggSinkOperatorX(ObjectPool* pool, int operato
           _limit(tnode.limit),
           _have_conjuncts(tnode.__isset.vconjunct && 
!tnode.vconjunct.nodes.empty()),
           _is_streaming(is_streaming),
-          _partition_exprs(tnode.agg_node.grouping_exprs),
+          _partition_exprs(tnode.__isset.distribute_expr_lists ? 
tnode.distribute_expr_lists[0]
+                                                               : 
std::vector<TExpr> {}),
           _is_colocate(tnode.agg_node.__isset.is_colocate && 
tnode.agg_node.is_colocate) {}
 
 template <typename LocalStateType>
diff --git a/be/src/pipeline/exec/analytic_sink_operator.cpp 
b/be/src/pipeline/exec/analytic_sink_operator.cpp
index 3e936456990..d9923a68f24 100644
--- a/be/src/pipeline/exec/analytic_sink_operator.cpp
+++ b/be/src/pipeline/exec/analytic_sink_operator.cpp
@@ -193,7 +193,8 @@ AnalyticSinkOperatorX::AnalyticSinkOperatorX(ObjectPool* 
pool, int operator_id,
                                      ? tnode.analytic_node.buffered_tuple_id
                                      : 0),
           _is_colocate(tnode.analytic_node.__isset.is_colocate && 
tnode.analytic_node.is_colocate),
-          _partition_exprs(tnode.analytic_node.partition_exprs) {}
+          _partition_exprs(tnode.__isset.distribute_expr_lists ? 
tnode.distribute_expr_lists[0]
+                                                               : 
std::vector<TExpr> {}) {}
 
 Status AnalyticSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* 
state) {
     RETURN_IF_ERROR(DataSinkOperatorX::init(tnode, state));
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp 
b/be/src/pipeline/exec/hashjoin_build_sink.cpp
index 8f3a7259582..f34f835d121 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.cpp
+++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp
@@ -385,9 +385,10 @@ 
HashJoinBuildSinkOperatorX::HashJoinBuildSinkOperatorX(ObjectPool* pool, int ope
                                                                     : 
TJoinDistributionType::NONE),
           _is_broadcast_join(tnode.hash_join_node.__isset.is_broadcast_join &&
                              tnode.hash_join_node.is_broadcast_join),
-          _use_global_rf(use_global_rf) {
-    _runtime_filter_descs = tnode.runtime_filters;
-}
+          _partition_exprs(tnode.__isset.distribute_expr_lists && 
!_is_broadcast_join
+                                   ? tnode.distribute_expr_lists[1]
+                                   : std::vector<TExpr> {}),
+          _use_global_rf(use_global_rf) {}
 
 Status HashJoinBuildSinkOperatorX::prepare(RuntimeState* state) {
     if (_is_broadcast_join) {
@@ -413,7 +414,6 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& 
tnode, RuntimeState* st
     for (const auto& eq_join_conjunct : eq_join_conjuncts) {
         vectorized::VExprContextSPtr ctx;
         
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, 
ctx));
-        _partition_exprs.push_back(eq_join_conjunct.right);
         _build_expr_ctxs.push_back(ctx);
 
         const auto vexpr = _build_expr_ctxs.back()->root();
diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h 
b/be/src/pipeline/exec/hashjoin_build_sink.h
index fa4635afad1..5ea504d488d 100644
--- a/be/src/pipeline/exec/hashjoin_build_sink.h
+++ b/be/src/pipeline/exec/hashjoin_build_sink.h
@@ -180,7 +180,7 @@ private:
 
     vectorized::SharedHashTableContextPtr _shared_hash_table_context = nullptr;
     std::vector<TRuntimeFilterDesc> _runtime_filter_descs;
-    std::vector<TExpr> _partition_exprs;
+    const std::vector<TExpr> _partition_exprs;
 
     const bool _use_global_rf;
 };
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.cpp 
b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
index ca1172501c1..f852d3c4440 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.cpp
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.cpp
@@ -218,7 +218,10 @@ HashJoinProbeOperatorX::HashJoinProbeOperatorX(ObjectPool* 
pool, const TPlanNode
                              tnode.hash_join_node.is_broadcast_join),
           
_hash_output_slot_ids(tnode.hash_join_node.__isset.hash_output_slot_ids
                                         ? 
tnode.hash_join_node.hash_output_slot_ids
-                                        : std::vector<SlotId> {}) {}
+                                        : std::vector<SlotId> {}),
+          _partition_exprs(tnode.__isset.distribute_expr_lists && 
!_is_broadcast_join
+                                   ? tnode.distribute_expr_lists[0]
+                                   : std::vector<TExpr> {}) {}
 
 Status HashJoinProbeOperatorX::pull(doris::RuntimeState* state, 
vectorized::Block* output_block,
                                     SourceState& source_state) const {
@@ -543,7 +546,6 @@ Status HashJoinProbeOperatorX::init(const TPlanNode& tnode, 
RuntimeState* state)
     for (const auto& eq_join_conjunct : eq_join_conjuncts) {
         vectorized::VExprContextSPtr ctx;
         
RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.left, 
ctx));
-        _partition_exprs.push_back(eq_join_conjunct.left);
         _probe_expr_ctxs.push_back(ctx);
         bool null_aware = eq_join_conjunct.__isset.opcode &&
                           eq_join_conjunct.opcode == TExprOpcode::EQ_FOR_NULL;
diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.h 
b/be/src/pipeline/exec/hashjoin_probe_operator.h
index 76103c4d8fa..093884b6d0f 100644
--- a/be/src/pipeline/exec/hashjoin_probe_operator.h
+++ b/be/src/pipeline/exec/hashjoin_probe_operator.h
@@ -193,7 +193,7 @@ private:
     std::vector<bool> _left_output_slot_flags;
     std::vector<bool> _right_output_slot_flags;
     std::vector<std::string> _right_table_column_names;
-    std::vector<TExpr> _partition_exprs;
+    const std::vector<TExpr> _partition_exprs;
 };
 
 } // namespace pipeline
diff --git a/be/src/pipeline/exec/sort_sink_operator.cpp 
b/be/src/pipeline/exec/sort_sink_operator.cpp
index 0eb14fe056a..e2c851f758f 100644
--- a/be/src/pipeline/exec/sort_sink_operator.cpp
+++ b/be/src/pipeline/exec/sort_sink_operator.cpp
@@ -77,7 +77,13 @@ SortSinkOperatorX::SortSinkOperatorX(ObjectPool* pool, int 
operator_id, const TP
           _use_topn_opt(tnode.sort_node.use_topn_opt),
           _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples),
           _use_two_phase_read(tnode.sort_node.sort_info.use_two_phase_read),
-          _merge_by_exchange(tnode.sort_node.merge_by_exchange) {}
+          _merge_by_exchange(tnode.sort_node.merge_by_exchange),
+          _is_colocate(tnode.sort_node.__isset.is_colocate ? 
tnode.sort_node.is_colocate : false),
+          _is_analytic_sort(tnode.sort_node.__isset.is_analytic_sort
+                                    ? tnode.sort_node.is_analytic_sort
+                                    : false),
+          _partition_exprs(tnode.__isset.distribute_expr_lists ? 
tnode.distribute_expr_lists[0]
+                                                               : 
std::vector<TExpr> {}) {}
 
 Status SortSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {
     RETURN_IF_ERROR(DataSinkOperatorX::init(tnode, state));
diff --git a/be/src/pipeline/exec/sort_sink_operator.h 
b/be/src/pipeline/exec/sort_sink_operator.h
index 64beb53ba9e..7069183f3b2 100644
--- a/be/src/pipeline/exec/sort_sink_operator.h
+++ b/be/src/pipeline/exec/sort_sink_operator.h
@@ -94,7 +94,11 @@ public:
     Status sink(RuntimeState* state, vectorized::Block* in_block,
                 SourceState source_state) override;
     DataDistribution required_data_distribution() const override {
-        if (_merge_by_exchange) {
+        if (_is_analytic_sort) {
+            return _is_colocate
+                           ? 
DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE, _partition_exprs)
+                           : DataDistribution(ExchangeType::HASH_SHUFFLE, 
_partition_exprs);
+        } else if (_merge_by_exchange) {
             // The current sort node is used for the ORDER BY
             return {ExchangeType::PASSTHROUGH};
         }
@@ -121,6 +125,9 @@ private:
     const RowDescriptor _row_descriptor;
     const bool _use_two_phase_read;
     const bool _merge_by_exchange;
+    const bool _is_colocate = false;
+    const bool _is_analytic_sort = false;
+    const std::vector<TExpr> _partition_exprs;
 };
 
 } // namespace pipeline
diff --git a/be/src/pipeline/pipeline.h b/be/src/pipeline/pipeline.h
index ef0acfba258..ab6850b704b 100644
--- a/be/src/pipeline/pipeline.h
+++ b/be/src/pipeline/pipeline.h
@@ -138,8 +138,9 @@ public:
                 return true;
             }
             return _data_distribution.distribution_type !=
-                           target_data_distribution.distribution_type ||
-                   _data_distribution.partition_exprs != 
target_data_distribution.partition_exprs;
+                           target_data_distribution.distribution_type &&
+                   !(is_hash_exchange(_data_distribution.distribution_type) &&
+                     
is_hash_exchange(target_data_distribution.distribution_type));
         } else {
             return _data_distribution.distribution_type !=
                            target_data_distribution.distribution_type &&
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index c82590940bb..1e72200f157 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -270,6 +270,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     public PlanFragment visitPhysicalDistribute(PhysicalDistribute<? extends 
Plan> distribute,
             PlanTranslatorContext context) {
         PlanFragment inputFragment = distribute.child().accept(this, context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(distribute.child());
         // TODO: why need set streaming here? should remove this.
         if (inputFragment.getPlanRoot() instanceof AggregationNode
                 && distribute.child() instanceof PhysicalHashAggregate
@@ -315,6 +316,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         }
         DataPartition dataPartition = 
toDataPartition(distribute.getDistributionSpec(), validOutputIds, context);
         exchangeNode.setPartitionType(dataPartition.getType());
+        exchangeNode.setDistributeExprLists(distributeExprLists);
         PlanFragment parentFragment = new 
PlanFragment(context.nextFragmentId(), exchangeNode, dataPartition);
         if (distribute.getDistributionSpec() instanceof 
DistributionSpecGather) {
             // gather to one instance
@@ -807,6 +809,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             PlanTranslatorContext context) {
 
         PlanFragment inputPlanFragment = aggregate.child(0).accept(this, 
context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(aggregate.child(0));
 
         List<Expression> groupByExpressions = 
aggregate.getGroupByExpressions();
         List<NamedExpression> outputExpressions = 
aggregate.getOutputExpressions();
@@ -849,6 +852,9 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 aggFunOutputIds, isPartial, outputTupleDesc, outputTupleDesc, 
aggregate.getAggPhase().toExec());
         AggregationNode aggregationNode = new 
AggregationNode(aggregate.translatePlanNodeId(),
                 inputPlanFragment.getPlanRoot(), aggInfo);
+
+        aggregationNode.setDistributeExprLists(distributeExprLists);
+
         if (!aggregate.getAggMode().isFinalPhase) {
             aggregationNode.unsetNeedsFinalize();
         }
@@ -941,10 +947,12 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     public PlanFragment visitPhysicalAssertNumRows(PhysicalAssertNumRows<? 
extends Plan> assertNumRows,
             PlanTranslatorContext context) {
         PlanFragment currentFragment = assertNumRows.child().accept(this, 
context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(assertNumRows.child());
         // create assertNode
         AssertNumRowsNode assertNumRowsNode = new 
AssertNumRowsNode(assertNumRows.translatePlanNodeId(),
                 currentFragment.getPlanRoot(),
                 
ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement()));
+        assertNumRowsNode.setDistributeExprLists(distributeExprLists);
         addPlanRoot(currentFragment, assertNumRowsNode, assertNumRows);
         return currentFragment;
     }
@@ -1143,6 +1151,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         // NOTICE: We must visit from right to left, to ensure the last 
fragment is root fragment
         PlanFragment rightFragment = hashJoin.child(1).accept(this, context);
         PlanFragment leftFragment = hashJoin.child(0).accept(this, context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(physicalHashJoin.left(), physicalHashJoin.right());
 
         if (JoinUtils.shouldNestedLoopJoin(hashJoin)) {
             throw new RuntimeException("Physical hash join could not execute 
without equal join condition.");
@@ -1161,7 +1170,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         HashJoinNode hashJoinNode = new 
HashJoinNode(hashJoin.translatePlanNodeId(), leftPlanRoot,
                 rightPlanRoot, JoinType.toJoinOperator(joinType), 
execEqConjuncts, Lists.newArrayList(),
                 null, null, null, hashJoin.isMarkJoin());
-
+        hashJoinNode.setDistributeExprLists(distributeExprLists);
         PlanFragment currentFragment = connectJoinNode(hashJoinNode, 
leftFragment, rightFragment, context, hashJoin);
 
         if (joinType == JoinType.NULL_AWARE_LEFT_ANTI_JOIN) {
@@ -1183,6 +1192,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         } else {
             hashJoinNode.setDistributionMode(DistributionMode.PARTITIONED);
         }
+
         // Nereids does not care about output order of join,
         // but BE need left child's output must be before right child's output.
         // So we need to swap the output order of left and right child if 
necessary.
@@ -1394,6 +1404,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         //       PhysicalPlan plan, PlanVisitor visitor, Context context).
         PlanFragment rightFragment = nestedLoopJoin.child(1).accept(this, 
context);
         PlanFragment leftFragment = nestedLoopJoin.child(0).accept(this, 
context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(nestedLoopJoin.child(0), nestedLoopJoin.child(1));
         PlanNode leftFragmentPlanRoot = leftFragment.getPlanRoot();
         PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
         if (JoinUtils.shouldNestedLoopJoin(nestedLoopJoin)) {
@@ -1407,6 +1418,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             NestedLoopJoinNode nestedLoopJoinNode = new 
NestedLoopJoinNode(nestedLoopJoin.translatePlanNodeId(),
                     leftFragmentPlanRoot, rightFragmentPlanRoot, tupleIds, 
JoinType.toJoinOperator(joinType),
                     null, null, null, nestedLoopJoin.isMarkJoin());
+            nestedLoopJoinNode.setDistributeExprLists(distributeExprLists);
             if (nestedLoopJoin.getStats() != null) {
                 nestedLoopJoinNode.setCardinality((long) 
nestedLoopJoin.getStats().getRowCount());
             }
@@ -1573,8 +1585,10 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     public PlanFragment visitPhysicalPartitionTopN(PhysicalPartitionTopN<? 
extends Plan> partitionTopN,
             PlanTranslatorContext context) {
         PlanFragment inputFragment = partitionTopN.child(0).accept(this, 
context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(partitionTopN.child(0));
         PartitionSortNode partitionSortNode = translatePartitionSortNode(
                 partitionTopN, inputFragment.getPlanRoot(), context);
+        partitionSortNode.setDistributeExprLists(distributeExprLists);
         addPlanRoot(inputFragment, partitionSortNode, partitionTopN);
         // in pipeline engine, we use parallel scan by default, but it broke 
the rule of data distribution
         // we need turn of parallel scan to ensure to get correct result.
@@ -1818,11 +1832,13 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     public PlanFragment visitPhysicalQuickSort(PhysicalQuickSort<? extends 
Plan> sort,
             PlanTranslatorContext context) {
         PlanFragment inputFragment = sort.child(0).accept(this, context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(sort.child(0));
 
         // 2. According to the type of sort, generate physical plan
         if (!sort.getSortPhase().isMerge()) {
             // For localSort or Gather->Sort, we just need to add sortNode
             SortNode sortNode = translateSortNode(sort, 
inputFragment.getPlanRoot(), context);
+            sortNode.setDistributeExprLists(distributeExprLists);
             addPlanRoot(inputFragment, sortNode, sort);
         } else {
             // For mergeSort, we need to push sortInfo to exchangeNode
@@ -1835,6 +1851,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             SortNode sortNode = (SortNode) 
inputFragment.getPlanRoot().getChild(0);
             ((ExchangeNode) 
inputFragment.getPlanRoot()).setMergeInfo(sortNode.getSortInfo());
             sortNode.setMergeByExchange();
+            sortNode.setDistributeExprLists(distributeExprLists);
         }
         return inputFragment;
     }
@@ -1842,6 +1859,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     @Override
     public PlanFragment visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, 
PlanTranslatorContext context) {
         PlanFragment inputFragment = topN.child(0).accept(this, context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(topN.child(0));
 
         // 2. According to the type of sort, generate physical plan
         if (!topN.getSortPhase().isMerge()) {
@@ -1874,6 +1892,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                     }
                 }
             }
+            sortNode.setDistributeExprLists(distributeExprLists);
             addPlanRoot(inputFragment, sortNode, topN);
         } else {
             // For mergeSort, we need to push sortInfo to exchangeNode
@@ -1886,6 +1905,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 return inputFragment;
             }
             ExchangeNode exchangeNode = (ExchangeNode) 
inputFragment.getPlanRoot();
+            exchangeNode.setDistributeExprLists(distributeExprLists);
             exchangeNode.setMergeInfo(((SortNode) 
exchangeNode.getChild(0)).getSortInfo());
             exchangeNode.setLimit(topN.getLimit());
             exchangeNode.setOffset(topN.getOffset());
@@ -1918,6 +1938,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     @Override
     public PlanFragment visitPhysicalRepeat(PhysicalRepeat<? extends Plan> 
repeat, PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = repeat.child(0).accept(this, context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(repeat.child(0));
 
         Set<VirtualSlotReference> sortedVirtualSlots = 
repeat.getSortedVirtualSlots();
         TupleDescriptor virtualSlotsTuple =
@@ -1965,6 +1986,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         RepeatNode repeatNode = new RepeatNode(repeat.translatePlanNodeId(),
                 inputPlanFragment.getPlanRoot(), groupingInfo, 
repeatSlotIdList,
                 allSlotId, 
repeat.computeVirtualSlotValues(sortedVirtualSlots));
+        repeatNode.setDistributeExprLists(distributeExprLists);
         addPlanRoot(inputPlanFragment, repeatNode, repeat);
         updateLegacyPlanIdToPhysicalPlan(inputPlanFragment.getPlanRoot(), 
repeat);
         return inputPlanFragment;
@@ -1974,6 +1996,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     public PlanFragment visitPhysicalWindow(PhysicalWindow<? extends Plan> 
physicalWindow,
             PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = physicalWindow.child(0).accept(this, 
context);
+        List<List<Expr>> distributeExprLists = 
getDistributeExprs(physicalWindow.child(0));
 
         // 1. translate to old optimizer variable
         // variable in Nereids
@@ -2049,6 +2072,11 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 orderElementsIsNullableMatched,
                 bufferedTupleDesc
         );
+        analyticEvalNode.setDistributeExprLists(distributeExprLists);
+        PlanNode root = inputPlanFragment.getPlanRoot();
+        if (root instanceof SortNode) {
+            ((SortNode) root).setIsAnalyticSort(true);
+        }
         inputPlanFragment.addPlanRoot(analyticEvalNode);
 
         // in pipeline engine, we use parallel scan by default, but it broke 
the rule of data distribution
@@ -2057,6 +2085,9 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         if 
(findOlapScanNodesByPassExchangeAndJoinNode(inputPlanFragment.getPlanRoot())) {
             inputPlanFragment.setHasColocatePlanNode(true);
             analyticEvalNode.setColocate(true);
+            if (root instanceof SortNode) {
+                ((SortNode) root).setColocate(true);
+            }
         }
         return inputPlanFragment;
     }
@@ -2447,4 +2478,31 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         }
         return false;
     }
+
+    private List<List<Expr>> getDistributeExprs(Plan ... children) {
+        List<List<Expr>> distributeExprLists = Lists.newArrayList();
+        for (Plan child : children) {
+            DistributionSpec spec = ((PhysicalPlan) 
child).getPhysicalProperties().getDistributionSpec();
+            
distributeExprLists.add(getDistributeExpr(child.getOutputExprIds(), spec));
+        }
+        return distributeExprLists;
+    }
+
+    private List<Expr> getDistributeExpr(List<ExprId> childOutputIds, 
DistributionSpec spec) {
+        if (spec instanceof DistributionSpecHash) {
+            DistributionSpecHash distributionSpecHash = (DistributionSpecHash) 
spec;
+            List<Expr> partitionExprs = Lists.newArrayList();
+            for (int i = 0; i < 
distributionSpecHash.getEquivalenceExprIds().size(); i++) {
+                Set<ExprId> equivalenceExprId = 
distributionSpecHash.getEquivalenceExprIds().get(i);
+                for (ExprId exprId : equivalenceExprId) {
+                    if (childOutputIds.contains(exprId)) {
+                        partitionExprs.add(context.findSlotRef(exprId));
+                        break;
+                    }
+                }
+            }
+            return partitionExprs;
+        }
+        return Lists.newArrayList();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
index 3f32ee59060..6885862e091 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
@@ -153,6 +153,8 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
 
     protected List<Expr> projectList;
 
+    private List<List<Expr>> distributeExprLists = new ArrayList<>();
+
     protected PlanNode(PlanNodeId id, ArrayList<TupleId> tupleIds, String 
planNodeName,
             StatisticalType statisticalType) {
         this.id = id;
@@ -526,6 +528,12 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
             expBuilder.append(detailPrefix).append("project output tuple id: ")
                     .append(outputTupleDesc.getId().asInt()).append("\n");
         }
+        if (!CollectionUtils.isEmpty(distributeExprLists)) {
+            for (List<Expr> distributeExprList : distributeExprLists) {
+                expBuilder.append(detailPrefix).append("distribute expr lists: 
")
+                    .append(getExplainString(distributeExprList)).append("\n");
+            }
+        }
         // Output Tuple Ids only when explain plan level is set to verbose
         if (detailLevel.equals(TExplainLevel.VERBOSE)) {
             expBuilder.append(detailPrefix + "tuple ids: ");
@@ -618,6 +626,14 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
                 msg.addToOutputSlotIds(slotId.asInt());
             }
         }
+        if (!CollectionUtils.isEmpty(distributeExprLists)) {
+            for (List<Expr> exprList : distributeExprLists) {
+                msg.addToDistributeExprLists(new ArrayList<>());
+                for (Expr expr : exprList) {
+                    
msg.distribute_expr_lists.get(msg.distribute_expr_lists.size() - 
1).add(expr.treeToThrift());
+                }
+            }
+        }
         toThrift(msg);
         container.addToNodes(msg);
         if (projectList != null) {
@@ -1174,6 +1190,10 @@ public abstract class PlanNode extends 
TreeNode<PlanNode> implements PlanStats {
         this.pushDownAggNoGroupingOp = pushDownAggNoGroupingOp;
     }
 
+    public void setDistributeExprLists(List<List<Expr>> distributeExprLists) {
+        this.distributeExprLists = distributeExprLists;
+    }
+
     public TPushAggOp getPushDownAggNoGroupingOp() {
         return pushDownAggNoGroupingOp;
     }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
index dba8f117985..27d42385363 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
@@ -71,6 +71,7 @@ public class SortNode extends PlanNode {
     private boolean isDefaultLimit;
     // if true, the output of this node feeds an AnalyticNode
     private boolean isAnalyticSort;
+    private boolean isColocate = false;
     private DataPartition inputPartition;
 
     private boolean isUnusedExprRemoved = false;
@@ -318,6 +319,8 @@ public class SortNode extends PlanNode {
         msg.sort_node.setOffset(offset);
         msg.sort_node.setUseTopnOpt(useTopnOpt);
         msg.sort_node.setMergeByExchange(this.mergeByexchange);
+        msg.sort_node.setIsAnalyticSort(isAnalyticSort);
+        msg.sort_node.setIsColocate(isColocate);
     }
 
     @Override
@@ -339,4 +342,8 @@ public class SortNode extends PlanNode {
         Expr.getIds(materializedTupleExprs, null, result);
         return new HashSet<>(result);
     }
+
+    public void setColocate(boolean colocate) {
+        isColocate = colocate;
+    }
 }
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index 23da0f23ebb..2f45f355b1e 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -892,6 +892,8 @@ struct TSortNode {
   6: optional bool is_default_limit                                            
  
   7: optional bool use_topn_opt
   8: optional bool merge_by_exchange
+  9: optional bool is_analytic_sort
+  10: optional bool is_colocate
 }
 
 enum TopNAlgorithm {
@@ -1251,6 +1253,8 @@ struct TPlanNode {
   48: optional TPushAggOp push_down_agg_type_opt
 
   49: optional i64 push_down_count
+
+  50: optional list<list<Exprs.TExpr>> distribute_expr_lists
   
   101: optional list<Exprs.TExpr> projections
   102: optional Types.TTupleId output_tuple_id
diff --git 
a/regression-test/data/pipeline_p0/local_shuffle/test_bucket_hash_local_shuffle.out
 
b/regression-test/data/pipeline_p0/local_shuffle/test_bucket_hash_local_shuffle.out
new file mode 100644
index 00000000000..1e2a2e65b40
--- /dev/null
+++ 
b/regression-test/data/pipeline_p0/local_shuffle/test_bucket_hash_local_shuffle.out
@@ -0,0 +1,14 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !bucket_shuffle_join --
+0
+
+-- !colocate_join --
+0
+
+-- !analytic --
+1
+1
+1
+1
+1
+
diff --git 
a/regression-test/suites/pipeline_p0/local_shuffle/test_bucket_hash_local_shuffle.groovy
 
b/regression-test/suites/pipeline_p0/local_shuffle/test_bucket_hash_local_shuffle.groovy
new file mode 100644
index 00000000000..bbe3e6a3c89
--- /dev/null
+++ 
b/regression-test/suites/pipeline_p0/local_shuffle/test_bucket_hash_local_shuffle.groovy
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_bucket_hash_local_shuffle") {
+    try {
+      sql """
+          CREATE TABLE IF NOT EXISTS date_dim (
+              d_date_sk bigint not null
+          )
+          DUPLICATE KEY(d_date_sk)
+          DISTRIBUTED BY HASH(d_date_sk) BUCKETS 12
+          PROPERTIES (
+            "replication_num" = "1"
+          );
+          """
+      sql """ insert into date_dim values(1) """
+      sql """
+          CREATE TABLE IF NOT EXISTS store_sales (
+              ss_item_sk bigint not null,
+              ss_ticket_number bigint not null,
+              ss_sold_date_sk bigint
+          )
+          DUPLICATE KEY(ss_item_sk, ss_ticket_number)
+          DISTRIBUTED BY HASH(ss_item_sk, ss_ticket_number) BUCKETS 32
+          PROPERTIES (
+            "replication_num" = "1",
+            "colocate_with" = "store"
+          );
+          """
+      sql """ insert into store_sales values(1, 1, 1),(1, 2, 1),(3, 2, 
1),(100, 2, 1),(12130, 2, 1)  """
+        sql """
+            CREATE TABLE IF NOT EXISTS store_returns (
+                sr_item_sk bigint not null,
+                sr_ticket_number bigint not null
+            )
+            duplicate key(sr_item_sk, sr_ticket_number)
+            distributed by hash (sr_item_sk, sr_ticket_number) buckets 32
+            properties (
+              "replication_num" = "1",
+              "colocate_with" = "store"
+            );
+            """
+        sql """ insert into store_returns values(1, 1),(1, 2),(3, 2),(100, 
2),(12130, 2)"""
+        qt_bucket_shuffle_join """ select 
/*+SET_VAR(disable_join_reorder=true,disable_colocate_plan=true,ignore_storage_data_distribution=false)*/
 count(*)
+                      from store_sales
+                      join date_dim on ss_sold_date_sk = d_date_sk
+                      left join store_returns on 
sr_ticket_number=ss_ticket_number and ss_item_sk=sr_item_sk where 
sr_ticket_number is null """
+        qt_colocate_join """ select 
/*+SET_VAR(disable_join_reorder=true,ignore_storage_data_distribution=false)*/ 
count(*)
+                      from store_sales
+                      join date_dim on ss_sold_date_sk = d_date_sk
+                      left join store_returns on 
sr_ticket_number=ss_ticket_number and ss_item_sk=sr_item_sk where 
sr_ticket_number is null """
+        qt_analytic """ select 
/*+SET_VAR(ignore_storage_data_distribution=false)*/ max(ss_sold_date_sk)
+                                OVER (PARTITION BY ss_ticket_number, 
ss_item_sk) from (select *
+                                                      from store_sales
+                                                      join date_dim on 
ss_sold_date_sk = d_date_sk) result """
+    } finally {
+        sql """ DROP TABLE IF EXISTS store_sales """
+        sql """ DROP TABLE IF EXISTS date_dim """
+        sql """ DROP TABLE IF EXISTS store_returns """
+    }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to