This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new ceb7b60a64 [fix](Nereids) update immutable LogicalAggregate attribute 
by mistake (#13740)
ceb7b60a64 is described below

commit ceb7b60a64725b430f2e124a382a05032fa115ff
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Mon Oct 31 14:11:55 2022 +0800

    [fix](Nereids) update immutable LogicalAggregate attribute by mistake 
(#13740)
---
 .../apache/doris/nereids/memo/GroupExpression.java |  12 +-
 .../java/org/apache/doris/nereids/memo/Memo.java   |  13 +-
 .../rules/rewrite/AggregateDisassemble.java        |   3 +-
 .../trees/plans/logical/LogicalAggregate.java      |   4 +-
 .../org/apache/doris/nereids/memo/MemoTest.java    |   6 +-
 .../rewrite/logical/AggregateDisassembleTest.java  | 282 +++++++++------------
 .../doris/nereids/stats/StatsCalculatorTest.java   |  12 +-
 7 files changed, 146 insertions(+), 186 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
index 9e95f9e8d7..92067f607d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
@@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.statistics.StatsDeriveResult;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 
@@ -43,7 +44,7 @@ public class GroupExpression {
     private double cost = 0.0;
     private CostEstimate costEstimate = null;
     private Group ownerGroup;
-    private List<Group> children;
+    private ImmutableList<Group> children;
     private final Plan plan;
     private final BitSet ruleMasks;
     private boolean statDerived;
@@ -66,7 +67,7 @@ public class GroupExpression {
     public GroupExpression(Plan plan, List<Group> children) {
         this.plan = Objects.requireNonNull(plan, "plan can not be null")
                 .withGroupExpression(Optional.of(this));
-        this.children = Lists.newArrayList(Objects.requireNonNull(children, 
"children can not be null"));
+        this.children = ImmutableList.copyOf(Objects.requireNonNull(children, 
"children can not be null"));
         this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
         this.statDerived = false;
         this.lowestCostTable = Maps.newHashMap();
@@ -84,10 +85,6 @@ public class GroupExpression {
         return children.size();
     }
 
-    public void addChild(Group child) {
-        children.add(child);
-    }
-
     public Group getOwnerGroup() {
         return ownerGroup;
     }
@@ -108,12 +105,13 @@ public class GroupExpression {
         return children;
     }
 
-    public void setChildren(List<Group> children) {
+    public void setChildren(ImmutableList<Group> children) {
         this.children = children;
     }
 
     /**
      * replaceChild.
+     *
      * @param originChild origin child group
      * @param newChild new child group
      */
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
index e9b9bbfce9..f5616a71c7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
@@ -203,7 +203,7 @@ public class Memo {
 
     /**
      * add or replace the plan into the target group.
-     *
+     * <p>
      * the result truth table:
      * <pre>
      * 
+---------------------------------------+-----------------------------------+--------------------------------+
@@ -296,8 +296,7 @@ public class Memo {
             }
         }
         plan = replaceChildrenToGroupPlan(plan, childrenGroups);
-        GroupExpression newGroupExpression = new GroupExpression(plan);
-        newGroupExpression.setChildren(childrenGroups);
+        GroupExpression newGroupExpression = new GroupExpression(plan, 
childrenGroups);
         return insertGroupExpression(newGroupExpression, targetGroup, 
plan.getLogicalProperties());
         // TODO: need to derive logical property if generate new group. 
currently we not copy logical plan into
     }
@@ -388,13 +387,15 @@ public class Memo {
         }
         for (GroupExpression groupExpression : needReplaceChild) {
             groupExpressions.remove(groupExpression);
-            List<Group> children = groupExpression.children();
+            List<Group> children = new ArrayList<>(groupExpression.children());
             // TODO: use a better way to replace child, avoid traversing all 
groupExpression
             for (int i = 0; i < children.size(); i++) {
                 if (children.get(i).equals(source)) {
                     children.set(i, destination);
                 }
             }
+            groupExpression.setChildren(ImmutableList.copyOf(children));
+
             GroupExpression that = groupExpressions.get(groupExpression);
             if (that != null && that.getOwnerGroup() != null
                     && 
!that.getOwnerGroup().equals(groupExpression.getOwnerGroup())) {
@@ -487,14 +488,14 @@ public class Memo {
 
     /**
      * eliminate fromGroup, clear targetGroup, then move the logical group 
expressions in the fromGroup to the toGroup.
-     *
+     * <p>
      * the scenario is:
      * ```
      *  Group 1(project, the targetGroup)                  Group 
1(logicalOlapScan, the targetGroup)
      *               |                             =>
      *  Group 0(logicalOlapScan, the fromGroup)
      * ```
-     *
+     * <p>
      * we should recycle the group 0, and recycle all group expressions in 
group 1, then move the logicalOlapScan to
      * the group 1, and reset logical properties of the group 1.
      */
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 911b6735ac..3deb794412 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -34,6 +34,7 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -147,7 +148,7 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
         //    
+-----------+---------------------+-------------------------+--------------------------------+
         //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: 
ExprId x
         // 2. collect local aggregate output expressions and local aggregate 
group by expression list
-        List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+        List<Expression> localGroupByExprs = new 
ArrayList<>(aggregate.getGroupByExpressions());
         List<NamedExpression> localOutputExprs = Lists.newArrayList();
         for (Expression originGroupByExpr : originGroupByExprs) {
             if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 3dfd2ab06d..626c74dab9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -119,8 +119,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             Optional<LogicalProperties> logicalProperties,
             CHILD_TYPE child) {
         super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, 
child);
-        this.groupByExpressions = groupByExpressions;
-        this.outputExpressions = outputExpressions;
+        this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
+        this.outputExpressions = ImmutableList.copyOf(outputExpressions);
         this.partitionExpressions = partitionExpressions;
         this.disassembled = disassembled;
         this.normalized = normalized;
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
index 9ad9bf8d16..f9b04d9758 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
@@ -71,7 +71,7 @@ class MemoTest implements PatternMatchSupported {
             JoinType.INNER_JOIN, logicalJoinAB, 
PlanConstructor.newLogicalOlapScan(2, "C", 0));
 
     @Test
-    void mergeGroup() throws Exception {
+    void mergeGroup() {
         Memo memo = new Memo();
         GroupId gid2 = new GroupId(2);
         Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), 
new LogicalProperties(ArrayList::new));
@@ -85,13 +85,13 @@ class MemoTest implements PatternMatchSupported {
         GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
         GroupId gid1 = new GroupId(1);
         Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
-        Map<GroupId, Group> groups = (Map<GroupId, Group>) 
Deencapsulation.getField(memo, "groups");
+        Map<GroupId, Group> groups = Deencapsulation.getField(memo, "groups");
         groups.put(gid2, srcGroup);
         groups.put(gid3, dstGroup);
         groups.put(gid0, g1);
         groups.put(gid1, g2);
         Map<GroupExpression, GroupExpression> groupExpressions =
-                (Map<GroupExpression, GroupExpression>) 
Deencapsulation.getField(memo, "groupExpressions");
+                Deencapsulation.getField(memo, "groupExpressions");
         groupExpressions.put(ge1, ge1);
         groupExpressions.put(ge2, ge2);
         memo.mergeGroup(srcGroup, dstGroup);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index 6fa37387b9..d92d6efb4e 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
@@ -31,14 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.RelationId;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.nereids.util.PlanRewriter;
-import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.TestInstance;
@@ -46,15 +44,17 @@ import org.junit.jupiter.api.TestInstance;
 import java.util.List;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class AggregateDisassembleTest {
+public class AggregateDisassembleTest implements PatternMatchSupported {
     private Plan rStudent;
 
     @BeforeAll
     public final void beforeAll() {
-        rStudent = new 
LogicalOlapScan(RelationId.createGenerator().getNextId(), 
PlanConstructor.student, ImmutableList.of(""));
+        rStudent = new 
LogicalOlapScan(RelationId.createGenerator().getNextId(), 
PlanConstructor.student,
+                ImmutableList.of(""));
     }
 
     /**
+     * <pre>
      * the initial plan is:
      *   Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], 
groupByExpr: [age])
      *   +--childPlan(id, name, age)
@@ -62,6 +62,7 @@ public class AggregateDisassembleTest {
      *   Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: 
[a])
      *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], 
groupByExpr: [age])
      *       +--childPlan(id, name, age)
+     * </pre>
      */
     @Test
     public void slotReferenceGroupBy() {
@@ -70,50 +71,43 @@ public class AggregateDisassembleTest {
         List<NamedExpression> outputExpressionList = Lists.newArrayList(
                 rStudent.getOutput().get(2).toSlot(),
                 new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
-        Plan root = new LogicalAggregate(groupExpressionList, 
outputExpressionList, rStudent);
-
-        Plan after = rewrite(root);
-
-        Assertions.assertTrue(after instanceof LogicalUnary);
-        Assertions.assertTrue(after instanceof LogicalAggregate);
-        Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
-        LogicalAggregate<Plan> global = (LogicalAggregate) after;
-        LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
-        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
-        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
         Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
         Expression localOutput1 = new 
Sum(rStudent.getOutput().get(0).toSlot());
         Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
 
-        Assertions.assertEquals(2, local.getOutputExpressions().size());
-        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
SlotReference);
-        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0));
-        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
Alias);
-        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1).child(0));
-        Assertions.assertEquals(1, local.getGroupByExpressions().size());
-        Assertions.assertEquals(localGroupBy, 
local.getGroupByExpressions().get(0));
-
-        Expression globalOutput0 = 
local.getOutputExpressions().get(0).toSlot();
-        Expression globalOutput1 = new 
Sum(local.getOutputExpressions().get(1).toSlot());
-        Expression globalGroupBy = 
local.getOutputExpressions().get(0).toSlot();
-
-        Assertions.assertEquals(2, global.getOutputExpressions().size());
-        Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof 
SlotReference);
-        Assertions.assertEquals(globalOutput0, 
global.getOutputExpressions().get(0));
-        Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof 
Alias);
-        Assertions.assertEquals(globalOutput1, 
global.getOutputExpressions().get(1).child(0));
-        Assertions.assertEquals(1, global.getGroupByExpressions().size());
-        Assertions.assertEquals(globalGroupBy, 
global.getGroupByExpressions().get(0));
-
-        // check id:
-        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
-                global.getOutputExpressions().get(0).getExprId());
-        Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
-                global.getOutputExpressions().get(1).getExprId());
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new AggregateDisassemble())
+                .printlnTree()
+                .matchesFromRoot(
+                        logicalAggregate(
+                                logicalAggregate()
+                                        .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                        .when(agg -> 
agg.getOutputExpressions().size() == 2)
+                                        .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
+                                        .when(agg -> 
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+                                        .when(agg -> 
agg.getGroupByExpressions().size() == 1)
+                                        .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy))
+                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                .when(agg -> agg.getOutputExpressions().size() 
== 2)
+                                .when(agg -> agg.getOutputExpressions().get(0)
+                                        
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
+                                .when(agg -> 
agg.getOutputExpressions().get(1).child(0)
+                                        .equals(new 
Sum(agg.child().getOutputExpressions().get(1).toSlot())))
+                                .when(agg -> 
agg.getGroupByExpressions().size() == 1)
+                                .when(agg -> agg.getGroupByExpressions().get(0)
+                                        
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
+                                // check id:
+                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId()
+                                        
.equals(outputExpressionList.get(0).getExprId()))
+                                .when(agg -> 
agg.getOutputExpressions().get(1).getExprId()
+                                        
.equals(outputExpressionList.get(1).getExprId()))
+                );
     }
 
     /**
+     * <pre>
      * the initial plan is:
      *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: 
[])
      *   +--childPlan(id, name, age)
@@ -121,44 +115,41 @@ public class AggregateDisassembleTest {
      *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
      *   +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: 
[])
      *       +--childPlan(id, name, age)
+     * </pre>
      */
     @Test
     public void globalAggregate() {
         List<Expression> groupExpressionList = Lists.newArrayList();
         List<NamedExpression> outputExpressionList = Lists.newArrayList(
-                new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
-        Plan root = new LogicalAggregate(groupExpressionList, 
outputExpressionList, rStudent);
-
-        Plan after = rewrite(root);
-
-        Assertions.assertTrue(after instanceof LogicalUnary);
-        Assertions.assertTrue(after instanceof LogicalAggregate);
-        Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
-        LogicalAggregate<Plan> global = (LogicalAggregate) after;
-        LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
-        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
-        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+                new Alias(new Sum(rStudent.getOutput().get(0)), "sum"));
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
         Expression localOutput0 = new 
Sum(rStudent.getOutput().get(0).toSlot());
 
-        Assertions.assertEquals(1, local.getOutputExpressions().size());
-        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
Alias);
-        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0).child(0));
-        Assertions.assertEquals(0, local.getGroupByExpressions().size());
-
-        Expression globalOutput0 = new 
Sum(local.getOutputExpressions().get(0).toSlot());
-
-        Assertions.assertEquals(1, global.getOutputExpressions().size());
-        Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof 
Alias);
-        Assertions.assertEquals(globalOutput0, 
global.getOutputExpressions().get(0).child(0));
-        Assertions.assertEquals(0, global.getGroupByExpressions().size());
-
-        // check id:
-        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
-                global.getOutputExpressions().get(0).getExprId());
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new AggregateDisassemble())
+                .printlnTree()
+                .matchesFromRoot(
+                        logicalAggregate(
+                                logicalAggregate()
+                                        .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                        .when(agg -> 
agg.getOutputExpressions().size() == 1)
+                                        .when(agg -> 
agg.getOutputExpressions().get(0).child(0).equals(localOutput0))
+                                        .when(agg -> 
agg.getGroupByExpressions().size() == 0)
+                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                .when(agg -> agg.getOutputExpressions().size() 
== 1)
+                                .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
+                                .when(agg -> 
agg.getOutputExpressions().get(0).child(0)
+                                        .equals(new 
Sum(agg.child().getOutputExpressions().get(0).toSlot())))
+                                .when(agg -> 
agg.getGroupByExpressions().size() == 0)
+                                // check id:
+                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId()
+                                        
.equals(outputExpressionList.get(0).getExprId()))
+                );
     }
 
     /**
+     * <pre>
      * the initial plan is:
      *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: 
[age])
      *   +--childPlan(id, name, age)
@@ -166,6 +157,7 @@ public class AggregateDisassembleTest {
      *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: 
[a])
      *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], 
groupByExpr: [age])
      *       +--childPlan(id, name, age)
+     * </pre>
      */
     @Test
     public void groupExpressionNotInOutput() {
@@ -173,45 +165,40 @@ public class AggregateDisassembleTest {
                 rStudent.getOutput().get(2).toSlot());
         List<NamedExpression> outputExpressionList = Lists.newArrayList(
                 new Alias(new Sum(rStudent.getOutput().get(0).toSlot()), 
"sum"));
-        Plan root = new LogicalAggregate(groupExpressionList, 
outputExpressionList, rStudent);
-
-        Plan after = rewrite(root);
-
-        Assertions.assertTrue(after instanceof LogicalUnary);
-        Assertions.assertTrue(after instanceof LogicalAggregate);
-        Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
-        LogicalAggregate<Plan> global = (LogicalAggregate) after;
-        LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
-        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
-        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
         Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
         Expression localOutput1 = new 
Sum(rStudent.getOutput().get(0).toSlot());
         Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
 
-        Assertions.assertEquals(2, local.getOutputExpressions().size());
-        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
SlotReference);
-        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0));
-        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
Alias);
-        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1).child(0));
-        Assertions.assertEquals(1, local.getGroupByExpressions().size());
-        Assertions.assertEquals(localGroupBy, 
local.getGroupByExpressions().get(0));
-
-        Expression globalOutput0 = new 
Sum(local.getOutputExpressions().get(1).toSlot());
-        Expression globalGroupBy = 
local.getOutputExpressions().get(0).toSlot();
-
-        Assertions.assertEquals(1, global.getOutputExpressions().size());
-        Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof 
Alias);
-        Assertions.assertEquals(globalOutput0, 
global.getOutputExpressions().get(0).child(0));
-        Assertions.assertEquals(1, global.getGroupByExpressions().size());
-        Assertions.assertEquals(globalGroupBy, 
global.getGroupByExpressions().get(0));
-
-        // check id:
-        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
-                global.getOutputExpressions().get(0).getExprId());
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new AggregateDisassemble())
+                .printlnTree()
+                .matchesFromRoot(
+                        logicalAggregate(
+                                logicalAggregate()
+                                        .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                        .when(agg -> 
agg.getOutputExpressions().size() == 2)
+                                        .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
+                                        .when(agg -> 
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+                                        .when(agg -> 
agg.getGroupByExpressions().size() == 1)
+                                        .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy))
+                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                .when(agg -> agg.getOutputExpressions().size() 
== 1)
+                                .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
+                                .when(agg -> 
agg.getOutputExpressions().get(0).child(0)
+                                        .equals(new 
Sum(agg.child().getOutputExpressions().get(1).toSlot())))
+                                .when(agg -> 
agg.getGroupByExpressions().size() == 1)
+                                .when(agg -> agg.getGroupByExpressions().get(0)
+                                        
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
+                                // check id:
+                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId()
+                                        
.equals(outputExpressionList.get(0).getExprId()))
+                );
     }
 
     /**
+     * <pre>
      * the initial plan is:
      *   Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as 
c], groupByExpr: [id])
      *   +-- childPlan(id, name, age)
@@ -220,6 +207,7 @@ public class AggregateDisassembleTest {
      *   +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr: 
[id, age])
      *       +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr: 
[id, age])
      *           +-- childPlan(id, name, age)
+     * </pre>
      */
     @Test
     public void distinctAggregateWithGroupBy() {
@@ -229,68 +217,44 @@ public class AggregateDisassembleTest {
                         new IntegerLiteral(2)), "c"));
         Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
 
