This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new c583563087b [feature](Nereids): double eager support mix function 
(#30468)
c583563087b is described below

commit c583563087bc5a0db9920aa88aafb63a5bd61e19
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Mon Jan 29 13:08:09 2024 +0800

    [feature](Nereids): double eager support mix function (#30468)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   6 +-
 .../org/apache/doris/nereids/rules/RuleType.java   |   3 +-
 ...hroughJoin.java => PushDownAggThroughJoin.java} | 107 +++++------
 .../rules/rewrite/PushDownSumThroughJoin.java      | 212 ---------------------
 .../rewrite/PushDownCountThroughJoinTest.java      |  13 +-
 .../rules/rewrite/PushDownSumThroughJoinTest.java  |  29 ++-
 .../eager_aggregate/push_down_sum_through_join.out |  12 +-
 .../nereids_rules_p0/eager_aggregate/basic.groovy  |   3 +-
 .../push_down_count_through_join.groovy            |   2 +-
 .../push_down_sum_through_join.groovy              |   4 +-
 10 files changed, 101 insertions(+), 290 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 2c0e57b715e..34f7afe4995 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -98,14 +98,13 @@ import 
org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
 import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
 import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
 import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
+import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
-import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
 import org.apache.doris.nereids.rules.rewrite.PushDownLimit;
 import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
 import 
org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughUnion;
-import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin;
 import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughUnion;
 import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin;
@@ -288,9 +287,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
 
             topic("Eager aggregation",
                     topDown(
-                            new PushDownSumThroughJoin(),
                             new PushDownAggThroughJoinOneSide(),
-                            new PushDownCountThroughJoin()
+                            new PushDownAggThroughJoin()
                     ),
                     custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, 
PushDownDistinctThroughJoin::new)
             ),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b35c7e03b72..594f49a3b70 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -167,8 +167,7 @@ public enum RuleType {
     ELIMINATE_SORT(RuleTypeClass.REWRITE),
 
     PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE),
-    PUSH_DOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE),
-    PUSH_DOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE),
+    PUSH_DOWN_AGG_THROUGH_JOIN(RuleTypeClass.REWRITE),
 
     TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE),
     TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java
similarity index 69%
rename from 
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java
rename to 
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java
index 462180ab7a6..f003d2ac2cc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java
@@ -67,7 +67,7 @@ import java.util.Set;
  *  </pre>
  * Notice: rule can't optimize condition that groupby is empty when Count(*) 
exists.
  */
-public class PushDownCountThroughJoin implements RewriteRuleFactory {
+public class PushDownAggThroughJoin implements RewriteRuleFactory {
     @Override
     public List<Rule> buildRules() {
         return ImmutableList.of(
@@ -78,19 +78,22 @@ public class PushDownCountThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Count && 
!f.isDistinct()
-                                            && (((Count) f).isCountStar() || 
f.child(0) instanceof Slot));
+                                    .allMatch(f -> !f.isDistinct()
+                                            && (f instanceof Count && 
(((Count) f).isCountStar() || f.child(
+                                            0) instanceof Slot)
+                                            || (f instanceof Sum && f.child(0) 
instanceof Slot))
+                                    );
                         })
                         .thenApply(ctx -> {
                             Set<Integer> enableNereidsRules = 
ctx.cascadesContext.getConnectContext()
                                     
.getSessionVariable().getEnableNereidsRules();
-                            if 
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type())) {
+                            if 
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type())) {
                                 return null;
                             }
                             LogicalAggregate<LogicalJoin<Plan, Plan>> agg = 
ctx.root;
-                            return pushCount(agg, agg.child(), 
ImmutableList.of());
+                            return pushAgg(agg, agg.child(), 
ImmutableList.of());
                         })
