This is an automated email from the ASF dual-hosted git repository. huajianlan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 3f3f2eb098 [Nereids][Improve] infer predicate after push down predicate (#12996) 3f3f2eb098 is described below commit 3f3f2eb098870acbfd29508db54fcab37e585f81 Author: shee <13843187+qz...@users.noreply.github.com> AuthorDate: Tue Nov 8 21:36:17 2022 +0800 [Nereids][Improve] infer predicate after push down predicate (#12996) This PR implements the function of predicate inference For example: ``` sql select * from student left join score on student.id = score.sid where score.sid > 1 ``` transformed logical plan tree: left join / \ filter(sid >1) filter(id > 1) <---- inferred predicate | | scan scan See `InferPredicatesTest` for more cases The logic is as follows: 1. poll up bottom predicate then infer additional predicates for example: select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id 1. poll up bottom predicate select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 2. infer select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 and t2.id = 1 finally transformed sql: select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next round of predicate push-down Now only support infer `ComparisonPredicate`. TODO: We should determine whether `expression` satisfies the condition for replacement eg: Satisfy `expression` is non-deterministic --- .../org/apache/doris/nereids/jobs/JobType.java | 1 + .../doris/nereids/jobs/batch/BatchRulesJob.java | 7 + .../jobs/batch/NereidsRewriteJobExecutor.java | 3 + .../nereids/jobs/rewrite/VisitorRewriteJob.java | 56 +++ .../rules/rewrite/logical/InferPredicates.java | 117 +++++ .../rewrite/logical/PredicatePropagation.java | 105 ++++ .../rules/rewrite/logical/PullUpPredicates.java | 165 +++++++ .../nereids/trees/expressions/Expression.java | 4 + .../nereids/trees/plans/logical/LogicalJoin.java | 6 + .../rules/rewrite/logical/InferPredicatesTest.java | 531 +++++++++++++++++++++ 10 files changed, 995 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/JobType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/JobType.java index 6531376177..528109babb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/JobType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/JobType.java @@ -29,6 +29,7 @@ public enum JobType { APPLY_RULE, DERIVE_STATS, TOP_DOWN_REWRITE, + VISITOR_REWRITE, BOTTOM_UP_REWRITE ; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java index c77c0ec8e2..52014728ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/BatchRulesJob.java @@ -19,11 +19,14 @@ package org.apache.doris.nereids.jobs.batch; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.Job; +import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob; import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob; import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob; +import org.apache.doris.nereids.jobs.rewrite.VisitorRewriteJob; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleFactory; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import java.util.ArrayList; import java.util.List; @@ -71,6 +74,10 @@ public abstract class BatchRulesJob { cascadesContext.getCurrentJobContext(), once); } + protected Job visitorJob(DefaultPlanRewriter<JobContext> planRewriter) { + return new VisitorRewriteJob(cascadesContext, planRewriter, true); + } + protected Job optimize() { return new OptimizeGroupJob( cascadesContext.getMemo().getRoot(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java index 52311db55a..634e4ca9fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit; import org.apache.doris.nereids.rules.rewrite.logical.EliminateUnnecessaryProject; import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction; import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin; +import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates; import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown; import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition; @@ -65,9 +66,11 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob { .add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction()))) .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) .add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false)) + .add(visitorJob(new InferPredicates())) .add(topDownBatch(ImmutableList.of(new ReorderJoin()))) .add(topDownBatch(ImmutableList.of(new ColumnPruning()))) .add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false)) + .add(visitorJob(new InferPredicates())) .add(topDownBatch(ImmutableList.of(PushFilterInsideJoin.INSTANCE))) .add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin()))) .add(topDownBatch(ImmutableList.of(new LimitPushDown()))) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java new file mode 100644 index 0000000000..f579de23ad --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/VisitorRewriteJob.java @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.jobs.Job; +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.jobs.JobType; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; + +import java.util.Objects; + +/** + * Use visitor to rewrite the plan. + */ +public class VisitorRewriteJob extends Job { + private final Group group; + + private final DefaultPlanRewriter<JobContext> planRewriter; + + /** + * Constructor. + */ + public VisitorRewriteJob(CascadesContext cascadesContext, DefaultPlanRewriter<JobContext> rewriter, boolean once) { + super(JobType.VISITOR_REWRITE, cascadesContext.getCurrentJobContext(), once); + this.group = Objects.requireNonNull(cascadesContext.getMemo().getRoot(), "group cannot be null"); + this.planRewriter = Objects.requireNonNull(rewriter, "planRewriter cannot be null"); + } + + @Override + public void execute() { + GroupExpression logicalExpression = group.getLogicalExpression(); + Plan root = context.getCascadesContext().getMemo().copyOut(logicalExpression, true); + Plan rewrittenRoot = root.accept(planRewriter, context); + context.getCascadesContext().getMemo().copyIn(rewrittenRoot, group, true); + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicates.java new file mode 100644 index 0000000000..d1e5ab77ef --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicates.java @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * infer additional predicates for `LogicalFilter` and `LogicalJoin`. + * The logic is as follows: + * 1. poll up bottom predicate then infer additional predicates + * for example: + * select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id + * 1. poll up bottom predicate + * select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 + * 2. infer + * select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 and t2.id = 1 + * finally transformed sql: + * select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1 + * 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next + * round of predicate push-down + */ +public class InferPredicates extends DefaultPlanRewriter<JobContext> { + private final PredicatePropagation propagation = new PredicatePropagation(); + private final PullUpPredicates pollUpPredicates = new PullUpPredicates(); + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) { + join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, context); + Plan left = join.left(); + Plan right = join.right(); + Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition()); + List<Expression> otherJoinConjuncts = Lists.newArrayList(join.getOtherJoinConjuncts()); + switch (join.getJoinType()) { + case INNER_JOIN: + case CROSS_JOIN: + case LEFT_SEMI_JOIN: + case RIGHT_SEMI_JOIN: + otherJoinConjuncts.addAll(inferNewPredicate(left, expressions)); + otherJoinConjuncts.addAll(inferNewPredicate(right, expressions)); + break; + case LEFT_OUTER_JOIN: + case LEFT_ANTI_JOIN: + otherJoinConjuncts.addAll(inferNewPredicate(right, expressions)); + break; + case RIGHT_OUTER_JOIN: + case RIGHT_ANTI_JOIN: + otherJoinConjuncts.addAll(inferNewPredicate(left, expressions)); + break; + default: + return join; + } + return join.withOtherJoinConjuncts(otherJoinConjuncts); + } + + @Override + public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext context) { + filter = (LogicalFilter<? extends Plan>) super.visit(filter, context); + Set<Expression> filterPredicates = pullUpPredicates(filter); + filterPredicates.removeAll(pullUpPredicates(filter.child())); + filter.getConjuncts().forEach(filterPredicates::remove); + if (!filterPredicates.isEmpty()) { + filterPredicates.addAll(filter.getConjuncts()); + return new LogicalFilter<>(ExpressionUtils.and(Lists.newArrayList(filterPredicates)), filter.child()); + } + return filter; + } + + private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) { + Set<Expression> baseExpressions = pullUpPredicates(left); + baseExpressions.addAll(pullUpPredicates(right)); + condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); + baseExpressions.addAll(propagation.infer(baseExpressions)); + return baseExpressions; + } + + private Set<Expression> pullUpPredicates(Plan plan) { + return Sets.newHashSet(plan.accept(pollUpPredicates, null)); + } + + private List<Expression> inferNewPredicate(Plan plan, Set<Expression> expressions) { + List<Expression> predicates = expressions.stream() + .filter(c -> !c.getInputSlots().isEmpty() && plan.getOutputSet().containsAll( + c.getInputSlots())).collect(Collectors.toList()); + predicates.removeAll(plan.accept(pollUpPredicates, null)); + return predicates; + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PredicatePropagation.java new file mode 100644 index 0000000000..4c8b72ef58 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PredicatePropagation.java @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.Sets; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * derive additional predicates. + * for example: + * a = b and a = 1 => b = 1 + */ +public class PredicatePropagation { + + /** + * infer additional predicates. + */ + public Set<Expression> infer(Set<Expression> predicates) { + Set<Expression> inferred = Sets.newHashSet(); + for (Expression predicate : predicates) { + if (canEquivalentInfer(predicate)) { + List<Expression> newInferred = predicates.stream() + .filter(p -> !p.equals(predicate)) + .map(p -> doInfer(predicate, p)) + .collect(Collectors.toList()); + inferred.addAll(newInferred); + } + } + inferred.removeAll(predicates); + return inferred; + } + + /** + * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` + * Now only support infer `ComparisonPredicate`. + * TODO: We should determine whether `expression` satisfies the condition for replacement + * eg: Satisfy `expression` is non-deterministic + */ + private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) { + return expression.accept(new DefaultExpressionRewriter<Void>() { + + @Override + public Expression visit(Expression expr, Void context) { + return expr; + } + + @Override + public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { + if (cp.left().isSlot() && cp.right().isConstant()) { + return replaceSlot(cp); + } else if (cp.left().isConstant() && cp.right().isSlot()) { + return replaceSlot(cp); + } + return super.visit(cp, context); + } + + private Expression replaceSlot(Expression expr) { + return expr.rewriteUp(e -> { + if (e.equals(leftSlotEqualToRightSlot.child(0))) { + return leftSlotEqualToRightSlot.child(1); + } else if (e.equals(leftSlotEqualToRightSlot.child(1))) { + return leftSlotEqualToRightSlot.child(0); + } else { + return e; + } + }); + } + }, null); + } + + /** + * Currently only equivalence derivation is supported + * and requires that the left and right sides of an expression must be slot + */ + private boolean canEquivalentInfer(Expression predicate) { + return predicate instanceof EqualTo + && predicate.children().stream().allMatch(e -> e instanceof SlotReference) + && predicate.child(0).getDataType().equals(predicate.child(1).getDataType()); + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpPredicates.java new file mode 100644 index 0000000000..9d86415344 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpPredicates.java @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.Collection; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * poll up effective predicates from operator's children. + */ +public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> { + + PredicatePropagation propagation = new PredicatePropagation(); + Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>(); + + @Override + public ImmutableSet<Expression> visit(Plan plan, Void context) { + if (plan.arity() == 1) { + return plan.child(0).accept(this, context); + } + return ImmutableSet.of(); + } + + @Override + public ImmutableSet<Expression> visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) { + return cacheOrElse(filter, () -> { + List<Expression> predicates = Lists.newArrayList(filter.getConjuncts()); + predicates.addAll(filter.child().accept(this, context)); + return getAvailableExpressions(predicates, filter); + }); + } + + @Override + public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) { + return cacheOrElse(join, () -> { + Set<Expression> predicates = Sets.newHashSet(); + ImmutableSet<Expression> leftPredicates = join.left().accept(this, context); + ImmutableSet<Expression> rightPredicates = join.right().accept(this, context); + switch (join.getJoinType()) { + case INNER_JOIN: + case CROSS_JOIN: + predicates.addAll(leftPredicates); + predicates.addAll(rightPredicates); + join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on))); + break; + case LEFT_SEMI_JOIN: + predicates.addAll(leftPredicates); + join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on))); + break; + case RIGHT_SEMI_JOIN: + predicates.addAll(rightPredicates); + join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on))); + break; + case LEFT_OUTER_JOIN: + case LEFT_ANTI_JOIN: + predicates.addAll(leftPredicates); + break; + case RIGHT_OUTER_JOIN: + case RIGHT_ANTI_JOIN: + predicates.addAll(rightPredicates); + break; + default: + } + return getAvailableExpressions(predicates, join); + }); + } + + @Override + public ImmutableSet<Expression> visitLogicalProject(LogicalProject<? extends Plan> project, Void context) { + return cacheOrElse(project, () -> { + ImmutableSet<Expression> childPredicates = project.child().accept(this, context); + Map<Expression, Slot> expressionSlotMap = project.getAliasToProducer() + .entrySet() + .stream() + .collect(Collectors.toMap(Entry::getValue, Entry::getKey)); + Expression expression = ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(childPredicates)), + expressionSlotMap); + List<Expression> predicates = ExpressionUtils.extractConjunction(expression); + return getAvailableExpressions(predicates, project); + }); + } + + @Override + public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) { + return cacheOrElse(aggregate, () -> { + ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context); + Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions() + .stream() + .filter(this::hasAgg) + .collect(Collectors.toMap( + namedExpr -> { + if (namedExpr instanceof Alias) { + return ((Alias) namedExpr).child(); + } else { + return namedExpr; + } + }, NamedExpression::toSlot) + ); + Expression expression = ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(childPredicates)), + expressionSlotMap); + List<Expression> predicates = ExpressionUtils.extractConjunction(expression); + return getAvailableExpressions(predicates, aggregate); + }); + } + + private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Expression>> predicatesSupplier) { + ImmutableSet<Expression> predicates = cache.get(plan); + if (predicates != null) { + return predicates; + } + predicates = predicatesSupplier.get(); + cache.put(plan, predicates); + return predicates; + } + + private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) { + Set<Expression> expressions = Sets.newHashSet(predicates); + expressions.addAll(propagation.infer(expressions)); + return expressions.stream() + .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) + .collect(ImmutableSet.toImmutableSet()); + } + + private boolean hasAgg(Expression expression) { + return expression.anyMatch(AggregateFunction.class::isInstance); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index b1d037c282..9ffa29b45e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -139,6 +139,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements return this instanceof NullLiteral; } + public boolean isSlot() { + return this instanceof Slot; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index d0279b921e..638973d2aa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -258,4 +258,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends public LogicalJoin withJoinType(JoinType joinType) { return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, left(), right(), joinReorderContext); } + + public LogicalJoin withOtherJoinConjuncts(List<Expression> otherJoinConjuncts) { + return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, left(), right(), + joinReorderContext); + } } + diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java new file mode 100644 index 0000000000..3c31b0466c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/InferPredicatesTest.java @@ -0,0 +1,531 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +public class InferPredicatesTest extends TestWithFeService implements PatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + + createTable("create table test.student (\n" + + "id int not null,\n" + + "name varchar(128),\n" + + "age int,sex int)\n" + + "distributed by hash(id) buckets 10\n" + + "properties('replication_num' = '1');"); + + createTable("create table test.score (\n" + + "sid int not null, \n" + + "cid int not null, \n" + + "grade double)\n" + + "distributed by hash(sid,cid) buckets 10\n" + + "properties('replication_num' = '1');"); + + createTable("create table test.course (\n" + + "id int not null, \n" + + "name varchar(128), \n" + + "teacher varchar(128))\n" + + "distributed by hash(id) buckets 10\n" + + "properties('replication_num' = '1');"); + + createTables("create table test.subquery1\n" + + "(k1 bigint, k2 bigint)\n" + + "duplicate key(k1)\n" + + "distributed by hash(k2) buckets 1\n" + + "properties('replication_num' = '1');\n", + "create table test.subquery2\n" + + "(k1 varchar(10), k2 bigint)\n" + + "partition by range(k2)\n" + + "(partition p1 values less than(\"10\"))\n" + + "distributed by hash(k2) buckets 1\n" + + "properties('replication_num' = '1');", + "create table test.subquery3\n" + + "(k1 int not null, k2 varchar(128), k3 bigint, v1 bigint, v2 bigint)\n" + + "distributed by hash(k2) buckets 1\n" + + "properties('replication_num' = '1');", + "create table test.subquery4\n" + + "(k1 bigint, k2 bigint)\n" + + "duplicate key(k1)\n" + + "distributed by hash(k2) buckets 1\n" + + "properties('replication_num' = '1');"); + + connectContext.setDatabase("default_cluster:test"); + } + + @Test + public void inferPredicatesTest01() { + String sql = "select * from student join score on student.id = score.sid where student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest02() { + String sql = "select * from student join score on student.id = score.sid"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ); + } + + @Test + public void inferPredicatesTest03() { + String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id IN (1, 2, 3)")), + logicalOlapScan() + ) + ); + } + + @Test + public void inferPredicatesTest04() { + String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id IN (1, 2, 3)")), + logicalOlapScan() + ) + ); + } + + @Test + public void inferPredicatesTest05() { + String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest06() { + String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest07() { + String sql = "select * from student left join score on student.id = score.sid where student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest08() { + String sql = "select * from student left join score on student.id = score.sid and student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalOlapScan(), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest09() { + // convert left join to inner join + String sql = "select * from student left join score on student.id = score.sid where score.sid > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest10() { + String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("id > 1")) + ), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest11() { + String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalProject( + logicalOlapScan() + ), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ); + } + + @Test + public void inferPredicatesTest12() { + String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("id > 1")), + logicalProject( + logicalAggregate( + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + )) + ) + ) + ); + } + + @Test + public void inferPredicatesTest13() { + String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("id = 1")) + ), + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid = 1")) + ) + ); + } + + @Test + public void inferPredicatesTest14() { + String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ) + ); + } + + @Test + public void inferPredicatesTest15() { + String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filer -> filer.getPredicates().toSql().contains("sid > 1")) + ) + ) + ); + } + + @Test + public void inferPredicatesTest16() { + String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalOlapScan(), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ) + ) + ); + } + + @Test + public void inferPredicatesTest17() { + String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalOlapScan(), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ) + ) + ); + } + + @Test + public void inferPredicatesTest18() { + String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ) + ) + ); + } + + @Test + public void inferPredicatesTest19() { + String sql = "select * from subquery1 left semi join (select t1.k3 from (select * from subquery3 left semi join (select k1 from subquery4 where k1 = 3) t on subquery3.k3 = t.k1) t1 inner join (select k2,sum(k2) as sk2 from subquery2 group by k2) t2 on t2.k2 = t1.v1 and t1.v2 > t2.sk2) t3 on t3.k3 = subquery1.k1"; + Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan(); + System.out.println(plan.treeString()); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("k1 = 3")), + logicalProject( + logicalJoin( + logicalJoin( + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("k3 = 3")) + ), + logicalProject( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("k1 = 3")) + ) + ), + logicalProject() + ) + ) + ) + ); + } + + @Test + public void inferPredicatesTest20() { + String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid"; + PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalJoin( + logicalOlapScan(), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ), + logicalOlapScan() + ) + ); + } + + @Test + public void inferPredicatesTest21() { + String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1"; + PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matchesFromRoot( + logicalJoin( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("sid > 1")) + ), + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getPredicates().toSql().contains("id > 1")) + ) + ); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org