This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch tpc_preview4 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 5ec08101d2bdb6ee4868cd671ca1f63d73586728 Author: minghong <[email protected]> AuthorDate: Tue Dec 2 19:22:11 2025 +0800 push down sum-if --- .../doris/nereids/jobs/executor/Rewriter.java | 2 + .../eageraggregation/PushdownSumIfAggregation.java | 160 ++++++++++ .../rewrite/eageraggregation/SumAggContext.java | 56 ++++ .../rewrite/eageraggregation/SumAggWriter.java | 323 +++++++++++++++++++++ .../doris/nereids/stats/ExpressionEstimation.java | 2 +- .../trees/plans/logical/LogicalSetOperation.java | 42 +++ .../nereids/trees/plans/logical/LogicalUnion.java | 8 + .../java/org/apache/doris/qe/SessionVariable.java | 10 + 8 files changed, 602 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 693ee3a2799..8c1a5a74df3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -170,6 +170,7 @@ import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning; import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin; import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply; import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply; +import org.apache.doris.nereids.rules.rewrite.eageraggregation.PushdownSumIfAggregation; import org.apache.doris.nereids.trees.plans.algebra.SetOperation; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalApply; @@ -658,6 +659,7 @@ public class Rewriter extends AbstractBatchJobExecutor { new PushDownAggThroughJoin() )), costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)), + custom(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN, PushdownSumIfAggregation::new), topDown(new PushCountIntoUnionAll()) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushdownSumIfAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushdownSumIfAggregation.java new file mode 100644 index 00000000000..1b1de45153b --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushdownSumIfAggregation.java @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** + * PushdownAggregationThroughJoinV2 + */ +public class PushdownSumIfAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Sum.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + int mode = ConnectContext.get().getSessionVariable().eagerAggregationMode; + if (mode < 0) { + return plan; + } else { + return plan.accept(this, jobContext); + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + // TODO : push down upper aggregations + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<NamedExpression> aliasToBePushDown = Lists.newArrayList(); + List<EqualTo> ifConditions = Lists.newArrayList(); + List<SlotReference> ifThenSlots = Lists.newArrayList(); + boolean patternMatch = true; + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + if (aggOutput instanceof Alias) { + Expression body = aggOutput.child(0); + if (body instanceof Sum) { + Expression sumBody = ((Sum) body).child(); + if (sumBody instanceof If) { + If ifBody = (If) sumBody; + if (ifBody.child(0) instanceof EqualTo + && ifBody.child(1) instanceof SlotReference + && ifBody.child(2) instanceof NullLiteral) { + ifConditions.add((EqualTo) ifBody.child(0)); + ifThenSlots.add((SlotReference) ifBody.child(1)); + aliasToBePushDown.add(aggOutput); + continue; + } + } + } + patternMatch = false; + } + } + if (!patternMatch) { + return agg; + } + if (ifThenSlots.isEmpty()) { + return agg; + } + ifThenSlots = Lists.newArrayList(Sets.newHashSet(ifThenSlots)); + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + if (SessionVariable.isFeDebug()) { + throw new RuntimeException("PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + } else { + return agg; + } + } + } + + SumAggContext sumAggContext = new SumAggContext(aliasToBePushDown, ifConditions, ifThenSlots, groupKeys); + SumAggWriter writer = new SumAggWriter(); + Plan child = agg.child().accept(writer, sumAggContext); + if (child != agg.child()) { + List<NamedExpression> outputExpressions = agg.getOutputExpressions(); + List<NamedExpression> newOutputExpressions = new ArrayList<>(); + for (NamedExpression output : outputExpressions) { + if (output instanceof SlotReference) { + newOutputExpressions.add(output); + } else if (output instanceof Alias + && output.child(0) instanceof Sum + && output.child(0).child(0) instanceof If + && output.child(0).child(0).child(1) instanceof SlotReference) { + SlotReference targetSlot = (SlotReference) output.child(0).child(0).child(1); + Slot toReplace = null; + for (Slot slot : child.getOutput()) { + if (slot.getExprId().equals(targetSlot.getExprId())) { + toReplace = slot; + } + } + if (toReplace != null) { + Alias newOutput = (Alias) ((Alias) output).withChildren( + new Sum( + new If( + output.child(0).child(0).child(0), + toReplace, + new NullLiteral() + ) + ) + ); + newOutputExpressions.add(newOutput); + } else { + return agg; + } + + } + } + return agg.withAggOutputChild(newOutputExpressions, child); + } + return agg; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggContext.java new file mode 100644 index 00000000000..7b3e7ee9482 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggContext.java @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; + +import com.google.common.collect.ImmutableList; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * SumAggContext + */ +public class SumAggContext { + public final List<NamedExpression> aliasToBePushDown; + public final List<EqualTo> ifConditions; + public final List<SlotReference> ifThenSlots; + public final List<SlotReference> groupKeys; + + public SumAggContext(List<NamedExpression> aliasToBePushDown, + List<EqualTo> ifConditions, List<SlotReference> ifThenSlots, + List<SlotReference> groupKeys) { + this.aliasToBePushDown = ImmutableList.copyOf(aliasToBePushDown); + this.ifConditions = ImmutableList.copyOf(ifConditions); + Set<SlotReference> distinct = new HashSet<>(ifThenSlots); + this.ifThenSlots = ImmutableList.copyOf(distinct); + this.groupKeys = ImmutableList.copyOf(groupKeys); + } + + public SumAggContext withIfThenSlots(List<SlotReference> ifThenSlots) { + return new SumAggContext(this.aliasToBePushDown, + this.ifConditions, + ifThenSlots, + this.groupKeys); + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggWriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggWriter.java new file mode 100644 index 00000000000..27a2165f422 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/SumAggWriter.java @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.rules.rewrite.StatsDerive; +import org.apache.doris.nereids.stats.ExpressionEstimation; +import org.apache.doris.nereids.stats.StatsCalculator; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.Statistics; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * SumAggWriter + */ +public class SumAggWriter extends DefaultPlanRewriter<SumAggContext> { + private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000; + private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; + private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; + private final StatsDerive derive = new StatsDerive(true); + + @Override + public Plan visit(Plan plan, SumAggContext context) { + return plan; + } + + @Override + public Plan visitLogicalProject(LogicalProject<? extends Plan> project, SumAggContext context) { + if (project.getProjects().stream().allMatch(proj -> proj instanceof SlotReference + || (proj instanceof Alias && proj.child(0) instanceof SlotReference))) { + List<SlotReference> slotToPush = new ArrayList<>(); + for (SlotReference slot : context.ifThenSlots) { + slotToPush.add((SlotReference) project.pushDownExpressionPastProject(slot)); + } + List<SlotReference> groupBySlots = new ArrayList<>(); + for (SlotReference slot : context.groupKeys) { + groupBySlots.add((SlotReference) project.pushDownExpressionPastProject(slot)); + } + SumAggContext contextForChild = new SumAggContext( + context.aliasToBePushDown, + context.ifConditions, + slotToPush, + groupBySlots); + Plan child = project.child().accept(this, contextForChild); + if (child != project.child()) { + List<NamedExpression> newProjects = Lists.newArrayList(); + for (NamedExpression ne : project.getProjects()) { + newProjects.add((NamedExpression) replaceBySlots(ne, child.getOutput())); + } + return project.withProjects(newProjects).withChildren(child); + } + } + return project; + } + + private static Expression replaceBySlots(Expression expression, List<Slot> slots) { + Map<Slot, Slot> replaceMap = new HashMap<>(); + for (Slot slot1 : expression.getInputSlots()) { + for (Slot slot2 : slots) { + if (slot1.getExprId().asInt() == slot2.getExprId().asInt()) { + replaceMap.put(slot1, slot2); + } + } + } + Expression result = ExpressionUtils.replace(expression, replaceMap); + return result; + } + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, SumAggContext context) { + Set<Slot> leftOutput = join.left().getOutputSet(); + Set<SlotReference> conditionSlots = join.getConditionSlot().stream() + .map(slot -> (SlotReference) slot).collect(Collectors.toSet()); + for (Slot slot : context.ifThenSlots) { + if (conditionSlots.contains(slot)) { + return join; + } + } + Set<SlotReference> conditionSlotsFromLeft = Sets.newHashSet(conditionSlots); + conditionSlotsFromLeft.retainAll(leftOutput); + for (SlotReference slot : context.groupKeys) { + if (leftOutput.contains(slot)) { + conditionSlotsFromLeft.add(slot); + } + } + if (leftOutput.containsAll(context.ifThenSlots)) { + SumAggContext contextForChild = new SumAggContext( + context.aliasToBePushDown, + context.ifConditions, + context.ifThenSlots, + Lists.newArrayList(conditionSlotsFromLeft) + ); + Plan left = join.left().accept(this, contextForChild); + if (join.left() != left) { + return join.withChildren(left, join.right()); + } + } + return join; + } + + @Override + public Plan visitLogicalUnion(LogicalUnion union, SumAggContext context) { + if (!union.getOutputSet().containsAll(context.ifThenSlots)) { + return union; + } + if (!union.getConstantExprsList().isEmpty()) { + return union; + } + + if (!union.getOutputs().stream().allMatch(e -> e instanceof SlotReference)) { + return union; + } + List<Plan> newChildren = Lists.newArrayList(); + + boolean changed = false; + for (int i = 0; i < union.children().size(); i++) { + Plan child = union.children().get(i); + List<SlotReference> ifThenSlotsForChild = new ArrayList<>(); + // List<SlotReference> groupByForChild = new ArrayList<>(); + for (SlotReference slot : context.ifThenSlots) { + Expression pushed = union.pushDownExpressionPastSetOperator(slot, i); + if (pushed instanceof SlotReference) { + ifThenSlotsForChild.add((SlotReference) pushed); + } else { + return union; + } + } + int childIdx = i; + SumAggContext contextForChild = new SumAggContext( + context.aliasToBePushDown, + context.ifConditions, + ifThenSlotsForChild, + context.groupKeys.stream().map(slot + -> (SlotReference) union.pushDownExpressionPastSetOperator(slot, childIdx)) + .collect(Collectors.toList()) + ); + Plan newChild = child.accept(this, contextForChild); + if (newChild != child) { + changed = true; + } + newChildren.add(newChild); + } + if (changed) { + List<List<SlotReference>> newRegularChildrenOutputs = Lists.newArrayList(); + for (int i = 0; i < newChildren.size(); i++) { + List<SlotReference> childOutput = new ArrayList<>(); + for (SlotReference slot : union.getRegularChildOutput(i)) { + for (Slot c : newChildren.get(i).getOutput()) { + if (slot.equals(c)) { + childOutput.add((SlotReference) c); + break; + } + } + } + newRegularChildrenOutputs.add(childOutput); + } + List<NamedExpression> newOutputs = new ArrayList<>(); + for (int i = 0; i < union.getOutput().size(); i++) { + SlotReference originSlot = (SlotReference) union.getOutput().get(i); + DataType dataType = newRegularChildrenOutputs.get(0).get(i).getDataType(); + newOutputs.add(originSlot.withNullableAndDataType(originSlot.nullable(), dataType)); + } + return union.withChildrenAndOutputs(newChildren, newOutputs, newRegularChildrenOutputs); + } else { + return union; + } + } + + @Override + public Plan visitLogicalRelation(LogicalRelation relation, SumAggContext context) { + return genAggregate(relation, context); + } + + private Plan genAggregate(Plan child, SumAggContext context) { + if (checkStats(child, context)) { + List<NamedExpression> aggOutputExpressions = new ArrayList<>(); + for (SlotReference slot : context.ifThenSlots) { + Alias alias = new Alias(slot.getExprId(), new Sum(slot)); + aggOutputExpressions.add(alias); + } + aggOutputExpressions.addAll(context.groupKeys); + + LogicalAggregate genAgg = new LogicalAggregate(context.groupKeys, aggOutputExpressions, child); + return genAgg; + } else { + return child; + } + + } + + private boolean checkStats(Plan plan, SumAggContext context) { + if (ConnectContext.get() == null) { + return false; + } + int mode = ConnectContext.get().getSessionVariable().eagerAggregationMode; + if (mode < 0) { + return false; + } + if (mode > 0) { + return true; + } + Statistics stats = plan.getStats(); + if (stats == null) { + stats = plan.accept(derive, new StatsDerive.DeriveContext()); + } + if (stats.getRowCount() == 0) { + return false; + } + + List<ColumnStatistic> groupKeysStats = new ArrayList<>(); + + List<ColumnStatistic> lower = Lists.newArrayList(); + List<ColumnStatistic> medium = Lists.newArrayList(); + List<ColumnStatistic> high = Lists.newArrayList(); + + List<ColumnStatistic>[] cards = new List[] {lower, medium, high}; + + for (NamedExpression key : context.groupKeys) { + ColumnStatistic colStats = ExpressionEstimation.INSTANCE.estimate(key, stats); + if (colStats.isUnKnown) { + return false; + } + groupKeysStats.add(colStats); + cards[groupByCardinality(colStats, stats.getRowCount())].add(colStats); + } + + double lowerCartesian = 1.0; + for (ColumnStatistic colStats : lower) { + lowerCartesian = lowerCartesian * colStats.ndv; + } + + // pow(row_count/20, a half of lower column size) + double lowerUpper = Math.max(stats.getRowCount() / 20, 1); + lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1)); + + if (high.isEmpty() && (lower.size() + medium.size()) == 1) { + return true; + } + + if (high.isEmpty() && medium.isEmpty()) { + if (lower.size() == 1 && lowerCartesian * 20 <= stats.getRowCount()) { + return true; + } else if (lower.size() == 2 && lowerCartesian * 7 <= stats.getRowCount()) { + return true; + } else if (lower.size() <= 3 && lowerCartesian * 20 <= stats.getRowCount() && lowerCartesian < lowerUpper) { + return true; + } else { + return false; + } + } + + if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) { + return false; + } + + // 3. Extremely low cardinality for lower with at most one medium or high. + double lowerCartesianLowerBound = + stats.getRowCount() / LOWER_AGGREGATE_EFFECT_COEFFICIENT; + if (high.size() + medium.size() == 1 && lower.size() <= 2 && lowerCartesian <= lowerCartesianLowerBound) { + StatsCalculator statsCalculator = new StatsCalculator(null); + double estAggRowCount = statsCalculator.estimateGroupByRowCount( + context.groupKeys.stream().map(s -> (Expression) s).collect(Collectors.toList()), + stats); + return estAggRowCount < lowerCartesianLowerBound; + } + + return false; + } + + // high(2): row_count / cardinality < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT + // medium(1): row_count / cardinality >= MEDIUM_AGGREGATE_EFFECT_COEFFICIENT and < LOW_AGGREGATE_EFFECT_COEFFICIENT + // lower(0): row_count / cardinality >= LOW_AGGREGATE_EFFECT_COEFFICIENT + private int groupByCardinality(ColumnStatistic colStats, double rowCount) { + if (rowCount == 0 || colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT > rowCount) { + return 2; + } else if (colStats.ndv * MEDIUM_AGGREGATE_EFFECT_COEFFICIENT <= rowCount + && colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT > rowCount) { + return 1; + } else if (colStats.ndv * LOW_AGGREGATE_EFFECT_COEFFICIENT <= rowCount) { + return 0; + } + return 2; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index 2eb3c3f88c2..3f68c8e9850 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -123,7 +123,7 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta public static final Logger LOG = LogManager.getLogger(ExpressionEstimation.class); public static final long DAYS_FROM_0_TO_1970 = 719528; public static final long DAYS_FROM_0_TO_9999 = 3652424; - private static final ExpressionEstimation INSTANCE = new ExpressionEstimation(); + public static final ExpressionEstimation INSTANCE = new ExpressionEstimation(); /** * returned columnStat is newly created or a copy of stats diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java index 25ff7d55720..69447253cba 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.MapType; import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.qe.GlobalVariable; import org.apache.doris.qe.SessionVariable; @@ -312,4 +313,45 @@ public abstract class LogicalSetOperation extends AbstractLogicalPlan public Optional<Plan> processProject(List<NamedExpression> parentProjects) { return Optional.of(PushProjectThroughUnion.doPushProject(parentProjects, this)); } + + /** + * Push down expression past SetOperation to a specific child. + * + * This method maps the expression from the SetOperation's output slots + * to the corresponding child's output slots. + * + * Example: + * SetOperation outputs: [x, y] + * Child 0 outputs (regularChildrenOutputs[0]): [a, b] + * Child 1 outputs (regularChildrenOutputs[1]): [c, d] + * + * If expression is "x + 1": + * - For childIdx=0, return "a + 1" + * - For childIdx=1, return "c + 1" + * + * @param expression the expression to push down + * @param childIdx the index of the child to push down to + * @return the rewritten expression for the child, or null if childIdx is out of + * bounds + */ + public Expression pushDownExpressionPastSetOperator(Expression expression, int childIdx) { + // Check if childIdx is valid + if (childIdx < 0 || childIdx >= regularChildrenOutputs.size()) { + return null; + } + + // Build mapping from SetOperation output slots to child output slots + java.util.HashMap<Slot, Expression> slotMapping = new java.util.HashMap<>(); + List<SlotReference> childOutputs = regularChildrenOutputs.get(childIdx); + + // Map each output slot to the corresponding child slot + for (int i = 0; i < outputs.size() && i < childOutputs.size(); i++) { + Slot outputSlot = outputs.get(i).toSlot(); + SlotReference childSlot = childOutputs.get(i); + slotMapping.put(outputSlot, childSlot); + } + + // Replace slots in the expression using the mapping + return ExpressionUtils.replace(expression, slotMapping); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java index 1a1b3b2e84e..ae672cdd887 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java @@ -170,6 +170,14 @@ public class LogicalUnion extends LogicalSetOperation implements Union, OutputPr return new LogicalUnion(qualifier, outputs, childrenOutputs, constantExprsList, hasPushedFilter, children); } + public LogicalSetOperation withChildrenAndOutputs(List<Plan> children, List<NamedExpression> newOuptuts, + List<List<SlotReference>> childrenOutputs) { + Preconditions.checkArgument(children.size() == childrenOutputs.size(), + "children size %s is not equals with children outputs size %s", + children.size(), childrenOutputs.size()); + return new LogicalUnion(qualifier, newOuptuts, childrenOutputs, constantExprsList, hasPushedFilter, children); + } + @Override public LogicalUnion withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalUnion(qualifier, outputs, regularChildrenOutputs, constantExprsList, hasPushedFilter, diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index cbfc1d7981c..16262196e10 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -1576,6 +1576,16 @@ public class SessionVariable implements Serializable, Writable { ) public boolean enablePruneNestedColumns = true; + @VariableMgr.VarAttr(name = "eager_aggregation_mode", needForward = true, + description = {"0: 根据统计信息决定是使用eager aggregation," + + "1: 强制使用 eager aggregation," + + "-1: 禁止使用 eager aggregation", + "0: Determine eager aggregation by statistics, " + + "1: force eager aggregation, " + + "-1: Prohibit eager aggregation "} + ) + public int eagerAggregationMode = 0; + public boolean enableTopnLazyMaterialization() { return ConnectContext.get() != null && ConnectContext.get().getSessionVariable().topNLazyMaterializationThreshold > 0; --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
