This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch tpc_preview4
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 5ec08101d2bdb6ee4868cd671ca1f63d73586728
Author: minghong <[email protected]>
AuthorDate: Tue Dec 2 19:22:11 2025 +0800

    push down sum-if
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   2 +
 .../eageraggregation/PushdownSumIfAggregation.java | 160 ++++++++++
 .../rewrite/eageraggregation/SumAggContext.java    |  56 ++++
 .../rewrite/eageraggregation/SumAggWriter.java     | 323 +++++++++++++++++++++
 .../doris/nereids/stats/ExpressionEstimation.java  |   2 +-
 .../trees/plans/logical/LogicalSetOperation.java   |  42 +++
 .../nereids/trees/plans/logical/LogicalUnion.java  |   8 +
 .../java/org/apache/doris/qe/SessionVariable.java  |  10 +
 8 files changed, 602 insertions(+), 1 deletion(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 693ee3a2799..8c1a5a74df3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -170,6 +170,7 @@ import 
org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning;
 import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin;
 import 
org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply;
 import 
org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply;
+import 
org.apache.doris.nereids.rules.rewrite.eageraggregation.PushdownSumIfAggregation;
 import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
@@ -658,6 +659,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
                                 new PushDownAggThroughJoin()
                         )),
                         
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, 
PushDownDistinctThroughJoin::new)),
+                        custom(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN, 
PushdownSumIfAggregation::new),
                         topDown(new PushCountIntoUnionAll())
                 ),
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushdownSumIfAggregation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushdownSumIfAggregation.java
new file mode 100644
index 00000000000..1b1de45153b
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushdownSumIfAggregation.java
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.eageraggregation;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.qe.SessionVariable;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * PushdownAggregationThroughJoinV2
+ */
+public class PushdownSumIfAggregation extends DefaultPlanRewriter<JobContext> 
implements CustomRewriter {
+    private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet(
+            Sum.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        int mode = 
ConnectContext.get().getSessionVariable().eagerAggregationMode;
+        if (mode < 0) {
+            return plan;
+        } else {
+            return plan.accept(this, jobContext);
+        }
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, 
JobContext context) {
+        Plan newChild = agg.child().accept(this, context);
+        if (newChild != agg.child()) {
+            // TODO : push down upper aggregations
+            return agg.withChildren(newChild);
+        }
+
+        if (agg.getSourceRepeat().isPresent()) {
+            return agg;
+        }
+
+        List<NamedExpression> aliasToBePushDown = Lists.newArrayList();
+        List<EqualTo> ifConditions = Lists.newArrayList();
+        List<SlotReference> ifThenSlots = Lists.newArrayList();
+        boolean patternMatch = true;
+        for (NamedExpression aggOutput : agg.getOutputExpressions()) {
+            if (aggOutput instanceof Alias) {
+                Expression body = aggOutput.child(0);
+                if (body instanceof Sum) {
+                    Expression sumBody = ((Sum) body).child();
+                    if (sumBody instanceof If) {
+                        If ifBody = (If) sumBody;
+                        if (ifBody.child(0) instanceof EqualTo
+                                && ifBody.child(1) instanceof SlotReference
+                                && ifBody.child(2) instanceof NullLiteral) {
+                            ifConditions.add((EqualTo) ifBody.child(0));
+                            ifThenSlots.add((SlotReference) ifBody.child(1));
+                            aliasToBePushDown.add(aggOutput);
+                            continue;
+                        }
+                    }
+                }
+                patternMatch = false;
+            }
+        }
+        if (!patternMatch) {
+            return agg;
+        }
+        if (ifThenSlots.isEmpty()) {
+            return agg;
+        }
+        ifThenSlots = Lists.newArrayList(Sets.newHashSet(ifThenSlots));
+
+        List<SlotReference> groupKeys = new ArrayList<>();
+        for (Expression groupKey : agg.getGroupByExpressions()) {
+            if (groupKey instanceof SlotReference) {
+                groupKeys.add((SlotReference) groupKey);
+            } else {
+                if (SessionVariable.isFeDebug()) {
+                    throw new RuntimeException("PushDownAggregation failed: 
agg is not normalized\n "
+                            + agg.treeString());
+                } else {
+                    return agg;
+                }
+            }
+        }
+
+        SumAggContext sumAggContext = new SumAggContext(aliasToBePushDown, 
ifConditions, ifThenSlots, groupKeys);
+        SumAggWriter writer = new SumAggWriter();
+        Plan child = agg.child().accept(writer, sumAggContext);
+        if (child != agg.child()) {
+            List<NamedExpression> outputExpressions = 
agg.getOutputExpressions();
+            List<NamedExpression> newOutputExpressions = new ArrayList<>();
+            for (NamedExpression output : outputExpressions) {
+                if (output instanceof SlotReference) {
+                    newOutputExpressions.add(output);
+                } else if (output instanceof Alias
+                        && output.child(0) instanceof Sum
+                        && output.child(0).child(0) instanceof If
+                        && output.child(0).child(0).child(1) instanceof 
SlotReference) {
+                    SlotReference targetSlot = (SlotReference) 
output.child(0).child(0).child(1);
+                    Slot toReplace = null;
+                    for (Slot slot : child.getOutput()) {
+                        if (slot.getExprId().equals(targetSlot.getExprId())) {
+                            toReplace = slot;
+                        }
+                    }
+                    if (toReplace != null) {
+                        Alias newOutput = (Alias) ((Alias) 
output).withChildren(
+                                new Sum(
+                                        new If(
+                                                
output.child(0).child(0).child(0),
+                                                toReplace,
+                                                new NullLiteral()
+                                    )
+                            )
+                        );
+                        newOutputExpressions.add(newOutput);
+                    } else {
+                        return agg;
+                    }
+
+                }
+            }
+            return agg.withAggOutputChild(newOutputExpressions, child);
+        }
+        return agg;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggContext.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggContext.java
new file mode 100644
index 00000000000..7b3e7ee9482
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggContext.java
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.eageraggregation;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * SumAggContext
+ */
+public class SumAggContext {
+    public final List<NamedExpression> aliasToBePushDown;
+    public final List<EqualTo> ifConditions;
+    public final List<SlotReference> ifThenSlots;
+    public final List<SlotReference> groupKeys;
+
+    public SumAggContext(List<NamedExpression> aliasToBePushDown,
+            List<EqualTo> ifConditions, List<SlotReference> ifThenSlots,
+            List<SlotReference> groupKeys) {
+        this.aliasToBePushDown = ImmutableList.copyOf(aliasToBePushDown);
+        this.ifConditions = ImmutableList.copyOf(ifConditions);
+        Set<SlotReference> distinct = new HashSet<>(ifThenSlots);
+        this.ifThenSlots = ImmutableList.copyOf(distinct);
+        this.groupKeys = ImmutableList.copyOf(groupKeys);
+    }
+
+    public SumAggContext withIfThenSlots(List<SlotReference> ifThenSlots) {
+        return new SumAggContext(this.aliasToBePushDown,
+                this.ifConditions,
+                ifThenSlots,
+                this.groupKeys);
+    }
+
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggWriter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggWriter.java
new file mode 100644
index 00000000000..27a2165f422
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggWriter.java
@@ -0,0 +1,323 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.eageraggregation;
+
+import org.apache.doris.nereids.rules.rewrite.StatsDerive;
+import org.apache.doris.nereids.stats.ExpressionEstimation;
+import org.apache.doris.nereids.stats.StatsCalculator;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * SumAggWriter
+ */
+public class SumAggWriter extends DefaultPlanRewriter<SumAggContext> {
+    private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
+    private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
+    private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;
+    private final StatsDerive derive = new StatsDerive(true);
+
+    @Override
+    public Plan visit(Plan plan, SumAggContext context) {
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalProject(LogicalProject<? extends Plan> project, 
SumAggContext context) {
+        if (project.getProjects().stream().allMatch(proj -> proj instanceof 
SlotReference
+                || (proj instanceof Alias && proj.child(0) instanceof 
SlotReference))) {
+            List<SlotReference> slotToPush = new ArrayList<>();
+            for (SlotReference slot : context.ifThenSlots) {
+                slotToPush.add((SlotReference) 
project.pushDownExpressionPastProject(slot));
+            }
+            List<SlotReference> groupBySlots = new ArrayList<>();
+            for (SlotReference slot : context.groupKeys) {
+                groupBySlots.add((SlotReference) 
project.pushDownExpressionPastProject(slot));
+            }
+            SumAggContext contextForChild = new SumAggContext(
+                    context.aliasToBePushDown,
+                    context.ifConditions,
+                    slotToPush,
+                    groupBySlots);
+            Plan child = project.child().accept(this, contextForChild);
+            if (child != project.child()) {
+                List<NamedExpression> newProjects = Lists.newArrayList();
+                for (NamedExpression ne : project.getProjects()) {
+                    newProjects.add((NamedExpression) replaceBySlots(ne, 
child.getOutput()));
+                }
+                return project.withProjects(newProjects).withChildren(child);
+            }
+        }
+        return project;
+    }
+
+    private static Expression replaceBySlots(Expression expression, List<Slot> 
slots) {
+        Map<Slot, Slot> replaceMap = new HashMap<>();
+        for (Slot slot1 : expression.getInputSlots()) {
+            for (Slot slot2 : slots) {
+                if (slot1.getExprId().asInt() == slot2.getExprId().asInt()) {
+                    replaceMap.put(slot1, slot2);
+                }
+            }
+        }
+        Expression result = ExpressionUtils.replace(expression, replaceMap);
+        return result;
+    }
+
+    @Override
+    public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, SumAggContext context) {
+        Set<Slot> leftOutput = join.left().getOutputSet();
+        Set<SlotReference> conditionSlots = join.getConditionSlot().stream()
+                .map(slot -> (SlotReference) slot).collect(Collectors.toSet());
+        for (Slot slot : context.ifThenSlots) {
+            if (conditionSlots.contains(slot)) {
+                return join;
+            }
+        }
+        Set<SlotReference> conditionSlotsFromLeft = 
Sets.newHashSet(conditionSlots);
+        conditionSlotsFromLeft.retainAll(leftOutput);
+        for (SlotReference slot : context.groupKeys) {
+            if (leftOutput.contains(slot)) {
+                conditionSlotsFromLeft.add(slot);
+            }
+        }
+        if (leftOutput.containsAll(context.ifThenSlots)) {
+            SumAggContext contextForChild = new SumAggContext(
+                    context.aliasToBePushDown,
+                    context.ifConditions,
+                    context.ifThenSlots,
+                    Lists.newArrayList(conditionSlotsFromLeft)
+            );
+            Plan left = join.left().accept(this, contextForChild);
+            if (join.left() != left) {
+                return join.withChildren(left, join.right());
+            }
+        }
+        return join;
+    }
+
+    @Override
+    public Plan visitLogicalUnion(LogicalUnion union, SumAggContext context) {
+        if (!union.getOutputSet().containsAll(context.ifThenSlots)) {
+            return union;
+        }
+        if (!union.getConstantExprsList().isEmpty()) {
+            return union;
+        }
+
+        if (!union.getOutputs().stream().allMatch(e -> e instanceof 
SlotReference)) {
+            return union;
+        }
+        List<Plan> newChildren = Lists.newArrayList();
+
+        boolean changed = false;
+        for (int i = 0; i < union.children().size(); i++) {
+            Plan child = union.children().get(i);
+            List<SlotReference> ifThenSlotsForChild = new ArrayList<>();
+            // List<SlotReference> groupByForChild = new ArrayList<>();
+            for (SlotReference slot : context.ifThenSlots) {
+                Expression pushed = 
union.pushDownExpressionPastSetOperator(slot, i);
+                if (pushed instanceof SlotReference) {
+                    ifThenSlotsForChild.add((SlotReference) pushed);
+                } else {
+                    return union;
+                }
+            }
+            int childIdx = i;
+            SumAggContext contextForChild = new SumAggContext(
+                    context.aliasToBePushDown,
+                    context.ifConditions,
+                    ifThenSlotsForChild,
+                    context.groupKeys.stream().map(slot
+                            -> (SlotReference) 
union.pushDownExpressionPastSetOperator(slot, childIdx))
+                            .collect(Collectors.toList())
+                    );
+            Plan newChild = child.accept(this, contextForChild);
+            if (newChild != child) {
+                changed = true;
+            }
+            newChildren.add(newChild);
+        }
+        if (changed) {
+            List<List<SlotReference>> newRegularChildrenOutputs = 
Lists.newArrayList();
+            for (int i = 0; i < newChildren.size(); i++) {
+                List<SlotReference> childOutput = new ArrayList<>();
+                for (SlotReference slot : union.getRegularChildOutput(i)) {
+                    for (Slot c : newChildren.get(i).getOutput()) {
+                        if (slot.equals(c)) {
+                            childOutput.add((SlotReference) c);
+                            break;
+                        }
+                    }
+                }
+                newRegularChildrenOutputs.add(childOutput);
+            }
+            List<NamedExpression> newOutputs = new ArrayList<>();
+            for (int i = 0; i < union.getOutput().size(); i++) {
+                SlotReference originSlot = (SlotReference) 
union.getOutput().get(i);
+                DataType dataType = 
newRegularChildrenOutputs.get(0).get(i).getDataType();
+                
newOutputs.add(originSlot.withNullableAndDataType(originSlot.nullable(), 
dataType));
+            }
+            return union.withChildrenAndOutputs(newChildren, newOutputs, 
newRegularChildrenOutputs);
+        } else {
+            return union;
+        }
+    }
+
+    @Override
+    public Plan visitLogicalRelation(LogicalRelation relation, SumAggContext 
context) {
+        return genAggregate(relation, context);
+    }
+
+    private Plan genAggregate(Plan child, SumAggContext context) {
+        if (checkStats(child, context)) {
+            List<NamedExpression> aggOutputExpressions = new ArrayList<>();
+            for (SlotReference slot : context.ifThenSlots) {
+                Alias alias = new Alias(slot.getExprId(), new Sum(slot));
+                aggOutputExpressions.add(alias);
+            }
+            aggOutputExpressions.addAll(context.groupKeys);
+
+            LogicalAggregate genAgg = new LogicalAggregate(context.groupKeys, 
aggOutputExpressions, child);
+            return genAgg;
+        } else {
+            return child;
+        }
+
+    }
+
+    private boolean checkStats(Plan plan, SumAggContext context) {
+        if (ConnectContext.get() == null) {
+            return false;
+        }
+        int mode = 
ConnectContext.get().getSessionVariable().eagerAggregationMode;
+        if (mode < 0) {
+            return false;
+        }
+        if (mode > 0) {
+            return true;
+        }
+        Statistics stats = plan.getStats();
+        if (stats == null) {
+            stats = plan.accept(derive, new StatsDerive.DeriveContext());
+        }
+        if (stats.getRowCount() == 0) {
+            return false;
+        }
+
+        List<ColumnStatistic> groupKeysStats = new ArrayList<>();
+
+        List<ColumnStatistic> lower = Lists.newArrayList();
+        List<ColumnStatistic> medium = Lists.newArrayList();
+        List<ColumnStatistic> high = Lists.newArrayList();
+
+        List<ColumnStatistic>[] cards = new List[] {lower, medium, high};
+
+        for (NamedExpression key : context.groupKeys) {
+            ColumnStatistic colStats = 
ExpressionEstimation.INSTANCE.estimate(key, stats);
+            if (colStats.isUnKnown) {
+                return false;
+            }
+            groupKeysStats.add(colStats);
+            cards[groupByCardinality(colStats, 
stats.getRowCount())].add(colStats);
+        }
+
+        double lowerCartesian = 1.0;
+        for (ColumnStatistic colStats : lower) {
+            lowerCartesian = lowerCartesian * colStats.ndv;
+        }
+
+        // pow(row_count/20, a half of lower column size)
+        double lowerUpper = Math.max(stats.getRowCount() / 20, 1);
+        lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1));
+
+        if (high.isEmpty() && (lower.size() + medium.size()) == 1) {
+            return true;
+        }
+
+        if (high.isEmpty() && medium.isEmpty()) {
+            if (lower.size() == 1 && lowerCartesian * 20 <= 
stats.getRowCount()) {
+                return true;
+            } else if (lower.size() == 2 && lowerCartesian * 7 <= 
stats.getRowCount()) {
+                return true;
+            } else if (lower.size() <= 3 && lowerCartesian * 20 <= 
stats.getRowCount() && lowerCartesian < lowerUpper) {
+                return true;
+            } else {
+                return false;
+            }
+        }
+
+        if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && 
!medium.isEmpty())) {
+            return false;
+        }
+
+        // 3. Extremely low cardinality for lower with at most one medium or 
high.
+        double lowerCartesianLowerBound =
+                stats.getRowCount() / LOWER_AGGREGATE_EFFECT_COEFFICIENT;
+        if (high.size() + medium.size() == 1 && lower.size() <= 2 && 
lowerCartesian <= lowerCartesianLowerBound) {
+            StatsCalculator statsCalculator = new StatsCalculator(null);
+            double estAggRowCount = statsCalculator.estimateGroupByRowCount(
+                    context.groupKeys.stream().map(s -> (Expression) 
s).collect(Collectors.toList()),
+                    stats);
+            return estAggRowCount < lowerCartesianLowerBound;
+        }
+
+        return false;
+    }
+
+    // high(2): row_count / cardinality < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT
+    // medium(1): row_count / cardinality >= 
MEDIUM_AGGREGATE_EFFECT_COEFFICIENT and < LOW_AGGREGATE_EFFECT_COEFFICIENT
+    // lower(0): row_count / cardinality >= LOW_AGGREGATE_EFFECT_COEFFICIENT
+    private int groupByCardinality(ColumnStatistic colStats, double rowCount) {
+        if (rowCount == 0 || colStats.ndv * 
MEDIUM_AGGREGATE_EFFECT_COEFFICIENT > rowCount) {
+            return 2;
+        } else if (colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT <= 
rowCount
+                && colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT > rowCount) 
{
+            return 1;
+        } else if (colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT <= 
rowCount) {
+            return 0;
+        }
+        return 2;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
index 2eb3c3f88c2..3f68c8e9850 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
@@ -123,7 +123,7 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
     public static final Logger LOG = 
LogManager.getLogger(ExpressionEstimation.class);
     public static final long DAYS_FROM_0_TO_1970 = 719528;
     public static final long DAYS_FROM_0_TO_9999 = 3652424;
-    private static final ExpressionEstimation INSTANCE = new 
ExpressionEstimation();
+    public static final ExpressionEstimation INSTANCE = new 
ExpressionEstimation();
 
     /**
      * returned columnStat is newly created or a copy of stats
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java
index 25ff7d55720..69447253cba 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java
@@ -40,6 +40,7 @@ import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.MapType;
 import org.apache.doris.nereids.types.StructField;
 import org.apache.doris.nereids.types.StructType;
+import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.TypeCoercionUtils;
 import org.apache.doris.qe.GlobalVariable;
 import org.apache.doris.qe.SessionVariable;
@@ -312,4 +313,45 @@ public abstract class LogicalSetOperation extends 
AbstractLogicalPlan
     public Optional<Plan> processProject(List<NamedExpression> parentProjects) 
{
         return 
Optional.of(PushProjectThroughUnion.doPushProject(parentProjects, this));
     }
+
+    /**
+     * Push down expression past SetOperation to a specific child.
+     *
+     * This method maps the expression from the SetOperation's output slots
+     * to the corresponding child's output slots.
+     *
+     * Example:
+     * SetOperation outputs: [x, y]
+     * Child 0 outputs (regularChildrenOutputs[0]): [a, b]
+     * Child 1 outputs (regularChildrenOutputs[1]): [c, d]
+     *
+     * If expression is "x + 1":
+     * - For childIdx=0, return "a + 1"
+     * - For childIdx=1, return "c + 1"
+     *
+     * @param expression the expression to push down
+     * @param childIdx   the index of the child to push down to
+     * @return the rewritten expression for the child, or null if childIdx is 
out of
+     *         bounds
+     */
+    public Expression pushDownExpressionPastSetOperator(Expression expression, 
int childIdx) {
+        // Check if childIdx is valid
+        if (childIdx < 0 || childIdx >= regularChildrenOutputs.size()) {
+            return null;
+        }
+
+        // Build mapping from SetOperation output slots to child output slots
+        java.util.HashMap<Slot, Expression> slotMapping = new 
java.util.HashMap<>();
+        List<SlotReference> childOutputs = 
regularChildrenOutputs.get(childIdx);
+
+        // Map each output slot to the corresponding child slot
+        for (int i = 0; i < outputs.size() && i < childOutputs.size(); i++) {
+            Slot outputSlot = outputs.get(i).toSlot();
+            SlotReference childSlot = childOutputs.get(i);
+            slotMapping.put(outputSlot, childSlot);
+        }
+
+        // Replace slots in the expression using the mapping
+        return ExpressionUtils.replace(expression, slotMapping);
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
index 1a1b3b2e84e..ae672cdd887 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java
@@ -170,6 +170,14 @@ public class LogicalUnion extends LogicalSetOperation 
implements Union, OutputPr
         return new LogicalUnion(qualifier, outputs, childrenOutputs, 
constantExprsList, hasPushedFilter, children);
     }
 
+    public LogicalSetOperation withChildrenAndOutputs(List<Plan> children, 
List<NamedExpression> newOuptuts,
+            List<List<SlotReference>> childrenOutputs) {
+        Preconditions.checkArgument(children.size() == childrenOutputs.size(),
+                "children size %s is not equals with children outputs size %s",
+                children.size(), childrenOutputs.size());
+        return new LogicalUnion(qualifier, newOuptuts, childrenOutputs, 
constantExprsList, hasPushedFilter, children);
+    }
+
     @Override
     public LogicalUnion withGroupExpression(Optional<GroupExpression> 
groupExpression) {
         return new LogicalUnion(qualifier, outputs, regularChildrenOutputs, 
constantExprsList, hasPushedFilter,
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index cbfc1d7981c..16262196e10 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -1576,6 +1576,16 @@ public class SessionVariable implements Serializable, 
Writable {
     )
     public boolean enablePruneNestedColumns = true;
 
+    @VariableMgr.VarAttr(name = "eager_aggregation_mode", needForward = true,
+            description = {"0: 根据统计信息决定是使用eager aggregation,"
+                    + "1: 强制使用 eager aggregation,"
+                    + "-1: 禁止使用 eager aggregation",
+                    "0: Determine eager aggregation by statistics, "
+                            + "1: force eager aggregation, "
+                            + "-1: Prohibit eager aggregation "}
+    )
+    public int eagerAggregationMode = 0;
+
     public boolean enableTopnLazyMaterialization() {
         return ConnectContext.get() != null
                 && 
ConnectContext.get().getSessionVariable().topNLazyMaterializationThreshold > 0;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to