This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new b32aac9195 [feature](Nereids)add normalize aggregate rule (#12013)
b32aac9195 is described below

commit b32aac919528f47a58247b4d05aa973a8643f885
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Wed Aug 24 18:30:18 2022 +0800

    [feature](Nereids)add normalize aggregate rule (#12013)
---
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../expression/rewrite/ExpressionRewrite.java      |   4 +-
 .../rules/rewrite/AggregateDisassemble.java        |   3 +-
 .../rules/rewrite/logical/NormalizeAggregate.java  | 138 +++++++++++++++++
 .../trees/plans/logical/LogicalAggregate.java      |  27 +++-
 .../trees/plans/logical/LogicalOlapScan.java       |   2 +-
 .../rewrite/logical/NormalizeAggregateTest.java    | 168 +++++++++++++++++++++
 7 files changed, 331 insertions(+), 12 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 7f4c22ba71..9dd73d04b4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -45,6 +45,7 @@ public enum RuleType {
     CHECK_ANALYSIS(RuleTypeClass.CHECK),
 
     // rewrite rules
+    NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
     AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
     COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
     ELIMINATE_ALIAS_NODE(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
index b285eaa2fa..96e61abc10 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
@@ -96,8 +96,8 @@ public class ExpressionRewrite implements RewriteRuleFactory {
                 if (outputExpressions.containsAll(newOutputExpressions)) {
                     return agg;
                 }
-                return new LogicalAggregate<>(newGroupByExprs, 
newOutputExpressions, agg.isDisassembled(),
-                        agg.getAggPhase(), agg.child());
+                return new LogicalAggregate<>(newGroupByExprs, 
newOutputExpressions,
+                        agg.isDisassembled(), agg.isNormalized(), 
agg.getAggPhase(), agg.child());
             }).toRule(RuleType.REWRITE_AGG_EXPRESSION);
         }
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 8a0363e103..1e8da9a14f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -51,7 +51,6 @@ import java.util.stream.Collectors;
  * TODO:
  *     1. use different class represent different phase aggregate
  *     2. if instance count is 1, shouldn't disassemble the agg plan
- *     3. we need another rule to removing duplicated expressions in group by 
expression list
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
@@ -123,6 +122,7 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                     localGroupByExprs,
                     localOutputExprs,
                     true,
+                    aggregate.isNormalized(),
                     AggPhase.LOCAL,
                     aggregate.child()
             );
@@ -130,6 +130,7 @@ public class AggregateDisassemble extends 
OneRewriteRuleFactory {
                     globalGroupByExprs,
                     globalOutputExprs,
                     true,
+                    aggregate.isNormalized(),
                     AggPhase.GLOBAL,
                     localAggregate
             );
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
new file mode 100644
index 0000000000..5aa70a0af3
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionReplacer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * normalize aggregate's group keys to SlotReference and generate a 
LogicalProject top on LogicalAggregate
+ * to hold to order of aggregate output, since aggregate output's order could 
change when we do translate.
+ *
+ * Apply this rule could simplify the processing of enforce and translate.
+ *
+ * Original Plan:
+ * Aggregate(
+ *   keys:[k1#1, K2#2 + 1],
+ *   outputs:[k1#1, Alias(K2# + 1)#4, Alias(k1#1 + 1)#5, Alias(SUM(v1#3))#6,
+ *            Alias(SUM(v1#3 + 1))#7, Alias(SUM(v1#3) + 1)#8])
+ *
+ * After rule:
+ * Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6, 
Alias(SR#11))#7, Alias(SR#10 + 1)#8)
+ * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, 
Alias(SUM(v1#3 + 1))#11])
+ *     +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
+ *
+ * More example could get from UT {@link NormalizeAggregateTest}
+ */
+public class NormalizeAggregate extends OneRewriteRuleFactory {
+    @Override
+    public Rule build() {
+        return logicalAggregate().when(aggregate -> 
!aggregate.isNormalized()).then(aggregate -> {
+            // substitution map used to substitute expression in aggregate's 
output to use it as top projections
+            Map<Expression, Expression> substitutionMap = Maps.newHashMap();
+            List<Expression> keys = aggregate.getGroupByExpressions();
+            List<NamedExpression> newOutputs = Lists.newArrayList();
+
+            // keys
+            Map<Boolean, List<Expression>> partitionedKeys = keys.stream()
+                    
.collect(Collectors.groupingBy(SlotReference.class::isInstance));
+            List<Expression> newKeys = Lists.newArrayList();
+            List<NamedExpression> bottomProjections = Lists.newArrayList();
+            if (partitionedKeys.containsKey(false)) {
+                // process non-SlotReference keys
+                newKeys.addAll(partitionedKeys.get(false).stream()
+                        .map(e -> new Alias(e, e.toSql()))
+                        .peek(a -> substitutionMap.put(a.child(), a.toSlot()))
+                        .peek(bottomProjections::add)
+                        .map(Alias::toSlot)
+                        .collect(Collectors.toList()));
+            }
+            if (partitionedKeys.containsKey(true)) {
+                // process SlotReference keys
+                partitionedKeys.get(true).stream()
+                        .map(SlotReference.class::cast)
+                        .peek(s -> substitutionMap.put(s, s))
+                        .peek(bottomProjections::add)
+                        .forEach(newKeys::add);
+            }
+            // add all necessary key to output
+            substitutionMap.entrySet().stream()
+                    .filter(kv -> aggregate.getOutputExpressions().stream()
+                            .anyMatch(e -> e.anyMatch(kv.getKey()::equals)))
+                    .map(Entry::getValue)
+                    .map(NamedExpression.class::cast)
+                    .forEach(newOutputs::add);
+
+            // if we generate bottom, we need to generate to project too.
+            // output
+            List<NamedExpression> outputs = aggregate.getOutputExpressions();
+            Map<Boolean, List<NamedExpression>> partitionedOutputs = 
outputs.stream()
+                    .collect(Collectors.groupingBy(e -> 
e.anyMatch(AggregateFunction.class::isInstance)));
+            if (partitionedOutputs.containsKey(true)) {
+                // process expressions that contain aggregate function
+                Set<AggregateFunction> aggregateFunctions = 
partitionedOutputs.get(true).stream()
+                        .flatMap(e -> 
e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+                        .collect(Collectors.toSet());
+                newOutputs.addAll(aggregateFunctions.stream()
+                        .map(f -> new Alias(f, f.toSql()))
+                        .peek(a -> substitutionMap.put(a.child(), a.toSlot()))
+                        .collect(Collectors.toList()));
+                // add slot references in aggregate function to bottom 
projections
+                bottomProjections.addAll(aggregateFunctions.stream()
+                        .flatMap(f -> 
f.<List<SlotReference>>collect(SlotReference.class::isInstance).stream())
+                        .map(SlotReference.class::cast)
+                        .collect(Collectors.toSet()));
+            }
+
+
+            // assemble
+            LogicalPlan root = aggregate.child();
+            if (partitionedKeys.containsKey(false)) {
+                root = new LogicalProject<>(bottomProjections, root);
+            }
+            root = new LogicalAggregate<>(newKeys, newOutputs, 
aggregate.isDisassembled(),
+                    true, aggregate.getAggPhase(), root);
+            List<NamedExpression> projections = outputs.stream()
+                    .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
substitutionMap))
+                    .map(NamedExpression.class::cast)
+                    .collect(Collectors.toList());
+            root = new LogicalProject<>(projections, root);
+
+            return root;
+        }).toRule(RuleType.NORMALIZE_AGGREGATE);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 06ec298851..019c090a96 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -52,6 +52,7 @@ import java.util.Optional;
 public class LogicalAggregate<CHILD_TYPE extends Plan> extends 
LogicalUnary<CHILD_TYPE> implements Aggregate {
 
     private final boolean disassembled;
+    private final boolean normalized;
     private final List<Expression> groupByExpressions;
     private final List<NamedExpression> outputExpressions;
     private final AggPhase aggPhase;
@@ -63,16 +64,18 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions,
             CHILD_TYPE child) {
-        this(groupByExpressions, outputExpressions, false, AggPhase.GLOBAL, 
child);
+        this(groupByExpressions, outputExpressions, false, false, 
AggPhase.GLOBAL, child);
     }
 
     public LogicalAggregate(
             List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions,
             boolean disassembled,
+            boolean normalized,
             AggPhase aggPhase,
             CHILD_TYPE child) {
-        this(groupByExpressions, outputExpressions, disassembled, aggPhase, 
Optional.empty(), Optional.empty(), child);
+        this(groupByExpressions, outputExpressions, disassembled, normalized,
+                aggPhase, Optional.empty(), Optional.empty(), child);
     }
 
     /**
@@ -82,6 +85,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions,
             boolean disassembled,
+            boolean normalized,
             AggPhase aggPhase,
             Optional<GroupExpression> groupExpression,
             Optional<LogicalProperties> logicalProperties,
@@ -90,6 +94,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         this.groupByExpressions = groupByExpressions;
         this.outputExpressions = outputExpressions;
         this.disassembled = disassembled;
+        this.normalized = normalized;
         this.aggPhase = aggPhase;
     }
 
@@ -136,6 +141,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         return disassembled;
     }
 
+    public boolean isNormalized() {
+        return normalized;
+    }
+
     /**
      * Determine the equality with another plan
      */
@@ -160,23 +169,25 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
     @Override
     public LogicalAggregate<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
-        return new LogicalAggregate(groupByExpressions, outputExpressions, 
disassembled, aggPhase, children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+                disassembled, normalized, aggPhase, children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
-        return new LogicalAggregate(groupByExpressions, outputExpressions, 
disassembled, aggPhase, groupExpression,
-            Optional.of(logicalProperties), children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+                disassembled, normalized, aggPhase, groupExpression, 
Optional.of(logicalProperties), children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
-        return new LogicalAggregate(groupByExpressions, outputExpressions, 
disassembled, aggPhase, Optional.empty(),
-            logicalProperties, children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions,
+                disassembled, normalized, aggPhase, Optional.empty(), 
logicalProperties, children.get(0));
     }
 
     public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> 
groupByExprList,
                                                  List<NamedExpression> 
outputExpressionList) {
-        return new LogicalAggregate(groupByExprList, outputExpressionList, 
disassembled, aggPhase, child());
+        return new LogicalAggregate<>(groupByExprList, outputExpressionList,
+                disassembled, normalized, aggPhase, child());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
index c071a84ea1..a2d5372327 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
@@ -68,7 +68,7 @@ public class LogicalOlapScan extends LogicalRelation  {
         return "ScanOlapTable ("
                 + qualifiedName()
                 + ", output: "
-                + 
computeOutput().stream().map(Objects::toString).collect(Collectors.joining(", 
", "[",  "]"))
+                + 
getOutput().stream().map(Objects::toString).collect(Collectors.joining(", ", 
"[",  "]"))
                 + ")";
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
new file mode 100644
index 0000000000..fd44a0d628
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.util.FieldChecker;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+
+import java.util.List;
+
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public class NormalizeAggregateTest implements PatternMatchSupported {
+    private Plan rStudent;
+
+    @BeforeAll
+    public final void beforeAll() {
+        rStudent = new LogicalOlapScan(PlanConstructor.student, 
ImmutableList.of("student"));
+    }
+
+    /**
+     * original plan:
+     * LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS 
`sum`#4], groupBy: [name#2])
+     * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, 
age#3])
+     *
+     * after rewrite:
+     * LogicalProject (name#2, sum(id)#5 AS `sum`#4)
+     * +--LogicalAggregate (phase: [GLOBAL], output: [name#2, sum(id#0) AS 
`sum(id)`#5], groupBy: [name#2])
+     *    +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, 
age#3])
+     */
+    @Test
+    public void testSimpleKeyWithSimpleAggregateFunction() {
+        NamedExpression key = rStudent.getOutput().get(2).toSlot();
+        NamedExpression aggregateFunction = new Alias(new 
Sum(rStudent.getOutput().get(0).toSlot()), "sum");
+        List<Expression> groupExpressionList = Lists.newArrayList(key);
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(key, 
aggregateFunction);
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new NormalizeAggregate())
+                .matchesFromRoot(
+                        logicalProject(
+                                logicalAggregate(
+                                        logicalOlapScan()
+                                
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
+                                        .when(aggregate -> 
aggregate.getOutputExpressions().get(0).equals(key))
+                                        .when(aggregate -> 
aggregate.getOutputExpressions().get(1).child(0).equals(aggregateFunction.child(0)))
+                                        .when(FieldChecker.check("normalized", 
true))
+                        ).when(project -> 
project.getProjects().get(0).equals(key))
+                                .when(project -> project.getProjects().get(1) 
instanceof Alias)
+                                .when(project -> ((Alias) 
(project.getProjects().get(1))).getExprId().equals(aggregateFunction.getExprId()))
+                                .when(project -> 
project.getProjects().get(1).child(0) instanceof SlotReference)
+                );
+    }
+
+    /**
+     * original plan:
+     * LogicalAggregate (phase: [GLOBAL], output: [(sum((id#0 * 1)) + 2) AS 
`(sum((id * 1)) + 2)`#4], groupBy: [name#2])
+     * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, 
age#3])
+     *
+     * after rewrite:
+     * LogicalProject ((sum((id * 1))#5 + 2) AS `(sum((id * 1)) + 2)`#4)
+     * +--LogicalAggregate (phase: [GLOBAL], output: [sum((id#0 * 1)) AS 
`sum((id * 1))`#5], groupBy: [name#2])
+     *    +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, 
age#3])
+     */
+    @Test
+    public void testComplexFuncWithComplexOutputOfFunc() {
+        NamedExpression key = rStudent.getOutput().get(2).toSlot();
+        List<Expression> groupExpressionList = Lists.newArrayList(key);
+        Expression aggregateFunction = new Sum(new 
Multiply(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(1)));
+        Expression complexOutput = new Add(aggregateFunction, new 
IntegerLiteral(2));
+        Alias output = new Alias(complexOutput, complexOutput.toSql());
+        List<NamedExpression> outputExpressionList = 
Lists.newArrayList(output);
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new NormalizeAggregate())
+                .matchesFromRoot(
+                        logicalProject(
+                                logicalAggregate(
+                                        logicalOlapScan()
+                                
).when(FieldChecker.check("groupByExpressions", ImmutableList.of(key)))
+                                        .when(aggregate -> 
aggregate.getOutputExpressions().size() == 1)
+                                        .when(aggregate -> 
aggregate.getOutputExpressions().get(0).child(0).equals(aggregateFunction))
+                        ).when(project -> project.getProjects().size() == 1)
+                                .when(project -> project.getProjects().get(0) 
instanceof Alias)
+                                .when(project -> 
project.getProjects().get(0).getExprId().equals(output.getExprId()))
+                                .when(project -> 
project.getProjects().get(0).child(0) instanceof Add)
+                                .when(project -> 
project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
+                                .when(project -> 
project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
+                );
+    }
+
+
+    /**
+     * original plan:
+     * LogicalAggregate (phase: [GLOBAL], output: [((gender#1 + 1) + 2) AS 
`((gender + 1) + 2)`#4], groupBy: [(gender#1 + 1)])
+     * +--ScanOlapTable (student.student, output: [id#0, gender#1, name#2, 
age#3])
+     *
+     * after rewrite:
+     * LogicalProject (((gender + 1)#5 + 2) AS `((gender + 1) + 2)`#4)
+     * +--LogicalAggregate (phase: [GLOBAL], output: [(gender + 1)#5], 
groupBy: [(gender + 1)#5])
+     *    +--LogicalProject ((gender#1 + 1) AS `(gender + 1)`#5)
+     *       +--ScanOlapTable (student.student, output: [id#0, gender#1, 
name#2, age#3])
+     */
+    @Test
+    public void testComplexKeyWithComplexOutputOfKey() {
+        Expression key = new Add(rStudent.getOutput().get(1).toSlot(), new 
IntegerLiteral(1));
+        Expression complexKeyOutput = new Add(key, new IntegerLiteral(2));
+        NamedExpression keyOutput = new Alias(complexKeyOutput, 
complexKeyOutput.toSql());
+        List<Expression> groupExpressionList = Lists.newArrayList(key);
+        List<NamedExpression> outputExpressionList = 
Lists.newArrayList(keyOutput);
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new NormalizeAggregate())
+                .matchesFromRoot(
+                        logicalProject(
+                                logicalAggregate(
+                                        logicalProject(
+                                                logicalOlapScan()
+                                        ).when(project -> 
project.getProjects().size() == 1)
+                                                .when(project -> 
project.getProjects().get(0) instanceof Alias)
+                                                .when(project -> 
project.getProjects().get(0).child(0).equals(key))
+                                ).when(aggregate -> 
aggregate.getGroupByExpressions().get(0) instanceof SlotReference)
+                                        .when(aggregate -> 
aggregate.getOutputExpressions().get(0) instanceof SlotReference)
+                                        .when(aggregate -> 
aggregate.getGroupByExpressions().equals(aggregate.getOutputExpressions()))
+                        ).when(project -> 
project.getProjects().get(0).getExprId().equals(keyOutput.getExprId()))
+                                .when(project -> 
project.getProjects().get(0).child(0) instanceof Add)
+                                .when(project -> 
project.getProjects().get(0).child(0).child(0) instanceof SlotReference)
+                                .when(project -> 
project.getProjects().get(0).child(0).child(1).equals(new IntegerLiteral(2)))
+
+                );
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to