This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 62b5210b59e [feature](Planner): Push down TopNDistinct through Join 
(#30216)
62b5210b59e is described below

commit 62b5210b59ea7cf12b5a2ab2912f8d21b5f182a9
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Tue Jan 23 16:46:35 2024 +0800

    [feature](Planner): Push down TopNDistinct through Join (#30216)
    
    Push down TopNDistinct through Outer/Cross Join
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   5 +-
 .../org/apache/doris/nereids/rules/RuleType.java   |   2 +
 .../rewrite/PushDownTopNDistinctThroughJoin.java   | 126 +++++++++++++++++++++
 .../org/apache/doris/nereids/util/PlanUtils.java   |   6 +
 .../limit_push_down/order_push_down.out            |  22 +++-
 .../push_down_top_n_distinct_through_join.out      |  47 ++++++++
 .../push_down_top_n_distinct_through_join.groovy   |  68 +++++++++++
 7 files changed, 271 insertions(+), 5 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index b2f1569ea1d..1704c629042 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -107,6 +107,7 @@ import 
org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughUnion;
 import org.apache.doris.nereids.rules.rewrite.PushDownMinMaxThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoinOneSide;
+import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughUnion;
 import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughWindow;
@@ -329,9 +330,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
                     topDown(new SplitLimit()),
                     topDown(
                             new PushDownLimit(),
-                            new PushDownTopNThroughJoin(),
                             new PushDownLimitDistinctThroughJoin(),
                             new PushDownLimitDistinctThroughUnion(),
+                            new PushDownTopNDistinctThroughJoin(),
+                            // new PushDownTopNDistinctThroughUnion(),
+                            new PushDownTopNThroughJoin(),
                             new PushDownTopNThroughWindow(),
                             new PushDownTopNThroughUnion()
                     ),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index e57934c9202..54ddd4152fa 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -265,6 +265,8 @@ public enum RuleType {
     // topN push down
     PUSH_DOWN_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
     PUSH_DOWN_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
+    PUSH_DOWN_TOP_N_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),
+    PUSH_DOWN_TOP_N_DISTINCT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
     PUSH_DOWN_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
     PUSH_DOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
     PUSH_DOWN_TOP_N_THROUGH_UNION(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java
new file mode 100644
index 00000000000..98d1cdd3478
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java
@@ -0,0 +1,126 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.properties.OrderKey;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Push down TopN-Distinct through Outer Join into left child .....
+ */
+public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                // topN -> join
+                logicalTopN(logicalAggregate(logicalJoin()).when(a -> 
a.isDistinct()))
+                        // TODO: complex orderby
+                        .when(topN -> 
topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                                .allMatch(Slot.class::isInstance))
+                        .then(topN -> {
+                            LogicalAggregate<LogicalJoin<Plan, Plan>> distinct 
= topN.child();
+                            LogicalJoin<Plan, Plan> join = distinct.child();
+                            Plan newJoin = pushTopNThroughJoin(topN, join);
+                            if (newJoin == null || 
topN.child().children().equals(newJoin.children())) {
+                                return null;
+                            }
+                            return 
topN.withChildren(distinct.withChildren(newJoin));
+                        })
+                        
.toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_JOIN),
+
+                // topN -> project -> join
+                
logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(p -> 
p.isAllSlots()))
+                        .when(a -> a.isDistinct()))
+                        .when(topN -> 
topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                                .allMatch(Slot.class::isInstance))
+                        .then(topN -> {
+                            LogicalAggregate<LogicalProject<LogicalJoin<Plan, 
Plan>>> distinct = topN.child();
+                            LogicalProject<LogicalJoin<Plan, Plan>> project = 
distinct.child();
+                            LogicalJoin<Plan, Plan> join = project.child();
+
+                            // If orderby exprs aren't all in the output of 
the project, we can't push down.
+                            // topN(order by: slot(a+1))
+                            // - project(a+1, b)
+                            // TODO: in the future, we also can push down it.
+                            Set<Slot> outputSet = 
project.child().getOutputSet();
+                            if 
(!topN.getOrderKeys().stream().map(OrderKey::getExpr)
+                                    .flatMap(e -> e.getInputSlots().stream())
+                                    .allMatch(outputSet::contains)) {
+                                return null;
+                            }
+
+                            Plan newJoin = pushTopNThroughJoin(topN, join);
+                            if (newJoin == null || 
join.children().equals(newJoin.children())) {
+                                return null;
+                            }
+                            return 
topN.withChildren(project.withChildren(distinct.withChildren(newJoin)));
+                        
}).toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_PROJECT_JOIN)
+        );
+    }
+
+    private Plan pushTopNThroughJoin(LogicalTopN<? extends Plan> topN, 
LogicalJoin<Plan, Plan> join) {
+        List<Slot> groupBySlots = ((LogicalAggregate<?>) 
topN.child()).getGroupByExpressions().stream()
+                .flatMap(e -> 
e.getInputSlots().stream()).collect(Collectors.toList());
+        switch (join.getJoinType()) {
+            case LEFT_OUTER_JOIN:
+                if (join.left().getOutputSet().containsAll(groupBySlots)) {
+                    LogicalTopN<Plan> left = 
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
+                            PlanUtils.distinct(join.left()));
+                    return join.withChildren(left, join.right());
+                }
+                return null;
+            case RIGHT_OUTER_JOIN:
+                if (join.right().getOutputSet().containsAll(groupBySlots)) {
+                    LogicalTopN<Plan> right = 
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
+                            PlanUtils.distinct(join.right()));
+                    return join.withChildren(join.left(), right);
+                }
+                return null;
+            case CROSS_JOIN:
+                if (join.left().getOutputSet().containsAll(groupBySlots)) {
+                    LogicalTopN<Plan> left = 
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
+                            PlanUtils.distinct(join.left()));
+                    return join.withChildren(left, join.right());
+                } else if 
(join.right().getOutputSet().containsAll(groupBySlots)) {
+                    LogicalTopN<Plan> right = 
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
+                            PlanUtils.distinct(join.right()));
+                    return join.withChildren(join.left(), right);
+                } else {
+                    return null;
+                }
+            default:
+                // don't push limit.
+                return null;
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 0dbd6044cae..94c0755642e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -22,10 +22,12 @@ import 
org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Sets;
 
 import java.util.List;
