This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 561709451c5aa14e441261ef305448bcd9e090be Author: 924060929 <924060...@qq.com> AuthorDate: Wed Mar 6 21:54:35 2024 +0800 [fix](Nereids) fix group_concat(distinct) failed (#31873) --- .../rules/implementation/AggregateStrategies.java | 33 ++++++++-------- .../functions/agg/AggregateFunction.java | 3 ++ .../expressions/functions/agg/GroupConcat.java | 9 +++++ .../nereids/trees/plans/algebra/Aggregate.java | 2 +- .../data/nereids_syntax_p0/group_concat.out | 8 ++++ .../suites/nereids_syntax_p0/group_concat.groovy | 45 +++++++++++++++++++++- 6 files changed, 81 insertions(+), 19 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 254e014240b..10b21d0b979 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -1108,7 +1108,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { * <p> * single node aggregate: * <p> - * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) * | @@ -1118,12 +1118,10 @@ public class AggregateStrategies implements ImplementationRuleFactory { * <p> * distribute node aggregate: * <p> - * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) * | - * PhysicalDistribute(distributionSpec=HASH(name)) - * | * LogicalOlapScan(table=tbl, **if distribute by name**) * */ @@ -1175,8 +1173,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (outputChild instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) outputChild; if (aggregateFunction.isDistinct()) { - Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); - Preconditions.checkArgument(aggChild.size() == 1, + Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1 + || aggregateFunction.getDistinctArguments().size() == 1, "cannot process more than one child in aggregate distinct function: " + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction @@ -1236,7 +1235,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { * after: * single node aggregate: * <p> - * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) * | @@ -1248,7 +1247,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { * <p> * distribute node aggregate: * <p> - * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) + * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) * | @@ -1331,14 +1330,14 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); - Preconditions.checkArgument(aggChild.size() == 1, + Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1 + || aggregateFunction.getDistinctArguments().size() == 1, "cannot process more than one child in aggregate distinct function: " + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); - return new AggregateExpression(nonDistinct, - bufferToResultParam, aggregateFunction.child(0)); + return new AggregateExpression(nonDistinct, bufferToResultParam, aggregateFunction); } else { Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr); return new AggregateExpression(aggregateFunction, @@ -1727,8 +1726,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); - Preconditions.checkArgument(aggChild.size() == 1, + Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1 + || aggregateFunction.getDistinctArguments().size() == 1, "cannot process more than one child in aggregate distinct function: " + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction @@ -1767,8 +1767,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); - Preconditions.checkArgument(aggChild.size() == 1, + Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1 + || aggregateFunction.getDistinctArguments().size() == 1, "cannot process more than one child in aggregate distinct function: " + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index a7e523dfdb5..4f53b383d24 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -124,4 +124,7 @@ public abstract class AggregateFunction extends BoundFunction implements Expects return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")"; } + public List<Expression> getDistinctArguments() { + return distinct ? getArguments() : ImmutableList.of(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java index 0f2e7bcb03a..d8b6646cff7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupConcat.java @@ -111,6 +111,15 @@ public class GroupConcat extends NullableAggregateFunction .anyMatch(expression -> !(expression instanceof OrderExpression) && expression.nullable()); } + @Override + public List<Expression> getDistinctArguments() { + if (distinct) { + return ImmutableList.of(getArgument(0)); + } else { + return ImmutableList.of(); + } + } + @Override public void checkLegalityBeforeTypeCoercion() { DataType typeOrArg0 = getArgumentType(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 8361e230be7..15fd5bec868 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -56,7 +56,7 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE default Set<Expression> getDistinctArguments() { return getAggregateFunctions().stream() .filter(AggregateFunction::isDistinct) - .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .flatMap(aggregateFunction -> aggregateFunction.getDistinctArguments().stream()) .collect(ImmutableSet.toImmutableSet()); } } diff --git a/regression-test/data/nereids_syntax_p0/group_concat.out b/regression-test/data/nereids_syntax_p0/group_concat.out new file mode 100644 index 00000000000..6e3ab42329b --- /dev/null +++ b/regression-test/data/nereids_syntax_p0/group_concat.out @@ -0,0 +1,8 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !group_by_distinct -- +1 \N +2 a +3 b +4 c +5 \N + diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy b/regression-test/suites/nereids_syntax_p0/group_concat.groovy index 60f52c2ba06..b46091616ba 100644 --- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy +++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy @@ -48,6 +48,47 @@ suite("group_concat") { sql "select group_concat(cast(number as string), NULL) from numbers('number'='10')" result([[null]]) } - - + + def testGroupByDistinct = { + sql "drop table if exists test_group_concat_distinct_tbl1" + sql """create table test_group_concat_distinct_tbl1( + tbl1_id1 int + ) distributed by hash(tbl1_id1) + properties('replication_num'='1') + """ + + sql "insert into test_group_concat_distinct_tbl1 values(1), (2), (3), (4), (5)" + + + sql "drop table if exists test_group_concat_distinct_tbl2" + sql """create table test_group_concat_distinct_tbl2( + tbl2_id1 int, + tbl2_id2 int, + ) distributed by hash(tbl2_id1) + properties('replication_num'='1') + """ + sql "insert into test_group_concat_distinct_tbl2 values(1, 11), (2, 22), (3, 33), (4, 44)" + + + sql "drop table if exists test_group_concat_distinct_tbl3" + sql """create table test_group_concat_distinct_tbl3( + tbl3_id2 int, + tbl3_name varchar(255) + ) distributed by hash(tbl3_id2) + properties('replication_num'='1') + """ + sql "insert into test_group_concat_distinct_tbl3 values(22, 'a'), (33, 'b'), (44, 'c')" + + sql "sync" + + order_qt_group_by_distinct """ + SELECT + tbl1.tbl1_id1, + group_concat(DISTINCT tbl3.tbl3_name, ',') AS `names` + FROM test_group_concat_distinct_tbl1 tbl1 + LEFT OUTER JOIN test_group_concat_distinct_tbl2 tbl2 ON tbl2.tbl2_id1 = tbl1.tbl1_id1 + LEFT OUTER JOIN test_group_concat_distinct_tbl3 tbl3 ON tbl3.tbl3_id2 = tbl2.tbl2_id2 + GROUP BY tbl1.tbl1_id1 + """ + }() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org