This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit ac495762292d3efb0f563d07414624808c56a563
Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com>
AuthorDate: Mon May 27 20:40:57 2024 +0800

    [Fix](nereids) fix merge aggregate setting top projection bug (#35348)
    
    introduced by #31811
    
    sql like this:
    
        select col1, col2 from  (select a as col1, a as col2 from mal_test1 
group by a) t group by col1, col2 ;
    
    Transformation Description:
    In the process of optimizing the query, an agg-project-agg pattern is 
transformed into a project-agg pattern:
    Before Transformation:
    
    LogicalAggregate
    +-- LogicalPrject
        +-- LogicalAggregate
    
    After Transformation:
    
    LogicalProject
    +-- LogicalAggregate
    
    Before the transformation, the projection in the LogicalProject was a AS 
col1, a AS col2, and the outer aggregate group by keys were col1, col2. After 
the transformation, the aggregate group by keys became a, a, and the projection 
remained a AS col1, a AS col2.
    
    Problem:
    When building the project projections, the group by key a, a needed to be 
transformed to a AS col1, a AS col2. The old code had a bug where it used the 
slot as the map key and the alias in the projections as the map value. This 
approach did not account for the situation where aliases might have the same 
slot.
    
    Solution:
    The new code fixes this issue by using the original outer aggregate group 
by expression's exprId. It searches within the original project projections to 
find the NamedExpression that has the same exprId. These expressions are then 
placed into the new projections. This method ensures that the correct aliases 
are maintained, resolving the bug.
---
 .../doris/nereids/rules/rewrite/MergeAggregate.java  | 20 ++++++++++++++++----
 .../merge_aggregate/merge_aggregate.out              |  9 +++++++++
 .../merge_aggregate/merge_aggregate.groovy           |  6 ++++++
 3 files changed, 31 insertions(+), 4 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
index 8ea8a7f217d..889adfb69f5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
@@ -17,8 +17,10 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.annotation.DependsRules;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -44,6 +46,9 @@ import java.util.Set;
 import java.util.stream.Collectors;
 
 /**MergeAggregate*/
+@DependsRules({
+        NormalizeAggregate.class
+})
 public class MergeAggregate implements RewriteRuleFactory {
     private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS =
             ImmutableSet.of("min", "max", "sum", "any_value");
@@ -108,10 +113,17 @@ public class MergeAggregate implements RewriteRuleFactory 
{
                 .withChildren(innerAgg.children());
 
         // construct upper project
-        Map<SlotReference, Alias> childToAlias = project.getProjects().stream()
-                .filter(expr -> (expr instanceof Alias) && (expr.child(0) 
instanceof SlotReference))
-                .collect(Collectors.toMap(alias -> (SlotReference) 
alias.child(0), alias -> (Alias) alias));
-        List<Expression> projectGroupBy = 
ExpressionUtils.replace(replacedGroupBy, childToAlias);
+        Map<ExprId, NamedExpression> exprIdToNameExpressionMap = new 
HashMap<>();
+        for (NamedExpression pro : project.getProjects()) {
+            exprIdToNameExpressionMap.put(pro.getExprId(), pro);
+        }
+        List<Expression> originOuterAggGroupBy = 
outerAgg.getGroupByExpressions();
+        List<Expression> projectGroupBy = new ArrayList<>();
+        for (Expression expression : originOuterAggGroupBy) {
+            ExprId exprId = ((NamedExpression) expression).getExprId();
+            NamedExpression namedExpression = 
exprIdToNameExpressionMap.get(exprId);
+            projectGroupBy.add(namedExpression);
+        }
         List<NamedExpression> upperProjects = 
ImmutableList.<NamedExpression>builder()
                 .addAll(projectGroupBy.stream().map(namedExpr -> 
(NamedExpression) namedExpr).iterator())
                 .addAll(replacedAggFunc.stream().map(expr -> 
((NamedExpression) expr).toSlot()).iterator())
diff --git 
a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out 
b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
index fba17e8d7b9..d7103bfed9f 100644
--- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
+++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
@@ -297,3 +297,12 @@ PhysicalResultSink
 ------hashAgg[LOCAL]
 --------PhysicalOlapScan[mal_test2]
 
+-- !agg_project_agg_the_project_has_duplicate_slot_output --
+1      7       7
+2      4       4
+6      \N      \N
+7      1       1
+8      2       2
+8      5       5
+9      3       3
+
diff --git 
a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
 
b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
index 039f087c938..4a20cf4d68b 100644
--- 
a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
@@ -256,4 +256,10 @@ suite("merge_aggregate") {
         explain shape plan
         select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) 
col2, max(pk) col3 from mal_test2 group by a) t;
     """
+
+    qt_agg_project_agg_the_project_has_duplicate_slot_output """
+    select max(col1), col10, col11 from 
+        (select a,max(b) as col1, count(b) as col4, a as col10, a as col11 
+        from mal_test1 group by a) t group by col10, col11 order by 1,2,3;
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to