This is an automated email from the ASF dual-hosted git repository.

starocean999 pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new a21bb804563 [fix](nereids)modify agg function nullability in 
PhysicalHashAggregate (#42054)
a21bb804563 is described below

commit a21bb8045638baddf2ee2ddc44707aac73c9b01a
Author: starocean999 <40539150+starocean...@users.noreply.github.com>
AuthorDate: Tue Oct 22 11:06:59 2024 +0800

    [fix](nereids)modify agg function nullability in PhysicalHashAggregate 
(#42054)
    
    ## Proposed changes
    pick from master https://github.com/apache/doris/pull/41943
    
    <!--Describe your changes.-->
---
 .../plans/physical/PhysicalHashAggregate.java      | 43 ++++++++++++++++--
 .../rules/rewrite/AggregateStrategiesTest.java     | 51 +++++++++++++++++++++-
 2 files changed, 88 insertions(+), 6 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
index 17ad6516fa7..5708a377273 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
@@ -22,16 +22,20 @@ import 
org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
 import org.apache.doris.nereids.properties.RequireProperties;
 import org.apache.doris.nereids.properties.RequirePropertiesSupplier;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
 import org.apache.doris.nereids.trees.plans.AggMode;
 import org.apache.doris.nereids.trees.plans.AggPhase;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.PlanType;
 import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
 import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.Utils;
 import org.apache.doris.statistics.Statistics;
 
@@ -91,8 +95,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnar
         super(PlanType.PHYSICAL_HASH_AGGREGATE, groupExpression, 
logicalProperties, child);
         this.groupByExpressions = ImmutableList.copyOf(
                 Objects.requireNonNull(groupByExpressions, "groupByExpressions 
cannot be null"));
-        this.outputExpressions = ImmutableList.copyOf(
-                Objects.requireNonNull(outputExpressions, "outputExpressions 
cannot be null"));
+        this.outputExpressions = adjustNullableForOutputs(
+                Objects.requireNonNull(outputExpressions, "outputExpressions 
cannot be null"),
+                groupByExpressions.isEmpty());
         this.partitionExpressions = Objects.requireNonNull(
                 partitionExpressions, "partitionExpressions cannot be null");
         this.aggregateParam = Objects.requireNonNull(aggregateParam, 
"aggregate param cannot be null");
@@ -118,8 +123,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnar
                 child);
         this.groupByExpressions = ImmutableList.copyOf(
                 Objects.requireNonNull(groupByExpressions, "groupByExpressions 
cannot be null"));
-        this.outputExpressions = ImmutableList.copyOf(
-                Objects.requireNonNull(outputExpressions, "outputExpressions 
cannot be null"));
+        this.outputExpressions = adjustNullableForOutputs(
+                Objects.requireNonNull(outputExpressions, "outputExpressions 
cannot be null"),
+                groupByExpressions.isEmpty());
         this.partitionExpressions = Objects.requireNonNull(
                 partitionExpressions, "partitionExpressions cannot be null");
         this.aggregateParam = Objects.requireNonNull(aggregateParam, 
"aggregate param cannot be null");
@@ -299,4 +305,33 @@ public class PhysicalHashAggregate<CHILD_TYPE extends 
Plan> extends PhysicalUnar
                 requireProperties, physicalProperties, statistics,
                 child());
     }
+
+    /**
+     * sql: select sum(distinct c1) from t;
+     * assume c1 is not null, because there is no group by
+     * sum(distinct c1)'s nullable is alwasNullable in rewritten phase.
+     * But in implementation phase, we may create 3 phase agg with group by 
key c1.
+     * And the sum(distinct c1)'s nullability should be changed depending on 
if there is any group by expressions.
+     * This pr update the agg function's nullability accordingly
+     */
+    private List<NamedExpression> 
adjustNullableForOutputs(List<NamedExpression> outputs, boolean alwaysNullable) 
{
+        return ExpressionUtils.rewriteDownShortCircuit(outputs, output -> {
+            if (output instanceof AggregateExpression) {
+                AggregateFunction function = ((AggregateExpression) 
output).getFunction();
+                if (function instanceof NullableAggregateFunction
+                        && ((NullableAggregateFunction) 
function).isAlwaysNullable() != alwaysNullable) {
+                    AggregateParam param = ((AggregateExpression) 
output).getAggregateParam();
+                    Expression child = ((AggregateExpression) output).child();
+                    AggregateFunction newFunction = 
((NullableAggregateFunction) function)
+                            .withAlwaysNullable(alwaysNullable);
+                    if (function == child) {
+                        // function is also child
+                        child = newFunction;
+                    }
+                    return new AggregateExpression(newFunction, param, child);
+                }
+            }
+            return output;
+        });
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
index 34c16309181..e1e03e64d98 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
@@ -29,8 +29,10 @@ import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.plans.AggMode;
@@ -54,6 +56,7 @@ import org.junit.jupiter.api.TestInstance;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 
 @TestInstance(TestInstance.Lifecycle.PER_CLASS)
 public class AggregateStrategiesTest implements MemoPatternMatchSupported {
@@ -138,7 +141,7 @@ public class AggregateStrategiesTest implements 
MemoPatternMatchSupported {
         Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList,
                 true, Optional.empty(), rStudent);
 
-        Sum localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot());
+        Sum localOutput0 = new Sum(false, true, 
rStudent.getOutput().get(0).toSlot());
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), root)
                 .applyImplementation(twoPhaseAggregateWithoutDistinct())
@@ -380,6 +383,40 @@ public class AggregateStrategiesTest implements 
MemoPatternMatchSupported {
                 );
     }
 
+    @Test
+    public void distinctApply4PhaseRuleNullableChange() {
+        Slot id = rStudent.getOutput().get(0).toSlot();
+        List<Expression> groupExpressionList = Lists.newArrayList();
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(
+                new Alias(new Count(true, id), "count_id"),
+                new Alias(new Sum(id), "sum_id"));
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList,
+                true, Optional.empty(), rStudent);
+
+        // select count(distinct id), sum(id) from t;
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyImplementation(fourPhaseAggregateWithDistinct())
+                .matches(
+                        physicalHashAggregate(
+                                physicalHashAggregate(
+                                        physicalHashAggregate(
+                                                physicalHashAggregate()
+                                                        .when(agg -> 
agg.getAggPhase().equals(AggPhase.LOCAL))
+                                                        .when(agg -> 
agg.getGroupByExpressions().get(0).equals(id))
+                                                        .when(agg -> 
verifyAlwaysNullableFlag(
+                                                                
agg.getAggregateFunctions(), false)))
+                                                .when(agg -> 
agg.getAggPhase().equals(AggPhase.GLOBAL))
+                                                .when(agg -> 
agg.getGroupByExpressions().get(0).equals(id))
+                                                .when(agg -> 
verifyAlwaysNullableFlag(agg.getAggregateFunctions(),
+                                                        false)))
+                                        .when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+                                        .when(agg -> 
agg.getGroupByExpressions().isEmpty())
+                                        .when(agg -> 
verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true)))
+                                .when(agg -> 
agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL))
+                                .when(agg -> 
agg.getGroupByExpressions().isEmpty())
+                                .when(agg -> 
verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true)));
+    }
+
     private Rule twoPhaseAggregateWithoutDistinct() {
         return new AggregateStrategies().buildRules()
                 .stream()
@@ -400,8 +437,18 @@ public class AggregateStrategiesTest implements 
MemoPatternMatchSupported {
     private Rule fourPhaseAggregateWithDistinct() {
         return new AggregateStrategies().buildRules()
                 .stream()
-                .filter(rule -> rule.getRuleType() == 
RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT)
+                .filter(rule -> rule.getRuleType() == 
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT)
                 .findFirst()
                 .get();
     }
+
+    private boolean verifyAlwaysNullableFlag(Set<AggregateFunction> functions, 
boolean alwaysNullable) {
+        for (AggregateFunction f : functions) {
+            if (f instanceof NullableAggregateFunction
+                    && ((NullableAggregateFunction) f).isAlwaysNullable() != 
alwaysNullable) {
+                return false;
+            }
+        }
+        return true;
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to