-                        .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN),
+                        .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN),
                 logicalAggregate(logicalProject(innerLogicalJoin()))
                         .when(agg -> agg.child().isAllSlots())
                         .when(agg -> 
agg.child().child().getOtherJoinConjuncts().isEmpty())
@@ -99,40 +102,42 @@ public class PushDownCountThroughJoin implements 
RewriteRuleFactory {
                         .when(agg -> {
                             Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
                             return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Count && 
!f.isDistinct()
-                                            && (((Count) f).isCountStar() || 
f.child(0) instanceof Slot));
+                                    .allMatch(f -> !f.isDistinct()
+                                            && (f instanceof Count && 
(((Count) f).isCountStar() || f.child(
+                                            0) instanceof Slot)
+                                            || (f instanceof Sum && f.child(0) 
instanceof Slot))
+                                    );
                         })
                         .thenApply(ctx -> {
                             Set<Integer> enableNereidsRules = 
ctx.cascadesContext.getConnectContext()
                                     
.getSessionVariable().getEnableNereidsRules();
-                            if 
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type())) {
+                            if 
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type())) {
                                 return null;
                             }
                             LogicalAggregate<LogicalProject<LogicalJoin<Plan, 
Plan>>> agg = ctx.root;
-                            return pushCount(agg, agg.child().child(), 
agg.child().getProjects());
+                            return pushAgg(agg, agg.child().child(), 
agg.child().getProjects());
                         })
-                        .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN)
+                        .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN)
         );
     }
 
