This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 561709451c5aa14e441261ef305448bcd9e090be
Author: 924060929 <924060...@qq.com>
AuthorDate: Wed Mar 6 21:54:35 2024 +0800

    [fix](Nereids) fix group_concat(distinct) failed (#31873)
---
 .../rules/implementation/AggregateStrategies.java  | 33 ++++++++--------
 .../functions/agg/AggregateFunction.java           |  3 ++
 .../expressions/functions/agg/GroupConcat.java     |  9 +++++
 .../nereids/trees/plans/algebra/Aggregate.java     |  2 +-
 .../data/nereids_syntax_p0/group_concat.out        |  8 ++++
 .../suites/nereids_syntax_p0/group_concat.groovy   | 45 +++++++++++++++++++++-
 6 files changed, 81 insertions(+), 19 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index 254e014240b..10b21d0b979 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -1108,7 +1108,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      * <p>
      *  single node aggregate:
      * <p>
-     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
+     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id))], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=INPUT_TO_BUFFER)
      *                                          |
@@ -1118,12 +1118,10 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      * <p>
      * distribute node aggregate:
      * <p>
-     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
+     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id))], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=INPUT_TO_BUFFER)
      *                                          |
-     *                 PhysicalDistribute(distributionSpec=HASH(name))
-     *                                          |
      *                LogicalOlapScan(table=tbl, **if distribute by name**)
      *
      */
@@ -1175,8 +1173,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     if (outputChild instanceof AggregateFunction) {
                         AggregateFunction aggregateFunction = 
(AggregateFunction) outputChild;
                         if (aggregateFunction.isDistinct()) {
-                            Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
-                            Preconditions.checkArgument(aggChild.size() == 1,
+                            Set<Expression> aggChild = 
Sets.newLinkedHashSet(aggregateFunction.children());
+                            Preconditions.checkArgument(aggChild.size() == 1
+                                            || 
aggregateFunction.getDistinctArguments().size() == 1,
                                     "cannot process more than one child in 
aggregate distinct function: "
                                             + aggregateFunction);
                             AggregateFunction nonDistinct = aggregateFunction
@@ -1236,7 +1235,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      * after:
      *  single node aggregate:
      * <p>
-     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
+     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id))], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=BUFFER_TO_BUFFER)
      *                                          |
@@ -1248,7 +1247,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      * <p>
      *  distribute node aggregate:
      * <p>
-     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
+     *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id))], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=BUFFER_TO_BUFFER)
      *                                          |