@@ -70,6 +72,10 @@ public class PlanUtils {
         return project(projects, plan).map(Plan.class::cast).orElse(plan);
     }
 
+    public static LogicalAggregate<Plan> distinct(Plan plan) {
+        return new LogicalAggregate<>(ImmutableList.copyOf(plan.getOutput()), 
false, plan);
+    }
+
     /**
      * merge childProjects with parentProjects
      */
diff --git 
a/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out 
b/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out
index b5cecee674e..b74605e19ec 100644
--- a/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out
+++ b/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out
@@ -119,7 +119,11 @@ PhysicalResultSink
 ------hashAgg[GLOBAL]
 --------hashAgg[LOCAL]
 ----------NestedLoopJoin[CROSS_JOIN]
-------------PhysicalOlapScan[t1]
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------hashAgg[GLOBAL]
+------------------hashAgg[LOCAL]
+--------------------PhysicalOlapScan[t1]
 ------------PhysicalOlapScan[t2]
 
 -- !limit_distinct --
@@ -129,7 +133,11 @@ PhysicalResultSink
 ------hashAgg[GLOBAL]
 --------hashAgg[LOCAL]
 ----------NestedLoopJoin[CROSS_JOIN]
-------------PhysicalOlapScan[t1]
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------hashAgg[GLOBAL]
+------------------hashAgg[LOCAL]
+--------------------PhysicalOlapScan[t1]
 ------------PhysicalOlapScan[t2]
 
 -- !limit_distinct --
