This is an automated email from the ASF dual-hosted git repository. englefly pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 6cf39e82c04 [enhance](nereids) expand support for eliminate agg by unique (#48317) 6cf39e82c04 is described below commit 6cf39e82c04aaa359bc98a4b3f86ee770399a1f3 Author: feiniaofeiafei <moail...@selectdb.com> AuthorDate: Thu May 29 15:24:22 2025 +0800 [enhance](nereids) expand support for eliminate agg by unique (#48317) ### What problem does this PR solve? Issue Number: close #xxx Related PR: #xxx Problem Summary: This PR enhances the EliminateGroupBy rule to support eliminating aggregations when the group-by key is unique, extending its functionality to handle scenarios where the aggregate function's child is not a slot reference. Additionally, it adds support for the avg, sum0, percentile, stddev, stddev_sample, variance, variance_sample, min_by, max_by, avg_weighted function. e.g. select a,max(a+1) from t group by a; -> select a,a+1 from t; select a,count(1) from t group by a; -> select a,1 from t; select a,avg(b) from t group by a; -> select a,cast(b as double) from t; --- .../nereids/rules/rewrite/EliminateGroupBy.java | 145 +++++++++++++++------ .../nereids/trees/expressions/literal/Literal.java | 22 ++++ .../doris/nereids/properties/UniqueTest.java | 2 +- .../rules/rewrite/EliminateGroupByTest.java | 47 +++++++ .../eliminate_gby_key/eliminate_group_by.out | Bin 0 -> 3494 bytes .../eliminate_gby_key/eliminate_group_by.groovy | 54 ++++++++ 6 files changed, 228 insertions(+), 42 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java index 9325607dd70..4bf1c04791e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java @@ -25,13 +25,31 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.AvgWeighted; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.MaxBy; +import org.apache.doris.nereids.trees.expressions.functions.agg.Median; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.MinBy; +import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile; +import org.apache.doris.nereids.trees.expressions.functions.agg.Stddev; +import org.apache.doris.nereids.trees.expressions.functions.agg.StddevSamp; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0; +import org.apache.doris.nereids.trees.expressions.functions.agg.Variance; +import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -46,54 +64,99 @@ import java.util.List; * Eliminate GroupBy. */ public class EliminateGroupBy extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends Expression>> supportedBasicFunctions + = ImmutableSet.of(Sum.class, Avg.class, Min.class, Max.class, Median.class, AnyValue.class); + private static final ImmutableSet<Class<? extends Expression>> supportedTwoArgsFunctions + = ImmutableSet.of(MinBy.class, MaxBy.class, AvgWeighted.class, Percentile.class); + private static final ImmutableSet<Class<? extends Expression>> supportedDevLikeFunctions + = ImmutableSet.of(Stddev.class, StddevSamp.class, Variance.class, VarianceSamp.class); + private static final ImmutableSet<Class<? extends Expression>> supportedFunctionSum0 + = ImmutableSet.of(Sum0.class); + private static final ImmutableSet<Class<? extends Expression>> allFunctionsExceptCount + = ImmutableSet.<Class<? extends Expression>>builder() + .addAll(supportedBasicFunctions) + .addAll(supportedTwoArgsFunctions) + .addAll(supportedDevLikeFunctions) + .addAll(supportedFunctionSum0) + .build(); @Override public Rule build() { return logicalAggregate() .when(agg -> ExpressionUtils.allMatch(agg.getGroupByExpressions(), Slot.class::isInstance)) - .then(agg -> { - List<Expression> groupByExpressions = agg.getGroupByExpressions(); - Builder<Slot> groupBySlots - = ImmutableSet.builderWithExpectedSize(groupByExpressions.size()); - for (Expression groupByExpression : groupByExpressions) { - groupBySlots.add((Slot) groupByExpression); - } - Plan child = agg.child(); - boolean unique = child.getLogicalProperties() - .getTrait() - .isUniqueAndNotNull(groupBySlots.build()); - if (!unique) { - return null; - } - for (AggregateFunction f : agg.getAggregateFunctions()) { - if (!((f instanceof Sum || f instanceof Count || f instanceof Min || f instanceof Max) - && (f.arity() == 1 && f.child(0) instanceof Slot))) { - return null; - } - } - List<NamedExpression> outputExpressions = agg.getOutputExpressions(); + .then(this::rewrite).toRule(RuleType.ELIMINATE_GROUP_BY); + } - ImmutableList.Builder<NamedExpression> newOutput - = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + private Plan rewrite(LogicalAggregate<Plan> agg) { + List<Expression> groupByExpressions = agg.getGroupByExpressions(); + Builder<Slot> groupBySlots + = ImmutableSet.builderWithExpectedSize(groupByExpressions.size()); + for (Expression groupByExpression : groupByExpressions) { + groupBySlots.add((Slot) groupByExpression); + } + Plan child = agg.child(); + boolean unique = child.getLogicalProperties() + .getTrait() + .isUniqueAndNotNull(groupBySlots.build()); + if (!unique) { + return null; + } + for (AggregateFunction f : agg.getAggregateFunctions()) { + if (!canRewrite(f)) { + return null; + } + } + List<NamedExpression> outputExpressions = agg.getOutputExpressions(); - for (NamedExpression ne : outputExpressions) { - if (ne instanceof Alias && ne.child(0) instanceof AggregateFunction) { - AggregateFunction f = (AggregateFunction) ne.child(0); - if (f instanceof Sum || f instanceof Min || f instanceof Max) { - newOutput.add(new Alias(ne.getExprId(), TypeCoercionUtils - .castIfNotSameType(f.child(0), f.getDataType()), ne.getName())); - } else if (f instanceof Count) { - newOutput.add((NamedExpression) ne.withChildren( - new If(new IsNull(f.child(0)), new BigIntLiteral(0), - new BigIntLiteral(1)))); - } else { - throw new IllegalStateException("Unexpected aggregate function: " + f); - } - } else { - newOutput.add(ne); - } + ImmutableList.Builder<NamedExpression> newOutput + = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + + for (NamedExpression ne : outputExpressions) { + if (ne instanceof Alias && ne.child(0) instanceof AggregateFunction) { + AggregateFunction f = (AggregateFunction) ne.child(0); + if (supportedBasicFunctions.contains(f.getClass())) { + newOutput.add(new Alias(ne.getExprId(), TypeCoercionUtils + .castIfNotSameType(f.child(0), f.getDataType()), ne.getName())); + } else if (f instanceof Count) { + if (((Count) f).isStar()) { + newOutput.add((NamedExpression) ne.withChildren(TypeCoercionUtils + .castIfNotSameType(new BigIntLiteral(1), f.getDataType()))); + } else { + newOutput.add((NamedExpression) ne.withChildren( + new If(new IsNull(f.child(0)), new BigIntLiteral(0), + new BigIntLiteral(1)))); } - return PlanUtils.projectOrSelf(newOutput.build(), child); - }).toRule(RuleType.ELIMINATE_GROUP_BY); + } else if (f instanceof Sum0) { + Coalesce coalesce = new Coalesce(f.child(0), + Literal.convertToTypedLiteral(0, f.child(0).getDataType())); + newOutput.add((NamedExpression) ne.withChildren( + TypeCoercionUtils.castIfNotSameType(coalesce, f.getDataType()))); + } else if (supportedTwoArgsFunctions.contains(f.getClass())) { + If ifFunc = new If(new IsNull(f.child(1)), new NullLiteral(f.child(0).getDataType()), + f.child(0)); + newOutput.add((NamedExpression) ne.withChildren( + TypeCoercionUtils.castIfNotSameType(ifFunc, f.getDataType()))); + } else if (supportedDevLikeFunctions.contains(f.getClass())) { + If ifFunc = new If(new IsNull(f.child(0)), new NullLiteral(DoubleType.INSTANCE), + new DoubleLiteral(0)); + newOutput.add((NamedExpression) ne.withChildren(ifFunc)); + } else { + return null; + } + } else { + newOutput.add(ne); + } + } + return PlanUtils.projectOrSelf(newOutput.build(), child); + } + + private boolean canRewrite(AggregateFunction f) { + if (allFunctionsExceptCount.contains(f.getClass())) { + return true; + } + if (f instanceof Count) { + return ((Count) f).isStar() || 1 == f.arity(); + } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java index 22e2f2ce60b..784f3fba0f3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.shape.LeafExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.CharType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; @@ -37,8 +38,12 @@ import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DecimalV2Type; import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.types.coercion.IntegralType; @@ -672,4 +677,21 @@ public abstract class Literal extends Expression implements LeafExpression { // different environment return new VarcharLiteral(new String(bytes, StandardCharsets.UTF_8)); } + + /**convertToTypedLiteral*/ + public static Literal convertToTypedLiteral(Object value, DataType dataType) { + Number number = (Number) value; + if (dataType.equals(TinyIntType.INSTANCE)) { + return new TinyIntLiteral(number.byteValue()); + } else if (dataType.equals(SmallIntType.INSTANCE)) { + return new SmallIntLiteral(number.shortValue()); + } else if (dataType.equals(IntegerType.INSTANCE)) { + return new IntegerLiteral(number.intValue()); + } else if (dataType.equals(BigIntType.INSTANCE)) { + return new BigIntLiteral(number.longValue()); + } else if (dataType.equals(DoubleType.INSTANCE)) { + return new DoubleLiteral(number.doubleValue()); + } + return null; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java index 391bc82021f..fbcf0c8028f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java @@ -69,7 +69,7 @@ class UniqueTest extends TestWithFeService { Assertions.assertTrue(plan.getLogicalProperties().getTrait() .isUnique(plan.getOutput().get(0))); plan = PlanChecker.from(connectContext) - .analyze("select id, sum(id), avg(id), max(id), min(id) from agg group by id") + .analyze("select id, sum(id), avg(id), max(id), min(id), topn(id,2) from agg group by id") .rewrite() .getPlan(); Assertions.assertTrue(plan.getLogicalProperties().getTrait() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java index f2a9e480f32..6d27469c469 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java @@ -98,4 +98,51 @@ class EliminateGroupByTest extends TestWithFeService implements MemoPatternMatch ) ); } + + @Test + void eliminateAvg() { + String sql = "select id, avg(age) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalEmptyRelation().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("cast(age as DOUBLE) AS `avg(age)`") + && p.getProjects().get(1).getDataType().isDoubleType() + ) + ); + } + + @Test + void eliminateCountStar() { + String sql = "select id, count(*) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalEmptyRelation().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("1 AS `count(*)`") + && p.getProjects().get(1).getDataType().isBigIntType() + ) + ); + } + + @Test + void eliminateExpr() { + String sql = "select id, avg(age+1), min(abs(age)) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalEmptyRelation().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("cast((age + 1) as DOUBLE) AS `avg(age+1)`") + && p.getProjects().get(1).getDataType().isDoubleType() + && p.getProjects().get(2).toSql().equals("abs(age) AS `min(abs(age))`") + && p.getProjects().get(2).getDataType().isBigIntType() + ) + ); + } } diff --git a/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.out b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.out new file mode 100644 index 00000000000..ab48af56fae Binary files /dev/null and b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.out differ diff --git a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.groovy b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.groovy new file mode 100644 index 00000000000..97a858a5c2e --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.groovy @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("eliminate_group_by") { +// sql "set disable_nereids_rules='ELIMINATE_GROUP_BY'" + sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalQuickSort'" + sql "drop table if exists test_unique2;" + sql """create table test_unique2(a int not null, b int) unique key(a) distributed by hash(a) properties("replication_num"="1");""" + sql "insert into test_unique2 values(1,2),(2,2),(3,4),(4,4),(5,null);" + qt_count_star "select a,count(*) from test_unique2 group by a order by 1,2;" + qt_count_1 "select a,count(1) from test_unique2 group by a order by 1,2;" + qt_avg "select a,avg(b) from test_unique2 group by a order by 1,2;" + qt_expr "select a,max(a+1),avg(abs(a+100)),sum(a+b) from test_unique2 group by a order by 1,2,3,4;" + qt_window "select a,avg(sum(b) over(partition by b order by a)) from test_unique2 group by a order by 1,2" + qt_two_args_func_min_by "select min_by(3,b),min_by(b,2),min_by(b,a), min_by(null,a), min_by(b, null), min_by(null, null) from test_unique2 group by a order by 1,2,3,4,5,6 " + qt_two_args_func_max_by "select max_by(3,a),max_by(a,2),max_by(b,a), max_by(null,a), max_by(b, null), max_by(null, null) from test_unique2 group by a order by 1,2,3,4,5,6 " + qt_two_args_func_avg_weighted "select avg_weighted(b,2), avg_weighted(b,a), avg_weighted(null,a), avg_weighted(b, null), avg_weighted(null, null) from test_unique2 group by a order by 1,2,3,4,5 " + qt_two_args_func_percentile "select percentile(b, null), percentile(null, null),percentile(b, 0.3) from test_unique2 group by a order by 1,2,3" + qt_stddev "select a,stddev(b),stddev(null) from test_unique2 group by a order by 1,2,3;" + qt_stddev_samp "select a,stddev_samp(b),stddev_samp(null) from test_unique2 group by a order by 1,2,3;" + qt_variance "select a,variance(b),variance(null) from test_unique2 group by a order by 1,2,3;" + qt_variance_samp "select a,variance_samp(b),variance_samp(null) from test_unique2 group by a order by 1,2,3;" + qt_sum0 "select a,sum0(b),sum0(null) from test_unique2 group by a order by 1,2,3;" + qt_median "select a,median(b),any_value(b),percentile(a,0.1),percentile(b,0.9),percentile(b,0.4) from test_unique2 group by a order by 1,2,3,4,5,6;" + + qt_count_star_shape "explain shape plan select a,count(*) from test_unique2 group by a order by 1,2;" + qt_count_1_shape "explain shape plan select a,count(1) from test_unique2 group by a order by 1,2;" + qt_avg_shape "explain shape plan select a,avg(b) from test_unique2 group by a order by 1,2;" + qt_expr_shape "explain shape plan select a,max(a+1),avg(abs(a+100)),sum(a+b) from test_unique2 group by a order by 1,2,3,4;" + qt_window_shape "explain shape plan select a,avg(sum(b) over(partition by b order by a)) from test_unique2 group by a order by 1,2" + qt_two_args_func_min_by_shape "explain shape plan select min_by(3,b),min_by(b,2),min_by(b,a), min_by(null,a), min_by(b, null), min_by(null, null) from test_unique2 group by a order by 1,2,3,4,5,6 " + qt_two_args_func_max_by_shape "explain shape plan select max_by(3,a),max_by(a,2),max_by(b,a), max_by(null,a), max_by(b, null), max_by(null, null) from test_unique2 group by a order by 1,2,3,4,5,6 " + qt_two_args_func_avg_weighted_shape "explain shape plan select avg_weighted(b,2), avg_weighted(b,a), avg_weighted(null,a), avg_weighted(b, null), avg_weighted(null, null) from test_unique2 group by a order by 1,2,3,4,5 " + qt_two_args_func_percentile_shape "explain shape plan select percentile(b, null), percentile(null, null),percentile(b, 0.3) from test_unique2 group by a order by 1,2,3" + qt_stddev_shape "explain shape plan select a,stddev(b),stddev(null) from test_unique2 group by a order by 1,2,3;" + qt_stddev_samp_shape "explain shape plan select a,stddev_samp(b),stddev_samp(null) from test_unique2 group by a order by 1,2,3;" + qt_variance_shape "explain shape plan select a,variance(b),variance(null) from test_unique2 group by a order by 1,2,3;" + qt_variance_samp_shape "explain shape plan select a,variance_samp(b),variance_samp(null) from test_unique2 group by a order by 1,2,3;" + qt_sum0_shape "explain shape plan select a,sum0(b),sum0(null) from test_unique2 group by a order by 1,2,3;" + qt_median_shape "explain shape plan select a,median(b),any_value(b),percentile(a,0.1),percentile(b,0.9),percentile(b,0.4) from test_unique2 group by a order by 1,2,3,4,5,6;" +} \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org