This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit ac495762292d3efb0f563d07414624808c56a563 Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com> AuthorDate: Mon May 27 20:40:57 2024 +0800 [Fix](nereids) fix merge aggregate setting top projection bug (#35348) introduced by #31811 sql like this: select col1, col2 from (select a as col1, a as col2 from mal_test1 group by a) t group by col1, col2 ; Transformation Description: In the process of optimizing the query, an agg-project-agg pattern is transformed into a project-agg pattern: Before Transformation: LogicalAggregate +-- LogicalPrject +-- LogicalAggregate After Transformation: LogicalProject +-- LogicalAggregate Before the transformation, the projection in the LogicalProject was a AS col1, a AS col2, and the outer aggregate group by keys were col1, col2. After the transformation, the aggregate group by keys became a, a, and the projection remained a AS col1, a AS col2. Problem: When building the project projections, the group by key a, a needed to be transformed to a AS col1, a AS col2. The old code had a bug where it used the slot as the map key and the alias in the projections as the map value. This approach did not account for the situation where aliases might have the same slot. Solution: The new code fixes this issue by using the original outer aggregate group by expression's exprId. It searches within the original project projections to find the NamedExpression that has the same exprId. These expressions are then placed into the new projections. This method ensures that the correct aliases are maintained, resolving the bug. --- .../doris/nereids/rules/rewrite/MergeAggregate.java | 20 ++++++++++++++++---- .../merge_aggregate/merge_aggregate.out | 9 +++++++++ .../merge_aggregate/merge_aggregate.groovy | 6 ++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 8ea8a7f217d..889adfb69f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -17,8 +17,10 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.annotation.DependsRules; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -44,6 +46,9 @@ import java.util.Set; import java.util.stream.Collectors; /**MergeAggregate*/ +@DependsRules({ + NormalizeAggregate.class +}) public class MergeAggregate implements RewriteRuleFactory { private static final ImmutableSet<String> ALLOW_MERGE_AGGREGATE_FUNCTIONS = ImmutableSet.of("min", "max", "sum", "any_value"); @@ -108,10 +113,17 @@ public class MergeAggregate implements RewriteRuleFactory { .withChildren(innerAgg.children()); // construct upper project - Map<SlotReference, Alias> childToAlias = project.getProjects().stream() - .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof SlotReference)) - .collect(Collectors.toMap(alias -> (SlotReference) alias.child(0), alias -> (Alias) alias)); - List<Expression> projectGroupBy = ExpressionUtils.replace(replacedGroupBy, childToAlias); + Map<ExprId, NamedExpression> exprIdToNameExpressionMap = new HashMap<>(); + for (NamedExpression pro : project.getProjects()) { + exprIdToNameExpressionMap.put(pro.getExprId(), pro); + } + List<Expression> originOuterAggGroupBy = outerAgg.getGroupByExpressions(); + List<Expression> projectGroupBy = new ArrayList<>(); + for (Expression expression : originOuterAggGroupBy) { + ExprId exprId = ((NamedExpression) expression).getExprId(); + NamedExpression namedExpression = exprIdToNameExpressionMap.get(exprId); + projectGroupBy.add(namedExpression); + } List<NamedExpression> upperProjects = ImmutableList.<NamedExpression>builder() .addAll(projectGroupBy.stream().map(namedExpr -> (NamedExpression) namedExpr).iterator()) .addAll(replacedAggFunc.stream().map(expr -> ((NamedExpression) expr).toSlot()).iterator()) diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index fba17e8d7b9..d7103bfed9f 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -297,3 +297,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalOlapScan[mal_test2] +-- !agg_project_agg_the_project_has_duplicate_slot_output -- +1 7 7 +2 4 4 +6 \N \N +7 1 1 +8 2 2 +8 5 5 +9 3 3 + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 039f087c938..4a20cf4d68b 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -256,4 +256,10 @@ suite("merge_aggregate") { explain shape plan select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; """ + + qt_agg_project_agg_the_project_has_duplicate_slot_output """ + select max(col1), col10, col11 from + (select a,max(b) as col1, count(b) as col4, a as col10, a as col11 + from mal_test1 group by a) t group by col10, col11 order by 1,2,3; + """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org