This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new ceb7b60a64 [fix](Nereids) update immutable LogicalAggregate attribute by mistake (#13740) ceb7b60a64 is described below commit ceb7b60a64725b430f2e124a382a05032fa115ff Author: jakevin <jakevin...@gmail.com> AuthorDate: Mon Oct 31 14:11:55 2022 +0800 [fix](Nereids) update immutable LogicalAggregate attribute by mistake (#13740) --- .../apache/doris/nereids/memo/GroupExpression.java | 12 +- .../java/org/apache/doris/nereids/memo/Memo.java | 13 +- .../rules/rewrite/AggregateDisassemble.java | 3 +- .../trees/plans/logical/LogicalAggregate.java | 4 +- .../org/apache/doris/nereids/memo/MemoTest.java | 6 +- .../rewrite/logical/AggregateDisassembleTest.java | 282 +++++++++------------ .../doris/nereids/stats/StatsCalculatorTest.java | 12 +- 7 files changed, 146 insertions(+), 186 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index 9e95f9e8d7..92067f607d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.statistics.StatsDeriveResult; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -43,7 +44,7 @@ public class GroupExpression { private double cost = 0.0; private CostEstimate costEstimate = null; private Group ownerGroup; - private List<Group> children; + private ImmutableList<Group> children; private final Plan plan; private final BitSet ruleMasks; private boolean statDerived; @@ -66,7 +67,7 @@ public class GroupExpression { public GroupExpression(Plan plan, List<Group> children) { this.plan = Objects.requireNonNull(plan, "plan can not be null") .withGroupExpression(Optional.of(this)); - this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null")); + this.children = ImmutableList.copyOf(Objects.requireNonNull(children, "children can not be null")); this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal()); this.statDerived = false; this.lowestCostTable = Maps.newHashMap(); @@ -84,10 +85,6 @@ public class GroupExpression { return children.size(); } - public void addChild(Group child) { - children.add(child); - } - public Group getOwnerGroup() { return ownerGroup; } @@ -108,12 +105,13 @@ public class GroupExpression { return children; } - public void setChildren(List<Group> children) { + public void setChildren(ImmutableList<Group> children) { this.children = children; } /** * replaceChild. + * * @param originChild origin child group * @param newChild new child group */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index e9b9bbfce9..f5616a71c7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -203,7 +203,7 @@ public class Memo { /** * add or replace the plan into the target group. - * + * <p> * the result truth table: * <pre> * +---------------------------------------+-----------------------------------+--------------------------------+ @@ -296,8 +296,7 @@ public class Memo { } } plan = replaceChildrenToGroupPlan(plan, childrenGroups); - GroupExpression newGroupExpression = new GroupExpression(plan); - newGroupExpression.setChildren(childrenGroups); + GroupExpression newGroupExpression = new GroupExpression(plan, childrenGroups); return insertGroupExpression(newGroupExpression, targetGroup, plan.getLogicalProperties()); // TODO: need to derive logical property if generate new group. currently we not copy logical plan into } @@ -388,13 +387,15 @@ public class Memo { } for (GroupExpression groupExpression : needReplaceChild) { groupExpressions.remove(groupExpression); - List<Group> children = groupExpression.children(); + List<Group> children = new ArrayList<>(groupExpression.children()); // TODO: use a better way to replace child, avoid traversing all groupExpression for (int i = 0; i < children.size(); i++) { if (children.get(i).equals(source)) { children.set(i, destination); } } + groupExpression.setChildren(ImmutableList.copyOf(children)); + GroupExpression that = groupExpressions.get(groupExpression); if (that != null && that.getOwnerGroup() != null && !that.getOwnerGroup().equals(groupExpression.getOwnerGroup())) { @@ -487,14 +488,14 @@ public class Memo { /** * eliminate fromGroup, clear targetGroup, then move the logical group expressions in the fromGroup to the toGroup. - * + * <p> * the scenario is: * ``` * Group 1(project, the targetGroup) Group 1(logicalOlapScan, the targetGroup) * | => * Group 0(logicalOlapScan, the fromGroup) * ``` - * + * <p> * we should recycle the group 0, and recycle all group expressions in group 1, then move the logicalOlapScan to * the group 1, and reset logical properties of the group 1. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java index 911b6735ac..3deb794412 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java @@ -34,6 +34,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -147,7 +148,7 @@ public class AggregateDisassemble extends OneRewriteRuleFactory { // +-----------+---------------------+-------------------------+--------------------------------+ // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x // 2. collect local aggregate output expressions and local aggregate group by expression list - List<Expression> localGroupByExprs = aggregate.getGroupByExpressions(); + List<Expression> localGroupByExprs = new ArrayList<>(aggregate.getGroupByExpressions()); List<NamedExpression> localOutputExprs = Lists.newArrayList(); for (Expression originGroupByExpr : originGroupByExprs) { if (inputSubstitutionMap.containsKey(originGroupByExpr)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 3dfd2ab06d..626c74dab9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -119,8 +119,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) { super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, child); - this.groupByExpressions = groupByExpressions; - this.outputExpressions = outputExpressions; + this.groupByExpressions = ImmutableList.copyOf(groupByExpressions); + this.outputExpressions = ImmutableList.copyOf(outputExpressions); this.partitionExpressions = partitionExpressions; this.disassembled = disassembled; this.normalized = normalized; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java index 9ad9bf8d16..f9b04d9758 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java @@ -71,7 +71,7 @@ class MemoTest implements PatternMatchSupported { JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0)); @Test - void mergeGroup() throws Exception { + void mergeGroup() { Memo memo = new Memo(); GroupId gid2 = new GroupId(2); Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new)); @@ -85,13 +85,13 @@ class MemoTest implements PatternMatchSupported { GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup)); GroupId gid1 = new GroupId(1); Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new)); - Map<GroupId, Group> groups = (Map<GroupId, Group>) Deencapsulation.getField(memo, "groups"); + Map<GroupId, Group> groups = Deencapsulation.getField(memo, "groups"); groups.put(gid2, srcGroup); groups.put(gid3, dstGroup); groups.put(gid0, g1); groups.put(gid1, g2); Map<GroupExpression, GroupExpression> groupExpressions = - (Map<GroupExpression, GroupExpression>) Deencapsulation.getField(memo, "groupExpressions"); + Deencapsulation.getField(memo, "groupExpressions"); groupExpressions.put(ge1, ge1); groupExpressions.put(ge2, ge2); memo.mergeGroup(srcGroup, dstGroup); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java index 6fa37387b9..d92d6efb4e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; @@ -31,14 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalUnary; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.PlanRewriter; -import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -46,15 +44,17 @@ import org.junit.jupiter.api.TestInstance; import java.util.List; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class AggregateDisassembleTest { +public class AggregateDisassembleTest implements PatternMatchSupported { private Plan rStudent; @BeforeAll public final void beforeAll() { - rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, ImmutableList.of("")); + rStudent = new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student, + ImmutableList.of("")); } /** + * <pre> * the initial plan is: * Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age]) * +--childPlan(id, name, age) @@ -62,6 +62,7 @@ public class AggregateDisassembleTest { * Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a]) * +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age]) * +--childPlan(id, name, age) + * </pre> */ @Test public void slotReferenceGroupBy() { @@ -70,50 +71,43 @@ public class AggregateDisassembleTest { List<NamedExpression> outputExpressionList = Lists.newArrayList( rStudent.getOutput().get(2).toSlot(), new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent); - - Plan after = rewrite(root); - - Assertions.assertTrue(after instanceof LogicalUnary); - Assertions.assertTrue(after instanceof LogicalAggregate); - Assertions.assertTrue(after.child(0) instanceof LogicalUnary); - LogicalAggregate<Plan> global = (LogicalAggregate) after; - LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0); - Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase()); - Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase()); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); Expression localOutput0 = rStudent.getOutput().get(2).toSlot(); Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot()); Expression localGroupBy = rStudent.getOutput().get(2).toSlot(); - Assertions.assertEquals(2, local.getOutputExpressions().size()); - Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference); - Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0)); - Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias); - Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0)); - Assertions.assertEquals(1, local.getGroupByExpressions().size()); - Assertions.assertEquals(localGroupBy, local.getGroupByExpressions().get(0)); - - Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot(); - Expression globalOutput1 = new Sum(local.getOutputExpressions().get(1).toSlot()); - Expression globalGroupBy = local.getOutputExpressions().get(0).toSlot(); - - Assertions.assertEquals(2, global.getOutputExpressions().size()); - Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference); - Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0)); - Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof Alias); - Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1).child(0)); - Assertions.assertEquals(1, global.getGroupByExpressions().size()); - Assertions.assertEquals(globalGroupBy, global.getGroupByExpressions().get(0)); - - // check id: - Assertions.assertEquals(outputExpressionList.get(0).getExprId(), - global.getOutputExpressions().get(0).getExprId()); - Assertions.assertEquals(outputExpressionList.get(1).getExprId(), - global.getOutputExpressions().get(1).getExprId()); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new AggregateDisassemble()) + .printlnTree() + .matchesFromRoot( + logicalAggregate( + logicalAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) + .when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1)) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + .when(agg -> agg.getOutputExpressions().get(1).child(0) + .equals(new Sum(agg.child().getOutputExpressions().get(1).toSlot()))) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + // check id: + .when(agg -> agg.getOutputExpressions().get(0).getExprId() + .equals(outputExpressionList.get(0).getExprId())) + .when(agg -> agg.getOutputExpressions().get(1).getExprId() + .equals(outputExpressionList.get(1).getExprId())) + ); } /** + * <pre> * the initial plan is: * Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: []) * +--childPlan(id, name, age) @@ -121,44 +115,41 @@ public class AggregateDisassembleTest { * Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: []) * +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: []) * +--childPlan(id, name, age) + * </pre> */ @Test public void globalAggregate() { List<Expression> groupExpressionList = Lists.newArrayList(); List<NamedExpression> outputExpressionList = Lists.newArrayList( - new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent); - - Plan after = rewrite(root); - - Assertions.assertTrue(after instanceof LogicalUnary); - Assertions.assertTrue(after instanceof LogicalAggregate); - Assertions.assertTrue(after.child(0) instanceof LogicalUnary); - LogicalAggregate<Plan> global = (LogicalAggregate) after; - LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0); - Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase()); - Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase()); + new Alias(new Sum(rStudent.getOutput().get(0)), "sum")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); Expression localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot()); - Assertions.assertEquals(1, local.getOutputExpressions().size()); - Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof Alias); - Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0).child(0)); - Assertions.assertEquals(0, local.getGroupByExpressions().size()); - - Expression globalOutput0 = new Sum(local.getOutputExpressions().get(0).toSlot()); - - Assertions.assertEquals(1, global.getOutputExpressions().size()); - Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof Alias); - Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0).child(0)); - Assertions.assertEquals(0, global.getGroupByExpressions().size()); - - // check id: - Assertions.assertEquals(outputExpressionList.get(0).getExprId(), - global.getOutputExpressions().get(0).getExprId()); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new AggregateDisassemble()) + .printlnTree() + .matchesFromRoot( + logicalAggregate( + logicalAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0).child(0).equals(localOutput0)) + .when(agg -> agg.getGroupByExpressions().size() == 0) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0) + .equals(new Sum(agg.child().getOutputExpressions().get(0).toSlot()))) + .when(agg -> agg.getGroupByExpressions().size() == 0) + // check id: + .when(agg -> agg.getOutputExpressions().get(0).getExprId() + .equals(outputExpressionList.get(0).getExprId())) + ); } /** + * <pre> * the initial plan is: * Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age]) * +--childPlan(id, name, age) @@ -166,6 +157,7 @@ public class AggregateDisassembleTest { * Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a]) * +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age]) * +--childPlan(id, name, age) + * </pre> */ @Test public void groupExpressionNotInOutput() { @@ -173,45 +165,40 @@ public class AggregateDisassembleTest { rStudent.getOutput().get(2).toSlot()); List<NamedExpression> outputExpressionList = Lists.newArrayList( new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), "sum")); - Plan root = new LogicalAggregate(groupExpressionList, outputExpressionList, rStudent); - - Plan after = rewrite(root); - - Assertions.assertTrue(after instanceof LogicalUnary); - Assertions.assertTrue(after instanceof LogicalAggregate); - Assertions.assertTrue(after.child(0) instanceof LogicalUnary); - LogicalAggregate<Plan> global = (LogicalAggregate) after; - LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0); - Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase()); - Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase()); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); Expression localOutput0 = rStudent.getOutput().get(2).toSlot(); Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot()); Expression localGroupBy = rStudent.getOutput().get(2).toSlot(); - Assertions.assertEquals(2, local.getOutputExpressions().size()); - Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference); - Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0)); - Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias); - Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0)); - Assertions.assertEquals(1, local.getGroupByExpressions().size()); - Assertions.assertEquals(localGroupBy, local.getGroupByExpressions().get(0)); - - Expression globalOutput0 = new Sum(local.getOutputExpressions().get(1).toSlot()); - Expression globalGroupBy = local.getOutputExpressions().get(0).toSlot(); - - Assertions.assertEquals(1, global.getOutputExpressions().size()); - Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof Alias); - Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0).child(0)); - Assertions.assertEquals(1, global.getGroupByExpressions().size()); - Assertions.assertEquals(globalGroupBy, global.getGroupByExpressions().get(0)); - - // check id: - Assertions.assertEquals(outputExpressionList.get(0).getExprId(), - global.getOutputExpressions().get(0).getExprId()); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new AggregateDisassemble()) + .printlnTree() + .matchesFromRoot( + logicalAggregate( + logicalAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 2) + .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) + .when(agg -> agg.getOutputExpressions().get(1).child(0).equals(localOutput1)) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0) + .equals(new Sum(agg.child().getOutputExpressions().get(1).toSlot()))) + .when(agg -> agg.getGroupByExpressions().size() == 1) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0).toSlot())) + // check id: + .when(agg -> agg.getOutputExpressions().get(0).getExprId() + .equals(outputExpressionList.get(0).getExprId())) + ); } /** + * <pre> * the initial plan is: * Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as c], groupByExpr: [id]) * +-- childPlan(id, name, age) @@ -220,6 +207,7 @@ public class AggregateDisassembleTest { * +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr: [id, age]) * +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr: [id, age]) * +-- childPlan(id, name, age) + * </pre> */ @Test public void distinctAggregateWithGroupBy() { @@ -229,68 +217,44 @@ public class AggregateDisassembleTest { new IntegerLiteral(2)), "c")); Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); - Plan after = rewrite(root); - - Assertions.assertTrue(after instanceof LogicalUnary); - Assertions.assertTrue(after instanceof LogicalAggregate); - Assertions.assertTrue(after.child(0) instanceof LogicalUnary); - LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after; - LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0); - LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0).child(0); - Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, distinctLocal.getAggPhase()); - Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase()); - Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase()); // check local: // id - Expression localOutput0 = rStudent.getOutput().get(0).toSlot(); + Expression localOutput0 = rStudent.getOutput().get(0); // age - Expression localOutput1 = rStudent.getOutput().get(2).toSlot(); + Expression localOutput1 = rStudent.getOutput().get(2); // id - Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot(); + Expression localGroupBy0 = rStudent.getOutput().get(0); // age - Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot(); - - Assertions.assertEquals(2, local.getOutputExpressions().size()); - Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof SlotReference); - Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0)); - Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof SlotReference); - Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1)); - Assertions.assertEquals(2, local.getGroupByExpressions().size()); - Assertions.assertEquals(localGroupBy0, local.getGroupByExpressions().get(0)); - Assertions.assertEquals(localGroupBy1, local.getGroupByExpressions().get(1)); - - // check global: - Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot(); - Expression globalOutput1 = local.getOutputExpressions().get(1).toSlot(); - Expression globalGroupBy0 = local.getOutputExpressions().get(0).toSlot(); - Expression globalGroupBy1 = local.getOutputExpressions().get(1).toSlot(); - - Assertions.assertEquals(2, global.getOutputExpressions().size()); - Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference); - Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0)); - Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof SlotReference); - Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1)); - Assertions.assertEquals(2, global.getGroupByExpressions().size()); - Assertions.assertEquals(globalGroupBy0, global.getGroupByExpressions().get(0)); - Assertions.assertEquals(globalGroupBy1, global.getGroupByExpressions().get(1)); - - // check distinct local: - Expression distinctLocalOutput = new Add(new Count(local.getOutputExpressions().get(1).toSlot(), true), - new IntegerLiteral(2)); - Expression distinctLocalGroupBy = local.getOutputExpressions().get(0).toSlot(); - - Assertions.assertEquals(1, distinctLocal.getOutputExpressions().size()); - Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) instanceof Alias); - Assertions.assertEquals(distinctLocalOutput, distinctLocal.getOutputExpressions().get(0).child(0)); - Assertions.assertEquals(1, distinctLocal.getGroupByExpressions().size()); - Assertions.assertEquals(distinctLocalGroupBy, distinctLocal.getGroupByExpressions().get(0)); - - // check id: - Assertions.assertEquals(outputExpressionList.get(0).getExprId(), - distinctLocal.getOutputExpressions().get(0).getExprId()); - } - - private Plan rewrite(Plan input) { - return PlanRewriter.topDownRewrite(input, new ConnectContext(), new AggregateDisassemble()); + Expression localGroupBy1 = rStudent.getOutput().get(2); + + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new AggregateDisassemble()) + .matchesFromRoot( + logicalAggregate( + logicalAggregate( + logicalAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getOutputExpressions().get(0).equals(localOutput0)) + .when(agg -> agg.getOutputExpressions().get(1).equals(localOutput1)) + .when(agg -> agg.getGroupByExpressions().get(0).equals(localGroupBy0)) + .when(agg -> agg.getGroupByExpressions().get(1).equals(localGroupBy1)) + ).when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getOutputExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + .when(agg -> agg.getOutputExpressions().get(1) + .equals(agg.child().getOutputExpressions().get(1))) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().getOutputExpressions().get(0))) + .when(agg -> agg.getGroupByExpressions().get(1) + .equals(agg.child().getOutputExpressions().get(1))) + ).when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL)) + .when(agg -> agg.getOutputExpressions().size() == 1) + .when(agg -> agg.getOutputExpressions().get(0) instanceof Alias) + .when(agg -> agg.getOutputExpressions().get(0).child(0) instanceof Add) + .when(agg -> agg.getGroupByExpressions().get(0) + .equals(agg.child().child().getOutputExpressions().get(0))) + .when(agg -> agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get( + 0).getExprId()) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java index b9fa35a335..3b8e2be8e1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java @@ -130,16 +130,14 @@ public class StatsCalculatorTest { childGroup.setStatistics(childStats); LogicalFilter<GroupPlan> logicalFilter = new LogicalFilter<>(and, groupPlan); - GroupExpression groupExpression = new GroupExpression(logicalFilter); - groupExpression.addChild(childGroup); + GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup)); Group ownerGroup = new Group(); groupExpression.setOwnerGroup(ownerGroup); StatsCalculator.estimate(groupExpression); Assertions.assertEquals((long) (10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001); LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan); - GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr); - groupExpressionOr.addChild(childGroup); + GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup)); Group ownerGroupOr = new Group(); groupExpressionOr.setOwnerGroup(ownerGroupOr); StatsCalculator.estimate(groupExpressionOr); @@ -243,8 +241,7 @@ public class StatsCalculatorTest { childGroup.setStatistics(childStats); LogicalLimit<GroupPlan> logicalLimit = new LogicalLimit<>(1, 2, groupPlan); - GroupExpression groupExpression = new GroupExpression(logicalLimit); - groupExpression.addChild(childGroup); + GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup)); Group ownerGroup = new Group(); ownerGroup.addGroupExpression(groupExpression); StatsCalculator.estimate(groupExpression); @@ -274,8 +271,7 @@ public class StatsCalculatorTest { childGroup.setStatistics(childStats); LogicalTopN<GroupPlan> logicalTopN = new LogicalTopN<>(Collections.emptyList(), 1, 2, groupPlan); - GroupExpression groupExpression = new GroupExpression(logicalTopN); - groupExpression.addChild(childGroup); + GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup)); Group ownerGroup = new Group(); ownerGroup.addGroupExpression(groupExpression); StatsCalculator.estimate(groupExpression); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org