This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new ebf474d9d89 [feature](nereids) deal the slots that appear both in agg func and grouping sets (#31318) ebf474d9d89 is described below commit ebf474d9d89cbca6728076ec1a27afbc0e51908f Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com> AuthorDate: Mon Feb 26 19:59:51 2024 +0800 [feature](nereids) deal the slots that appear both in agg func and grouping sets (#31318) this PR support slot appearing both in agg func and grouping sets. sql like below: select sum(a) from t group by grouping sets ((a)); Before this PR, Nereids throw exception like below: col_int_undef_signed cannot both in select list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union instead. This PR removes the restriction and supports this situation. --- .../nereids/rules/analysis/NormalizeRepeat.java | 100 ++++++++++++++++----- .../grouping_sets/test_grouping_sets.out | 26 ++++++ ...ot_both_appear_in_agg_fun_and_grouping_sets.out | 66 ++++++++++++++ .../query_p0/grouping_sets/test_grouping_sets.out | 5 ++ .../grouping_sets/test_grouping_sets.groovy | 26 ++---- ...both_appear_in_agg_fun_and_grouping_sets.groovy | 62 +++++++++++++ .../suites/nereids_syntax_p0/grouping_sets.groovy | 16 ---- .../grouping_sets/test_grouping_sets.groovy | 27 +----- 8 files changed, 248 insertions(+), 80 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 005cc663862..9326ee725ff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotTriplet; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.OrderExpression; @@ -44,8 +43,10 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; +import org.jetbrains.annotations.NotNull; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -80,35 +81,16 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> { checkRepeatLegality(repeat); // add virtual slot, LogicalAggregate and LogicalProject for normalize - return normalizeRepeat(repeat); + LogicalAggregate<Plan> agg = normalizeRepeat(repeat); + return dealSlotAppearBothInAggFuncAndGroupingSets(agg); }) ); } private void checkRepeatLegality(LogicalRepeat<Plan> repeat) { - checkIfAggFuncSlotInGroupingSets(repeat); checkGroupingSetsSize(repeat); } - private void checkIfAggFuncSlotInGroupingSets(LogicalRepeat<Plan> repeat) { - Set<Slot> aggUsedSlots = repeat.getOutputExpressions().stream() - .flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream()) - .flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream()) - .collect(ImmutableSet.toImmutableSet()); - Set<ExprId> groupingSetsUsedSlotExprIds = repeat.getGroupingSets().stream() - .flatMap(Collection::stream) - .flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream()) - .map(SlotReference::getExprId) - .collect(Collectors.toSet()); - for (Slot slot : aggUsedSlots) { - if (groupingSetsUsedSlotExprIds.contains(slot.getExprId())) { - throw new AnalysisException("column: " + slot.toSql() + " cannot both in select " - + "list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, " - + "please use union instead."); - } - } - } - private void checkGroupingSetsSize(LogicalRepeat<Plan> repeat) { Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf( ExpressionUtils.flatExpressions(repeat.getGroupingSets())); @@ -265,4 +247,78 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { return expr; } } + + /* + * compute slots that appear both in agg func and grouping sets, + * copy the slots and output in the project below the repeat as new copied slots, + * and refer the new copied slots in aggregate parameters. + * eg: original plan after normalizedRepeat + * LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, GROUPING_ID#1, sum(a#0) as `sum(a)`#2]) + * +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, GROUPING_ID#1] + * +--LogicalProject (projects =[a#0]) + * After: + * LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, GROUPING_ID#1, sum(a#3) as `sum(a)`#2]) + * +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, a#3, GROUPING_ID#1] + * +--LogicalProject (projects =[a#0, a#0 as `a`#3]) + */ + private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets( + @NotNull LogicalAggregate<Plan> aggregate) { + LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child(); + Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream() + .flatMap(e -> e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream()) + .flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream()) + .collect(ImmutableSet.toImmutableSet()); + Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream() + .flatMap(Collection::stream) + .flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream()) + .collect(Collectors.toSet()); + + Set<Slot> resSet = new HashSet<>(aggUsedSlots); + resSet.retainAll(groupingSetsUsedSlot); + if (resSet.isEmpty()) { + return aggregate; + } + Map<Slot, Alias> slotMapping = resSet.stream().collect( + Collectors.toMap(key -> key, Alias::new) + ); + Set<Alias> newAliases = new HashSet<>(slotMapping.values()); + List<Slot> newSlots = newAliases.stream() + .map(Alias::toSlot) + .collect(Collectors.toList()); + + // modify repeat child to a new project with more projections + List<Slot> originSlots = repeat.child().getOutput(); + ImmutableList<NamedExpression> immList = + ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build(); + LogicalProject<Plan> newProject = new LogicalProject<>(immList, repeat.child()); + repeat = repeat.withChildren(ImmutableList.of(newProject)); + + // modify repeat outputs + List<Slot> originRepeatSlots = repeat.getOutput(); + repeat = repeat.withAggOutput(ImmutableList + .<NamedExpression>builder() + .addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference)) + .collect(Collectors.toList())) + .addAll(newSlots) + .addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference)) + .collect(Collectors.toList())) + .build()); + aggregate = aggregate.withChildren(ImmutableList.of(repeat)); + + // modify aggregate functions' parameter slot reference to new copied slots + List<NamedExpression> newOutputExpressions = aggregate.getOutputExpressions().stream() + .map(output -> (NamedExpression) output.rewriteDownShortCircuit(expr -> { + if (expr instanceof AggregateFunction) { + return expr.rewriteDownShortCircuit(e -> { + if (e instanceof Slot && slotMapping.containsKey(e)) { + return slotMapping.get(e).toSlot(); + } + return e; + }); + } + return expr; + }) + ).collect(Collectors.toList()); + return aggregate.withAggOutput(newOutputExpressions); + } } diff --git a/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out b/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out index f2da1d2f673..67d76e45936 100644 --- a/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out +++ b/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out @@ -48,4 +48,30 @@ 2 10 1991 -- !select7 -- +\N \N 1002 +\N \N 2002 +\N \N 3004 +\N 1986 1001 +\N 1989 2003 +1 \N 1001 +1 1989 1001 +2 \N 1001 +2 1986 1001 +3 \N 1002 +3 1989 1002 + +-- !select8 -- +\N \N 0.9990029910269193 +\N \N 0.9995007488766849 +\N \N 0.9996672212978369 +\N 1986 0.999001996007984 +\N 1989 0.9995009980039921 +1 \N 0.999001996007984 +1 1989 0.999001996007984 +2 \N 0.999001996007984 +2 1986 0.999001996007984 +3 \N 0.9990029910269193 +3 1989 0.9990029910269193 + +-- !select9 -- diff --git a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out new file mode 100644 index 00000000000..901226f8548 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out @@ -0,0 +1,66 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select1 -- +\N +\N +-48 +-48 +-43 +-43 +-43 +-12 +82 +82 +89 +89 + +-- !select2 -- +\N +\N +-46 +-46 +-39 +-39 +-38 +-11 +91 +91 +97 +97 + +-- !select3 -- +\N +\N +\N +-47 +-47 +-47 +-42 +-42 +-42 +-42 +-11 +83 +83 +90 +90 +16055 +19197 + +-- !select4 -- +\N +a +how +j +say +yeah + +-- !select5 -- +1 +1 +1 +2 +3 +3 +4 +5 + diff --git a/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out b/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out index b3d3050ee77..052d4e1c35d 100644 --- a/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out +++ b/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out @@ -203,3 +203,8 @@ test 2 1989-03-21 \N 1001 0 1 1 2012-03-14 \N 1002 0 1 1 +-- !select24 -- +1 0 +2 0 +3 0 + diff --git a/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy b/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy index b5671a77a56..79a193c95e2 100644 --- a/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy +++ b/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy @@ -45,27 +45,15 @@ suite("test_grouping_sets") { group by grouping sets((k_if, k1),()) order by k_if, k1, k2_sum """ - test { - sql """ - SELECT k1, k2, SUM(k3) FROM nereids_test_query_db.test - GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2 + qt_select7 """ + SELECT k1, k2, SUM(k3) k3_ FROM nereids_test_query_db.test + GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2, k3_ """ - check{result, exception, startTime, endTime -> - assertTrue(exception != null) - logger.info(exception.message) - } - } - test { - sql """ - SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM nereids_test_query_db.test - GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2 + qt_select8 """ + SELECT k1, k2, SUM(k3)/(SUM(k3)+1) k3_ FROM nereids_test_query_db.test + GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2, k3_ """ - check{result, exception, startTime, endTime -> - assertTrue(exception != null) - logger.info(exception.message) - } - } - qt_select7 """ select k1,k2,sum(k3) from nereids_test_query_db.test where 1 = 2 group by grouping sets((k1), (k1,k2)) """ + qt_select9 """ select k1,k2,sum(k3) from nereids_test_query_db.test where 1 = 2 group by grouping sets((k1), (k1,k2)) """ } diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy new file mode 100644 index 00000000000..ac711cf5aab --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("slot_both_appear_in_agg_fun_and_grouping_sets") { + + sql """ + DROP TABLE IF EXISTS table_10_undef_undef4 + """ + + sql """ + create table table_10_undef_undef4 (`pk` int,`col_int_undef_signed` int , + `col_text_undef_signed` text ) engine=olap distributed by hash(pk) buckets 10 + properties( 'replication_num' = '1'); + """ + + sql """ + insert into table_10_undef_undef4 values (0,16054,null),(1,-12,null), + (2,-48,'j'),(3,null,null),(4,-43,"say"),(5,-43,null),(6,null,'a'),(7,19196,null), + (8,89,"how"),(9,82,"yeah"); + + """ + + qt_select1 """ + SELECT MIN(`col_int_undef_signed`) FROM table_10_undef_undef4 AS T1 GROUP BY + GROUPING SETS((`col_int_undef_signed`,`col_text_undef_signed`), (`col_text_undef_signed`), ()) + HAVING T1.`col_int_undef_signed` < 3 OR T1.col_text_undef_signed > '' order by 1; + """ + + qt_select2 """ + SELECT MIN(col_int_undef_signed+pk) FROM table_10_undef_undef4 AS T1 GROUP BY + GROUPING SETS((col_int_undef_signed,col_text_undef_signed), + (col_text_undef_signed), (pk),()) HAVING T1.col_int_undef_signed < 3 OR T1.col_text_undef_signed > '' order by 1; + """ + + qt_select3 """ + SELECT MIN(col_int_undef_signed+1) FROM table_10_undef_undef4 AS T1 GROUP BY + GROUPING SETS((col_int_undef_signed+1,col_text_undef_signed), (col_text_undef_signed), ()) order by 1; + """ + + qt_select4 """ + select group_concat(col_text_undef_signed,',' ) from table_10_undef_undef4 + group by grouping sets((col_text_undef_signed)) order by 1; + """ + + qt_select5 """ + select sum(rank() over (partition by col_text_undef_signed order by col_int_undef_signed)) + as col1 from table_10_undef_undef4 group by grouping sets((col_int_undef_signed)) order by 1; + """ +} diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy index 0845d705e86..8ca787fabfb 100644 --- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy +++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy @@ -138,22 +138,6 @@ suite("test_nereids_grouping_sets") { group by grouping sets((k_if, k1),()) order by k_if, k1, k2_sum """ - test { - sql """ - SELECT k1, k2, SUM(k3) FROM groupingSetsTable - GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2 - """ - exception "java.sql.SQLException: errCode = 2, detailMessage = column: k3 cannot both in select list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union instead." - } - - test { - sql """ - SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM groupingSetsTable - GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2 - """ - exception "java.sql.SQLException: errCode = 2, detailMessage = column: k3 cannot both in select list and aggregate functions when using GROUPING SETS/CUBE/ROLLUP, please use union instead." - } - order_qt_select """ select k1, sum(k2) from (select k1, k2, grouping(k1), grouping(k2) from groupingSetsTableNotNullable group by grouping sets((k1), (k2)))a group by k1 """ diff --git a/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy b/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy index 6564bca3509..c56ba366bbb 100644 --- a/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy +++ b/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy @@ -52,15 +52,6 @@ suite("test_grouping_sets", "p0") { exception "errCode = 2, detailMessage = column: `k3` cannot both in select list and aggregate functions" } - sql """set enable_nereids_planner=true;""" - sql """set enable_fallback_to_original_planner=false;""" - test { - sql """ - SELECT k1, k2, SUM(k3) FROM test_query_db.test - GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2 - """ - exception "errCode = 2, detailMessage = column: k3 cannot both in select list and aggregate functions" - } sql """set enable_nereids_planner=false;""" sql """set enable_fallback_to_original_planner=true;""" test { @@ -71,15 +62,6 @@ suite("test_grouping_sets", "p0") { exception "errCode = 2, detailMessage = column: `k3` cannot both in select list and aggregate functions" } - sql """set enable_nereids_planner=true;""" - sql """set enable_fallback_to_original_planner=false;""" - test { - sql """ - SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM test_query_db.test - GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order by k1, k2 - """ - exception "errCode = 2, detailMessage = column: k3 cannot both in select list and aggregate functions" - } sql """set enable_nereids_planner=false;""" sql """set enable_fallback_to_original_planner=true;""" @@ -269,9 +251,8 @@ suite("test_grouping_sets", "p0") { sql """set enable_nereids_planner=true;""" sql """set enable_fallback_to_original_planner=false;""" - test { - sql "select k1, if(grouping(k1)=1, count(k1), 0) from test_query_db.test group by grouping sets((k1))" - exception "k1 cannot both in select list and aggregate functions " + - "when using GROUPING SETS/CUBE/ROLLUP, please use union instead." - } + qt_select24 """ + select k1, if(grouping(k1)=1, count(k1), 0) from test_query_db.test group by grouping sets((k1)) + order by 1,2 + """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org