-        Plan after = rewrite(root);
-
-        Assertions.assertTrue(after instanceof LogicalUnary);
-        Assertions.assertTrue(after instanceof LogicalAggregate);
-        Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
-        LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
-        LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
-        LogicalAggregate<Plan> local = (LogicalAggregate) 
after.child(0).child(0);
-        Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, 
distinctLocal.getAggPhase());
-        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
-        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
         // check local:
         // id
-        Expression localOutput0 = rStudent.getOutput().get(0).toSlot();
+        Expression localOutput0 = rStudent.getOutput().get(0);
         // age
-        Expression localOutput1 = rStudent.getOutput().get(2).toSlot();
+        Expression localOutput1 = rStudent.getOutput().get(2);
         // id
-        Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot();
+        Expression localGroupBy0 = rStudent.getOutput().get(0);
         // age
-        Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot();
-
-        Assertions.assertEquals(2, local.getOutputExpressions().size());
-        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
SlotReference);
-        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0));
-        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
SlotReference);
-        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1));
-        Assertions.assertEquals(2, local.getGroupByExpressions().size());
-        Assertions.assertEquals(localGroupBy0, 
local.getGroupByExpressions().get(0));
-        Assertions.assertEquals(localGroupBy1, 
local.getGroupByExpressions().get(1));
-
-        // check global:
-        Expression globalOutput0 = 
local.getOutputExpressions().get(0).toSlot();
-        Expression globalOutput1 = 
local.getOutputExpressions().get(1).toSlot();
-        Expression globalGroupBy0 = 
local.getOutputExpressions().get(0).toSlot();
-        Expression globalGroupBy1 = 
local.getOutputExpressions().get(1).toSlot();
-
-        Assertions.assertEquals(2, global.getOutputExpressions().size());
-        Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof 
SlotReference);
-        Assertions.assertEquals(globalOutput0, 
global.getOutputExpressions().get(0));
-        Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof 
SlotReference);
-        Assertions.assertEquals(globalOutput1, 
global.getOutputExpressions().get(1));
-        Assertions.assertEquals(2, global.getGroupByExpressions().size());
-        Assertions.assertEquals(globalGroupBy0, 
global.getGroupByExpressions().get(0));
-        Assertions.assertEquals(globalGroupBy1, 
global.getGroupByExpressions().get(1));
-
-        // check distinct local:
-        Expression distinctLocalOutput = new Add(new 
Count(local.getOutputExpressions().get(1).toSlot(), true),
-                new IntegerLiteral(2));
-        Expression distinctLocalGroupBy = 
local.getOutputExpressions().get(0).toSlot();
-
-        Assertions.assertEquals(1, 
distinctLocal.getOutputExpressions().size());
-        Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) 
instanceof Alias);
-        Assertions.assertEquals(distinctLocalOutput, 
distinctLocal.getOutputExpressions().get(0).child(0));
-        Assertions.assertEquals(1, 
distinctLocal.getGroupByExpressions().size());
-        Assertions.assertEquals(distinctLocalGroupBy, 
distinctLocal.getGroupByExpressions().get(0));
-
-        // check id:
-        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
-                distinctLocal.getOutputExpressions().get(0).getExprId());
-    }
-
-    private Plan rewrite(Plan input) {
-        return PlanRewriter.topDownRewrite(input, new ConnectContext(), new 
AggregateDisassemble());
+        Expression localGroupBy1 = rStudent.getOutput().get(2);
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new AggregateDisassemble())
+                .matchesFromRoot(
+                        logicalAggregate(
+                                logicalAggregate(
+                                        logicalAggregate()
+                                                .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                                .when(agg -> 
agg.getOutputExpressions().get(0).equals(localOutput0))
+                                                .when(agg -> 
agg.getOutputExpressions().get(1).equals(localOutput1))
+                                                .when(agg -> 
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+                                                .when(agg -> 
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
+                                ).when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                        .when(agg -> 
agg.getOutputExpressions().get(0)
+                                                
.equals(agg.child().getOutputExpressions().get(0)))
+                                        .when(agg -> 
agg.getOutputExpressions().get(1)
+                                                
.equals(agg.child().getOutputExpressions().get(1)))
+                                        .when(agg -> 
agg.getGroupByExpressions().get(0)
+                                                
.equals(agg.child().getOutputExpressions().get(0)))
+                                        .when(agg -> 
agg.getGroupByExpressions().get(1)
+                                                
.equals(agg.child().getOutputExpressions().get(1)))
+                        ).when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+                                .when(agg -> agg.getOutputExpressions().size() 
== 1)
+                                .when(agg -> agg.getOutputExpressions().get(0) 
instanceof Alias)
+                                .when(agg -> 
agg.getOutputExpressions().get(0).child(0) instanceof Add)
+                                .when(agg -> agg.getGroupByExpressions().get(0)
+                                        
.equals(agg.child().child().getOutputExpressions().get(0)))
+                                .when(agg -> 
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
+                                        0).getExprId())
+                );
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
index b9fa35a335..3b8e2be8e1 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
@@ -130,16 +130,14 @@ public class StatsCalculatorTest {
         childGroup.setStatistics(childStats);
 
         LogicalFilter<GroupPlan> logicalFilter = new LogicalFilter<>(and, 
groupPlan);
-        GroupExpression groupExpression = new GroupExpression(logicalFilter);
-        groupExpression.addChild(childGroup);
+        GroupExpression groupExpression = new GroupExpression(logicalFilter, 
ImmutableList.of(childGroup));
         Group ownerGroup = new Group();
         groupExpression.setOwnerGroup(ownerGroup);
         StatsCalculator.estimate(groupExpression);
         Assertions.assertEquals((long) (10000 * 0.1 * 0.05), 
ownerGroup.getStatistics().getRowCount(), 0.001);
 
         LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, 
groupPlan);
-        GroupExpression groupExpressionOr = new 
GroupExpression(logicalFilterOr);
-        groupExpressionOr.addChild(childGroup);
+        GroupExpression groupExpressionOr = new 
GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
         Group ownerGroupOr = new Group();
         groupExpressionOr.setOwnerGroup(ownerGroupOr);
         StatsCalculator.estimate(groupExpressionOr);
@@ -243,8 +241,7 @@ public class StatsCalculatorTest {
         childGroup.setStatistics(childStats);
 
         LogicalLimit<GroupPlan> logicalLimit = new LogicalLimit<>(1, 2, 
groupPlan);
-        GroupExpression groupExpression = new GroupExpression(logicalLimit);
-        groupExpression.addChild(childGroup);
+        GroupExpression groupExpression = new GroupExpression(logicalLimit, 
ImmutableList.of(childGroup));
         Group ownerGroup = new Group();
         ownerGroup.addGroupExpression(groupExpression);
         StatsCalculator.estimate(groupExpression);
@@ -274,8 +271,7 @@ public class StatsCalculatorTest {
         childGroup.setStatistics(childStats);
 
         LogicalTopN<GroupPlan> logicalTopN = new 
LogicalTopN<>(Collections.emptyList(), 1, 2, groupPlan);
-        GroupExpression groupExpression = new GroupExpression(logicalTopN);
-        groupExpression.addChild(childGroup);
+        GroupExpression groupExpression = new GroupExpression(logicalTopN, 
ImmutableList.of(childGroup));
         Group ownerGroup = new Group();
         ownerGroup.addGroupExpression(groupExpression);
         StatsCalculator.estimate(groupExpression);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to