-    private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> 
agg,
+    private static LogicalAggregate<Plan> pushAgg(LogicalAggregate<? extends 
Plan> agg,
             LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
         List<Slot> leftOutput = join.left().getOutput();
         List<Slot> rightOutput = join.right().getOutput();
 
-        List<Count> leftCounts = new ArrayList<>();
-        List<Count> rightCounts = new ArrayList<>();
+        List<AggregateFunction> leftAggs = new ArrayList<>();
+        List<AggregateFunction> rightAggs = new ArrayList<>();
         List<Count> countStars = new ArrayList<>();
         for (AggregateFunction f : agg.getAggregateFunctions()) {
-            Count count = (Count) f;
-            if (count.isCountStar()) {
-                countStars.add(count);
+            if (f instanceof Count && ((Count) f).isCountStar()) {
+                countStars.add((Count) f);
             } else {
-                Slot slot = (Slot) count.child(0);
+                Slot slot = (Slot) f.child(0);
                 if (leftOutput.contains(slot)) {
-                    leftCounts.add(count);
+                    leftAggs.add(f);
                 } else if (rightOutput.contains(slot)) {
-                    rightCounts.add(count);
+                    rightAggs.add(f);
                 } else {
                     throw new IllegalStateException("Slot " + slot + " not 
found in join output");
                 }
@@ -168,63 +173,59 @@ public class PushDownCountThroughJoin implements 
RewriteRuleFactory {
 
         Alias leftCnt = null;
         Alias rightCnt = null;
-        // left Count agg
-        Map<Slot, NamedExpression> leftCntSlotToOutput = new HashMap<>();
-        Builder<NamedExpression> leftCntAggOutputBuilder = 
ImmutableList.<NamedExpression>builder()
-                .addAll(leftGroupBy);
-        leftCounts.forEach(func -> {
+        // left agg
+        Map<Slot, NamedExpression> leftSlotToOutput = new HashMap<>();
+        Builder<NamedExpression> leftAggOutputBuilder = 
ImmutableList.<NamedExpression>builder().addAll(leftGroupBy);
+        leftAggs.forEach(func -> {
             Alias alias = func.alias(func.getName());
-            leftCntSlotToOutput.put((Slot) func.child(0), alias);
-            leftCntAggOutputBuilder.add(alias);
+            leftSlotToOutput.put((Slot) func.child(0), alias);
+            leftAggOutputBuilder.add(alias);
         });
-        if (!rightCounts.isEmpty() || !countStars.isEmpty()) {
+        if (!rightAggs.isEmpty() || !countStars.isEmpty()) {
             leftCnt = new Count().alias("leftCntStar");
-            leftCntAggOutputBuilder.add(leftCnt);
+            leftAggOutputBuilder.add(leftCnt);
         }
-        LogicalAggregate<Plan> leftCntAgg = new LogicalAggregate<>(
-                ImmutableList.copyOf(leftGroupBy), 
leftCntAggOutputBuilder.build(), join.left());
-
-        // right Count agg
-        Map<Slot, NamedExpression> rightCntSlotToOutput = new HashMap<>();
-        Builder<NamedExpression> rightCntAggOutputBuilder = 
ImmutableList.<NamedExpression>builder()
-                .addAll(rightGroupBy);
-        rightCounts.forEach(func -> {
+        LogicalAggregate<Plan> leftAgg = new LogicalAggregate<>(
+                ImmutableList.copyOf(leftGroupBy), 
leftAggOutputBuilder.build(), join.left());
+        // right agg
+        Map<Slot, NamedExpression> rightSlotToOutput = new HashMap<>();
+        Builder<NamedExpression> rightAggOutputBuilder = 
ImmutableList.<NamedExpression>builder().addAll(rightGroupBy);
+        rightAggs.forEach(func -> {
             Alias alias = func.alias(func.getName());
-            rightCntSlotToOutput.put((Slot) func.child(0), alias);
-            rightCntAggOutputBuilder.add(alias);
+            rightSlotToOutput.put((Slot) func.child(0), alias);
+            rightAggOutputBuilder.add(alias);
         });
-
-        if (!leftCounts.isEmpty() || !countStars.isEmpty()) {
+        if (!leftAggs.isEmpty() || !countStars.isEmpty()) {
             rightCnt = new Count().alias("rightCntStar");
-            rightCntAggOutputBuilder.add(rightCnt);
+            rightAggOutputBuilder.add(rightCnt);
         }
-        LogicalAggregate<Plan> rightCntAgg = new LogicalAggregate<>(
-                ImmutableList.copyOf(rightGroupBy), 
rightCntAggOutputBuilder.build(), join.right());
+        LogicalAggregate<Plan> rightAgg = new LogicalAggregate<>(
+                ImmutableList.copyOf(rightGroupBy), 
rightAggOutputBuilder.build(), join.right());
 
-        Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg);
+        Plan newJoin = join.withChildren(leftAgg, rightAgg);
 
         // top Sum agg
         // count(slot) -> sum( count(slot) * cntStar )
         // count(*) -> sum( leftCntStar * leftCntStar )
         List<NamedExpression> newOutputExprs = new ArrayList<>();
         for (NamedExpression ne : agg.getOutputExpressions()) {
-            if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) {
-                Count oldTopCnt = (Count) ((Alias) ne).child();
-                if (oldTopCnt.isCountStar()) {
+            if (ne instanceof Alias && ((Alias) ne).child() instanceof 
AggregateFunction) {
+                AggregateFunction func = (AggregateFunction) ((Alias) 
ne).child();
+                if (func instanceof Count && ((Count) func).isCountStar()) {
                     Preconditions.checkState(rightCnt != null && leftCnt != 
null);
                     Expression expr = new Sum(new Multiply(leftCnt.toSlot(), 
rightCnt.toSlot()));
                     newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
                 } else {
-                    Slot slot = (Slot) oldTopCnt.child(0);
-                    if (leftCntSlotToOutput.containsKey(slot)) {
+                    Slot slot = (Slot) func.child(0);
+                    if (leftSlotToOutput.containsKey(slot)) {
                         Preconditions.checkState(rightCnt != null);
                         Expression expr = new Sum(
-                                new 
Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
+                                new 
Multiply(leftSlotToOutput.get(slot).toSlot(), rightCnt.toSlot()));
                         newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
-                    } else if (rightCntSlotToOutput.containsKey(slot)) {
+                    } else if (rightSlotToOutput.containsKey(slot)) {
                         Preconditions.checkState(leftCnt != null);
                         Expression expr = new Sum(
-                                new 
Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
+                                new 
Multiply(rightSlotToOutput.get(slot).toSlot(), leftCnt.toSlot()));
                         newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
                     } else {
                         throw new IllegalStateException("Slot " + slot + " not 
found in join output");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
deleted file mode 100644
index e8987e670a5..00000000000
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java
+++ /dev/null
@@ -1,212 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.Multiply;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableList.Builder;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * TODO: distinct
- * Related paper "Eager aggregation and lazy aggregation".
- * <pre>
- * aggregate: Sum(x)
- * |
- * join
- * |   \
- * |    *
- * (x)
- * ->
- * aggregate: Sum(sum1)
- * |
- * join
- * |   \
- * |    *
- * aggregate: Sum(x) as sum1
- * </pre>
- */
-public class PushDownSumThroughJoin implements RewriteRuleFactory {
-    @Override
-    public List<Rule> buildRules() {
-        return ImmutableList.of(
-                logicalAggregate(innerLogicalJoin())
-                        .when(agg -> 
agg.child().getOtherJoinConjuncts().isEmpty())
-                        .whenNot(agg -> 
agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
-                        .when(agg -> {
-                            Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
-                            return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Sum && 
!f.isDistinct() && f.child(0) instanceof Slot);
-                        })
-                        .thenApply(ctx -> {
-                            Set<Integer> enableNereidsRules = 
ctx.cascadesContext.getConnectContext()
-                                    
.getSessionVariable().getEnableNereidsRules();
-                            if 
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type())) {
-                                return null;
-                            }
-                            LogicalAggregate<LogicalJoin<Plan, Plan>> agg = 
ctx.root;
-                            return pushSum(agg, agg.child(), 
ImmutableList.of());
-                        })
-                        .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN),
-                logicalAggregate(logicalProject(innerLogicalJoin()))
-                        .when(agg -> agg.child().isAllSlots())
-                        .when(agg -> 
agg.child().child().getOtherJoinConjuncts().isEmpty())
-                        .whenNot(agg -> 
agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
-                        .when(agg -> {
-                            Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
-                            return !funcs.isEmpty() && funcs.stream()
-                                    .allMatch(f -> f instanceof Sum && 
!f.isDistinct() && f.child(0) instanceof Slot);
-                        })
-                        .thenApply(ctx -> {
-                            Set<Integer> enableNereidsRules = 
ctx.cascadesContext.getConnectContext()
-                                    
.getSessionVariable().getEnableNereidsRules();
-                            if 
(!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type())) {
-                                return null;
-                            }
-                            LogicalAggregate<LogicalProject<LogicalJoin<Plan, 
Plan>>> agg = ctx.root;
-                            return pushSum(agg, agg.child().child(), 
agg.child().getProjects());
-                        })
-                        .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN)
-        );
-    }
-
-    private LogicalAggregate<Plan> pushSum(LogicalAggregate<? extends Plan> 
agg,
-            LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
-        List<Slot> leftOutput = join.left().getOutput();
-        List<Slot> rightOutput = join.right().getOutput();
-
-        List<Sum> leftSums = new ArrayList<>();
-        List<Sum> rightSums = new ArrayList<>();
-        for (AggregateFunction f : agg.getAggregateFunctions()) {
-            Sum sum = (Sum) f;
-            Slot slot = (Slot) sum.child();
-            if (leftOutput.contains(slot)) {
-                leftSums.add(sum);
-            } else if (rightOutput.contains(slot)) {
-                rightSums.add(sum);
-            } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
-            }
-        }
-        if (leftSums.isEmpty() && rightSums.isEmpty()
-                || (!leftSums.isEmpty() && !rightSums.isEmpty())) {
-            return null;
-        }
-
-        Set<Slot> leftGroupBy = new HashSet<>();
-        Set<Slot> rightGroupBy = new HashSet<>();
-        for (Expression e : agg.getGroupByExpressions()) {
-            Slot slot = (Slot) e;
-            if (leftOutput.contains(slot)) {
-                leftGroupBy.add(slot);
-            } else if (rightOutput.contains(slot)) {
-                rightGroupBy.add(slot);
-            } else {
-                return null;
-            }
-        }
-        join.getHashJoinConjuncts().forEach(e -> 
e.getInputSlots().forEach(slot -> {
-            if (leftOutput.contains(slot)) {
-                leftGroupBy.add(slot);
-            } else if (rightOutput.contains(slot)) {
-                rightGroupBy.add(slot);
-            } else {
-                throw new IllegalStateException("Slot " + slot + " not found 
in join output");
-            }
-        }));
-
-        List<Sum> sums;
-        Set<Slot> sumGroupBy;
-        Set<Slot> cntGroupBy;
-        Plan sumChild;
-        Plan cntChild;
-        if (!leftSums.isEmpty()) {
-            sums = leftSums;
-            sumGroupBy = leftGroupBy;
-            cntGroupBy = rightGroupBy;
-            sumChild = join.left();
-            cntChild = join.right();
-        } else {
-            sums = rightSums;
-            sumGroupBy = rightGroupBy;
-            cntGroupBy = leftGroupBy;
-            sumChild = join.right();
-            cntChild = join.left();
-        }
-
-        // Sum agg
-        Map<Slot, NamedExpression> sumSlotToOutput = new HashMap<>();
-        Builder<NamedExpression> sumAggOutputBuilder = 
ImmutableList.<NamedExpression>builder().addAll(sumGroupBy);
-        sums.forEach(func -> {
-            Alias alias = func.alias(func.getName());
-            sumSlotToOutput.put((Slot) func.child(0), alias);
-            sumAggOutputBuilder.add(alias);
-        });
-        LogicalAggregate<Plan> sumAgg = new LogicalAggregate<>(
-                ImmutableList.copyOf(sumGroupBy), sumAggOutputBuilder.build(), 
sumChild);
-
-        // Count agg
-        Alias cnt = new Count().alias("cnt");
-        List<NamedExpression> cntAggOutput = 
ImmutableList.<NamedExpression>builder()
-                .addAll(cntGroupBy).add(cnt)
-                .build();
-        LogicalAggregate<Plan> cntAgg = new LogicalAggregate<>(
-                ImmutableList.copyOf(cntGroupBy), cntAggOutput, cntChild);
-
-        Plan newJoin = !leftSums.isEmpty() ? join.withChildren(sumAgg, cntAgg) 
: join.withChildren(cntAgg, sumAgg);
-
-        // top Sum agg
-        // replace sum(x) -> sum(sum# * cnt)
-        List<NamedExpression> newOutputExprs = new ArrayList<>();
-        for (NamedExpression ne : agg.getOutputExpressions()) {
-            if (ne instanceof Alias && ((Alias) ne).child() instanceof 
AggregateFunction) {
-                AggregateFunction func = (AggregateFunction) ((Alias) 
ne).child();
-                Slot slot = (Slot) func.child(0);
-                if (sumSlotToOutput.containsKey(slot)) {
-                    Expression expr = func.withChildren(new 
Multiply(sumSlotToOutput.get(slot).toSlot(), cnt.toSlot()));
-                    newOutputExprs.add((NamedExpression) 
ne.withChildren(expr));
-                } else {
-                    throw new IllegalStateException("Slot " + slot + " not 
found in join output");
-                }
-            } else {
-                newOutputExprs.add(ne);
-            }
-        }
-        return agg.withAggOutputChild(newOutputExprs, newJoin);
-    }
-}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
index 34ccfe70f70..8e0e0e15df3 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java
@@ -45,7 +45,7 @@ class PushDownCountThroughJoinTest implements 
MemoPatternMatchSupported {
     private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() {
         @Mock
         public Set<Integer> getEnableNereidsRules() {
-            return 
ImmutableSet.of(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type());
+            return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type());
         }
     };
 
@@ -58,7 +58,8 @@ class PushDownCountThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownCountThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
+                .printlnTree()
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
@@ -81,7 +82,7 @@ class PushDownCountThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownCountThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
@@ -101,7 +102,7 @@ class PushDownCountThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownCountThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
@@ -122,7 +123,7 @@ class PushDownCountThroughJoinTest implements 
MemoPatternMatchSupported {
 
         // shouldn't rewrite.
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownCountThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
@@ -145,7 +146,7 @@ class PushDownCountThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownCountThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
index 088372b0d76..29a745b379f 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java
@@ -45,7 +45,7 @@ class PushDownSumThroughJoinTest implements 
MemoPatternMatchSupported {
     private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() {
         @Mock
         public Set<Integer> getEnableNereidsRules() {
-            return ImmutableSet.of(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type());
+            return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type());
         }
     };
 
@@ -58,7 +58,7 @@ class PushDownSumThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownSumThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
@@ -78,7 +78,28 @@ class PushDownSumThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownSumThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
+                .matches(
+                        logicalAggregate(
+                                logicalJoin(
+                                        logicalAggregate(),
+                                        logicalAggregate()
+                                )
+                        )
+                );
+    }
+
+    @Test
+    void testSingleJoinBothSum() {
+        Alias leftSum = new Sum(scan1.getOutput().get(1)).alias("leftSum");
+        Alias rightSum = new Sum(scan2.getOutput().get(1)).alias("rightSum");
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0), 
ImmutableList.of(scan1.getOutput().get(0), leftSum, rightSum))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
@@ -99,7 +120,7 @@ class PushDownSumThroughJoinTest implements 
MemoPatternMatchSupported {
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
-                .applyTopDown(new PushDownSumThroughJoin())
+                .applyTopDown(new PushDownAggThroughJoin())
                 .matches(
                         logicalAggregate(
                                 logicalJoin(
diff --git 
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
 
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
index da05df5419d..106d8882079 100644
--- 
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
+++ 
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out
@@ -176,8 +176,10 @@ PhysicalResultSink
 --hashAgg[GLOBAL]
 ----hashAgg[LOCAL]
 ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = 
t2.name)) otherCondition=()
---------PhysicalOlapScan[sum_t]
---------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
 
 -- !groupby_pushdown_with_where_clause --
 PhysicalResultSink
@@ -195,8 +197,10 @@ PhysicalResultSink
 --hashAgg[GLOBAL]
 ----hashAgg[LOCAL]
 ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=()
---------PhysicalOlapScan[sum_t]
---------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
+--------hashAgg[LOCAL]
+----------PhysicalOlapScan[sum_t]
 
 -- !groupby_pushdown_with_order_by_limit --
 PhysicalResultSink
diff --git 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
index 58d50b3add4..249e7af4bb4 100644
--- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
+++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy
@@ -22,8 +22,7 @@ suite("eager_aggregate_basic") {
     sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
 
     sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side"
-    sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join"
-    sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join"
+    sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join"
 
     sql """
         DROP TABLE IF EXISTS shunt_log_com_dd_library;
diff --git 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
index f5f4bf53b45..37cd6000941 100644
--- 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy
@@ -48,7 +48,7 @@ suite("push_down_count_through_join") {
     sql "insert into count_t values (9, 3, null)"
     sql "insert into count_t values (10, null, null)"
 
-    sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join"
+    sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join"
 
     qt_groupby_pushdown_basic """
         explain shape plan select count(t1.score) from count_t t1, count_t t2 
where t1.id = t2.id group by t1.name;
diff --git 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
index e51899dcc3d..95736d26475 100644
--- 
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy
@@ -48,7 +48,7 @@ suite("push_down_sum_through_join") {
     sql "insert into sum_t values (9, 3, null)"
     sql "insert into sum_t values (10, null, null)"
 
-    sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join"
+    sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join"
 
     qt_groupby_pushdown_basic """
         explain shape plan select sum(t1.score) from sum_t t1, sum_t t2 where 
t1.id = t2.id group by t1.name;
@@ -131,7 +131,7 @@ suite("push_down_sum_through_join") {
     """
 
     qt_groupby_pushdown_varied_aggregates """
-        explain shape plan select sum(t1.score), avg(t1.id), count(t2.name) 
from sum_t t1 join sum_t t2 on t1.id = t2.id group by t1.name;
+        explain shape plan select sum(t1.score), count(t2.name) from sum_t t1 
join sum_t t2 on t1.id = t2.id group by t1.name;
     """
 
     qt_groupby_pushdown_with_order_by_limit """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to