@@ -139,7 +147,10 @@ PhysicalResultSink
 ------hashAgg[GLOBAL]
 --------hashAgg[LOCAL]
 ----------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.id = t2.id)) 
otherCondition=()
-------------PhysicalOlapScan[t1]
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------hashAgg[LOCAL]
+------------------PhysicalOlapScan[t1]
 ------------PhysicalOlapScan[t2]
 
 -- !limit_distinct --
@@ -149,7 +160,10 @@ PhysicalResultSink
 ------hashAgg[GLOBAL]
 --------hashAgg[LOCAL]
 ----------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.id = t2.id)) 
otherCondition=()
-------------PhysicalOlapScan[t1]
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------hashAgg[LOCAL]
+------------------PhysicalOlapScan[t1]
 ------------PhysicalOlapScan[t2]
 
 -- !limit_window --
diff --git 
a/regression-test/data/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_join.out
 
b/regression-test/data/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_join.out
new file mode 100644
index 00000000000..72452ed3c14
--- /dev/null
+++ 
b/regression-test/data/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_join.out
@@ -0,0 +1,47 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !push_down_topn_through_join --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalTopN[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------NestedLoopJoin[LEFT_OUTER_JOIN](t1.id = t1.id)
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------hashAgg[LOCAL]
+------------------PhysicalOlapScan[table_join]
+------------PhysicalOlapScan[table_join]
+
+-- !push_down_topn_through_join_data --
+0
+1
+2
+3
+4
+5
+6
+7
+
+-- !push_down_topn_through_join --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalTopN[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------NestedLoopJoin[CROSS_JOIN]
+------------PhysicalTopN[MERGE_SORT]
+--------------PhysicalTopN[LOCAL_SORT]
+----------------hashAgg[LOCAL]
+------------------PhysicalOlapScan[table_join]
+------------PhysicalOlapScan[table_join]
+
+-- !push_down_topn_through_join_data --
+0
+1
+2
+3
+4
+5
+6
+7
+
diff --git 
a/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_join.groovy
 
b/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_join.groovy
new file mode 100644
index 00000000000..890e2ba897d
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/push_down_top_n/push_down_top_n_distinct_through_join.groovy
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("push_down_top_n_distinct_through_join") {
+    sql "SET enable_nereids_planner=true"
+    sql "set runtime_filter_mode=OFF"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
+    sql "SET disable_join_reorder=true"
+
+    sql """
+        DROP TABLE IF EXISTS table_join;
+    """
+
+    sql """
+    CREATE TABLE IF NOT EXISTS table_join(
+      `id` int(32) NULL,
+      `score` int(64) NULL,
+      `name` varchar(64) NULL
+    ) ENGINE = OLAP
+    DISTRIBUTED BY HASH(id) BUCKETS 4
+    PROPERTIES (
+      "replication_allocation" = "tag.location.default: 1"
+    );
+    """
+
+    sql """
+        insert into table_join values
+            (0, NULL, 'Test'),
+            (1, 1, 'Test'),
+            (2, 2, 'Test'),
+            (3, 3, 'Test'),
+            (4, 4, 'Test'),
+            (5, 5, 'Test'),
+            (6, 6, 'Test'),
+            (7, 7, 'Test');
+    """
+
+    qt_push_down_topn_through_join """
+        explain shape plan select distinct * from (select t1.id from 
table_join t1 left join table_join t2 on t1.id = t1.id) t order by id limit 10;
+    """
+
+    qt_push_down_topn_through_join_data """
+        select distinct * from (select t1.id from table_join t1 left join 
table_join t2 on t1.id = t1.id) t order by id limit 10;
+    """
+
+    qt_push_down_topn_through_join """
+        explain shape plan select distinct * from (select t1.id from 
table_join t1 cross join table_join t2) t order by id limit 10;
+    """
+
+    qt_push_down_topn_through_join_data """
+        select distinct * from (select t1.id from table_join t1 cross join 
table_join t2) t order by id limit 10;
+    """
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to