This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new ebf474d9d89 [feature](nereids) deal the slots that appear both in agg 
func and grouping sets (#31318)
ebf474d9d89 is described below

commit ebf474d9d89cbca6728076ec1a27afbc0e51908f
Author: feiniaofeiafei <53502832+feiniaofeia...@users.noreply.github.com>
AuthorDate: Mon Feb 26 19:59:51 2024 +0800

    [feature](nereids) deal the slots that appear both in agg func and grouping 
sets (#31318)
    
    this PR support slot appearing both in agg func and grouping sets.
    sql like below:
    select sum(a) from t group by grouping sets ((a));
    
    Before this PR, Nereids throw exception like below:
    col_int_undef_signed cannot both in select list and aggregate functions 
when using GROUPING SETS/CUBE/ROLLUP, please use union instead.
    
    This PR removes the restriction and supports this situation.
---
 .../nereids/rules/analysis/NormalizeRepeat.java    | 100 ++++++++++++++++-----
 .../grouping_sets/test_grouping_sets.out           |  26 ++++++
 ...ot_both_appear_in_agg_fun_and_grouping_sets.out |  66 ++++++++++++++
 .../query_p0/grouping_sets/test_grouping_sets.out  |   5 ++
 .../grouping_sets/test_grouping_sets.groovy        |  26 ++----
 ...both_appear_in_agg_fun_and_grouping_sets.groovy |  62 +++++++++++++
 .../suites/nereids_syntax_p0/grouping_sets.groovy  |  16 ----
 .../grouping_sets/test_grouping_sets.groovy        |  27 +-----
 8 files changed, 248 insertions(+), 80 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
index 005cc663862..9326ee725ff 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
@@ -23,7 +23,6 @@ import org.apache.doris.nereids.rules.RuleType;
 import 
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
 import 
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotTriplet;
 import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.OrderExpression;
@@ -44,8 +43,10 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.collect.Sets.SetView;
+import org.jetbrains.annotations.NotNull;
 
 import java.util.Collection;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -80,35 +81,16 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory 
{
             
logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> {
                 checkRepeatLegality(repeat);
                 // add virtual slot, LogicalAggregate and LogicalProject for 
normalize
-                return normalizeRepeat(repeat);
+                LogicalAggregate<Plan> agg = normalizeRepeat(repeat);
+                return dealSlotAppearBothInAggFuncAndGroupingSets(agg);
             })
         );
     }
 
     private void checkRepeatLegality(LogicalRepeat<Plan> repeat) {
-        checkIfAggFuncSlotInGroupingSets(repeat);
         checkGroupingSetsSize(repeat);
     }
 