@@ -1331,14 +1330,14 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     if (expr instanceof AggregateFunction) {
                         AggregateFunction aggregateFunction = 
(AggregateFunction) expr;
                         if (aggregateFunction.isDistinct()) {
-                            Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
-                            Preconditions.checkArgument(aggChild.size() == 1,
+                            Set<Expression> aggChild = 
Sets.newLinkedHashSet(aggregateFunction.children());
+                            Preconditions.checkArgument(aggChild.size() == 1
+                                            || 
aggregateFunction.getDistinctArguments().size() == 1,
                                     "cannot process more than one child in 
aggregate distinct function: "
                                             + aggregateFunction);
                             AggregateFunction nonDistinct = aggregateFunction
                                     .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
-                            return new AggregateExpression(nonDistinct,
-                                    bufferToResultParam, 
aggregateFunction.child(0));
+                            return new AggregateExpression(nonDistinct, 
bufferToResultParam, aggregateFunction);
                         } else {
                             Alias alias = 
nonDistinctAggFunctionToAliasPhase2.get(expr);
                             return new AggregateExpression(aggregateFunction,
@@ -1727,8 +1726,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                         if (expr instanceof AggregateFunction) {
                             AggregateFunction aggregateFunction = 
(AggregateFunction) expr;
                             if (aggregateFunction.isDistinct()) {
-                                Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
-                                Preconditions.checkArgument(aggChild.size() == 
1,
+                                Set<Expression> aggChild = 
Sets.newLinkedHashSet(aggregateFunction.children());
+                                Preconditions.checkArgument(aggChild.size() == 
1
+                                                || 
aggregateFunction.getDistinctArguments().size() == 1,
                                         "cannot process more than one child in 
aggregate distinct function: "
                                                 + aggregateFunction);
                                 AggregateFunction nonDistinct = 
aggregateFunction
@@ -1767,8 +1767,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 if (expr instanceof AggregateFunction) {
                     AggregateFunction aggregateFunction = (AggregateFunction) 
expr;
                     if (aggregateFunction.isDistinct()) {
-                        Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
-                        Preconditions.checkArgument(aggChild.size() == 1,
+                        Set<Expression> aggChild = 
Sets.newLinkedHashSet(aggregateFunction.children());
+                        Preconditions.checkArgument(aggChild.size() == 1
+                                || 
aggregateFunction.getDistinctArguments().size() == 1,
                                 "cannot process more than one child in 
aggregate distinct function: "
                                         + aggregateFunction);
                         AggregateFunction nonDistinct = aggregateFunction
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
index a7e523dfdb5..4f53b383d24 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
@@ -124,4 +124,7 @@ public abstract class AggregateFunction extends 
BoundFunction implements Expects
         return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
     }
 
+    public List<Expression> getDistinctArguments() {
+        return distinct ? getArguments() : ImmutableList.of();
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
index 0f2e7bcb03a..d8b6646cff7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java
@@ -111,6 +111,15 @@ public class GroupConcat extends NullableAggregateFunction
                 .anyMatch(expression -> !(expression instanceof 
OrderExpression) && expression.nullable());
     }
 
+    @Override
+    public List<Expression> getDistinctArguments() {
+        if (distinct) {
+            return ImmutableList.of(getArgument(0));
+        } else {
+            return ImmutableList.of();
+        }
+    }
+
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType typeOrArg0 = getArgumentType(0);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
index 8361e230be7..15fd5bec868 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
@@ -56,7 +56,7 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends 
UnaryPlan<CHILD_TYPE
     default Set<Expression> getDistinctArguments() {
         return getAggregateFunctions().stream()
                 .filter(AggregateFunction::isDistinct)
-                .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
+                .flatMap(aggregateFunction -> 
aggregateFunction.getDistinctArguments().stream())
                 .collect(ImmutableSet.toImmutableSet());
     }
 }
diff --git a/regression-test/data/nereids_syntax_p0/group_concat.out 
b/regression-test/data/nereids_syntax_p0/group_concat.out
new file mode 100644
index 00000000000..6e3ab42329b
--- /dev/null
+++ b/regression-test/data/nereids_syntax_p0/group_concat.out
@@ -0,0 +1,8 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !group_by_distinct --
+1      \N
+2      a
+3      b
+4      c
+5      \N
+
diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy 
b/regression-test/suites/nereids_syntax_p0/group_concat.groovy
index 60f52c2ba06..b46091616ba 100644
--- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy
+++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy
@@ -48,6 +48,47 @@ suite("group_concat") {
         sql "select group_concat(cast(number as string), NULL) from 
numbers('number'='10')"
         result([[null]])
     }
-    
-    
+
+    def testGroupByDistinct = {
+        sql "drop table if exists test_group_concat_distinct_tbl1"
+        sql """create table test_group_concat_distinct_tbl1(
+                        tbl1_id1 int
+                    ) distributed by hash(tbl1_id1)
+                    properties('replication_num'='1')
+                    """
+
+        sql "insert into test_group_concat_distinct_tbl1 values(1), (2), (3), 
(4), (5)"
+
+
+        sql "drop table if exists test_group_concat_distinct_tbl2"
+        sql """create table test_group_concat_distinct_tbl2(
+                        tbl2_id1 int,
+                        tbl2_id2 int,
+                    ) distributed by hash(tbl2_id1)
+                    properties('replication_num'='1')
+                    """
+        sql "insert into test_group_concat_distinct_tbl2 values(1, 11), (2, 
22), (3, 33), (4, 44)"
+
+
+        sql "drop table if exists test_group_concat_distinct_tbl3"
+        sql """create table test_group_concat_distinct_tbl3(
+                        tbl3_id2 int,
+                        tbl3_name varchar(255)
+                    ) distributed by hash(tbl3_id2)
+                    properties('replication_num'='1')
+                    """
+        sql "insert into test_group_concat_distinct_tbl3 values(22, 'a'), (33, 
'b'), (44, 'c')"
+
+        sql "sync"
+
+        order_qt_group_by_distinct """
+            SELECT
+                 tbl1.tbl1_id1,
+                 group_concat(DISTINCT tbl3.tbl3_name, ',') AS `names`
+             FROM test_group_concat_distinct_tbl1 tbl1
+             LEFT OUTER JOIN test_group_concat_distinct_tbl2 tbl2 ON 
tbl2.tbl2_id1 = tbl1.tbl1_id1
+             LEFT OUTER JOIN test_group_concat_distinct_tbl3 tbl3 ON 
tbl3.tbl3_id2 = tbl2.tbl2_id2
+             GROUP BY tbl1.tbl1_id1
+           """
+    }()
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to