This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 04237f60e053925f0d29b40afd1ecacadf9fe8ed Author: jakevin <jakevin...@gmail.com> AuthorDate: Fri Jan 26 21:10:08 2024 +0800 [feature](Nereids): eager aggreagate support mix agg function (#30400) --- .../doris/nereids/jobs/executor/Rewriter.java | 10 +- .../org/apache/doris/nereids/rules/RuleType.java | 5 +- ...oin.java => PushDownAggThroughJoinOneSide.java} | 50 +-- .../rewrite/PushDownCountThroughJoinOneSide.java | 216 ------------- .../rewrite/PushDownSumThroughJoinOneSide.java | 98 ------ .../PushDownCountThroughJoinOneSideTest.java | 139 -------- .../rewrite/PushDownMinMaxSumThroughJoinTest.java | 357 +++++++++++++++++++++ .../rewrite/PushDownMinMaxThroughJoinTest.java | 183 ----------- .../rewrite/PushDownSumThroughJoinOneSideTest.java | 135 -------- .../push_down_count_through_join_one_side.out | 6 +- .../eager_aggregate/push_down_max_through_join.out | 6 +- .../eager_aggregate/push_down_min_through_join.out | 6 +- .../push_down_sum_through_join_one_side.out | 6 +- .../nereids_rules_p0/eager_aggregate/basic.groovy | 2 +- .../eager_aggregate/basic_one_side.groovy | 4 +- .../push_down_count_through_join_one_side.groovy | 2 +- .../push_down_max_through_join.groovy | 2 +- .../push_down_min_through_join.groovy | 2 +- .../push_down_sum_through_join_one_side.groovy | 2 +- 19 files changed, 413 insertions(+), 818 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 4b488b6cfee..2c0e57b715e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -98,16 +98,14 @@ import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan; +import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide; import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoin; -import org.apache.doris.nereids.rules.rewrite.PushDownCountThroughJoinOneSide; import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; import org.apache.doris.nereids.rules.rewrite.PushDownLimit; import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughUnion; -import org.apache.doris.nereids.rules.rewrite.PushDownMinMaxThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoin; -import org.apache.doris.nereids.rules.rewrite.PushDownSumThroughJoinOneSide; import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughUnion; import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin; @@ -291,13 +289,9 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eager aggregation", topDown( new PushDownSumThroughJoin(), - new PushDownMinMaxThroughJoin(), + new PushDownAggThroughJoinOneSide(), new PushDownCountThroughJoin() ), - topDown( - new PushDownSumThroughJoinOneSide(), - new PushDownCountThroughJoinOneSide() - ), custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 6a994c1b6e5..58947760b42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -167,13 +167,10 @@ public enum RuleType { COLUMN_PRUNING(RuleTypeClass.REWRITE), ELIMINATE_SORT(RuleTypeClass.REWRITE), - PUSH_DOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE), PUSH_DOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE), PUSH_DOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE), - PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE), - PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE), - TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE), LOGICAL_SEMI_JOIN_COMMUTE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java similarity index 81% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index 3057f1eafc4..f32bf8ea91b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -24,8 +24,10 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -46,22 +48,22 @@ import java.util.Set; * TODO: distinct * Related paper "Eager aggregation and lazy aggregation". * <pre> - * aggregate: Min/Max(x) + * aggregate: Min/Max/Sum(x) * | * join * | \ * | * * (x) * -> - * aggregate: Min/Max(min1) + * aggregate: Min/Max/Sum(min1) * | * join * | \ * | * - * aggregate: Min/Max(x) as min1 + * aggregate: Min/Max/Sum(x) as min1 * </pre> */ -public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { +public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { @Override public List<Rule> buildRules() { return ImmutableList.of( @@ -71,19 +73,20 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { .when(agg -> { Set<AggregateFunction> funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child( - 0) instanceof Slot); + .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum + || (f instanceof Count && !((Count) f).isCountStar())) && !f.isDistinct() + && f.child(0) instanceof Slot); }) .thenApply(ctx -> { Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN.type())) { + if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE.type())) { return null; } LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return pushMinMaxSum(agg, agg.child(), ImmutableList.of()); + return pushMinMaxSumCount(agg, agg.child(), ImmutableList.of()); }) - .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN), + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE), logicalAggregate(logicalProject(innerLogicalJoin())) .when(agg -> agg.child().isAllSlots()) .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) @@ -91,27 +94,27 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { .when(agg -> { Set<AggregateFunction> funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() - .allMatch( - f -> (f instanceof Min || f instanceof Max) && !f.isDistinct() && f.child( - 0) instanceof Slot); + .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum + || (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct() + && f.child(0) instanceof Slot); }) .thenApply(ctx -> { Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN.type())) { + if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE.type())) { return null; } LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return pushMinMaxSum(agg, agg.child().child(), agg.child().getProjects()); + return pushMinMaxSumCount(agg, agg.child().child(), agg.child().getProjects()); }) - .toRule(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN) + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE) ); } /** * Push down Min/Max/Sum through join. */ - public static LogicalAggregate<Plan> pushMinMaxSum(LogicalAggregate<? extends Plan> agg, + public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? extends Plan> agg, LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { List<Slot> leftOutput = join.left().getOutput(); List<Slot> rightOutput = join.right().getOutput(); @@ -183,21 +186,22 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { Preconditions.checkState(left != join.left() || right != join.right()); Plan newJoin = join.withChildren(left, right); - // top agg + // top agg TODO: AVG // replace // min(x) -> min(min#) // max(x) -> max(max#) // sum(x) -> sum(sum#) + // count(x) -> sum(count#) List<NamedExpression> newOutputExprs = new ArrayList<>(); for (NamedExpression ne : agg.getOutputExpressions()) { if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { AggregateFunction func = (AggregateFunction) ((Alias) ne).child(); Slot slot = (Slot) func.child(0); if (leftSlotToOutput.containsKey(slot)) { - Expression newFunc = func.withChildren(leftSlotToOutput.get(slot).toSlot()); + Expression newFunc = replaceAggFunc(func, leftSlotToOutput.get(slot).toSlot()); newOutputExprs.add((NamedExpression) ne.withChildren(newFunc)); } else if (rightSlotToOutput.containsKey(slot)) { - Expression newFunc = func.withChildren(rightSlotToOutput.get(slot).toSlot()); + Expression newFunc = replaceAggFunc(func, rightSlotToOutput.get(slot).toSlot()); newOutputExprs.add((NamedExpression) ne.withChildren(newFunc)); } else { throw new IllegalStateException("Slot " + slot + " not found in join output"); @@ -210,4 +214,12 @@ public class PushDownMinMaxThroughJoin implements RewriteRuleFactory { // TODO: column prune project return agg.withAggOutputChild(newOutputExprs, newJoin); } + + private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) { + if (func instanceof Count) { + return new Sum(inputSlot); + } else { + return func.withChildren(inputSlot); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSide.java deleted file mode 100644 index 5abe33fb142..00000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSide.java +++ /dev/null @@ -1,216 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * TODO: distinct | just push one level - * Support Pushdown Count(col). - * Count(col) -> Sum( cnt ) - * <p> - * Related paper "Eager aggregation and lazy aggregation". - * <pre> - * aggregate: count(x) - * | - * join - * | \ - * | * - * (x) - * -> - * aggregate: Sum( cnt ) - * | - * join - * | \ - * | * - * aggregate: count(x) as cnt - * </pre> - * Notice: rule can't optimize condition that groupby is empty when Count(*) exists. - */ -public class PushDownCountThroughJoinOneSide implements RewriteRuleFactory { - @Override - public List<Rule> buildRules() { - return ImmutableList.of( - logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) - .when(agg -> { - Set<AggregateFunction> funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && !f.isDistinct() - && (!((Count) f).isCountStar() && f.child(0) instanceof Slot)); - }) - .thenApply(ctx -> { - Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() - .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE.type())) { - return null; - } - LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return pushCount(agg, agg.child(), ImmutableList.of()); - }) - .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE), - logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) - .when(agg -> { - Set<AggregateFunction> funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && !f.isDistinct() - && (!((Count) f).isCountStar() && f.child(0) instanceof Slot)); - }) - .thenApply(ctx -> { - Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() - .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE.type())) { - return null; - } - LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return pushCount(agg, agg.child().child(), agg.child().getProjects()); - }) - .toRule(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE) - ); - } - - private LogicalAggregate<Plan> pushCount(LogicalAggregate<? extends Plan> agg, - LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { - List<Slot> leftOutput = join.left().getOutput(); - List<Slot> rightOutput = join.right().getOutput(); - - List<Count> leftCounts = new ArrayList<>(); - List<Count> rightCounts = new ArrayList<>(); - for (AggregateFunction f : agg.getAggregateFunctions()) { - Count count = (Count) f; - Slot slot = (Slot) count.child(0); - if (leftOutput.contains(slot)) { - leftCounts.add(count); - } else if (rightOutput.contains(slot)) { - rightCounts.add(count); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - - Set<Slot> leftGroupBy = new HashSet<>(); - Set<Slot> rightGroupBy = new HashSet<>(); - for (Expression e : agg.getGroupByExpressions()) { - Slot slot = (Slot) e; - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - return null; - } - } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - })); - - Plan left = join.left(); - Plan right = join.right(); - - Map<Slot, NamedExpression> leftCntSlotToOutput = new HashMap<>(); - Map<Slot, NamedExpression> rightCntSlotToOutput = new HashMap<>(); - - // left Count agg - if (!leftCounts.isEmpty()) { - Builder<NamedExpression> leftCntAggOutputBuilder = ImmutableList.<NamedExpression>builder() - .addAll(leftGroupBy); - leftCounts.forEach(func -> { - Alias alias = func.alias(func.getName()); - leftCntSlotToOutput.put((Slot) func.child(0), alias); - leftCntAggOutputBuilder.add(alias); - }); - left = new LogicalAggregate<>(ImmutableList.copyOf(leftGroupBy), leftCntAggOutputBuilder.build(), - join.left()); - } - - // right Count agg - if (!rightCounts.isEmpty()) { - Builder<NamedExpression> rightCntAggOutputBuilder = ImmutableList.<NamedExpression>builder() - .addAll(rightGroupBy); - rightCounts.forEach(func -> { - Alias alias = func.alias(func.getName()); - rightCntSlotToOutput.put((Slot) func.child(0), alias); - rightCntAggOutputBuilder.add(alias); - }); - - right = new LogicalAggregate<>(ImmutableList.copyOf(rightGroupBy), rightCntAggOutputBuilder.build(), - join.right()); - } - - Preconditions.checkState(left != join.left() || right != join.right()); - Plan newJoin = join.withChildren(left, right); - - // top Sum agg - // count(slot) -> sum( count(slot) as cnt ) - List<NamedExpression> newOutputExprs = new ArrayList<>(); - for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof Count) { - Count oldTopCnt = (Count) ((Alias) ne).child(); - - Slot slot = (Slot) oldTopCnt.child(0); - if (leftCntSlotToOutput.containsKey(slot)) { - Expression expr = new Sum(leftCntSlotToOutput.get(slot).toSlot()); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else if (rightCntSlotToOutput.containsKey(slot)) { - Expression expr = new Sum(rightCntSlotToOutput.get(slot).toSlot()); - newOutputExprs.add((NamedExpression) ne.withChildren(expr)); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } else { - newOutputExprs.add(ne); - } - } - return agg.withAggOutputChild(newOutputExprs, newJoin); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java deleted file mode 100644 index 88b13b383a3..00000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSide.java +++ /dev/null @@ -1,98 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; - -import com.google.common.collect.ImmutableList; - -import java.util.List; -import java.util.Set; - -/** - * TODO: distinct - * Related paper "Eager aggregation and lazy aggregation". - * <pre> - * aggregate: Sum(x) - * | - * join - * | \ - * | * - * (x) - * -> - * aggregate: Sum(sum1) - * | - * join - * | \ - * | * - * aggregate: Sum(x) as sum1 - * </pre> - */ -public class PushDownSumThroughJoinOneSide implements RewriteRuleFactory { - @Override - public List<Rule> buildRules() { - return ImmutableList.of( - logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) - .when(agg -> { - Set<AggregateFunction> funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot); - }) - .thenApply(ctx -> { - Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() - .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE.type())) { - return null; - } - LogicalAggregate<LogicalJoin<Plan, Plan>> agg = ctx.root; - return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child(), ImmutableList.of()); - }) - .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN), - logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) - .when(agg -> { - Set<AggregateFunction> funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Sum && !f.isDistinct() && f.child(0) instanceof Slot); - }) - .thenApply(ctx -> { - Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() - .getSessionVariable().getEnableNereidsRules(); - if (!enableNereidsRules.contains(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE.type())) { - return null; - } - LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = ctx.root; - return PushDownMinMaxThroughJoin.pushMinMaxSum(agg, agg.child().child(), - agg.child().getProjects()); - }) - .toRule(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN) - ); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSideTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSideTest.java deleted file mode 100644 index 3106eb30f45..00000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountThroughJoinOneSideTest.java +++ /dev/null @@ -1,139 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.util.LogicalPlanBuilder; -import org.apache.doris.nereids.util.MemoPatternMatchSupported; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.qe.SessionVariable; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import mockit.Mock; -import mockit.MockUp; -import org.junit.jupiter.api.Test; - -import java.util.Set; - -class PushDownCountThroughJoinOneSideTest implements MemoPatternMatchSupported { - private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() { - @Mock - public Set<Integer> getEnableNereidsRules() { - return ImmutableSet.of(RuleType.PUSH_DOWN_COUNT_THROUGH_JOIN_ONE_SIDE.type()); - } - }; - - @Test - void testSingleCount() { - Alias count = new Count(scan1.getOutput().get(0)).alias("count"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoinOneSide()) - .printlnTree() - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testMultiCount() { - Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1"); - Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2"); - Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), - ImmutableList.of(scan1.getOutput().get(0), leftCnt1, leftCnt2, rightCnt1)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoinOneSide()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalAggregate() - ) - ) - ); - } - - @Test - void testSingleCountStar() { - Alias count = new Count().alias("countStar"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoinOneSide()) - .printlnTree() - .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testBothSideCountAndCountStar() { - Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt"); - Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt"); - Alias countStar = new Count().alias("countStar"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), - ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownCountThroughJoinOneSide()) - .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) - ) - ); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java new file mode 100644 index 00000000000..58ab7fbe9e9 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import mockit.Mock; +import mockit.MockUp; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +class PushDownMinMaxSumThroughJoinTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); + private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() { + @Mock + public Set<Integer> getEnableNereidsRules() { + return ImmutableSet.of(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE.type()); + } + }; + + @Test + void testSingleJoin() { + Alias min = new Min(scan1.getOutput().get(0)).alias("min"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testMultiJoin() { + Alias min = new Min(scan1.getOutput().get(0)).alias("min"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .printlnTree() + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ), + logicalOlapScan() + ) + ), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testAggNotOutputGroupBy() { + // agg don't output group by + Alias min = new Min(scan1.getOutput().get(0)).alias("min"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testBothSideSingleJoin() { + Alias min = new Min(scan1.getOutput().get(1)).alias("min"); + Alias max = new Max(scan2.getOutput().get(1)).alias("max"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min, max)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testBothSide() { + Alias min = new Min(scan1.getOutput().get(1)).alias("min"); + Alias max = new Max(scan3.getOutput().get(1)).alias("max"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min, max)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ), + logicalAggregate() + ) + ) + ); + } + + @Test + void testSingleJoinLeftSum() { + Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testSingleJoinRightSum() { + Alias sum = new Sum(scan2.getOutput().get(1)).alias("sum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalOlapScan(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testSumAggNotOutputGroupBy() { + // agg don't output group by + Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(sum)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testMultiSum() { + Alias leftSum1 = new Sum(scan1.getOutput().get(0)).alias("leftSum1"); + Alias leftSum2 = new Sum(scan1.getOutput().get(1)).alias("leftSum2"); + Alias rightSum1 = new Sum(scan2.getOutput().get(1)).alias("rightSum1"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of(scan1.getOutput().get(0), leftSum1, leftSum2, rightSum1)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testSingleCount() { + Alias count = new Count(scan1.getOutput().get(0)).alias("count"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .printlnTree() + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testMultiCount() { + Alias leftCnt1 = new Count(scan1.getOutput().get(0)).alias("leftCnt1"); + Alias leftCnt2 = new Count(scan1.getOutput().get(1)).alias("leftCnt2"); + Alias rightCnt1 = new Count(scan2.getOutput().get(1)).alias("rightCnt1"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of(scan1.getOutput().get(0), leftCnt1, leftCnt2, rightCnt1)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalAggregate(), + logicalAggregate() + ) + ) + ); + } + + @Test + void testSingleCountStar() { + Alias count = new Count().alias("countStar"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), count)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .printlnTree() + .matches( + logicalAggregate( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testBothSideCountAndCountStar() { + Alias leftCnt = new Count(scan1.getOutput().get(0)).alias("leftCnt"); + Alias rightCnt = new Count(scan2.getOutput().get(0)).alias("rightCnt"); + Alias countStar = new Count().alias("countStar"); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of(scan1.getOutput().get(0), leftCnt, rightCnt, countStar)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new PushDownAggThroughJoinOneSide()) + .matches( + logicalAggregate( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoinTest.java deleted file mode 100644 index cf28954a47c..00000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxThroughJoinTest.java +++ /dev/null @@ -1,183 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.functions.agg.Max; -import org.apache.doris.nereids.trees.expressions.functions.agg.Min; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.util.LogicalPlanBuilder; -import org.apache.doris.nereids.util.MemoPatternMatchSupported; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.qe.SessionVariable; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import mockit.Mock; -import mockit.MockUp; -import org.junit.jupiter.api.Test; - -import java.util.Set; - -class PushDownMinMaxThroughJoinTest implements MemoPatternMatchSupported { - private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - private static final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); - private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() { - @Mock - public Set<Integer> getEnableNereidsRules() { - return ImmutableSet.of(RuleType.PUSH_DOWN_MIN_MAX_THROUGH_JOIN.type()); - } - }; - - @Test - void testSingleJoin() { - Alias min = new Min(scan1.getOutput().get(0)).alias("min"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownMinMaxThroughJoin()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testMultiJoin() { - Alias min = new Min(scan1.getOutput().get(0)).alias("min"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) - .join(scan4, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownMinMaxThroughJoin()) - .printlnTree() - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate( - logicalJoin( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ), - logicalOlapScan() - ) - ), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testAggNotOutputGroupBy() { - // agg don't output group by - Alias min = new Min(scan1.getOutput().get(0)).alias("min"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownMinMaxThroughJoin()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testBothSideSingleJoin() { - Alias min = new Min(scan1.getOutput().get(1)).alias("min"); - Alias max = new Max(scan2.getOutput().get(1)).alias("max"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), min, max)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .printlnTree() - .applyTopDown(new PushDownMinMaxThroughJoin()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalAggregate() - ) - ) - ); - } - - @Test - void testBothSide() { - Alias min = new Min(scan1.getOutput().get(1)).alias("min"); - Alias max = new Max(scan3.getOutput().get(1)).alias("max"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(min, max)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownMinMaxThroughJoin()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ), - logicalAggregate() - ) - ) - ); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSideTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSideTest.java deleted file mode 100644 index 2e0f124b810..00000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownSumThroughJoinOneSideTest.java +++ /dev/null @@ -1,135 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.util.LogicalPlanBuilder; -import org.apache.doris.nereids.util.MemoPatternMatchSupported; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.qe.SessionVariable; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import mockit.Mock; -import mockit.MockUp; -import org.junit.jupiter.api.Test; - -import java.util.Set; - -class PushDownSumThroughJoinOneSideTest implements MemoPatternMatchSupported { - private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - private MockUp<SessionVariable> mockUp = new MockUp<SessionVariable>() { - @Mock - public Set<Integer> getEnableNereidsRules() { - return ImmutableSet.of(RuleType.PUSH_DOWN_SUM_THROUGH_JOIN_ONE_SIDE.type()); - } - }; - - @Test - void testSingleJoinLeftSum() { - Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoinOneSide()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testSingleJoinRightSum() { - Alias sum = new Sum(scan2.getOutput().get(1)).alias("sum"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(scan1.getOutput().get(0), sum)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoinOneSide()) - .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), - logicalAggregate() - ) - ) - ); - } - - @Test - void testAggNotOutputGroupBy() { - // agg don't output group by - Alias sum = new Sum(scan1.getOutput().get(1)).alias("sum"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), ImmutableList.of(sum)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoinOneSide()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalOlapScan() - ) - ) - ); - } - - @Test - void testMultiSum() { - Alias leftSum1 = new Sum(scan1.getOutput().get(0)).alias("leftSum1"); - Alias leftSum2 = new Sum(scan1.getOutput().get(1)).alias("leftSum2"); - Alias rightSum1 = new Sum(scan2.getOutput().get(1)).alias("rightSum1"); - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .aggGroupUsingIndex(ImmutableList.of(0), - ImmutableList.of(scan1.getOutput().get(0), leftSum1, leftSum2, rightSum1)) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(new PushDownSumThroughJoinOneSide()) - .matches( - logicalAggregate( - logicalJoin( - logicalAggregate(), - logicalAggregate() - ) - ) - ); - } -} diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out index 59c57e460e8..0de2a12166a 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out @@ -148,8 +148,10 @@ PhysicalResultSink --hashAgg[GLOBAL] ----hashAgg[LOCAL] ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=() ---------PhysicalOlapScan[count_t_one_side] ---------PhysicalOlapScan[count_t_one_side] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[count_t_one_side] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[count_t_one_side] -- !groupby_pushdown_equal_conditions_non_aggregate_selection -- PhysicalResultSink diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out index bd4430fcb66..9a7cfa6a4f5 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_max_through_join.out @@ -148,8 +148,10 @@ PhysicalResultSink --hashAgg[GLOBAL] ----hashAgg[LOCAL] ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=() ---------PhysicalOlapScan[max_t] ---------PhysicalOlapScan[max_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[max_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[max_t] -- !groupby_pushdown_equal_conditions_non_aggregate_selection -- PhysicalResultSink diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out index a0a2acd9449..3e2ccc6f432 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_min_through_join.out @@ -148,8 +148,10 @@ PhysicalResultSink --hashAgg[GLOBAL] ----hashAgg[LOCAL] ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=() ---------PhysicalOlapScan[min_t] ---------PhysicalOlapScan[min_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[min_t] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[min_t] -- !groupby_pushdown_equal_conditions_non_aggregate_selection -- PhysicalResultSink diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out index 8046cec6d95..65d3a7b68f1 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.out @@ -148,8 +148,10 @@ PhysicalResultSink --hashAgg[GLOBAL] ----hashAgg[LOCAL] ------hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id) and (t1.name = t2.name)) otherCondition=() ---------PhysicalOlapScan[sum_t_one_side] ---------PhysicalOlapScan[sum_t_one_side] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[sum_t_one_side] +--------hashAgg[LOCAL] +----------PhysicalOlapScan[sum_t_one_side] -- !groupby_pushdown_equal_conditions_non_aggregate_selection -- PhysicalResultSink diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy index afa64135d39..58d50b3add4 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic.groovy @@ -21,7 +21,7 @@ suite("eager_aggregate_basic") { sql "SET enable_fallback_to_original_planner=false" sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'" - sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join" sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join" diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy index cb84e0cc1ec..cc1c0c8c736 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/basic_one_side.groovy @@ -21,9 +21,7 @@ suite("eager_aggregate_basic_one_side") { sql "SET enable_fallback_to_original_planner=false" sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'" - sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join_one_side" - sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join_one_side" - sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join_one_side" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" sql """ DROP TABLE IF EXISTS shunt_log_com_dd_library_one_side; diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy index 037368f051f..88862874362 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy @@ -48,7 +48,7 @@ suite("push_down_count_through_join_one_side") { sql "insert into count_t_one_side values (9, 3, null)" sql "insert into count_t_one_side values (10, null, null)" - sql "SET ENABLE_NEREIDS_RULES=push_down_count_through_join_one_side" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" qt_groupby_pushdown_basic """ explain shape plan select count(t1.score) from count_t_one_side t1, count_t_one_side t2 where t1.id = t2.id group by t1.name; diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy index 68d1946b35e..26772637fe7 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_max_through_join.groovy @@ -48,7 +48,7 @@ suite("push_down_max_through_join") { sql "insert into max_t values (9, 3, null)" sql "insert into max_t values (10, null, null)" - sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" qt_groupby_pushdown_basic """ explain shape plan select max(t1.score) from max_t t1, max_t t2 where t1.id = t2.id group by t1.name; diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy index 560bf1c0d72..7942fbd28c4 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_min_through_join.groovy @@ -48,7 +48,7 @@ suite("push_down_min_through_join") { sql "insert into min_t values (9, 3, null)" sql "insert into min_t values (10, null, null)" - sql "SET ENABLE_NEREIDS_RULES=push_down_min_max_through_join" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" qt_groupby_pushdown_basic """ explain shape plan select min(t1.score) from min_t t1, min_t t2 where t1.id = t2.id group by t1.name; diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy index 1ecc6aa48a8..fecf1410261 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_sum_through_join_one_side.groovy @@ -48,7 +48,7 @@ suite("push_down_sum_through_join_one_side") { sql "insert into sum_t_one_side values (9, 3, null)" sql "insert into sum_t_one_side values (10, null, null)" - sql "SET ENABLE_NEREIDS_RULES=push_down_sum_through_join_one_side" + sql "SET ENABLE_NEREIDS_RULES=push_down_agg_through_join_one_side" qt_groupby_pushdown_basic """ explain shape plan select sum(t1.score) from sum_t_one_side t1, sum_t_one_side t2 where t1.id = t2.id group by t1.name; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org