This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new fa3bdbce966 [opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE 
(#43856)
fa3bdbce966 is described below

commit fa3bdbce966512da6986046df8558c1d04e93f61
Author: minghong <zhoumingh...@selectdb.com>
AuthorDate: Thu Nov 28 11:19:32 2024 +0800

    [opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE (#43856)
    
    ### What problem does this PR solve?
    PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE has some restrictions
    
    do not support count(*)
    do not support join with other join conditions
    do not support the project between agg and join that contains non-slot
    expressions
    this pr removes above restrictions for pattern: agg-project-join
---
 .../rewrite/PushDownAggThroughJoinOneSide.java     | 123 +++++++++++++++------
 .../rewrite/PushDownMinMaxSumThroughJoinTest.java  |  16 ++-
 .../push_down_count_through_join_one_side.out      |  22 ++++
 .../push_down_count_through_join_one_side.groovy   |  95 ++++++++++++++++
 4 files changed, 212 insertions(+), 44 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
index f32bf8ea91b..c5d3d0fb49a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
@@ -36,6 +36,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
+import com.google.common.collect.Lists;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -74,8 +75,8 @@ public class PushDownAggThroughJoinOneSide implements 
RewriteRuleFactory {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
                                     .allMatch(f -> (f instanceof Min || f 
instanceof Max || f instanceof Sum
-                                            || (f instanceof Count && 
!((Count) f).isCountStar())) && !f.isDistinct()
-                                            && f.child(0) instanceof Slot);
+                                            || f instanceof Count && 
!f.isDistinct()
+                                            && (f.children().isEmpty() || 
f.child(0) instanceof Slot)));
                         })
                         .thenApply(ctx -> {
                             Set<Integer> enableNereidsRules = 
ctx.cascadesContext.getConnectContext()
@@ -88,15 +89,16 @@ public class PushDownAggThroughJoinOneSide implements 
RewriteRuleFactory {
                         })
                         .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE),
                 logicalAggregate(logicalProject(innerLogicalJoin()))
-                        .when(agg -> agg.child().isAllSlots())
-                        .when(agg -> 
agg.child().child().getOtherJoinConjuncts().isEmpty())
-                        .whenNot(agg -> 
agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
+                        // .when(agg -> agg.child().isAllSlots())
+                        // .when(agg -> 
agg.child().child().getOtherJoinConjuncts().isEmpty())
+                        .whenNot(agg -> agg.child()
+                                .child(0).children().stream().anyMatch(p -> p 
instanceof LogicalAggregate))
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
                                     .allMatch(f -> (f instanceof Min || f 
instanceof Max || f instanceof Sum
-                                            || (f instanceof Count && 
(!((Count) f).isCountStar()))) && !f.isDistinct()
-                                            && f.child(0) instanceof Slot);
+                                            || f instanceof Count) && 
!f.isDistinct()
+                                            && (f.children().isEmpty() || 
f.child(0) instanceof Slot));
                         })
                         .thenApply(ctx -> {
                             Set<Integer> enableNereidsRules = 
ctx.cascadesContext.getConnectContext()
@@ -118,23 +120,6 @@ public class PushDownAggThroughJoinOneSide implements 
RewriteRuleFactory {
             LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
         List<Slot> leftOutput = join.left().getOutput();
         List<Slot> rightOutput = join.right().getOutput();
-
-        List<AggregateFunction> leftFuncs = new ArrayList<>();
-        List<AggregateFunction> rightFuncs = new ArrayList<>();
-        for (AggregateFunction func : agg.getAggregateFunctions()) {
-            Slot slot = (Slot) func.child(0);
-            if (leftOutput.contains(slot)) {
-                leftFuncs.add(func);
-            } else if (rightOutput.contains(slot)) {
-                rightFuncs.add(func);
-            } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
-            }
-        }
-        if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
-            return null;
-        }
-
         Set<Slot> leftGroupBy = new HashSet<>();
         Set<Slot> rightGroupBy = new HashSet<>();
         for (Expression e : agg.getGroupByExpressions()) {
@@ -144,18 +129,71 @@ public class PushDownAggThroughJoinOneSide implements 
RewriteRuleFactory {
             } else if (rightOutput.contains(slot)) {
                 rightGroupBy.add(slot);
             } else {
-                return null;
+                if (projects.isEmpty()) {
+                    // TODO: select ... from ... group by A , B, 1.2; 1.2 is 
constant
+                    return null;
+                } else {
+                    for (NamedExpression proj : projects) {
+                        if (proj instanceof Alias && 
proj.toSlot().equals(slot)) {
+                            Set<Slot> inputForAliasSet = proj.getInputSlots();
+                            for (Slot aliasInputSlot : inputForAliasSet) {
+                                if (leftOutput.contains(aliasInputSlot)) {
+                                    leftGroupBy.add(aliasInputSlot);
+                                } else if 
(rightOutput.contains(aliasInputSlot)) {
+                                    rightGroupBy.add(aliasInputSlot);
+                                } else {
+                                    return null;
+                                }
+                            }
+                            break;
+                        }
+                    }
+                }
             }
         }
-        join.getHashJoinConjuncts().forEach(e -> 
e.getInputSlots().forEach(slot -> {
-            if (leftOutput.contains(slot)) {
-                leftGroupBy.add(slot);
-            } else if (rightOutput.contains(slot)) {
-                rightGroupBy.add(slot);
+
+        List<AggregateFunction> leftFuncs = new ArrayList<>();
+        List<AggregateFunction> rightFuncs = new ArrayList<>();
+        Count countStar = null;
+        Count rewrittenCountStar = null;
+        for (AggregateFunction func : agg.getAggregateFunctions()) {
+            if (func instanceof Count && ((Count) func).isCountStar()) {
+                countStar = (Count) func;
+            } else {
+                Slot slot = (Slot) func.child(0);
+                if (leftOutput.contains(slot)) {
+                    leftFuncs.add(func);
+                } else if (rightOutput.contains(slot)) {
+                    rightFuncs.add(func);
+                } else {
+                    throw new IllegalStateException("Slot " + slot + " not 
found in join output");
+                }
+            }
+        }
+        // rewrite count(*) to count(A), where A is slot from left/right group 
by key
+        if (countStar != null) {
+            if (!leftGroupBy.isEmpty()) {
+                rewrittenCountStar = (Count) 
countStar.withChildren(leftGroupBy.iterator().next());
+                leftFuncs.add(rewrittenCountStar);
+            } else if (!rightGroupBy.isEmpty()) {
+                rewrittenCountStar = (Count) 
countStar.withChildren(rightGroupBy.iterator().next());
+                rightFuncs.add(rewrittenCountStar);
             } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
+                return null;
+            }
+        }
+        for (Expression condition : join.getHashJoinConjuncts()) {
+            for (Slot joinConditionSlot : condition.getInputSlots()) {
+                if (leftOutput.contains(joinConditionSlot)) {
+                    leftGroupBy.add(joinConditionSlot);
+                } else if (rightOutput.contains(joinConditionSlot)) {
+                    rightGroupBy.add(joinConditionSlot);
+                } else {
+                    // apply failed
+                    return null;
+                }
             }
-        }));
+        }
 
         Plan left = join.left();
         Plan right = join.right();
@@ -196,6 +234,10 @@ public class PushDownAggThroughJoinOneSide implements 
RewriteRuleFactory {
         for (NamedExpression ne : agg.getOutputExpressions()) {
             if (ne instanceof Alias && ((Alias) ne).child() instanceof 
AggregateFunction) {
                 AggregateFunction func = (AggregateFunction) ((Alias) 
ne).child();
+                if (func instanceof Count && ((Count) func).isCountStar()) {
+                    // countStar is already rewritten as count(left_slot) or 
count(right_slot)
+                    func = rewrittenCountStar;
+                }
                 Slot slot = (Slot) func.child(0);
                 if (leftSlotToOutput.containsKey(slot)) {
                     Expression newFunc = replaceAggFunc(func, 
leftSlotToOutput.get(slot).toSlot());
@@ -210,9 +252,20 @@ public class PushDownAggThroughJoinOneSide implements 
RewriteRuleFactory {
                 newOutputExprs.add(ne);
             }
         }
-
-        // TODO: column prune project
-        return agg.withAggOutputChild(newOutputExprs, newJoin);
+        Plan newAggChild = newJoin;
+        if (agg.child() instanceof LogicalProject) {
+            LogicalProject project = (LogicalProject) agg.child();
+            List<NamedExpression> newProjections = Lists.newArrayList();
+            newProjections.addAll(project.getProjects());
+            Set<NamedExpression> leftDifference = new 
HashSet<NamedExpression>(left.getOutput());
+            leftDifference.removeAll(project.getProjects());
+            newProjections.addAll(leftDifference);
+            Set<NamedExpression> rightDifference = new 
HashSet<NamedExpression>(right.getOutput());
+            rightDifference.removeAll(project.getProjects());
+            newProjections.addAll(rightDifference);
+            newAggChild = ((LogicalProject) 
agg.child()).withProjectsAndChild(newProjections, newJoin);
+        }
+        return agg.withAggOutputChild(newOutputExprs, newAggChild);
     }
 
     private static Expression replaceAggFunc(AggregateFunction func, Slot 
inputSlot) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
index 58ab7fbe9e9..cffe91045d0 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
@@ -323,11 +323,11 @@ class PushDownMinMaxSumThroughJoinTest implements 
MemoPatternMatchSupported {
                 .applyTopDown(new PushDownAggThroughJoinOneSide())
                 .printlnTree()
                 .matches(
-                        logicalAggregate(
-                                logicalJoin(
-                                        logicalOlapScan(),
+                        logicalJoin(
+                                logicalAggregate(
                                         logicalOlapScan()
-                                )
+                                ),
+                                logicalOlapScan()
                         )
                 );
     }
@@ -346,11 +346,9 @@ class PushDownMinMaxSumThroughJoinTest implements 
MemoPatternMatchSupported {
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
                 .applyTopDown(new PushDownAggThroughJoinOneSide())
                 .matches(
-                        logicalAggregate(
-                                logicalJoin(
-                                        logicalOlapScan(),
-                                        logicalOlapScan()
-                                )
+                        logicalJoin(
+                                logicalAggregate(logicalOlapScan()),
+                                logicalAggregate(logicalOlapScan())
                         )
                 );
     }
diff --git 
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
 
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
index da69919becd..8267eb3e38f 100644
--- 
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
+++ 
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
@@ -1034,3 +1034,25 @@ Used:
 UnUsed: use_push_down_agg_through_join_one_side
 SyntaxError:
 
+-- !shape --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalTopN[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------hashJoin[INNER_JOIN] 
hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = 
dw_user_b2c_tracking_info_tmp_ymd.dt) and 
(dwd_tracking_sensor_init_tmp_ymd.guid = 
dw_user_b2c_tracking_info_tmp_ymd.guid)) 
otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= 
substring(first_visit_time, 1, 10)))
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') 
and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'))
+------------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd]
+------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'))
+--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd]
+
+Hint log:
+Used: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE
+UnUsed:
+SyntaxError:
+
+-- !agg_pushed --
+2      是       2024-08-19
+
diff --git 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
index 02e06710296..e551fa04c91 100644
--- 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
@@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") {
     qt_with_hint_groupby_pushdown_nested_queries """
         explain shape plan select /*+ 
USE_CBO_RULE(push_down_agg_through_join_one_side) */  count(*) from (select * 
from count_t_one_side where score > 20) t1 join (select * from count_t_one_side 
where id < 100) t2 on t1.id = t2.id group by t1.name;
     """
+
+    sql """
+    drop table if exists dw_user_b2c_tracking_info_tmp_ymd;
+    create table dw_user_b2c_tracking_info_tmp_ymd (
+        guid int,
+        dt varchar,
+        first_visit_time varchar
+    )Engine=Olap
+    DUPLICATE KEY(guid)
+    distributed by hash(dt) buckets 3
+    properties('replication_num' = '1');
+
+    insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', 
'2024-08-19');
+
+    drop table if exists dwd_tracking_sensor_init_tmp_ymd;
+    create table dwd_tracking_sensor_init_tmp_ymd (
+        guid int,
+        dt varchar,
+        tracking_type varchar
+    )Engine=Olap
+    DUPLICATE KEY(guid)
+    distributed by hash(dt) buckets 3
+    properties('replication_num' = '1');
+
+    insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 
'click'), (1, '2024-08-19', 'click');
+    """
+    sql """ 
+    set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE";
+    set disable_join_reorder=true;
+    """
+
+    qt_shape """
+    explain shape plan
+    SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
+        Count(*)                            AS accee593,
+       CASE
+         WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
+              Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+              10) THEN
+         '是'
+         WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
+              Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+              10) THEN
+         '否'
+         ELSE '-1'
+       end                                 AS a1302fb2,
+       dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
+    FROM   dwd_tracking_sensor_init_tmp_ymd
+        LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
+                ON dwd_tracking_sensor_init_tmp_ymd.guid =
+                    dw_user_b2c_tracking_info_tmp_ymd.guid
+                    AND dwd_tracking_sensor_init_tmp_ymd.dt =
+                        dw_user_b2c_tracking_info_tmp_ymd.dt
+    WHERE  dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
+        AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
+        AND dwd_tracking_sensor_init_tmp_ymd.dt >=
+            Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 
10)
+        AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
+    GROUP  BY 2,
+            3
+    ORDER  BY 3 ASC
+    LIMIT  10000; 
+    """
+
+    qt_agg_pushed  """
+    SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
+        Count(*)                            AS accee593,
+       CASE
+         WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
+              Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+              10) THEN
+         '是'
+         WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
+              Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+              10) THEN
+         '否'
+         ELSE '-1'
+       end                                 AS a1302fb2,
+       dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
+    FROM   dwd_tracking_sensor_init_tmp_ymd
+        LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
+                ON dwd_tracking_sensor_init_tmp_ymd.guid =
+                    dw_user_b2c_tracking_info_tmp_ymd.guid
+                    AND dwd_tracking_sensor_init_tmp_ymd.dt =
+                        dw_user_b2c_tracking_info_tmp_ymd.dt
+    WHERE  dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
+        AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
+        AND dwd_tracking_sensor_init_tmp_ymd.dt >=
+            Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 
10)
+        AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
+    GROUP  BY 2,
+            3
+    ORDER  BY 3 ASC
+    LIMIT  10000; 
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to