This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 686938f5db486091dd0bd0fe56aa9a316f4329fe Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com> AuthorDate: Thu Feb 29 16:02:43 2024 +0800 [fix](nereids) window function with grouping sets work not well (#31475) ```sql select a, c, sum(sum(b)) over(partition by c order by c rows between unbounded preceding and current row) from test_window_table2 group by grouping sets((a),( c)) having a > 1 order by 1,2,3; ``` for this kind of case: sum(sum(col)) over, nereids has cannot find slot problem. the output slot of repeat and aggregate is computed wrongly. Only collecting the trival-agg in NormalizeRepeat can fix this problem. Co-authored-by: feiniaofeiafei <moail...@selectdb.com> --- .../nereids/rules/analysis/NormalizeAggregate.java | 23 ++----------- .../nereids/rules/analysis/NormalizeRepeat.java | 7 ++-- .../org/apache/doris/nereids/util/PlanUtils.java | 25 ++++++++++++++ .../grouping_sets/window_agg_grouping_sets.out | 4 +++ .../grouping_sets/window_agg_grouping_sets.groovy | 40 ++++++++++++++++++++++ 5 files changed, 76 insertions(+), 23 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 8071cb4cb84..5874c26e177 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -31,13 +31,13 @@ import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -145,7 +145,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { // collect all trival-agg List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions(); List<AggregateFunction> aggFuncs = Lists.newArrayList(); - aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); + aggregateOutput.forEach(o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); // split non-distinct agg child as two part // TRUE part 1: need push down itself, if it contains subqury or window expression @@ -291,23 +291,4 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { } return builder.build(); } - - private static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor<Void, List<AggregateFunction>> { - - private static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs(); - - @Override - public Void visitWindow(WindowExpression windowExpression, List<AggregateFunction> context) { - for (Expression child : windowExpression.getExpressionsInWindowSpec()) { - child.accept(this, context); - } - return null; - } - - @Override - public Void visitAggregateFunction(AggregateFunction aggregateFunction, List<AggregateFunction> context) { - context.add(aggregateFunction); - return null; - } - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 9326ee725ff..3c893ce4bec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -37,9 +37,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; @@ -169,8 +171,9 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { .flatMap(function -> function.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); - Set<AggregateFunction> aggregateFunctions = ExpressionUtils.collect( - repeat.getOutputExpressions(), AggregateFunction.class::isInstance); + List<AggregateFunction> aggregateFunctions = Lists.newArrayList(); + repeat.getOutputExpressions().forEach( + o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions)); ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream() .flatMap(function -> function.getArguments().stream().map(arg -> { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index 45dbd5b8e82..c2ac7e4314f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -22,6 +22,9 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.WindowExpression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; @@ -121,4 +124,26 @@ public class PlanUtils { .collect(ImmutableSet.toImmutableSet()); return resultSet; } + + /** + * collect non_window_agg_func + */ + public static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor<Void, List<AggregateFunction>> { + + public static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs(); + + @Override + public Void visitWindow(WindowExpression windowExpression, List<AggregateFunction> context) { + for (Expression child : windowExpression.getExpressionsInWindowSpec()) { + child.accept(this, context); + } + return null; + } + + @Override + public Void visitAggregateFunction(AggregateFunction aggregateFunction, List<AggregateFunction> context) { + context.add(aggregateFunction); + return null; + } + } } diff --git a/regression-test/data/nereids_rules_p0/grouping_sets/window_agg_grouping_sets.out b/regression-test/data/nereids_rules_p0/grouping_sets/window_agg_grouping_sets.out new file mode 100644 index 00000000000..3ad36a2579a --- /dev/null +++ b/regression-test/data/nereids_rules_p0/grouping_sets/window_agg_grouping_sets.out @@ -0,0 +1,4 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select1 -- +2 \N 46.0000000000 + diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/window_agg_grouping_sets.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/window_agg_grouping_sets.groovy new file mode 100644 index 00000000000..a485980df40 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/window_agg_grouping_sets.groovy @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("window_agg_grouping_sets") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """ + DROP TABLE IF EXISTS test_window_table2 + """ + + sql """ + create table test_window_table2 ( a varchar(100) null, b decimalv3(18,10) null, c int, ) ENGINE=OLAP + DUPLICATE KEY(`a`) DISTRIBUTED BY HASH(`a`) BUCKETS 1 PROPERTIES + ( "replication_allocation" = "tag.location.default: 1" ); + """ + + sql """ + insert into test_window_table2 values("1", 1, 1),("1", 1, 2),("1", 2, 1),("1", 2, 2), + ("2", 11, 1),("2", 11, 2),("2", 12, 1),("2", 12, 2); + """ + + qt_select1 """ + select a, c, sum(sum(b)) over(partition by c order by c rows between unbounded preceding and current row) + from test_window_table2 group by grouping sets((a),( c)) having a > 1 order by 1,2,3; + """ + +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org