This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch release-2.0.2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 5827ca902ea2716cf94d68a7cee5459dc0e52d81
Author: starocean999 <40539150+starocean...@users.noreply.github.com>
AuthorDate: Fri Oct 27 09:36:10 2023 +0800

    [fix](nereids) push down subquery exprs in non-distinct agg functions 
(#25955)
---
 .../nereids/rules/analysis/NormalizeAggregate.java | 26 ++++++++++++++--------
 .../subquery/test_subquery_in_project.out          |  6 +++++
 .../subquery/test_subquery_in_project.groovy       |  8 +++++++
 3 files changed, 31 insertions(+), 9 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index 180e9b915c3..c287f2dffe9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -20,12 +20,14 @@ package org.apache.doris.nereids.rules.analysis;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
+import 
org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
 import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
 import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
@@ -101,11 +103,17 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
 
             List<NamedExpression> aggregateOutput = 
aggregate.getOutputExpressions();
             Set<Alias> existsAlias = 
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
+            // we need push down subquery exprs in side non-distinct agg 
functions
+            Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(
+                    
Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput,
+                            expr -> expr instanceof AggregateFunction
+                                    && !((AggregateFunction) 
expr).isDistinct())),
+                    SubqueryExpr.class::isInstance);
             Set<Expression> groupingByExprs = 
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
-            NormalizeToSlotContext groupByToSlotContext =
-                    NormalizeToSlotContext.buildContext(existsAlias, 
groupingByExprs);
-            Set<NamedExpression> bottomGroupByProjects =
-                    
groupByToSlotContext.pushDownToNamedExpression(groupingByExprs);
+            NormalizeToSlotContext bottomSlotContext =
+                    NormalizeToSlotContext.buildContext(existsAlias, 
Sets.union(groupingByExprs, subqueryExprs));
+            Set<NamedExpression> bottomOutputs =
+                    
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, 
subqueryExprs));
 
             List<AggregateFunction> aggFuncs = Lists.newArrayList();
             aggregateOutput.forEach(o -> 
o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
@@ -119,8 +127,8 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             // after normalize:
             // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 
1)[#1])
             // +-- project((a[#0] + 1)[#1])
-            List<AggregateFunction> normalizedAggFuncs = 
groupByToSlotContext.normalizeToUseSlotRef(aggFuncs);
-            Set<NamedExpression> bottomProjects = 
Sets.newHashSet(bottomGroupByProjects);
+            List<AggregateFunction> normalizedAggFuncs = 
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
+            Set<NamedExpression> bottomProjects = 
Sets.newHashSet(bottomOutputs);
             // TODO: if we have distinct agg, we must push down its children,
             //   because need use it to generate distribution enforce
             // step 1: split agg functions into 2 parts: distinct and not 
distinct
@@ -174,7 +182,7 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
                     NormalizeToSlotContext.buildContext(existsAlias, 
normalizedAggFuncs);
             // agg output include 2 part, normalized group by slots and 
normalized agg functions
             List<NamedExpression> normalizedAggOutput = 
ImmutableList.<NamedExpression>builder()
-                    
.addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator())
+                    
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
                     
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
                     .build();
             // add normalized agg's input slots to bottom projects
@@ -188,7 +196,7 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
                     .collect(Collectors.toSet());
             bottomProjects.addAll(aggInputSlots);
             // build group by exprs
-            List<Expression> normalizedGroupExprs = 
groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
+            List<Expression> normalizedGroupExprs = 
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);
 
             Plan bottomPlan;
             if (!bottomProjects.isEmpty()) {
@@ -198,7 +206,7 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             }
 
             List<NamedExpression> upperProjects = 
normalizeOutput(aggregateOutput,
-                    groupByToSlotContext, normalizedAggFuncsToSlotContext);
+                    bottomSlotContext, normalizedAggFuncsToSlotContext);
 
             return new LogicalProject<>(upperProjects,
                     aggregate.withNormalized(normalizedGroupExprs, 
normalizedAggOutput, bottomPlan));
diff --git 
a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out 
b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
index 5b979356390..4d8bd4c7361 100644
--- a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
+++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
@@ -48,3 +48,9 @@ true
 \N     2.0
 2020-09-09     2.0
 
+-- !sql15 --
+12
+
+-- !sql16 --
+12
+
diff --git 
a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy 
b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
index 0521334d8ae..b9de14e530b 100644
--- a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
+++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
@@ -116,5 +116,13 @@ suite("test_subquery_in_project") {
                 end 'test'  from test_sql group by cube(dt) order by dt;
     """
 
+    qt_sql15 """
+        select sum(age + (select sum(age) from test_sql)) from test_sql;
+    """
+
+    qt_sql16 """
+        select sum(distinct age + (select sum(age) from test_sql)) from 
test_sql;
+    """
+
     sql """drop table if exists test_sql;"""
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to