This is an automated email from the ASF dual-hosted git repository. dataroaring pushed a commit to branch branch-4.0-preview in repository https://gitbox.apache.org/repos/asf/doris.git
commit 4e6af32545f44c04a711bd614f806c259d8a2365 Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com> AuthorDate: Fri Apr 26 12:34:24 2024 +0800 [Fix](nereids) fix rule merge_aggregate when has project (#33892) --- .../doris/nereids/jobs/executor/Rewriter.java | 4 +- .../nereids/rules/rewrite/MergeAggregate.java | 23 ++++--- .../merge_aggregate/merge_aggregate.out | 51 ++++++++++++++ .../merge_aggregate/merge_aggregate.groovy | 80 ++++++++++++++++++++++ 4 files changed, 148 insertions(+), 10 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index e8223524367..80da080daf6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -297,7 +297,9 @@ public class Rewriter extends AbstractBatchJobExecutor { topic("Eliminate GroupBy", topDown(new EliminateGroupBy(), - new MergeAggregate()) + new MergeAggregate(), + // need to adjust min/max/sum nullable attribute after merge aggregate + new AdjustAggregateNullableForEmptySet()) ), topic("Eager aggregation", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 9a0b9f8b5e0..a2c23dd9b41 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -34,10 +34,12 @@ import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -87,15 +89,14 @@ public class MergeAggregate implements RewriteRuleFactory { private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) { LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child(); LogicalAggregate<Plan> innerAgg = project.child(); - + List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions(); + List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections( + project.getProjects(), (List) outputExpressions); // rewrite agg function. e.g. max(max) - List<NamedExpression> aggFunc = outerAgg.getOutputExpressions().stream() + List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc)) .collect(Collectors.toList()); - // rewrite agg function directly refer to the slot below the project - List<Expression> replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(), - (List) aggFunc); // replace groupByKeys directly refer to the slot below the project List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(), outerAgg.getGroupByExpressions()); @@ -138,13 +139,17 @@ public class MergeAggregate implements RewriteRuleFactory { } boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg, - boolean sameGroupBy) { + boolean sameGroupBy, Optional<LogicalProject> projectOptional) { innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream() .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction)) .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0), (existValue, newValue) -> existValue)); Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions(); - for (AggregateFunction outerFunc : aggregateFunctions) { + List<AggregateFunction> replacedAggFunctions = projectOptional.map(project -> + (List<AggregateFunction>) PlanUtils.replaceExpressionByProjections( + projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions))) + .orElse(new ArrayList<>(aggregateFunctions)); + for (AggregateFunction outerFunc : replacedAggFunctions) { if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) { return false; } @@ -188,7 +193,7 @@ public class MergeAggregate implements RewriteRuleFactory { } boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size()); - return commonCheck(outerAgg, innerAgg, sameGroupBy); + return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty()); } private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) { @@ -206,6 +211,6 @@ public class MergeAggregate implements RewriteRuleFactory { return false; } boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size()); - return commonCheck(outerAgg, innerAgg, sameGroupBy); + return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project)); } } diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out index ba5b127a56f..fba17e8d7b9 100644 --- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out +++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out @@ -246,3 +246,54 @@ PhysicalResultSink --------------------PhysicalProject ----------------------PhysicalOlapScan[mal_test1] +-- !test_has_project_distinct_cant_transform -- +1 + +-- !test_has_project_distinct_cant_transform_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalProject +----------hashAgg[GLOBAL] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------PhysicalOlapScan[mal_test_merge_agg] + +-- !test_distinct_expr_transform -- +-1 + +-- !test_distinct_expr_transform_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalProject +----------PhysicalOlapScan[mal_test_merge_agg] + +-- !test_has_project_distinct_expr_transform -- +1 +1 +1 + +-- !test_has_project_distinct_expr_transform -- +PhysicalResultSink +--PhysicalDistribute[DistributionSpecGather] +----PhysicalProject +------hashAgg[GLOBAL] +--------PhysicalDistribute[DistributionSpecHash] +----------hashAgg[LOCAL] +------------PhysicalProject +--------------PhysicalOlapScan[mal_test_merge_agg] + +-- !test_sum_empty_table -- +\N \N \N + +-- !test_sum_empty_table_shape -- +PhysicalResultSink +--hashAgg[GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[LOCAL] +--------PhysicalOlapScan[mal_test2] + diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy index 44c256e2f57..46cd4a0a9b7 100644 --- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy +++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy @@ -174,4 +174,84 @@ suite("merge_aggregate") { group by a order by 1,2; """ + sql "drop table if exists mal_test_merge_agg" + sql """ + create table mal_test_merge_agg( + k1 int null, + k2 int not null, + k3 string null, + k4 varchar(100) null + ) + duplicate key (k1,k2) + distributed BY hash(k1) buckets 3 + properties("replication_num" = "1"); + """ + sql "insert into mal_test_merge_agg select 1,1,'1','a';" + sql "insert into mal_test_merge_agg select 2,2,'2','b';" + sql "insert into mal_test_merge_agg select 3,-3,null,'c';" + sql "sync" + + qt_test_has_project_distinct_cant_transform """ + select max(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + qt_test_has_project_distinct_cant_transform_shape """ + explain shape plan + select max(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + + qt_test_distinct_expr_transform """ + select max(count_col) + from ( + select k4, + max(-abs(k1)) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + qt_test_distinct_expr_transform_shape """ + explain shape plan + select max(count_col) + from ( + select k4, + max(-abs(k1)) as count_col + from mal_test_merge_agg group by k4 + ) t ; + """ + + qt_test_has_project_distinct_expr_transform """ + select sum(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t group by k4; + """ + + qt_test_has_project_distinct_expr_transform """ + explain shape plan + select sum(count_col) + from ( + select k4, + count(distinct case when k3 is null then 1 else 0 end) as count_col + from mal_test_merge_agg group by k4 + ) t group by k4; + """ + + qt_test_sum_empty_table """ + select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; + """ + + qt_test_sum_empty_table_shape """ + explain shape plan + select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t; + """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org