This is an automated email from the ASF dual-hosted git repository. starocean999 pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push: new a21bb804563 [fix](nereids)modify agg function nullability in PhysicalHashAggregate (#42054) a21bb804563 is described below commit a21bb8045638baddf2ee2ddc44707aac73c9b01a Author: starocean999 <40539150+starocean...@users.noreply.github.com> AuthorDate: Tue Oct 22 11:06:59 2024 +0800 [fix](nereids)modify agg function nullability in PhysicalHashAggregate (#42054) ## Proposed changes pick from master https://github.com/apache/doris/pull/41943 <!--Describe your changes.--> --- .../plans/physical/PhysicalHashAggregate.java | 43 ++++++++++++++++-- .../rules/rewrite/AggregateStrategiesTest.java | 51 +++++++++++++++++++++- 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java index 17ad6516fa7..5708a377273 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java @@ -22,16 +22,20 @@ import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.properties.RequireProperties; import org.apache.doris.nereids.properties.RequirePropertiesSupplier; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import org.apache.doris.statistics.Statistics; @@ -91,8 +95,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar super(PlanType.PHYSICAL_HASH_AGGREGATE, groupExpression, logicalProperties, child); this.groupByExpressions = ImmutableList.copyOf( Objects.requireNonNull(groupByExpressions, "groupByExpressions cannot be null")); - this.outputExpressions = ImmutableList.copyOf( - Objects.requireNonNull(outputExpressions, "outputExpressions cannot be null")); + this.outputExpressions = adjustNullableForOutputs( + Objects.requireNonNull(outputExpressions, "outputExpressions cannot be null"), + groupByExpressions.isEmpty()); this.partitionExpressions = Objects.requireNonNull( partitionExpressions, "partitionExpressions cannot be null"); this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregate param cannot be null"); @@ -118,8 +123,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar child); this.groupByExpressions = ImmutableList.copyOf( Objects.requireNonNull(groupByExpressions, "groupByExpressions cannot be null")); - this.outputExpressions = ImmutableList.copyOf( - Objects.requireNonNull(outputExpressions, "outputExpressions cannot be null")); + this.outputExpressions = adjustNullableForOutputs( + Objects.requireNonNull(outputExpressions, "outputExpressions cannot be null"), + groupByExpressions.isEmpty()); this.partitionExpressions = Objects.requireNonNull( partitionExpressions, "partitionExpressions cannot be null"); this.aggregateParam = Objects.requireNonNull(aggregateParam, "aggregate param cannot be null"); @@ -299,4 +305,33 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar requireProperties, physicalProperties, statistics, child()); } + + /** + * sql: select sum(distinct c1) from t; + * assume c1 is not null, because there is no group by + * sum(distinct c1)'s nullable is alwasNullable in rewritten phase. + * But in implementation phase, we may create 3 phase agg with group by key c1. + * And the sum(distinct c1)'s nullability should be changed depending on if there is any group by expressions. + * This pr update the agg function's nullability accordingly + */ + private List<NamedExpression> adjustNullableForOutputs(List<NamedExpression> outputs, boolean alwaysNullable) { + return ExpressionUtils.rewriteDownShortCircuit(outputs, output -> { + if (output instanceof AggregateExpression) { + AggregateFunction function = ((AggregateExpression) output).getFunction(); + if (function instanceof NullableAggregateFunction + && ((NullableAggregateFunction) function).isAlwaysNullable() != alwaysNullable) { + AggregateParam param = ((AggregateExpression) output).getAggregateParam(); + Expression child = ((AggregateExpression) output).child(); + AggregateFunction newFunction = ((NullableAggregateFunction) function) + .withAlwaysNullable(alwaysNullable); + if (function == child) { + // function is also child + child = newFunction; + } + return new AggregateExpression(newFunction, param, child); + } + } + return output; + }); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java index 34c16309181..e1e03e64d98 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java @@ -29,8 +29,10 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.AggMode; @@ -54,6 +56,7 @@ import org.junit.jupiter.api.TestInstance; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.Set; @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class AggregateStrategiesTest implements MemoPatternMatchSupported { @@ -138,7 +141,7 @@ public class AggregateStrategiesTest implements MemoPatternMatchSupported { Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, true, Optional.empty(), rStudent); - Sum localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot()); + Sum localOutput0 = new Sum(false, true, rStudent.getOutput().get(0).toSlot()); PlanChecker.from(MemoTestUtils.createConnectContext(), root) .applyImplementation(twoPhaseAggregateWithoutDistinct()) @@ -380,6 +383,40 @@ public class AggregateStrategiesTest implements MemoPatternMatchSupported { ); } + @Test + public void distinctApply4PhaseRuleNullableChange() { + Slot id = rStudent.getOutput().get(0).toSlot(); + List<Expression> groupExpressionList = Lists.newArrayList(); + List<NamedExpression> outputExpressionList = Lists.newArrayList( + new Alias(new Count(true, id), "count_id"), + new Alias(new Sum(id), "sum_id")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, + true, Optional.empty(), rStudent); + + // select count(distinct id), sum(id) from t; + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyImplementation(fourPhaseAggregateWithDistinct()) + .matches( + physicalHashAggregate( + physicalHashAggregate( + physicalHashAggregate( + physicalHashAggregate() + .when(agg -> agg.getAggPhase().equals(AggPhase.LOCAL)) + .when(agg -> agg.getGroupByExpressions().get(0).equals(id)) + .when(agg -> verifyAlwaysNullableFlag( + agg.getAggregateFunctions(), false))) + .when(agg -> agg.getAggPhase().equals(AggPhase.GLOBAL)) + .when(agg -> agg.getGroupByExpressions().get(0).equals(id)) + .when(agg -> verifyAlwaysNullableFlag(agg.getAggregateFunctions(), + false))) + .when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL)) + .when(agg -> agg.getGroupByExpressions().isEmpty()) + .when(agg -> verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true))) + .when(agg -> agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL)) + .when(agg -> agg.getGroupByExpressions().isEmpty()) + .when(agg -> verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true))); + } + private Rule twoPhaseAggregateWithoutDistinct() { return new AggregateStrategies().buildRules() .stream() @@ -400,8 +437,18 @@ public class AggregateStrategiesTest implements MemoPatternMatchSupported { private Rule fourPhaseAggregateWithDistinct() { return new AggregateStrategies().buildRules() .stream() - .filter(rule -> rule.getRuleType() == RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT) + .filter(rule -> rule.getRuleType() == RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT) .findFirst() .get(); } + + private boolean verifyAlwaysNullableFlag(Set<AggregateFunction> functions, boolean alwaysNullable) { + for (AggregateFunction f : functions) { + if (f instanceof NullableAggregateFunction + && ((NullableAggregateFunction) f).isAlwaysNullable() != alwaysNullable) { + return false; + } + } + return true; + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org