This is an automated email from the ASF dual-hosted git repository. englefly pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new a1da57c63e [opt](Nereids)(WIP) optimize agg and window normalization step 2 #19305 a1da57c63e is described below commit a1da57c63ecef45a02b914ed377742925b76ed92 Author: Zhang Wenxin <101034200+morrys...@users.noreply.github.com> AuthorDate: Fri May 12 14:00:13 2023 +0800 [opt](Nereids)(WIP) optimize agg and window normalization step 2 #19305 1. refactor aggregate normalization to avoid data amplification before aggregate 2. remove useless aggreagte processing in ExtractAndNormalizeWindowExpression 3. only push distinct aggregate function children TODO: 1. push down redundant expression in aggregate functions 2. refactor normalize repeat rule 3. move expression normalization and optimization after plan normalization to avoid unexpected expression optimization. --- .../doris/nereids/jobs/batch/NereidsRewriter.java | 26 +- .../rules/analysis/ProjectToGlobalAggregate.java | 6 +- .../rules/expression/rules/FunctionBinder.java | 1 + .../rules/implementation/AggregateStrategies.java | 22 +- .../rewrite/logical/EliminateGroupByConstant.java | 8 +- .../ExtractAndNormalizeWindowExpression.java | 62 ++--- .../rules/rewrite/logical/NormalizeAggregate.java | 289 +++++++++------------ .../rules/rewrite/logical/NormalizeToSlot.java | 112 ++++++-- .../org/apache/doris/nereids/trees/TreeNode.java | 27 -- .../trees/expressions/WindowExpression.java | 7 - .../functions/agg/AggregateFunction.java | 12 +- .../ExtractAndNormalizeWindowExpressionTest.java | 2 +- .../rewrite/logical/NormalizeAggregateTest.java | 9 +- .../suites/nereids_syntax_p0/explain.groovy | 1 - 14 files changed, 289 insertions(+), 295 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java index 07c8903334..b43621562c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java @@ -84,27 +84,27 @@ import java.util.List; */ public class NereidsRewriter extends BatchRewriteJob { private static final List<RewriteJob> REWRITE_JOBS = jobs( - topic("Normalization", + topic("Plan Normalization", topDown( new EliminateOrderByConstant(), new EliminateGroupByConstant(), - // MergeProjects depends on this rule new LogicalSubQueryAliasToLogicalProject(), - - // rewrite expressions, no depends + // TODO: we should do expression normalization after plan normalization + // because some rewritten depends on sub expression tree matching + // such as group by key matching and replaced + // but we need to do some normalization before subquery unnesting, + // such as extract common expression. new ExpressionNormalization(), new ExpressionOptimization(), new AvgDistinctToSumDivCount(), new CountDistinctRewrite(), - new ExtractFilterFromCrossJoin() ), - - // ExtractSingleTableExpressionFromDisjunction conflict to InPredicateToEqualToRule - // in the ExpressionNormalization, so must invoke in another job, or else run into - // dead loop topDown( + // ExtractSingleTableExpressionFromDisjunction conflict to InPredicateToEqualToRule + // in the ExpressionNormalization, so must invoke in another job, or else run into + // dead loop new ExtractSingleTableExpressionFromDisjunction() ) ), @@ -131,15 +131,15 @@ public class NereidsRewriter extends BatchRewriteJob { ) ), + // we should eliminate hint again because some hint maybe exist in the CTE or subquery. + // so this rule should invoke after "Subquery unnesting" + custom(RuleType.ELIMINATE_HINT, EliminateLogicalSelectHint::new), + // please note: this rule must run before NormalizeAggregate topDown( new AdjustAggregateNullableForEmptySet() ), - // we should eliminate hint again because some hint maybe exist in the CTE or subquery. - // so this rule should invoke after "Subquery unnesting" - custom(RuleType.ELIMINATE_HINT, EliminateLogicalSelectHint::new), - // The rule modification needs to be done after the subquery is unnested, // because for scalarSubQuery, the connection condition is stored in apply in the analyzer phase, // but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java index a4cf1d1a8c..66371ae000 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java @@ -49,7 +49,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory { logicalProject().then(project -> { boolean needGlobalAggregate = project.getProjects() .stream() - .anyMatch(p -> p.accept(NeedAggregateChecker.INSTANCE, null)); + .anyMatch(p -> p.accept(ContainsAggregateChecker.INSTANCE, null)); if (needGlobalAggregate) { return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child()); @@ -60,9 +60,9 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory { ); } - private static class NeedAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> { + private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> { - private static final NeedAggregateChecker INSTANCE = new NeedAggregateChecker(); + private static final ContainsAggregateChecker INSTANCE = new ContainsAggregateChecker(); @Override public Boolean visit(Expression expr, Void context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java index cc64666e60..c2c48de051 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java @@ -58,6 +58,7 @@ import java.util.stream.Collectors; * function binder */ public class FunctionBinder extends AbstractExpressionRewriteRule { + public static final FunctionBinder INSTANCE = new FunctionBinder(); @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 50bbbc0b20..1a1f085c5c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; @@ -215,6 +216,25 @@ public class AggregateStrategies implements ImplementationRuleFactory { return canNotPush; } + // TODO: refactor this to process slot reference or expression together + boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream() + .map(ExpressionTrait::getArguments) + .flatMap(List::stream) + .allMatch(argument -> { + if (argument instanceof SlotReference) { + return true; + } + if (argument instanceof Cast) { + return argument.child(0) instanceof SlotReference + && argument.getDataType().isNumericType() + && argument.child(0).getDataType().isNumericType(); + } + return false; + }); + if (!onlyContainsSlotOrNumericCastSlot) { + return canNotPush; + } + // we already normalize the arguments to slotReference List<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream() .flatMap(aggregateFunction -> aggregateFunction.getArguments().stream()) @@ -228,7 +248,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .collect(ImmutableList.toImmutableList()); } - boolean onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction + onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction .stream() .allMatch(argument -> { if (argument instanceof SlotReference) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java index a9f3650cb4..5ba3689edd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java @@ -57,8 +57,12 @@ public class EliminateGroupByConstant extends OneRewriteRuleFactory { Set<Expression> slotGroupByExprs = Sets.newLinkedHashSet(); Expression lit = null; for (Expression expression : groupByExprs) { - expression = FoldConstantRule.INSTANCE.rewrite(expression, context); - if (!(expression instanceof Literal)) { + // NOTICE: we should not use the expression after fold as new aggregate's output or group expr + // because we rely on expression matching to replace subtree that same as group by expr in output + // if we do constant folding before normalize aggregate, the subtree will change and matching fail + // such as: select a + 1 + 2 + 3, sum(b) from t group by a + 1 + 2 + Expression foldExpression = FoldConstantRule.INSTANCE.rewrite(expression, context); + if (!(foldExpression instanceof Literal)) { slotGroupByExprs.add(expression); } else { lit = expression; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java index 9282ef3825..da972816e6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java @@ -25,9 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.WindowExpression; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import org.apache.doris.nereids.util.ExpressionUtils; @@ -41,7 +39,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; /** - * extract window expressions from LogicalProject.projects and Normalize LogicalWindow + * extract window expressions from LogicalProject#projects and Normalize LogicalWindow */ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory implements NormalizeToSlot { @@ -60,15 +58,8 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i if (bottomProjects.isEmpty()) { normalizedChild = project.child(); } else { - boolean needAggregate = bottomProjects.stream().anyMatch(expr -> - expr.anyMatch(AggregateFunction.class::isInstance)); - if (needAggregate) { - normalizedChild = new LogicalAggregate<>(ImmutableList.of(), - ImmutableList.copyOf(bottomProjects), project.child()); - } else { - normalizedChild = project.withProjectsAndChild( - ImmutableList.copyOf(bottomProjects), project.child()); - } + normalizedChild = project.withProjectsAndChild( + ImmutableList.copyOf(bottomProjects), project.child()); } // 2. handle window's outputs and windowExprs @@ -96,35 +87,32 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i // bottomProjects includes: // 1. expressions from function and WindowSpec's partitionKeys and orderKeys // 2. other slots of outputExpressions - /* - avg(c) / sum(a+1) over (order by avg(b)) group by a - win(x/sum(z) over y) - prj(x, y, a+1 as z) - agg(avg(c) x, avg(b) y, a) - proj(a b c) - toBePushDown = {avg(c), a+1, avg(b)} - */ + // + // avg(c) / sum(a+1) over (order by avg(b)) group by a + // win(x/sum(z) over y) + // prj(x, y, a+1 as z) + // agg(avg(c) x, avg(b) y, a) + // proj(a b c) + // toBePushDown = {avg(c), a+1, avg(b)} return expressions.stream() .flatMap(expression -> { if (expression.anyMatch(WindowExpression.class::isInstance)) { - Set<Slot> inputSlots = expression.getInputSlots().stream().collect(Collectors.toSet()); + Set<Slot> inputSlots = Sets.newHashSet(expression.getInputSlots()); Set<WindowExpression> collects = expression.collect(WindowExpression.class::isInstance); - Set<Slot> windowInputSlots = collects.stream().flatMap( - win -> win.getInputSlots().stream() - ).collect(Collectors.toSet()); - /* - substr( - ref_1.cp_type, - max( - cast(ref_1.`cp_catalog_page_number` as int)) over (...) - ), - 1) - - in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in - substr.inputSlots, but not in windowExpression.inputSlots - - inputSlots= {ref_1.cp_type} - */ + Set<Slot> windowInputSlots = collects.stream() + .flatMap(win -> win.getInputSlots().stream()) + .collect(Collectors.toSet()); + // substr( + // ref_1.cp_type, + // max( + // cast(ref_1.`cp_catalog_page_number` as int)) over (...) + // ), + // 1) + // + // in above case, ref_1.cp_type should be pushed down. ref_1.cp_type is in + // substr.inputSlots, but not in windowExpression.inputSlots + // + // inputSlots= {ref_1.cp_type} inputSlots.removeAll(windowInputSlots); return Stream.concat( collects.stream().flatMap(windowExpression -> diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index fccd933094..b9859a14cf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -20,14 +20,14 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; -import org.apache.doris.nereids.trees.UnaryNode; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -36,12 +36,12 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; -import com.google.common.collect.Sets; +import com.google.common.collect.Maps; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.Stream; /** * normalize aggregate's group keys and AggregateFunction's child to SlotReference @@ -95,173 +95,144 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali @Override public Rule build() { return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { - // push expression to bottom project - Set<Alias> existsAliases = ExpressionUtils.mutableCollect( - aggregate.getOutputExpressions(), Alias.class::isInstance); - Set<AggregateFunction> aggregateFunctionsInWindow = collectAggregateFunctionsInWindow( - aggregate.getOutputExpressions()); - Set<Expression> existsAggAlias = existsAliases.stream().map(UnaryNode::child) - .filter(AggregateFunction.class::isInstance) - .collect(Collectors.toSet()); - - /* - * agg-functions inside window function is regarded as an output of aggregate. - * select sum(avg(c)) over ... - * is regarded as - * select avg(c), sum(avg(c)) over ... - * - * the plan: - * project(sum(y) over) - * Aggregate(avg(c) as y) - * - * after Aggregate, the 'y' is removed by upper project. - * - * aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))} - */ - List<Alias> aliasOfAggFunInWindowUsedAsAggOutput = Lists.newArrayList(); - for (AggregateFunction aggFun : aggregateFunctionsInWindow) { - if (!existsAggAlias.contains(aggFun)) { - Alias alias = new Alias(aggFun, aggFun.toSql()); - existsAliases.add(alias); - aliasOfAggFunInWindowUsedAsAggOutput.add(alias); + List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions(); + Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); + Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions()); + NormalizeToSlotContext groupByToSlotContext = + NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs); + Set<NamedExpression> bottomGroupByProjects = + groupByToSlotContext.pushDownToNamedExpression(groupingByExprs); + + List<AggregateFunction> aggFuncs = Lists.newArrayList(); + aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); + // use group by context to normalize agg functions to process + // sql like: select sum(a + 1) from t group by a + 1 + // + // before normalize: + // agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1]) + // +-- project(a[#0], (a[#0] + 1)[#1]) + // + // after normalize: + // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1]) + // +-- project((a[#0] + 1)[#1]) + List<AggregateFunction> normalizedAggFuncs = groupByToSlotContext.normalizeToUseSlotRef(aggFuncs); + List<NamedExpression> bottomProjects = Lists.newArrayList(bottomGroupByProjects); + // TODO: if we have distinct agg, we must push down its children, + // because need use it to generate distribution enforce + // step 1: split agg functions into 2 parts: distinct and not distinct + List<AggregateFunction> distinctAggFuncs = Lists.newArrayList(); + List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList(); + for (AggregateFunction aggregateFunction : normalizedAggFuncs) { + if (aggregateFunction.isDistinct()) { + distinctAggFuncs.add(aggregateFunction); + } else { + nonDistinctAggFuncs.add(aggregateFunction); } } - Set<Expression> needToSlots = collectGroupByAndArgumentsOfAggregateFunctions(aggregate); - NormalizeToSlotContext groupByAndArgumentToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, needToSlots); - Set<NamedExpression> bottomProjects = - groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots); - Plan normalizedChild = bottomProjects.isEmpty() - ? aggregate.child() - : new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); - - // begin normalize aggregate - - // replace groupBy and arguments of aggregate function to slot, may be this output contains - // some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace - // the sum(value) to slot and move the `slot + 1` to the upper project later. - List<NamedExpression> normalizeOutputPhase1 = Stream.concat( - aggregate.getOutputExpressions().stream(), - aliasOfAggFunInWindowUsedAsAggOutput.stream()) - .map(expr -> groupByAndArgumentToSlotContext - .normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance)) - .collect(Collectors.toList()); - - Set<Slot> windowInputSlots = collectWindowInputSlots(aggregate.getOutputExpressions()); - Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots); - itemsInWindow.addAll(aggregateFunctionsInWindow); - NormalizeToSlotContext windowToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, itemsInWindow); - normalizeOutputPhase1 = normalizeOutputPhase1.stream() - .map(expr -> windowToSlotContext - .normalizeToUseSlotRefDown(expr, WindowExpression.class::isInstance, true)) - .collect(Collectors.toList()); - - Set<AggregateFunction> normalizedAggregateFunctions = - collectNonWindowedAggregateFunctions(normalizeOutputPhase1); - - existsAliases = ExpressionUtils.collect(normalizeOutputPhase1, Alias.class::isInstance); - - // now reuse the exists alias for the aggregate functions, - // or create new alias for the aggregate functions - NormalizeToSlotContext aggregateFunctionToSlotContext = - NormalizeToSlotContext.buildContext(existsAliases, normalizedAggregateFunctions); - - Set<NamedExpression> normalizedAggregateFunctionsWithAlias = - aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions); - - List<Slot> normalizedGroupBy = - (List) groupByAndArgumentToSlotContext - .normalizeToUseSlotRef(aggregate.getGroupByExpressions()); - - // we can safely add all groupBy and aggregate functions to output, because we will - // add a project on it, and the upper project can protect the scope of visible of slot - List<NamedExpression> normalizedAggregateOutput = ImmutableList.<NamedExpression>builder() - .addAll(normalizedGroupBy) - .addAll(normalizedAggregateFunctionsWithAlias) + // step 2: if we only have one distinct agg function, we do push down for it + if (!distinctAggFuncs.isEmpty()) { + // process distinct normalize and put it back to normalizedAggFuncs + List<AggregateFunction> newDistinctAggFuncs = Lists.newArrayList(); + Map<Expression, Expression> replaceMap = Maps.newHashMap(); + Map<Expression, NamedExpression> aliasCache = Maps.newHashMap(); + for (AggregateFunction distinctAggFunc : distinctAggFuncs) { + List<Expression> newChildren = Lists.newArrayList(); + for (Expression child : distinctAggFunc.children()) { + if (child instanceof SlotReference) { + newChildren.add(child); + } else { + NamedExpression alias; + if (aliasCache.containsKey(child)) { + alias = aliasCache.get(child); + } else { + alias = new Alias(child, child.toSql()); + aliasCache.put(child, alias); + } + bottomProjects.add(alias); + newChildren.add(alias.toSlot()); + } + } + AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren); + replaceMap.put(distinctAggFunc, newDistinctAggFunc); + newDistinctAggFuncs.add(newDistinctAggFunc); + } + aggregateOutput = aggregateOutput.stream() + .map(e -> ExpressionUtils.replace(e, replaceMap)) + .map(NamedExpression.class::cast) + .collect(Collectors.toList()); + distinctAggFuncs = newDistinctAggFuncs; + } + normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs); + normalizedAggFuncs.addAll(distinctAggFuncs); + // TODO: process redundant expressions in aggregate functions children + // build normalized agg output + NormalizeToSlotContext normalizedAggFuncsToSlotContext = + NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs); + // agg output include 2 part, normalized group by slots and normalized agg functions + List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder() + .addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator()) + .addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)) .build(); + // add normalized agg's input slots to bottom projects + Set<Slot> bottomProjectSlots = bottomProjects.stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toSet()); + Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream() + .map(Expression::getInputSlots) + .flatMap(Set::stream) + .filter(e -> !bottomProjectSlots.contains(e)) + .collect(Collectors.toSet()); + bottomProjects.addAll(aggInputSlots); + // build group by exprs + List<Expression> normalizedGroupExprs = groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs); + // build upper project, use two context to do pop up, because agg output maybe contain two part: + // group by keys and agg expressions + List<NamedExpression> upperProjects = groupByToSlotContext + .normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput); + upperProjects = normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); + // process Expression like Alias(SlotReference#0)#0 + upperProjects = upperProjects.stream().map(e -> { + if (e instanceof Alias) { + Alias alias = (Alias) e; + if (alias.child() instanceof SlotReference) { + SlotReference slotReference = (SlotReference) alias.child(); + if (slotReference.getExprId().equals(alias.getExprId())) { + return slotReference; + } + } + } + return e; + }).collect(Collectors.toList()); + + Plan bottomPlan; + if (!bottomProjects.isEmpty()) { + bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); + } else { + bottomPlan = aggregate.child(); + } - LogicalAggregate<Plan> normalizedAggregate = aggregate.withNormalized( - (List) normalizedGroupBy, normalizedAggregateOutput, normalizedChild); - - normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput); - // exclude same-name functions in WindowExpression - List<NamedExpression> upperProjects = normalizeOutputPhase1.stream() - .map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList()); - return new LogicalProject<>(upperProjects, normalizedAggregate); + return new LogicalProject<>(upperProjects, + aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan)); }).toRule(RuleType.NORMALIZE_AGGREGATE); } - private Set<Expression> collectGroupByAndArgumentsOfAggregateFunctions(LogicalAggregate<? extends Plan> aggregate) { - // 2 parts need push down: - // groupingByExpr, argumentsOfAggregateFunction - - Set<Expression> groupingByExpr = ImmutableSet.copyOf(aggregate.getGroupByExpressions()); - - Set<AggregateFunction> aggregateFunctions = collectNonWindowedAggregateFunctions( - aggregate.getOutputExpressions()); + private static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor<Void, List<AggregateFunction>> { - Set<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream() - .flatMap(function -> function.getArguments().stream() - .map(expr -> expr instanceof OrderExpression ? expr.child(0) : expr)) - .collect(ImmutableSet.toImmutableSet()); + private static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs(); - Set<Expression> windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions()); - - Set<Expression> needPushDown = ImmutableSet.<Expression>builder() - // group by should be pushed down, e.g. group by (k + 1), - // we should push down the `k + 1` to the bottom plan - .addAll(groupingByExpr) - // e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan - .addAll(argumentsOfAggregateFunction) - .addAll(windowFunctionKeys) - .build(); - return needPushDown; - } - - private Set<Expression> collectWindowFunctionKeys(List<NamedExpression> aggOutput) { - Set<Expression> windowInputs = Sets.newHashSet(); - for (Expression expr : aggOutput) { - Set<WindowExpression> windows = expr.collect(WindowExpression.class::isInstance); - for (WindowExpression win : windows) { - windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk -> pk.getInputSlots().stream()).collect( - Collectors.toList())); - windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok -> ok.getInputSlots().stream()).collect( - Collectors.toList())); + @Override + public Void visitWindow(WindowExpression windowExpression, List<AggregateFunction> context) { + for (Expression child : windowExpression.getExpressionsInWindowSpec()) { + child.accept(this, context); } + return null; } - return windowInputs; - } - /** - * select sum(c2), avg(min(c2)) over (partition by max(c1) order by count(c1)) from T ... - * extract {sum, min, max, count}. avg is not extracted. - */ - private Set<AggregateFunction> collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) { - return ExpressionUtils.collect(aggOutput, expr -> { - if (expr instanceof AggregateFunction) { - return !((AggregateFunction) expr).isWindowFunction(); - } - return false; - }); - } - - private Set<AggregateFunction> collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) { - - List<WindowExpression> windows = Lists.newArrayList( - ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance)); - return ExpressionUtils.collect(windows, expr -> { - if (expr instanceof AggregateFunction) { - return !((AggregateFunction) expr).isWindowFunction(); - } - return false; - }); - } - - private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput) { - List<WindowExpression> windows = Lists.newArrayList( - ExpressionUtils.collect(aggOutput, WindowExpression.class::isInstance)); - return windows.stream().flatMap(win -> win.getInputSlots().stream()).collect(Collectors.toSet()); + @Override + public Void visitAggregateFunction(AggregateFunction aggregateFunction, List<AggregateFunction> context) { + context.add(aggregateFunction); + return null; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java index 8ef966496e..974655a80b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java @@ -21,17 +21,20 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.WindowExpression; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; -import java.util.function.Predicate; +import java.util.stream.Collectors; import javax.annotation.Nullable; /** NormalizeToSlot */ @@ -45,9 +48,16 @@ public interface NormalizeToSlot { this.normalizeToSlotMap = normalizeToSlotMap; } - /** buildContext */ + /** + * build normalization context by follow step. + * 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias + * 2. for all input source expressions, use existsAliasMap to construct triple: + * origin expr, pushed expr and alias to replace origin expr, + * see more detail in {@link NormalizeToSlotTriplet} + * 3. construct a map: original expr -> triple constructed by step 2 + */ public static NormalizeToSlotContext buildContext( - Set<Alias> existsAliases, Set<? extends Expression> sourceExpressions) { + Set<Alias> existsAliases, Collection<? extends Expression> sourceExpressions) { Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap(); Map<Expression, Alias> existsAliasMap = Maps.newLinkedHashMap(); @@ -70,13 +80,21 @@ public interface NormalizeToSlot { return normalizeToUseSlotRef(ImmutableList.of(expression)).get(0); } - /** normalizeToUseSlotRef, no custom normalize */ - public <E extends Expression> List<E> normalizeToUseSlotRef(List<E> expressions) { + /** + * normalizeToUseSlotRef, no custom normalize. + * This function use a lambda that always return original expression as customNormalize + * So always use normalizeToSlotMap to process normalization when we call this function + */ + public <E extends Expression> List<E> normalizeToUseSlotRef(Collection<E> expressions) { return normalizeToUseSlotRef(expressions, (context, expr) -> expr); } - /** normalizeToUseSlotRef */ - public <E extends Expression> List<E> normalizeToUseSlotRef(List<E> expressions, + /** + * normalizeToUseSlotRef. + * try to use customNormalize do normalization first. if customNormalize cannot handle current expression, + * use normalizeToSlotMap to get the default replaced expression. + */ + public <E extends Expression> List<E> normalizeToUseSlotRef(Collection<E> expressions, BiFunction<NormalizeToSlotContext, Expression, Expression> customNormalize) { return expressions.stream() .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { @@ -89,22 +107,11 @@ public interface NormalizeToSlot { })).collect(ImmutableList.toImmutableList()); } - public <E extends Expression> E normalizeToUseSlotRefUp(E expression, Predicate skip) { - return (E) expression.rewriteDownShortCircuitUp(child -> { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; - }, skip); - } - - /** - * rewrite subtrees whose root matches predicate border - * when we traverse to the node satisfies border predicate, aboveBorder becomes false - */ - public <E extends Expression> E normalizeToUseSlotRefDown(E expression, Predicate border, boolean aboveBorder) { - return (E) expression.rewriteDownShortCircuitDown(child -> { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; - }, border, aboveBorder); + public <E extends Expression> List<E> normalizeToUseSlotRefWithoutWindowFunction( + Collection<E> expressions) { + return expressions.stream() + .map(e -> (E) e.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap)) + .collect(Collectors.toList()); } /** @@ -124,6 +131,54 @@ public interface NormalizeToSlot { } } + /** + * replace any expression except window function. + * because the window function could be same with aggregate function and should never be replaced. + */ + class NormalizeWithoutWindowFunction + extends DefaultExpressionRewriter<Map<Expression, NormalizeToSlotTriplet>> { + + public static final NormalizeWithoutWindowFunction INSTANCE = new NormalizeWithoutWindowFunction(); + + private NormalizeWithoutWindowFunction() { + } + + @Override + public Expression visit(Expression expr, Map<Expression, NormalizeToSlotTriplet> replaceMap) { + if (replaceMap.containsKey(expr)) { + return replaceMap.get(expr).remainExpr; + } + return super.visit(expr, replaceMap); + } + + @Override + public Expression visitWindow(WindowExpression windowExpression, + Map<Expression, NormalizeToSlotTriplet> replaceMap) { + if (replaceMap.containsKey(windowExpression)) { + return replaceMap.get(windowExpression).remainExpr; + } + List<Expression> newChildren = new ArrayList<>(); + Expression function = super.visit(windowExpression.getFunction(), replaceMap); + newChildren.add(function); + boolean hasNewChildren = function != windowExpression.getFunction(); + for (Expression partitionKey : windowExpression.getPartitionKeys()) { + Expression newChild = partitionKey.accept(this, replaceMap); + if (newChild != partitionKey) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + for (Expression orderKey : windowExpression.getOrderKeys()) { + Expression newChild = orderKey.accept(this, replaceMap); + if (newChild != orderKey) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression; + } + } + /** NormalizeToSlotTriplet */ class NormalizeToSlotTriplet { // which expression need to normalized to slot? @@ -142,7 +197,12 @@ public interface NormalizeToSlot { this.pushedExpr = pushedExpr; } - /** toTriplet */ + /** + * construct triplet by three conditions. + * 1. already has exists alias: use this alias as pushed expr + * 2. expression is {@link NamedExpression}, use itself as pushed expr + * 3. other expression, construct a new Alias contains current expression as pushed expr + */ public static NormalizeToSlotTriplet toTriplet(Expression expression, @Nullable Alias existsAlias) { if (existsAlias != null) { return new NormalizeToSlotTriplet(expression, existsAlias.toSlot(), existsAlias); @@ -150,9 +210,7 @@ public interface NormalizeToSlot { if (expression instanceof NamedExpression) { NamedExpression namedExpression = (NamedExpression) expression; - NormalizeToSlotTriplet normalizeToSlotTriplet = - new NormalizeToSlotTriplet(expression, namedExpression.toSlot(), namedExpression); - return normalizeToSlotTriplet; + return new NormalizeToSlotTriplet(expression, namedExpression.toSlot(), namedExpression); } Alias alias = new Alias(expression, expression.toSql()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index 0394ebea87..3c64a043d6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -96,33 +96,6 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> { return currentNode; } - /** - * same as rewriteDownShortCircuit, - * except that subtrees, whose root satisfies predicate is satisfied, are not rewritten - */ - default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE> rewriteFunction, Predicate skip) { - NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this); - if (skip.test(currentNode)) { - return currentNode; - } - if (currentNode == this) { - Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity()); - boolean changed = false; - for (NODE_TYPE child : children()) { - NODE_TYPE newChild = child.rewriteDownShortCircuitUp(rewriteFunction, skip); - if (child != newChild) { - changed = true; - } - newChildren.add(newChild); - } - - if (changed) { - currentNode = currentNode.withChildren(newChildren.build()); - } - } - return currentNode; - } - /** * similar to rewriteDownShortCircuit, except that only subtrees, whose root satisfies * border predicate are rewritten. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java index ffc0522498..831acbf5a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java @@ -19,7 +19,6 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.UnaryNode; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; @@ -55,9 +54,6 @@ public class WindowExpression extends Expression { .addAll(orderKeys) .build().toArray(new Expression[0])); this.function = function; - if (function instanceof AggregateFunction) { - ((AggregateFunction) function).setWindowFunction(true); - } this.partitionKeys = ImmutableList.copyOf(partitionKeys); this.orderKeys = ImmutableList.copyOf(orderKeys); this.windowFrame = Optional.empty(); @@ -73,9 +69,6 @@ public class WindowExpression extends Expression { .add(windowFrame) .build().toArray(new Expression[0])); this.function = function; - if (function instanceof AggregateFunction) { - ((AggregateFunction) function).setWindowFunction(true); - } this.partitionKeys = ImmutableList.copyOf(partitionKeys); this.orderKeys = ImmutableList.copyOf(orderKeys); this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index a170ae0dd5..a7e523dfdb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -38,7 +38,6 @@ import java.util.stream.Collectors; public abstract class AggregateFunction extends BoundFunction implements ExpectsInputTypes { protected final boolean distinct; - protected boolean isWindowFunction = false; public AggregateFunction(String name, Expression... arguments) { this(name, false, arguments); @@ -78,14 +77,6 @@ public abstract class AggregateFunction extends BoundFunction implements Expects return distinct; } - public boolean isWindowFunction() { - return isWindowFunction; - } - - public void setWindowFunction(boolean windowFunction) { - isWindowFunction = windowFunction; - } - @Override public boolean equals(Object o) { if (this == o) { @@ -95,8 +86,7 @@ public abstract class AggregateFunction extends BoundFunction implements Expects return false; } AggregateFunction that = (AggregateFunction) o; - return isWindowFunction == that.isWindowFunction - && Objects.equals(distinct, that.distinct) + return Objects.equals(distinct, that.distinct) && Objects.equals(getName(), that.getName()) && Objects.equals(children, that.children); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java index 96c833a1de..be7da80ed1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java @@ -197,7 +197,7 @@ public class ExtractAndNormalizeWindowExpressionTest implements MemoPatternMatch // when Window's function is same as AggregateFunction. // In this example, agg function [sum(id+1)] is same as Window's function [sum(id+1) over...] List<NamedExpression> projects = project.getProjects(); - return projects.get(1).child(0) instanceof SlotReference + return projects.get(1) instanceof SlotReference && projects.get(2).equals(windowAlias); }) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java index 254684eedb..ee0316e67f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java @@ -88,10 +88,9 @@ public class NormalizeAggregateTest implements MemoPatternMatchSupported { .equals(aggregateFunction.child(0))) .when(FieldChecker.check("normalized", true)) ).when(project -> project.getProjects().get(0).equals(key)) - .when(project -> project.getProjects().get(1) instanceof Alias) .when(project -> (project.getProjects().get(1)).getExprId() .equals(aggregateFunction.getExprId())) - .when(project -> project.getProjects().get(1).child(0) instanceof SlotReference) + .when(project -> project.getProjects().get(1) instanceof SlotReference) ); } @@ -102,8 +101,8 @@ public class NormalizeAggregateTest implements MemoPatternMatchSupported { * * after rewrite: * LogicalProject ( projects=[(sum((id * 1))#6 + 2) AS `(sum((id * 1)) + 2)`#4] ) - * +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum((id * 1)#5) AS `sum((id * 1))`#6], groupByExpr=[name#2] ) - * +--LogicalProject ( projects=[name#2, (id#0 * 1) AS `(id * 1)`#5] ) + * +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum(id#0 * 1) AS `sum((id * 1))`#6], groupByExpr=[name#2] ) + * +--LogicalProject ( projects=[name#2, id#0] ) * +--GroupPlan( GroupId#0 ) */ @Test @@ -126,8 +125,6 @@ public class NormalizeAggregateTest implements MemoPatternMatchSupported { logicalProject( logicalOlapScan() ).when(project -> project.getProjects().size() == 2) - .when(project -> project.getProjects().get(0) instanceof SlotReference) - .when(project -> project.getProjects().get(1).child(0).equals(multiply)) ).when(agg -> agg.getGroupByExpressions().equals( ImmutableList.of(rStudent.getOutput().get(2))) ) diff --git a/regression-test/suites/nereids_syntax_p0/explain.groovy b/regression-test/suites/nereids_syntax_p0/explain.groovy index 91a2abb95d..3a98d42f62 100644 --- a/regression-test/suites/nereids_syntax_p0/explain.groovy +++ b/regression-test/suites/nereids_syntax_p0/explain.groovy @@ -25,7 +25,6 @@ suite("nereids_explain") { explain { sql("select count(2) + 1, sum(2) + sum(lo_suppkey) from lineorder") contains "(sum(2) + sum(lo_suppkey))[#" - contains "project output tuple id: 1" } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org