-    private void checkIfAggFuncSlotInGroupingSets(LogicalRepeat<Plan> repeat) {
-        Set<Slot> aggUsedSlots = repeat.getOutputExpressions().stream()
-                .flatMap(e -> 
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
-                .flatMap(e -> 
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
-                .collect(ImmutableSet.toImmutableSet());
-        Set<ExprId> groupingSetsUsedSlotExprIds = 
repeat.getGroupingSets().stream()
-                .flatMap(Collection::stream)
-                .flatMap(e -> 
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
-                .map(SlotReference::getExprId)
-                .collect(Collectors.toSet());
-        for (Slot slot : aggUsedSlots) {
-            if (groupingSetsUsedSlotExprIds.contains(slot.getExprId())) {
-                throw new AnalysisException("column: " + slot.toSql() + " 
cannot both in select "
-                        + "list and aggregate functions when using GROUPING 
SETS/CUBE/ROLLUP, "
-                        + "please use union instead.");
-            }
-        }
-    }
-
     private void checkGroupingSetsSize(LogicalRepeat<Plan> repeat) {
         Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
                 ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
@@ -265,4 +247,78 @@ public class NormalizeRepeat extends 
OneAnalysisRuleFactory {
             return expr;
         }
     }
+
+    /*
+     * compute slots that appear both in agg func and grouping sets,
+     * copy the slots and output in the project below the repeat as new copied 
slots,
+     * and refer the new copied slots in aggregate parameters.
+     * eg: original plan after normalizedRepeat
+     * LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, 
GROUPING_ID#1, sum(a#0) as `sum(a)`#2])
+     *   +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, 
GROUPING_ID#1]
+     *      +--LogicalProject (projects =[a#0])
+     * After:
+     * LogicalAggregate (groupByExpr=[a#0, GROUPING_ID#1], outputExpr=[a#0, 
GROUPING_ID#1, sum(a#3) as `sum(a)`#2])
+     *   +--LogicalRepeat (groupingSets=[[a#0]], outputExpr=[a#0, a#3, 
GROUPING_ID#1]
+     *      +--LogicalProject (projects =[a#0, a#0 as `a`#3])
+     */
+    private LogicalAggregate<Plan> dealSlotAppearBothInAggFuncAndGroupingSets(
+            @NotNull LogicalAggregate<Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        Set<Slot> aggUsedSlots = aggregate.getOutputExpressions().stream()
+                .flatMap(e -> 
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+                .flatMap(e -> 
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
+                .collect(ImmutableSet.toImmutableSet());
+        Set<Slot> groupingSetsUsedSlot = repeat.getGroupingSets().stream()
+                .flatMap(Collection::stream)
+                .flatMap(e -> 
e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
+                .collect(Collectors.toSet());
+
+        Set<Slot> resSet = new HashSet<>(aggUsedSlots);
+        resSet.retainAll(groupingSetsUsedSlot);
+        if (resSet.isEmpty()) {
+            return aggregate;
+        }
+        Map<Slot, Alias> slotMapping = resSet.stream().collect(
+                Collectors.toMap(key -> key, Alias::new)
+        );
+        Set<Alias> newAliases = new HashSet<>(slotMapping.values());
+        List<Slot> newSlots = newAliases.stream()
+                .map(Alias::toSlot)
+                .collect(Collectors.toList());
+
+        // modify repeat child to a new project with more projections
+        List<Slot> originSlots = repeat.child().getOutput();
+        ImmutableList<NamedExpression> immList =
+                
ImmutableList.<NamedExpression>builder().addAll(originSlots).addAll(newAliases).build();
+        LogicalProject<Plan> newProject = new LogicalProject<>(immList, 
repeat.child());
+        repeat = repeat.withChildren(ImmutableList.of(newProject));
+
+        // modify repeat outputs
+        List<Slot> originRepeatSlots = repeat.getOutput();
+        repeat = repeat.withAggOutput(ImmutableList
+                .<NamedExpression>builder()
+                .addAll(originRepeatSlots.stream().filter(slot -> ! (slot 
instanceof VirtualSlotReference))
+                        .collect(Collectors.toList()))
+                .addAll(newSlots)
+                .addAll(originRepeatSlots.stream().filter(slot -> (slot 
instanceof VirtualSlotReference))
+                        .collect(Collectors.toList()))
+                .build());
+        aggregate = aggregate.withChildren(ImmutableList.of(repeat));
+
+        // modify aggregate functions' parameter slot reference to new copied 
slots
+        List<NamedExpression> newOutputExpressions = 
aggregate.getOutputExpressions().stream()
+                .map(output -> (NamedExpression) 
output.rewriteDownShortCircuit(expr -> {
+                    if (expr instanceof AggregateFunction) {
+                        return expr.rewriteDownShortCircuit(e -> {
+                            if (e instanceof Slot && 
slotMapping.containsKey(e)) {
+                                return slotMapping.get(e).toSlot();
+                            }
+                            return e;
+                        });
+                    }
+                    return expr;
+                })
+                ).collect(Collectors.toList());
+        return aggregate.withAggOutput(newOutputExpressions);
+    }
 }
diff --git 
a/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out 
b/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
index f2da1d2f673..67d76e45936 100644
--- a/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
+++ b/regression-test/data/nereids_p0/grouping_sets/test_grouping_sets.out
@@ -48,4 +48,30 @@
 2      10      1991
 
 -- !select7 --
+\N     \N      1002
+\N     \N      2002
+\N     \N      3004
+\N     1986    1001
+\N     1989    2003
+1      \N      1001
+1      1989    1001
+2      \N      1001
+2      1986    1001
+3      \N      1002
+3      1989    1002
+
+-- !select8 --
+\N     \N      0.9990029910269193
+\N     \N      0.9995007488766849
+\N     \N      0.9996672212978369
+\N     1986    0.999001996007984
+\N     1989    0.9995009980039921
+1      \N      0.999001996007984
+1      1989    0.999001996007984
+2      \N      0.999001996007984
+2      1986    0.999001996007984
+3      \N      0.9990029910269193
+3      1989    0.9990029910269193
+
+-- !select9 --
 
diff --git 
a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
 
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
new file mode 100644
index 00000000000..901226f8548
--- /dev/null
+++ 
b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out
@@ -0,0 +1,66 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select1 --
+\N
+\N
+-48
+-48
+-43
+-43
+-43
+-12
+82
+82
+89
+89
+
+-- !select2 --
+\N
+\N
+-46
+-46
+-39
+-39
+-38
+-11
+91
+91
+97
+97
+
+-- !select3 --
+\N
+\N
+\N
+-47
+-47
+-47
+-42
+-42
+-42
+-42
+-11
+83
+83
+90
+90
+16055
+19197
+
+-- !select4 --
+\N
+a
+how
+j
+say
+yeah
+
+-- !select5 --
+1
+1
+1
+2
+3
+3
+4
+5
+
diff --git a/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out 
b/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
index b3d3050ee77..052d4e1c35d 100644
--- a/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
+++ b/regression-test/data/query_p0/grouping_sets/test_grouping_sets.out
@@ -203,3 +203,8 @@ test        2
 1989-03-21     \N      1001    0       1       1
 2012-03-14     \N      1002    0       1       1
 
+-- !select24 --
+1      0
+2      0
+3      0
+
diff --git 
a/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy 
b/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
index b5671a77a56..79a193c95e2 100644
--- a/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
+++ b/regression-test/suites/nereids_p0/grouping_sets/test_grouping_sets.groovy
@@ -45,27 +45,15 @@ suite("test_grouping_sets") {
                  group by grouping sets((k_if, k1),()) order by k_if, k1, 
k2_sum
                """
 
-    test {
-        sql """
-              SELECT k1, k2, SUM(k3) FROM nereids_test_query_db.test
-              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2
+    qt_select7 """
+              SELECT k1, k2, SUM(k3) k3_ FROM nereids_test_query_db.test
+              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2, k3_
             """
-            check{result, exception, startTime, endTime ->
-                assertTrue(exception != null)
-                logger.info(exception.message)
-            }
-    }
 
-    test {
-        sql """
-              SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM 
nereids_test_query_db.test
-              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2
+    qt_select8 """
+              SELECT k1, k2, SUM(k3)/(SUM(k3)+1) k3_ FROM 
nereids_test_query_db.test
+              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2, k3_
             """
-            check{result, exception, startTime, endTime ->
-                assertTrue(exception != null)
-                logger.info(exception.message)
-            }
-    }
 
-   qt_select7 """ select k1,k2,sum(k3) from nereids_test_query_db.test where 1 
= 2 group by grouping sets((k1), (k1,k2)) """ 
+    qt_select9 """ select k1,k2,sum(k3) from nereids_test_query_db.test where 
1 = 2 group by grouping sets((k1), (k1,k2)) """
 }
diff --git 
a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
 
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
new file mode 100644
index 00000000000..ac711cf5aab
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+suite("slot_both_appear_in_agg_fun_and_grouping_sets") {
+
+    sql """
+         DROP TABLE IF EXISTS table_10_undef_undef4
+        """
+
+    sql """
+        create table table_10_undef_undef4 (`pk` int,`col_int_undef_signed` 
int  ,
+        `col_text_undef_signed` text  ) engine=olap distributed by hash(pk) 
buckets 10
+         properties(    'replication_num' = '1'); 
+         """
+
+    sql """
+     insert into table_10_undef_undef4 values (0,16054,null),(1,-12,null),
+     
(2,-48,'j'),(3,null,null),(4,-43,"say"),(5,-43,null),(6,null,'a'),(7,19196,null),
+     (8,89,"how"),(9,82,"yeah"); 
+
+     """
+
+    qt_select1 """
+         SELECT MIN(`col_int_undef_signed`) FROM table_10_undef_undef4 AS T1 
GROUP BY 
+         GROUPING SETS((`col_int_undef_signed`,`col_text_undef_signed`), 
(`col_text_undef_signed`), ())
+          HAVING T1.`col_int_undef_signed` < 3 OR T1.col_text_undef_signed > 
'' order by 1; 
+     """
+
+    qt_select2 """
+         SELECT MIN(col_int_undef_signed+pk) FROM table_10_undef_undef4 AS T1 
GROUP BY 
+         GROUPING SETS((col_int_undef_signed,col_text_undef_signed), 
+         (col_text_undef_signed), (pk),()) HAVING T1.col_int_undef_signed < 3 
OR T1.col_text_undef_signed > '' order by 1; 
+     """
+
+    qt_select3 """
+          SELECT MIN(col_int_undef_signed+1) FROM table_10_undef_undef4 AS T1 
GROUP BY 
+          GROUPING SETS((col_int_undef_signed+1,col_text_undef_signed), 
(col_text_undef_signed), ())  order by 1;
+     """
+
+    qt_select4 """
+          select group_concat(col_text_undef_signed,',' ) from 
table_10_undef_undef4 
+          group by grouping sets((col_text_undef_signed)) order by 1;
+    """
+
+    qt_select5 """
+          select sum(rank() over (partition by col_text_undef_signed order by 
col_int_undef_signed)) 
+          as col1 from table_10_undef_undef4 group by grouping 
sets((col_int_undef_signed)) order by 1;
+    """
+}
diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy 
b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
index 0845d705e86..8ca787fabfb 100644
--- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
+++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy
@@ -138,22 +138,6 @@ suite("test_nereids_grouping_sets") {
                  group by grouping sets((k_if, k1),()) order by k_if, k1, 
k2_sum
                """
 
-    test {
-        sql """
-              SELECT k1, k2, SUM(k3) FROM groupingSetsTable
-              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2
-            """
-        exception "java.sql.SQLException: errCode = 2, detailMessage = column: 
k3 cannot both in select list and aggregate functions when using GROUPING 
SETS/CUBE/ROLLUP, please use union instead."
-    }
-
-    test {
-        sql """
-              SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM groupingSetsTable
-              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2
-            """
-        exception "java.sql.SQLException: errCode = 2, detailMessage = column: 
k3 cannot both in select list and aggregate functions when using GROUPING 
SETS/CUBE/ROLLUP, please use union instead."
-    }
-
     order_qt_select """
         select k1, sum(k2) from (select k1, k2, grouping(k1), grouping(k2) 
from groupingSetsTableNotNullable group by grouping sets((k1), (k2)))a group by 
k1
     """
diff --git 
a/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy 
b/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
index 6564bca3509..c56ba366bbb 100644
--- a/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
+++ b/regression-test/suites/query_p0/grouping_sets/test_grouping_sets.groovy
@@ -52,15 +52,6 @@ suite("test_grouping_sets", "p0") {
         exception "errCode = 2, detailMessage = column: `k3` cannot both in 
select list and aggregate functions"
     }
 
-    sql """set enable_nereids_planner=true;"""
-    sql """set enable_fallback_to_original_planner=false;"""
-    test {
-        sql """
-              SELECT k1, k2, SUM(k3) FROM test_query_db.test
-              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2
-            """
-        exception "errCode = 2, detailMessage = column: k3 cannot both in 
select list and aggregate functions"
-    }
     sql """set enable_nereids_planner=false;"""
     sql """set enable_fallback_to_original_planner=true;"""
     test {
@@ -71,15 +62,6 @@ suite("test_grouping_sets", "p0") {
         exception "errCode = 2, detailMessage = column: `k3` cannot both in 
select list and aggregate functions"
     }
 
-    sql """set enable_nereids_planner=true;"""
-    sql """set enable_fallback_to_original_planner=false;"""
-    test {
-        sql """
-              SELECT k1, k2, SUM(k3)/(SUM(k3)+1) FROM test_query_db.test
-              GROUP BY GROUPING SETS ((k1, k2), (k1), (k2), ( ), (k3) ) order 
by k1, k2
-            """
-        exception "errCode = 2, detailMessage = column: k3 cannot both in 
select list and aggregate functions"
-    }
     sql """set enable_nereids_planner=false;"""
     sql """set enable_fallback_to_original_planner=true;"""
 
@@ -269,9 +251,8 @@ suite("test_grouping_sets", "p0") {
     sql """set enable_nereids_planner=true;"""
     sql """set enable_fallback_to_original_planner=false;"""
 
-    test {
-        sql "select k1, if(grouping(k1)=1, count(k1), 0) from 
test_query_db.test group by grouping sets((k1))"
-        exception "k1 cannot both in select list and aggregate functions " +
-                "when using GROUPING SETS/CUBE/ROLLUP, please use union 
instead."
-    }
+    qt_select24 """
+        select k1, if(grouping(k1)=1, count(k1), 0) from test_query_db.test 
group by grouping sets((k1))
+        order by 1,2
+        """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to