This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 39a7a4cc55 [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556) 39a7a4cc55 is described below commit 39a7a4cc55c9e3b00a6465d99f85db911b92d234 Author: jakevin <jakevin...@gmail.com> AuthorDate: Wed Apr 12 12:13:06 2023 +0800 [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556) --- .../org/apache/doris/nereids/rules/RuleType.java | 4 + .../nereids/rules/exploration/EagerCount.java | 4 +- .../nereids/rules/exploration/EagerGroupBy.java | 4 +- .../rules/exploration/EagerGroupByCount.java | 138 +++++++++++++++++ .../nereids/rules/exploration/EagerSplit.java | 164 +++++++++++++++++++++ .../rules/exploration/EagerGroupByCountTest.java | 101 +++++++++++++ .../nereids/rules/exploration/EagerSplitTest.java | 102 +++++++++++++ 7 files changed, 514 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index b14d8b5029..9f1234803d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -241,6 +241,10 @@ public enum RuleType { PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION), EAGER_COUNT(RuleTypeClass.EXPLORATION), EAGER_GROUP_BY(RuleTypeClass.EXPLORATION), + EAGER_GROUP_BY_COUNT(RuleTypeClass.EXPLORATION), + EAGER_SPLIT(RuleTypeClass.EXPLORATION), + + EXPLORATION_SENTINEL(RuleTypeClass.EXPLORATION), // implementation rules LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java index add6d9ccc2..cde94b33eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java @@ -49,7 +49,7 @@ import java.util.Set; * | * * (x) * -> - * aggregate: SUM(x * cnt) + * aggregate: SUM(x) * cnt * | * join * | \ @@ -62,7 +62,7 @@ public class EagerCount extends OneExplorationRuleFactory { @Override public Rule build() { - return logicalAggregate(logicalJoin()) + return logicalAggregate(innerLogicalJoin()) .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) .when(agg -> agg.getAggregateFunctions().stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java index 0db10dd1ee..27fcd149b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java @@ -54,13 +54,15 @@ import java.util.stream.Collectors; * | * * aggregate: SUM(x) as sum1 * </pre> + * After Eager Group By, new plan also can apply `Eager Count`. + * It's `Double Eager`. */ public class EagerGroupBy extends OneExplorationRuleFactory { public static final EagerGroupBy INSTANCE = new EagerGroupBy(); @Override public Rule build() { - return logicalAggregate(logicalJoin()) + return logicalAggregate(innerLogicalJoin()) .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) .when(agg -> agg.getAggregateFunctions().stream() .allMatch(f -> f instanceof Sum diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java new file mode 100644 index 0000000000..c538250538 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Related paper "Eager aggregation and lazy aggregation". + * <pre> + * aggregate: SUM(x), SUM(y) + * | + * join + * | \ + * | (y) + * (x) + * -> + * aggregate: SUM(sum1), SUM(y) * cnt + * | + * join + * | \ + * | (y) + * aggregate: SUM(x) as sum1 , COUNT as cnt + * </pre> + */ +public class EagerGroupByCount extends OneExplorationRuleFactory { + public static final EagerGroupByCount INSTANCE = new EagerGroupByCount(); + + @Override + public Rule build() { + return logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.getAggregateFunctions().stream() + .allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof Slot)) + .then(agg -> { + LogicalJoin<GroupPlan, GroupPlan> join = agg.child(); + List<Slot> leftOutput = join.left().getOutput(); + List<Sum> leftSums = new ArrayList<>(); + List<Sum> rightSums = new ArrayList<>(); + for (AggregateFunction f : agg.getAggregateFunctions()) { + Sum sum = (Sum) f; + if (leftOutput.contains((Slot) sum.child())) { + leftSums.add(sum); + } else { + rightSums.add(sum); + } + } + if (leftSums.size() == 0 || rightSums.size() == 0) { + return null; + } + + // left bottom agg + Set<Slot> bottomAggGroupBy = new HashSet<>(); + agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains) + .forEach(bottomAggGroupBy::add); + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (leftOutput.contains(slot)) { + bottomAggGroupBy.add(slot); + } + })); + List<NamedExpression> bottomSums = new ArrayList<>(); + for (int i = 0; i < leftSums.size(); i++) { + bottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "sum" + i)); + } + Alias cnt = new Alias(new Count(Literal.of(1)), "cnt"); + List<NamedExpression> bottomAggOutput = ImmutableList.<NamedExpression>builder() + .addAll(bottomAggGroupBy).addAll(bottomSums).add(cnt).build(); + LogicalAggregate<GroupPlan> bottomAgg = new LogicalAggregate<>( + ImmutableList.copyOf(bottomAggGroupBy), bottomAggOutput, join.left()); + Plan newJoin = join.withChildren(bottomAgg, join.right()); + + // top agg + List<NamedExpression> newOutputExprs = new ArrayList<>(); + List<Alias> leftSumOutputExprs = new ArrayList<>(); + List<Alias> rightSumOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { + Alias sumOutput = (Alias) ne; + Slot child = (Slot) ((Sum) (sumOutput).child()).child(); + if (leftOutput.contains(child)) { + leftSumOutputExprs.add(sumOutput); + } else { + rightSumOutputExprs.add(sumOutput); + } + } else { + newOutputExprs.add(ne); + } + } + for (int i = 0; i < leftSumOutputExprs.size(); i++) { + Alias oldSum = leftSumOutputExprs.get(i); + // sum in bottom Agg + Slot bottomSum = bottomSums.get(i).toSlot(); + Alias newSum = new Alias(oldSum.getExprId(), new Sum(bottomSum), oldSum.getName()); + newOutputExprs.add(newSum); + } + for (Alias oldSum : rightSumOutputExprs) { + Sum oldSumFunc = (Sum) oldSum.child(); + newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()), + oldSum.getName())); + } + return agg.withAggOutput(newOutputExprs).withChildren(newJoin); + }).toRule(RuleType.EAGER_GROUP_BY_COUNT); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java new file mode 100644 index 0000000000..abf6dabad8 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Related paper "Eager aggregation and lazy aggregation". + * <pre> + * aggregate: SUM(x), SUM(y) + * | + * join + * | \ + * | (y) + * (x) + * -> + * aggregate: SUM(sum1) * cnt2, SUM(sum2) * cnt1 + * | + * join + * | \ + * | aggregate: SUM(y) as sum2, COUNT: cnt2 + * aggregate: SUM(x) as sum1, COUNT: cnt1 + * </pre> + */ +public class EagerSplit extends OneExplorationRuleFactory { + public static final EagerSplit INSTANCE = new EagerSplit(); + + @Override + public Rule build() { + return logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.getAggregateFunctions().stream() + .allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof SlotReference)) + .then(agg -> { + LogicalJoin<GroupPlan, GroupPlan> join = agg.child(); + List<Slot> leftOutput = join.left().getOutput(); + List<Slot> rightOutput = join.right().getOutput(); + List<Sum> leftSums = new ArrayList<>(); + List<Sum> rightSums = new ArrayList<>(); + for (AggregateFunction f : agg.getAggregateFunctions()) { + Sum sum = (Sum) f; + if (leftOutput.contains((Slot) sum.child())) { + leftSums.add(sum); + } else { + rightSums.add(sum); + } + } + if (leftSums.size() == 0 || rightSums.size() == 0) { + return null; + } + + // left bottom agg + Set<Slot> leftBottomAggGroupBy = new HashSet<>(); + agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains) + .forEach(leftBottomAggGroupBy::add); + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (leftOutput.contains(slot)) { + leftBottomAggGroupBy.add(slot); + } + })); + List<NamedExpression> leftBottomSums = new ArrayList<>(); + for (int i = 0; i < leftSums.size(); i++) { + leftBottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "left_sum" + i)); + } + Alias leftCnt = new Alias(new Count(Literal.of(1)), "left_cnt"); + List<NamedExpression> leftBottomAggOutput = ImmutableList.<NamedExpression>builder() + .addAll(leftBottomAggGroupBy).addAll(leftBottomSums).add(leftCnt).build(); + LogicalAggregate<GroupPlan> leftBottomAgg = new LogicalAggregate<>( + ImmutableList.copyOf(leftBottomAggGroupBy), leftBottomAggOutput, join.left()); + + // right bottom agg + Set<Slot> rightBottomAggGroupBy = new HashSet<>(); + agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(rightOutput::contains) + .forEach(rightBottomAggGroupBy::add); + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (rightOutput.contains(slot)) { + rightBottomAggGroupBy.add(slot); + } + })); + List<NamedExpression> rightBottomSums = new ArrayList<>(); + for (int i = 0; i < rightSums.size(); i++) { + rightBottomSums.add(new Alias(new Sum(rightSums.get(i).child()), "right_sum" + i)); + } + Alias rightCnt = new Alias(new Count(Literal.of(1)), "right_cnt"); + List<NamedExpression> rightBottomAggOutput = ImmutableList.<NamedExpression>builder() + .addAll(rightBottomAggGroupBy).addAll(rightBottomSums).add(rightCnt).build(); + LogicalAggregate<GroupPlan> rightBottomAgg = new LogicalAggregate<>( + ImmutableList.copyOf(rightBottomAggGroupBy), rightBottomAggOutput, join.right()); + + Plan newJoin = join.withChildren(leftBottomAgg, rightBottomAgg); + + // top agg + List<NamedExpression> newOutputExprs = new ArrayList<>(); + List<Alias> leftSumOutputExprs = new ArrayList<>(); + List<Alias> rightSumOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { + Alias sumOutput = (Alias) ne; + Slot child = (Slot) ((Sum) (sumOutput).child()).child(); + if (leftOutput.contains(child)) { + leftSumOutputExprs.add(sumOutput); + } else { + rightSumOutputExprs.add(sumOutput); + } + } else { + newOutputExprs.add(ne); + } + } + Preconditions.checkState(leftSumOutputExprs.size() == leftBottomSums.size()); + Preconditions.checkState(rightSumOutputExprs.size() == rightBottomSums.size()); + for (int i = 0; i < leftSumOutputExprs.size(); i++) { + Alias oldSum = leftSumOutputExprs.get(i); + Slot bottomSum = leftBottomSums.get(i).toSlot(); + Alias newSum = new Alias(oldSum.getExprId(), + new Multiply(new Sum(bottomSum), rightCnt.toSlot()), oldSum.getName()); + newOutputExprs.add(newSum); + } + for (int i = 0; i < rightSumOutputExprs.size(); i++) { + Alias oldSum = rightSumOutputExprs.get(i); + Slot bottomSum = rightBottomSums.get(i).toSlot(); + Alias newSum = new Alias(oldSum.getExprId(), + new Multiply(new Sum(bottomSum), leftCnt.toSlot()), oldSum.getName()); + newOutputExprs.add(newSum); + } + return agg.withAggOutput(newOutputExprs).withChildren(newJoin); + }).toRule(RuleType.EAGER_SPLIT); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java new file mode 100644 index 0000000000..de132d22d2 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.algebra.Aggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class EagerGroupByCountTest implements MemoPatternMatchSupported { + + private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(), + PlanConstructor.student, ImmutableList.of("")); + private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(), + PlanConstructor.score, ImmutableList.of("")); + + @Test + void singleSum() { + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0, 4), + ImmutableList.of( + new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"), + new Alias(new Sum(scan2.getOutput().get(2)), "rsum0") + )) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyExploration(EagerGroupByCount.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalAggregate( + logicalJoin( + logicalAggregate().when( + bottomAgg -> bottomAgg.getOutputExprsSql().equals("id, sum(age) AS `sum0`, count(1) AS `cnt`")), + logicalOlapScan() + ) + ).when(newAgg -> + newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions()) + && newAgg.getOutputExprsSql().equals("sum(sum0) AS `lsum0`, (sum(grade) * cnt) AS `rsum0`")) + ); + } + + @Test + void multiSum() { + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0, 4), + ImmutableList.of( + new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"), + new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"), + new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"), + new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"), + new Alias(new Sum(scan2.getOutput().get(2)), "rsum1") + )) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyExploration(EagerGroupByCount.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalAggregate( + logicalJoin( + logicalAggregate().when(cntAgg -> cntAgg.getOutputExprsSql() + .equals("id, sum(gender) AS `sum0`, sum(name) AS `sum1`, sum(age) AS `sum2`, count(1) AS `cnt`")), + logicalOlapScan() + ) + ).when(newAgg -> + newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions()) + && newAgg.getOutputExprsSql() + .equals("sum(sum0) AS `lsum0`, sum(sum1) AS `lsum1`, sum(sum2) AS `lsum2`, (sum(cid) * cnt) AS `rsum0`, (sum(grade) * cnt) AS `rsum1`")) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java new file mode 100644 index 0000000000..37e347894f --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.algebra.Aggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class EagerSplitTest implements MemoPatternMatchSupported { + + private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(), + PlanConstructor.student, ImmutableList.of("")); + private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(), + PlanConstructor.score, ImmutableList.of("")); + + @Test + void singleSum() { + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0, 4), + ImmutableList.of( + new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"), + new Alias(new Sum(scan2.getOutput().get(2)), "rsum0") + )) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyExploration(EagerSplit.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalAggregate( + logicalJoin( + logicalAggregate().when( + a -> a.getOutputExprsSql().equals("id, sum(age) AS `left_sum0`, count(1) AS `left_cnt`")), + logicalAggregate().when( + a -> a.getOutputExprsSql().equals("sid, sum(grade) AS `right_sum0`, count(1) AS `right_cnt`")) + ) + ).when(newAgg -> + newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions()) + && newAgg.getOutputExprsSql().equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(right_sum0) * left_cnt) AS `rsum0`")) + ); + } + + @Test + void multiSum() { + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0, 4), + ImmutableList.of( + new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"), + new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"), + new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"), + new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"), + new Alias(new Sum(scan2.getOutput().get(2)), "rsum1") + )) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyExploration(EagerSplit.INSTANCE.build()) + .printlnExploration() + .matchesExploration( + logicalAggregate( + logicalJoin( + logicalAggregate().when(a -> a.getOutputExprsSql() + .equals("id, sum(gender) AS `left_sum0`, sum(name) AS `left_sum1`, sum(age) AS `left_sum2`, count(1) AS `left_cnt`")), + logicalAggregate().when(a -> a.getOutputExprsSql() + .equals("sid, sum(cid) AS `right_sum0`, sum(grade) AS `right_sum1`, count(1) AS `right_cnt`")) + ) + ).when(newAgg -> + newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions()) + && newAgg.getOutputExprsSql() + .equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(left_sum1) * right_cnt) AS `lsum1`, (sum(left_sum2) * right_cnt) AS `lsum2`, (sum(right_sum0) * left_cnt) AS `rsum0`, (sum(right_sum1) * left_cnt) AS `rsum1`")) + ); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org