This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new c583563087b [feature](Nereids): double eager support mix function (#30468) c583563087b is described below commit c583563087bc5a0db9920aa88aafb63a5bd61e19 Author: jakevin <jakevin...@gmail.com> AuthorDate: Mon Jan 29 13:08:09 2024 +0800 [feature](Nereids): double eager support mix function (#30468) --- .../doris/nereids/jobs/executor/Rewriter.java | 6 +- .../org/apache/doris/nereids/rules/RuleType.java | 3 +- ...hroughJoin.java => PushDownAggThroughJoin.java} | 107 +++++------ .../rules/rewrite/PushDownSumThroughJoin.java | 212 --------------------- .../rewrite/PushDownCountThroughJoinTest.java | 13 +- .../rules/rewrite/PushDownSumThroughJoinTest.java | 29 ++- .../eager_aggregate/push_down_sum_through_join.out | 12 +- .../nereids_rules_p0/eager_aggregate/basic.groovy | 3 +- .../push_down_count_through_join.groovy | 2 +- .../push_down_sum_through_join.groovy | 4 +- 10 files changed, 101 insertions(+), 290 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 2c0e57b715e..34f7afe4995 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -98,14 +98,13 @@ import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan; +import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide; -import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; import org.apache.doris.nereids.rules.rewrite.PushDownLimit; import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughUnion; -import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughUnion; import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin; @@ -288,9 +287,8 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eager aggregation", topDown( - new PushDownSumThroughJoin(), new PushDownAggThroughJoinOneSide(), - new PushDownCountThroughJoin() + new PushDownAggThroughJoin() ), custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index b35c7e03b72..594f49a3b70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -167,8 +167,7 @@ public enum RuleType { ELIMINATE_SORT(RuleTypeClass.REWRITE), PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE), - PUSH_DOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE), - PUSH_DOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSH_DOWN_AGG_THROUGH_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java similarity index 69% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java index 462180ab7a6..f003d2ac2cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoin.java @@ -67,7 +67,7 @@ import java.util.Set; * </pre> * Notice: rule can't optimize condition that groupby is empty when Count(*) exists. */ -public class PushDownCountThroughJoin implements RewriteRuleFactory { +public class PushDownAggThroughJoin implements RewriteRuleFactory { @Override public List<Rule> buildRules() { return ImmutableList.of( @@ -78,19 +78,22 @@ public class PushDownCountThroughJoin implements RewriteRuleFactory { .when(agg -> { Set<AggregateFunction> funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && !f.isDistinct() - && (((Count) f).isCountStar() || f.child(0) instanceof Slot)); + .allMatch(f -> !f.isDistinct() + && (f instanceof Count && (((Count) f).isCountStar() || f.child( + 0) instanceof Slot) + || (f instanceof Sum && f.child(0) instanceof Slot)) + ); }) .thenApply(ctx -> { Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type())) { + if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type())) { return null; } LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return pushCount(agg, agg.child(), ImmutableList.of()); + return pushAgg(agg, agg.child(), ImmutableList.of()); }) - .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN), + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN), logicalAggregate(logicalProject(innerLogicalJoin())) .when(agg -> agg.child().isAllSlots()) .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) @@ -99,40 +102,42 @@ public class PushDownCountThroughJoin implements RewriteRuleFactory { .when(agg -> { Set<AggregateFunction> funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && !f.isDistinct() - && (((Count) f).isCountStar() || f.child(0) instanceof Slot)); + .allMatch(f -> !f.isDistinct() + && (f instanceof Count && (((Count) f).isCountStar() || f.child( + 0) instanceof Slot) + || (f instanceof Sum && f.child(0) instanceof Slot)) + ); }) .thenApply(ctx -> { Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type())) { + if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type())) { return null; } LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return pushCount(agg, agg.child().child(), agg.child().getProjects()); + return pushAgg(agg, agg.child().child(), agg.child().getProjects()); }) - .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN) + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN) ); } - private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg, + private static LogicalAggregate<Plan> pushAgg(LogicalAggregate<? extends Plan> agg, LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { List<Slot> leftOutput = join.left().getOutput(); List<Slot> rightOutput = join.right().getOutput(); - List<Count> leftCounts = new ArrayList<>(); - List<Count> rightCounts = new ArrayList<>(); + List<AggregateFunction> leftAggs = new ArrayList<>(); + List<AggregateFunction> rightAggs = new ArrayList<>(); List<Count> countStars = new ArrayList<>(); for (AggregateFunction f : agg.getAggregateFunctions()) { - Count count = (Count) f; - if (count.isCountStar()) { - countStars.add(count); + if (f instanceof Count && ((Count) f).isCountStar()) { + countStars.add((Count) f); } else { - Slot slot = (Slot) count.child(0); + Slot slot = (Slot) f.child(0); if (leftOutput.contains(slot)) { - leftCounts.add(count); + leftAggs.add(f); } else if (rightOutput.contains(slot)) { - rightCounts.add(count); + rightAggs.add(f); } else { throw new IllegalStateException("Slot " + slot + " not found in join output"); } @@ -168,63 +173,59 @@ public class PushDownCountThroughJoin implements RewriteRuleFactory { Alias leftCnt = null; Alias rightCnt = null; - // left Count agg - Map<Slot, NamedExpression> leftCntSlotToOutput = new HashMap<>(); - Builder<NamedExpression> leftCntAggOutputBuilder = ImmutableList.<NamedExpression>builder() - .addAll(leftGroupBy); - leftCounts.forEach(func -> { + // left agg + Map<Slot, NamedExpression> leftSlotToOutput = new HashMap<>(); + Builder<NamedExpression> leftAggOutputBuilder = ImmutableList.<NamedExpression>builder().addAll(leftGroupBy); + leftAggs.forEach(func -> { Alias alias = func.alias(func.getName()); - leftCntSlotToOutput.put((Slot) func.child(0), alias); - leftCntAggOutputBuilder.add(alias); + leftSlotToOutput.put((Slot) func.child(0), alias); + leftAggOutputBuilder.add(alias); }); - if (!rightCounts.isEmpty() || !countStars.isEmpty()) { + if (!rightAggs.isEmpty() || !countStars.isEmpty()) { leftCnt = new Count().alias("leftCntStar"); - leftCntAggOutputBuilder.add(leftCnt); + leftAggOutputBuilder.add(leftCnt); } - LogicalAggregate<Plan> leftCntAgg = new LogicalAggregate<>( - ImmutableList.copyOf(leftGroupBy), leftCntAggOutputBuilder.build(), join.left()); - - // right Count agg - Map<Slot, NamedExpression> rightCntSlotToOutput = new HashMap<>(); - Builder<NamedExpression> rightCntAggOutputBuilder = ImmutableList.<NamedExpression>builder() - .addAll(rightGroupBy); - rightCounts.forEach(func -> { + LogicalAggregate<Plan> leftAgg = new LogicalAggregate<>( + ImmutableList.copyOf(leftGroupBy), leftAggOutputBuilder.build(), join.left()); + // right agg + Map<Slot, NamedExpression> rightSlotToOutput = new HashMap<>(); + Builder<NamedExpression> rightAggOutputBuilder = ImmutableList.<NamedExpression>builder().addAll(rightGroupBy); + rightAggs.forEach(func -> { Alias alias = func.alias(func.getName()); - rightCntSlotToOutput.put((Slot) func.child(0), alias); - rightCntAggOutputBuilder.add(alias); + rightSlotToOutput.put((Slot) func.child(0), alias); + rightAggOutputBuilder.add(alias); }); - - if (!leftCounts.isEmpty() || !countStars.isEmpty()) { + if (!leftAggs.isEmpty() || !countStars.isEmpty()) { rightCnt = new Count().alias("rightCntStar"); - rightCntAggOutputBuilder.add(rightCnt); + rightAggOutputBuilder.add(rightCnt); } - LogicalAggregate<Plan> rightCntAgg = new LogicalAggregate<>( - ImmutableList.copyOf(rightGroupBy), rightCntAggOutputBuilder.build(), join.right()); + LogicalAggregate<Plan> rightAgg = new LogicalAggregate<>( + ImmutableList.copyOf(rightGroupBy), rightAggOutputBuilder.build(), join.right()); - Plan newJoin = join.withChildren(leftCntAgg, rightCntAgg); + Plan newJoin = join.withChildren(leftAgg, rightAgg); // top Sum agg // count(slot) -> sum( count(slot) * cntStar ) // count(*) -> sum( leftCntStar * leftCntStar ) List<NamedExpression> newOutputExprs = new ArrayList<>(); for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) { - Count oldTopCnt = (Count) ((Alias) ne).child(); - if (oldTopCnt.isCountStar()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { + AggregateFunction func = (AggregateFunction) ((Alias) ne).child(); + if (func instanceof Count && ((Count) func).isCountStar()) { Preconditions.checkState(rightCnt != null && leftCnt != null); Expression expr = new Sum(new Multiply(leftCnt.toSlot(), rightCnt.toSlot())); newOutputExprs.add((NamedExpression) ne.withChildren(expr)); } else { - Slot slot = (Slot) oldTopCnt.child(0); - if (leftCntSlotToOutput.containsKey(slot)) { + Slot slot = (Slot) func.child(0); + if (leftSlotToOutput.containsKey(slot)) { Preconditions.checkState(rightCnt != null); Expression expr = new Sum( - new Multiply(leftCntSlotToOutput.get(slot).toSlot(), rightCnt.toSlot())); + new Multiply(leftSlotToOutput.get(slot).toSlot(), rightCnt.toSlot())); newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else if (rightCntSlotToOutput.containsKey(slot)) { + } else if (rightSlotToOutput.containsKey(slot)) { Preconditions.checkState(leftCnt != null); Expression expr = new Sum( - new Multiply(rightCntSlotToOutput.get(slot).toSlot(), leftCnt.toSlot())); + new Multiply(rightSlotToOutput.get(slot).toSlot(), leftCnt.toSlot())); newOutputExprs.add((NamedExpression) ne.withChildren(expr)); } else { throw new IllegalStateException("Slot " + slot + " not found in join output"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java deleted file mode 100644 index e8987e670a5..00000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoin.java +++ /dev/null @@ -1,212 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Multiply; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * TODO: distinct - * Related paper "Eager aggregation and lazy aggregation". - * <pre> - * aggregate: Sum(x) - * | - * join - * | \ - * | * - * (x) - * -> - * aggregate: Sum(sum1) - * | - * join - * | \ - * | * - * aggregate: Sum(x) as sum1 - * </pre> - */ -public class PushDownSumThroughJoin implements RewriteRuleFactory { - @Override - public List<Rule> buildRules() { - return ImmutableList.of( - logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) - .when(agg -> { - Set<AggregateFunction> funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot); - }) - .thenApply(ctx -> { - Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() - .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type())) { - return null; - } - LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return pushSum(agg, agg.child(), ImmutableList.of()); - }) - .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN), - logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) - .when(agg -> { - Set<AggregateFunction> funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot); - }) - .thenApply(ctx -> { - Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() - .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type())) { - return null; - } - LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return pushSum(agg, agg.child().child(), agg.child().getProjects()); - }) - .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN) - ); - } - - private LogicalAggregate<Plan> pushSum(LogicalAggregate<? extends Plan> agg, - LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { - List<Slot> leftOutput = join.left().getOutput(); - List<Slot> rightOutput = join.right().getOutput(); - - List<Sum> leftSums = new ArrayList<>(); - List<Sum> rightSums = new ArrayList<>(); - for (AggregateFunction f : agg.getAggregateFunctions()) { - Sum sum = (Sum) f; - Slot slot = (Slot) sum.child(); - if (leftOutput.contains(slot)) { - leftSums.add(sum); - } else if (rightOutput.contains(slot)) { - rightSums.add(sum); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftSums.isEmpty() && rightSums.isEmpty() - || (!leftSums.isEmpty() && !rightSums.isEmpty())) { - return null; - } - - Set<Slot> leftGroupBy = new HashSet<>(); - Set<Slot> rightGroupBy = new HashSet<>(); - for (Expression e : agg.getGroupByExpressions()) { - Slot slot = (Slot) e; - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - return null; - } - } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - })); - - List<Sum> sums; - Set<Slot> sumGroupBy; - Set<Slot> cntGroupBy; - Plan sumChild; - Plan cntChild; - if (!leftSums.isEmpty()) { - sums = leftSums; - sumGroupBy = leftGroupBy; - cntGroupBy = rightGroupBy; - sumChild = join.left(); - cntChild = join.right(); - } else { - sums = rightSums; - sumGroupBy = rightGroupBy; - cntGroupBy = leftGroupBy; - sumChild = join.right(); - cntChild = join.left(); - } - - // Sum agg - Map<Slot, NamedExpression> sumSlotToOutput = new HashMap<>(); - Builder<NamedExpression> sumAggOutputBuilder = ImmutableList.<NamedExpression>builder().addAll(sumGroupBy); - sums.forEach(func -> { - Alias alias = func.alias(func.getName()); - sumSlotToOutput.put((Slot) func.child(0), alias); - sumAggOutputBuilder.add(alias); - }); - LogicalAggregate<Plan> sumAgg = new LogicalAggregate<>( - ImmutableList.copyOf(sumGroupBy), sumAggOutputBuilder.build(), sumChild); - - // Count agg - Alias cnt = new Count().alias("cnt"); - List<NamedExpression> cntAggOutput = ImmutableList.<NamedExpression>builder() - .addAll(cntGroupBy).add(cnt) - .build(); - LogicalAggregate<Plan> cntAgg = new LogicalAggregate<>( - ImmutableList.copyOf(cntGroupBy), cntAggOutput, cntChild); - - Plan newJoin = !leftSums.isEmpty() ? join.withChildren(sumAgg, cntAgg) : join.withChildren(cntAgg, sumAgg); - - // top Sum agg - // replace sum(x) -> sum(sum# * cnt) - List<NamedExpression> newOutputExprs = new ArrayList<>(); - for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { - AggregateFunction func = (AggregateFunction) ((Alias) ne).child(); - Slot slot = (Slot) func.child(0); - if (sumSlotToOutput.containsKey(slot)) { - Expression expr = func.withChildren(new Multiply(sumSlotToOutput.get(slot).toSlot(), cnt.toSlot())); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } else { - newOutputExprs.add(ne); - } - } - return agg.withAggOutputChild(newOutputExprs, newJoin); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java index 34ccfe70f70..8e0e0e15df3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinTest.java @@ -45,7 +45,7 @@ class PushDownCountThroughJoinTest implements MemoPatternMatchSupported { private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() { @Mock public Set<Integer> getEnableNereidsRules() { - return ImmutableSet.of(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN.type()); + return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type()); } }; @@ -58,7 +58,8 @@ class PushDownCountThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) + .printlnTree() .matches( logicalAggregate( logicalJoin( @@ -81,7 +82,7 @@ class PushDownCountThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( @@ -101,7 +102,7 @@ class PushDownCountThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( @@ -122,7 +123,7 @@ class PushDownCountThroughJoinTest implements MemoPatternMatchSupported { // shouldn't rewrite. PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( @@ -145,7 +146,7 @@ class PushDownCountThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java index 088372b0d76..29a745b379f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinTest.java @@ -45,7 +45,7 @@ class PushDownSumThroughJoinTest implements MemoPatternMatchSupported { private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() { @Mock public Set<Integer> getEnableNereidsRules() { - return ImmutableSet.of(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN.type()); + return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN.type()); } }; @@ -58,7 +58,7 @@ class PushDownSumThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( @@ -78,7 +78,28 @@ class PushDownSumThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testSingleJoinBothSum() { + Alias leftSum = new Sum(scan1.getOutput().get(1)).alias("leftSum"); + Alias rightSum = new Sum(scan2.getOutput().get(1)).alias("rightSum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), leftSum, rightSum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( @@ -99,7 +120,7 @@ class PushDownSumThroughJoinTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoin()) + .applyTopDown(new PushDownAggThroughJoin()) .matches( logicalAggregate( logicalJoin( diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out index da05df5419d..106d8882079 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.out @@ -176,8 +176,10 @@ PhysicalResultSink --hashAgg[GLOBAL] ----hashAgg[LOCAL] ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=() ---------PhysicalOlapScan[sum_t] ---------PhysicalOlapScan[sum_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[sum_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[sum_t] -- !groupby_pushdown_with_where_clause -- PhysicalResultSink @@ -195,8 +197,10 @@ PhysicalResultSink --hashAgg[GLOBAL] ----hashAgg[LOCAL] ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() ---------PhysicalOlapScan[sum_t] ---------PhysicalOlapScan[sum_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[sum_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[sum_t] -- !groupby_pushdown_with_order_by_limit -- PhysicalResultSink diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy index 58d50b3add4..249e7af4bb4 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy @@ -22,8 +22,7 @@ suite("eager_aggregate_basic") { sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'" sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" - sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join" - sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join" sql """ DROP TABLE IF EXISTS shunt_log_com_dd_library; diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy index f5f4bf53b45..37cd6000941 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join.groovy @@ -48,7 +48,7 @@ suite("push_down_count_through_join") { sql "insert into count_t values (9, 3, null)" sql "insert into count_t values (10, null, null)" - sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join" qt_groupby_pushdown_basic """ explain shape plan select count(t1.score) from count_t t1, count_t t2 where t1.id = t2.id group by t1.name; diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy index e51899dcc3d..95736d26475 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join.groovy @@ -48,7 +48,7 @@ suite("push_down_sum_through_join") { sql "insert into sum_t values (9, 3, null)" sql "insert into sum_t values (10, null, null)" - sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join" qt_groupby_pushdown_basic """ explain shape plan select sum(t1.score) from sum_t t1, sum_t t2 where t1.id = t2.id group by t1.name; @@ -131,7 +131,7 @@ suite("push_down_sum_through_join") { """ qt_groupby_pushdown_varied_aggregates """ - explain shape plan select sum(t1.score), avg(t1.id), count(t2.name) from sum_t t1 join sum_t t2 on t1.id = t2.id group by t1.name; + explain shape plan select sum(t1.score), count(t2.name) from sum_t t1 join sum_t t2 on t1.id = t2.id group by t1.name; """ qt_groupby_pushdown_with_order_by_limit """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org