This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 858e8234d7 [feature](Nereids) add predicates push down on all join type (#12571) 858e8234d7 is described below commit 858e8234d742a60729df6c7a90f526d58ef0cc12 Author: morrySnow <101034200+morrys...@users.noreply.github.com> AuthorDate: Thu Sep 15 15:18:42 2022 +0800 [feature](Nereids) add predicates push down on all join type (#12571) * [feature](Nereids) add predicates push down on all join type --- .../doris/nereids/jobs/batch/RewriteJob.java | 17 +- .../org/apache/doris/nereids/rules/RuleSet.java | 17 +- .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../logical/PushDownJoinOtherCondition.java | 99 +++++++++ ...ughJoin.java => PushPredicatesThroughJoin.java} | 91 +++++--- .../logical/FindHashConditionForJoinTest.java | 4 +- .../rules/rewrite/logical/LimitPushDownTest.java | 9 +- .../logical/PruneOlapScanPartitionTest.java | 8 +- .../logical/PushDownJoinOtherConditionTest.java | 196 ++++++++++++++++++ .../rewrite/logical/PushDownPredicateTest.java | 228 --------------------- .../logical/PushPredicateThroughJoinTest.java | 208 +++++++++++++++++++ .../apache/doris/nereids/util/PlanConstructor.java | 8 +- 12 files changed, 594 insertions(+), 292 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java index 242687ba6d..245095abc8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/RewriteJob.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.batch; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.Job; +import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization; import org.apache.doris.nereids.rules.mv.SelectRollup; import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; @@ -27,14 +28,8 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit; import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin; import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits; -import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects; import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate; import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition; -import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin; -import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject; -import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit; import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; import com.google.common.collect.ImmutableList; @@ -64,15 +59,9 @@ public class RewriteJob extends BatchRulesJob { .add(topDownBatch(ImmutableList.of(new ExpressionNormalization()))) .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) .add(topDownBatch(ImmutableList.of(new ReorderJoin()))) - .add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin()))) - .add(topDownBatch(ImmutableList.of(new NormalizeAggregate()))) .add(topDownBatch(ImmutableList.of(new ColumnPruning()))) - .add(topDownBatch(ImmutableList.of(new PushPredicateThroughJoin(), - new PushdownProjectThroughLimit(), - new PushdownFilterThroughProject(), - new MergeConsecutiveProjects(), - new MergeConsecutiveFilters(), - new MergeConsecutiveLimits()))) + .add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES)) + .add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin()))) .add(topDownBatch(ImmutableList.of(new AggregateDisassemble()))) .add(topDownBatch(ImmutableList.of(new LimitPushDown()))) .add(topDownBatch(ImmutableList.of(new EliminateLimit()))) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index fa480135ce..cf2c0ce311 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -33,9 +33,13 @@ import org.apache.doris.nereids.rules.implementation.LogicalOneRowRelationToPhys import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject; import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort; import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; -import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; +import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters; +import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits; import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects; +import org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition; +import org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin; import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject; +import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -55,9 +59,14 @@ public class RuleSet { .add(new MergeConsecutiveProjects()) .build(); - public static final List<Rule> REWRITE_RULES = planRuleFactories() - .add(new AggregateDisassemble()) - .build(); + public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of( + new PushDownJoinOtherCondition(), + new PushPredicatesThroughJoin(), + new PushdownProjectThroughLimit(), + new PushdownFilterThroughProject(), + new MergeConsecutiveProjects(), + new MergeConsecutiveFilters(), + new MergeConsecutiveLimits()); public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories() .add(new LogicalAggToPhysicalHashAgg()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 826acf522f..c9fe816970 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -79,6 +79,7 @@ public enum RuleType { EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE), // predicate push down rules PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE), PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE), // column prune rules, COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java new file mode 100644 index 0000000000..88744fbe65 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherCondition.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Set; + +/** + * Push the other join conditions in LogicalJoin to children. + */ +public class PushDownJoinOtherCondition extends OneRewriteRuleFactory { + private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.LEFT_SEMI_JOIN, + JoinType.RIGHT_OUTER_JOIN, + JoinType.RIGHT_ANTI_JOIN, + JoinType.RIGHT_SEMI_JOIN, + JoinType.CROSS_JOIN + ); + + private static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.LEFT_OUTER_JOIN, + JoinType.LEFT_ANTI_JOIN, + JoinType.LEFT_SEMI_JOIN, + JoinType.RIGHT_SEMI_JOIN, + JoinType.CROSS_JOIN + ); + + @Override + public Rule build() { + return logicalJoin().then(join -> { + if (!join.getOtherJoinCondition().isPresent()) { + return null; + } + List<Expression> otherConjuncts = ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get()); + List<Expression> leftConjuncts = Lists.newArrayList(); + List<Expression> rightConjuncts = Lists.newArrayList(); + + for (Expression otherConjunct : otherConjuncts) { + if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()) + && allCoveredBy(otherConjunct, join.left().getOutputSet())) { + leftConjuncts.add(otherConjunct); + } + if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType()) + && allCoveredBy(otherConjunct, join.right().getOutputSet())) { + rightConjuncts.add(otherConjunct); + } + } + + if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) { + return null; + } + + otherConjuncts.removeAll(leftConjuncts); + otherConjuncts.removeAll(rightConjuncts); + + Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left()); + Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right()); + + return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), + ExpressionUtils.optionalAnd(otherConjuncts), left, right); + + }).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION); + } + + private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) { + return inputSlotSet.containsAll(predicate.getInputSlots()); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java similarity index 61% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java index 9859deb747..3cdfac4918 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicatesThroughJoin.java @@ -21,15 +21,17 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; @@ -37,17 +39,39 @@ import java.util.Objects; import java.util.Set; /** - * Push the predicate in the LogicalFilter or LogicalJoin to the join children. - * todo: Now, only support eq on condition for inner join, support other case later + * Push the predicate in the LogicalFilter to the join children. */ -public class PushPredicateThroughJoin extends OneRewriteRuleFactory { +public class PushPredicatesThroughJoin extends OneRewriteRuleFactory { + + private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.LEFT_OUTER_JOIN, + JoinType.LEFT_SEMI_JOIN, + JoinType.LEFT_ANTI_JOIN, + JoinType.CROSS_JOIN + ); + + private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.RIGHT_OUTER_JOIN, + JoinType.RIGHT_SEMI_JOIN, + JoinType.RIGHT_ANTI_JOIN, + JoinType.CROSS_JOIN + ); + + private static final ImmutableList<JoinType> COULD_PUSH_EQUAL_TO = ImmutableList.of( + JoinType.INNER_JOIN + ); + /* * For example: - * select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2 + * select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 + * where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2 + * * Logical plan tree: * project * | - * filter (a.k1 > 1 and b.k1 > 2) + * filter (a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2) * | * join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5) * / \ @@ -55,69 +79,72 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { * transformed: * project * | - * join (a.k1 = b.k1) + * filter(a.k2 > b.k2) + * | + * join (otherConditions: a.k1 = b.k1) * / \ - * filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5) + * filter(a.k1 > 1 and a.k2 > 2) filter(b.k1 > 2 and b.k2 > 5) * | | * scan scan */ @Override public Rule build() { - return logicalFilter(innerLogicalJoin()).then(filter -> { + return logicalFilter(logicalJoin()).then(filter -> { LogicalJoin<GroupPlan, GroupPlan> join = filter.child(); - Expression wherePredicates = filter.getPredicates(); - Expression onPredicates = join.getOtherJoinCondition().orElse(BooleanLiteral.TRUE); + Expression filterPredicates = filter.getPredicates(); - List<Expression> otherConditions = Lists.newArrayList(); - List<Expression> eqConditions = Lists.newArrayList(); + List<Expression> filterConditions = Lists.newArrayList(); + List<Expression> joinConditions = Lists.newArrayList(); Set<Slot> leftInput = join.left().getOutputSet(); Set<Slot> rightInput = join.right().getOutputSet(); - ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates)) + ExpressionUtils.extractConjunction(filterPredicates) .forEach(predicate -> { - if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) { - eqConditions.add(predicate); + if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput)) + && COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) { + joinConditions.add(predicate); } else { - otherConditions.add(predicate); + filterConditions.add(predicate); } }); List<Expression> leftPredicates = Lists.newArrayList(); List<Expression> rightPredicates = Lists.newArrayList(); - for (Expression p : otherConditions) { + for (Expression p : filterConditions) { Set<Slot> slots = p.getInputSlots(); if (slots.isEmpty()) { leftPredicates.add(p); rightPredicates.add(p); continue; } - if (leftInput.containsAll(slots)) { + if (leftInput.containsAll(slots) && COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) { leftPredicates.add(p); } - if (rightInput.containsAll(slots)) { + if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) { rightPredicates.add(p); } } - otherConditions.removeAll(leftPredicates); - otherConditions.removeAll(rightPredicates); - otherConditions.addAll(eqConditions); + filterConditions.removeAll(leftPredicates); + filterConditions.removeAll(rightPredicates); + join.getOtherJoinCondition().map(joinConditions::add); - return pushDownPredicate(join, otherConditions, leftPredicates, rightPredicates); + return PlanUtils.filterOrSelf(filterConditions, + pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates)); }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN); } - private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> joinPlan, + private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> join, List<Expression> joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) { // todo expr should optimize again using expr rewrite - Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, joinPlan.left()); - Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, joinPlan.right()); + Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left()); + Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right()); - return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(), + return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), ExpressionUtils.optionalAnd(joinConditions), leftPlan, rightPlan); } @@ -128,13 +155,13 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { ComparisonPredicate comparison = (ComparisonPredicate) predicate; - Set<Slot> leftSlots = comparison.left().getInputSlots(); - Set<Slot> rightSlots = comparison.right().getInputSlots(); - - if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) { + if (!(comparison instanceof EqualTo)) { return null; } + Set<Slot> leftSlots = comparison.left().getInputSlots(); + Set<Slot> rightSlots = comparison.right().getInputSlots(); + if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots)) || (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) { return predicate; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java index 435942e562..025f2a39eb 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java @@ -57,8 +57,8 @@ import java.util.Optional; class FindHashConditionForJoinTest { @Test public void testFindHashCondition() { - Plan student = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of("")); - Plan score = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of("")); + Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); + Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); Slot studentId = student.getOutput().get(0); Slot gender = student.getOutput().get(1); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java index 714b5511a6..f71417e469 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -48,8 +49,8 @@ import java.util.function.Function; import java.util.stream.Collectors; class LimitPushDownTest extends TestWithFeService implements PatternMatchSupported { - private Plan scanScore = new LogicalOlapScan(PlanConstructor.score); - private Plan scanStudent = new LogicalOlapScan(PlanConstructor.student); + private Plan scanScore = new LogicalOlapScan(new RelationId(0), PlanConstructor.score); + private Plan scanStudent = new LogicalOlapScan(new RelationId(1), PlanConstructor.student); @Override protected void runBeforeAll() throws Exception { @@ -213,8 +214,8 @@ class LimitPushDownTest extends TestWithFeService implements PatternMatchSupport joinType, joinConditions, Optional.empty(), - new LogicalOlapScan(PlanConstructor.score), - new LogicalOlapScan(PlanConstructor.student) + new LogicalOlapScan(new RelationId(0), PlanConstructor.score), + new LogicalOlapScan(new RelationId(1), PlanConstructor.student) ); if (hasProject) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java index b6bd36b1f5..ca52fa2360 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PruneOlapScanPartitionTest.java @@ -88,7 +88,7 @@ class PruneOlapScanPartitionTest { olapTable.getName(); result = "tbl"; }}; - LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable); + LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable); SlotReference slotRef = new SlotReference("col1", IntegerType.INSTANCE); Expression expression = new LessThan(slotRef, new IntegerLiteral(4)); LogicalFilter<LogicalOlapScan> filter = new LogicalFilter<>(expression, scan); @@ -104,7 +104,7 @@ class PruneOlapScanPartitionTest { Expression greaterThan6 = new GreaterThan(slotRef, new IntegerLiteral(6)); Or lessThan0OrGreaterThan6 = new Or(lessThan0, greaterThan6); filter = new LogicalFilter<>(lessThan0OrGreaterThan6, scan); - scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable); + scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable); cascadesContext = MemoTestUtils.createCascadesContext(filter); rules = Lists.newArrayList(new PruneOlapScanPartition().build()); cascadesContext.topDownRewrite(rules); @@ -118,7 +118,7 @@ class PruneOlapScanPartitionTest { Expression lessThanEqual5 = new LessThanEqual(slotRef, new IntegerLiteral(5)); And greaterThanEqual0AndLessThanEqual5 = new And(greaterThanEqual0, lessThanEqual5); - scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable); + scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable); filter = new LogicalFilter<>(greaterThanEqual0AndLessThanEqual5, scan); cascadesContext = MemoTestUtils.createCascadesContext(filter); rules = Lists.newArrayList(new PruneOlapScanPartition().build()); @@ -153,7 +153,7 @@ class PruneOlapScanPartitionTest { olapTable.getName(); result = "tbl"; }}; - LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextId(), olapTable); + LogicalOlapScan scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), olapTable); Expression left = new LessThan(new SlotReference("col1", IntegerType.INSTANCE), new IntegerLiteral(4)); Expression right = new GreaterThan(new SlotReference("col2", IntegerType.INSTANCE), new IntegerLiteral(11)); CompoundPredicate and = new And(left, right); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java new file mode 100644 index 0000000000..f6f5d664fe --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownJoinOtherConditionTest.java @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.Memo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.nereids.util.PlanRewriter; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.Optional; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class PushDownJoinOtherConditionTest { + + private Plan rStudent; + private Plan rScore; + + /** + * ut before. + */ + @BeforeAll + public final void beforeAll() { + rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); + rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); + } + + @Test + public void oneSide() { + oneSide(JoinType.CROSS_JOIN, false); + oneSide(JoinType.INNER_JOIN, false); + oneSide(JoinType.LEFT_OUTER_JOIN, true); + oneSide(JoinType.LEFT_SEMI_JOIN, true); + oneSide(JoinType.LEFT_ANTI_JOIN, true); + oneSide(JoinType.RIGHT_OUTER_JOIN, false); + oneSide(JoinType.RIGHT_SEMI_JOIN, false); + oneSide(JoinType.RIGHT_ANTI_JOIN, false); + } + + private void oneSide(JoinType joinType, boolean testRight) { + + Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); + Expression condition = ExpressionUtils.and(pushSide1, pushSide2); + + Plan left = rStudent; + Plan right = rScore; + if (testRight) { + left = rScore; + right = rStudent; + } + + Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right); + Plan root = new LogicalProject<>(Lists.newArrayList(), join); + + Memo memo = rewrite(root); + Group rootGroup = memo.getRoot(); + + Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); + Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(1).getLogicalExpression().getPlan(); + if (testRight) { + shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(1).getLogicalExpression().getPlan(); + shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + } + + Assertions.assertTrue(shouldJoin instanceof LogicalJoin); + Assertions.assertTrue(shouldFilter instanceof LogicalFilter); + Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); + LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter; + + Assertions.assertEquals(condition, actualFilter.getPredicates()); + } + + @Test + public void bothSideToBothSide() { + bothSideToBothSide(JoinType.CROSS_JOIN); + bothSideToBothSide(JoinType.INNER_JOIN); + bothSideToBothSide(JoinType.LEFT_SEMI_JOIN); + bothSideToBothSide(JoinType.RIGHT_SEMI_JOIN); + } + + private void bothSideToBothSide(JoinType joinType) { + + Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression condition = ExpressionUtils.and(leftSide, rightSide); + + Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), rStudent, rScore); + Plan root = new LogicalProject<>(Lists.newArrayList(), join); + + Memo memo = rewrite(root); + Group rootGroup = memo.getRoot(); + + Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); + Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(1).getLogicalExpression().getPlan(); + + Assertions.assertTrue(shouldJoin instanceof LogicalJoin); + Assertions.assertTrue(leftFilter instanceof LogicalFilter); + Assertions.assertTrue(rightFilter instanceof LogicalFilter); + LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter; + LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter; + Assertions.assertEquals(leftSide, actualLeft.getPredicates()); + Assertions.assertEquals(rightSide, actualRight.getPredicates()); + } + + @Test + public void bothSideToOneSide() { + bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, true); + bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, true); + bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, false); + bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, false); + } + + private void bothSideToOneSide(JoinType joinType, boolean testRight) { + + Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression condition = ExpressionUtils.and(pushSide, reserveSide); + + Plan left = rStudent; + Plan right = rScore; + if (testRight) { + left = rScore; + right = rStudent; + } + + Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right); + Plan root = new LogicalProject<>(Lists.newArrayList(), join); + + Memo memo = rewrite(root); + Group rootGroup = memo.getRoot(); + + Plan shouldJoin = rootGroup.getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + Plan shouldFilter = rootGroup.getLogicalExpression() + .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); + Plan shouldScan = rootGroup.getLogicalExpression() + .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); + if (testRight) { + shouldFilter = rootGroup.getLogicalExpression() + .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); + shouldScan = rootGroup.getLogicalExpression() + .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); + } + + Assertions.assertTrue(shouldJoin instanceof LogicalJoin); + Assertions.assertTrue(shouldFilter instanceof LogicalFilter); + Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); + LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter; + Assertions.assertEquals(pushSide, actualFilter.getPredicates()); + } + + private Memo rewrite(Plan plan) { + return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushDownJoinOtherCondition()); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java deleted file mode 100644 index 3224045c2c..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java +++ /dev/null @@ -1,228 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite.logical; - -import org.apache.doris.nereids.memo.Group; -import org.apache.doris.nereids.memo.Memo; -import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization; -import org.apache.doris.nereids.trees.expressions.Add; -import org.apache.doris.nereids.trees.expressions.And; -import org.apache.doris.nereids.trees.expressions.Between; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.GreaterThan; -import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; -import org.apache.doris.nereids.trees.expressions.LessThanEqual; -import org.apache.doris.nereids.trees.expressions.Subtract; -import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.types.DoubleType; -import org.apache.doris.nereids.types.StringType; -import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.PlanRewriter; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; - -import java.util.ArrayList; -import java.util.Optional; - -/** - * plan rewrite ut. - */ -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -public class PushDownPredicateTest { - - private Plan rStudent; - private Plan rScore; - private Plan rCourse; - - /** - * ut before. - */ - @BeforeAll - public final void beforeAll() { - rStudent = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of("")); - - rScore = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of("")); - - rCourse = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.course, ImmutableList.of("")); - } - - @Test - public void pushDownPredicateIntoScanTest1() { - // select id,name,grade from student join score on student.id = score.sid and student.id > 1 - // and score.cid > 2 where student.age > 18 and score.grade > 60 - Expression onCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0)); - Expression onCondition2 = new GreaterThan(rStudent.getOutput().get(0), Literal.of(1)); - Expression onCondition3 = new GreaterThan(rScore.getOutput().get(0), Literal.of(2)); - Expression onCondition = ExpressionUtils.and(onCondition1, onCondition2, onCondition3); - - Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression whereCondition2 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2); - - Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.of(onCondition), rStudent, rScore); - Plan filter = new LogicalFilter(whereCondition, join); - - Plan root = new LogicalProject( - Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)), - filter - ); - - Memo memo = rewrite(root); - - Group rootGroup = memo.getRoot(); - - Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression() - .getPlan(); - Plan op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression() - .getPlan(); - - Assertions.assertTrue(op1 instanceof LogicalJoin); - Assertions.assertTrue(op2 instanceof LogicalFilter); - Assertions.assertTrue(op3 instanceof LogicalFilter); - LogicalJoin join1 = (LogicalJoin) op1; - LogicalFilter filter1 = (LogicalFilter) op2; - LogicalFilter filter2 = (LogicalFilter) op3; - - Assertions.assertEquals(onCondition1, join1.getOtherJoinCondition().get()); - Assertions.assertEquals(ExpressionUtils.and(onCondition2, whereCondition1), filter1.getPredicates()); - Assertions.assertEquals(ExpressionUtils.and(onCondition3, - new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE))), - filter2.getPredicates()); - } - - @Test - public void pushDownPredicateIntoScanTest3() { - //select id,name,grade from student left join score on student.id + 1 = score.sid - 2 - //where student.age > 18 and score.grade > 60 - Expression whereCondition1 = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)), - new Subtract(rScore.getOutput().get(0), Literal.of(2))); - Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); - Expression whereCondition3 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3); - - Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.empty(), rStudent, rScore); - Plan filter = new LogicalFilter(whereCondition, join); - - Plan root = new LogicalProject( - Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)), - filter - ); - - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); - - Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression() - .getPlan(); - Plan op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression() - .getPlan(); - - Assertions.assertTrue(op1 instanceof LogicalJoin); - Assertions.assertTrue(op2 instanceof LogicalFilter); - Assertions.assertTrue(op3 instanceof LogicalFilter); - LogicalJoin join1 = (LogicalJoin) op1; - LogicalFilter filter1 = (LogicalFilter) op2; - LogicalFilter filter2 = (LogicalFilter) op3; - Assertions.assertEquals(whereCondition1, join1.getOtherJoinCondition().get()); - Assertions.assertEquals(whereCondition2, filter1.getPredicates()); - Assertions.assertEquals( - new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE)), - filter2.getPredicates()); - } - - @Test - public void pushDownPredicateIntoScanTest4() { - /* - select - student.name, - course.name, - score.grade - from student,score,course - where on student.id = score.sid and student.age between 18 and 20 and score.grade > 60 and student.id = score.sid - */ - - // student.id = score.sid - Expression whereCondition1 = new EqualTo(rStudent.getOutput().get(0), rScore.getOutput().get(0)); - // score.cid = course.cid - Expression whereCondition2 = new EqualTo(rScore.getOutput().get(1), rCourse.getOutput().get(0)); - // student.age between 18 and 20 - Expression whereCondition3 = new Between(rStudent.getOutput().get(2), Literal.of(18), Literal.of(20)); - // student.age >= 18 and student.age <= 20 - Expression whereCondition3result = new And( - new GreaterThanEqual(rStudent.getOutput().get(2), new Cast(Literal.of(18), StringType.INSTANCE)), - new LessThanEqual(rStudent.getOutput().get(2), new Cast(Literal.of(20), StringType.INSTANCE))); - - // score.grade > 60 - Expression whereCondition4 = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - - Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3, - whereCondition4); - - Plan join = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(), Optional.empty(), rStudent, rScore); - Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, ImmutableList.of(), Optional.empty(), join, rCourse); - Plan filter = new LogicalFilter(whereCondition, join1); - - Plan root = new LogicalProject( - Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2)), - filter - ); - - Memo memo = rewrite(root); - Group rootGroup = memo.getRoot(); - Plan join2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); - Plan join3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression() - .getPlan(); - Plan op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression() - .child(0).getLogicalExpression().getPlan(); - Plan op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression() - .child(1).getLogicalExpression().getPlan(); - - Assertions.assertTrue(join2 instanceof LogicalJoin); - Assertions.assertTrue(join3 instanceof LogicalJoin); - Assertions.assertTrue(op1 instanceof LogicalFilter); - Assertions.assertTrue(op2 instanceof LogicalFilter); - - Assertions.assertEquals(whereCondition2, ((LogicalJoin) join2).getOtherJoinCondition().get()); - Assertions.assertEquals(whereCondition1, ((LogicalJoin) join3).getOtherJoinCondition().get()); - Assertions.assertEquals(whereCondition3result.toSql(), ((LogicalFilter) op1).getPredicates().toSql()); - Assertions.assertEquals( - new GreaterThan(rScore.getOutput().get(2), new Cast(Literal.of(60), DoubleType.INSTANCE)), - ((LogicalFilter) op2).getPredicates()); - } - - private Memo rewrite(Plan plan) { - Plan normalizedPlan = PlanRewriter.topDownRewrite(plan, new ConnectContext(), new ExpressionNormalization()); - return PlanRewriter.topDownRewriteMemo(normalizedPlan, new ConnectContext(), new PushPredicateThroughJoin()); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java new file mode 100644 index 0000000000..3374613f73 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoinTest.java @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.Memo; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.nereids.util.PlanRewriter; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.Optional; + +/** + * plan rewrite ut. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class PushPredicateThroughJoinTest { + + private Plan rStudent; + private Plan rScore; + + /** + * ut before. + */ + @BeforeAll + public final void beforeAll() { + rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); + rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); + } + + @Test + public void oneSide() { + oneSide(JoinType.CROSS_JOIN, false); + oneSide(JoinType.INNER_JOIN, false); + oneSide(JoinType.LEFT_OUTER_JOIN, false); + oneSide(JoinType.LEFT_SEMI_JOIN, false); + oneSide(JoinType.LEFT_ANTI_JOIN, false); + oneSide(JoinType.RIGHT_OUTER_JOIN, true); + oneSide(JoinType.RIGHT_SEMI_JOIN, true); + oneSide(JoinType.RIGHT_ANTI_JOIN, true); + } + + private void oneSide(JoinType joinType, boolean testRight) { + + Expression whereCondition1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression whereCondition2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); + Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2); + + Plan left = rStudent; + Plan right = rScore; + if (testRight) { + left = rScore; + right = rStudent; + } + + Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right); + Plan filter = new LogicalFilter<>(whereCondition, join); + Plan root = new LogicalProject<>(Lists.newArrayList(), filter); + + Memo memo = rewrite(root); + Group rootGroup = memo.getRoot(); + + Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); + Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(1).getLogicalExpression().getPlan(); + if (testRight) { + shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(1).getLogicalExpression().getPlan(); + shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + } + + Assertions.assertTrue(shouldJoin instanceof LogicalJoin); + Assertions.assertTrue(shouldFilter instanceof LogicalFilter); + Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); + LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter; + + Assertions.assertEquals(whereCondition, actualFilter.getPredicates()); + } + + @Test + public void bothSideToBothSide() { + bothSideToBothSide(JoinType.INNER_JOIN); + } + + private void bothSideToBothSide(JoinType joinType) { + + Expression bothSideEqualTo = new EqualTo(new Add(rStudent.getOutput().get(0), Literal.of(1)), + new Subtract(rScore.getOutput().get(0), Literal.of(2))); + Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.and(bothSideEqualTo, leftSide, rightSide); + + Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), rStudent, rScore); + Plan filter = new LogicalFilter<>(whereCondition, join); + Plan root = new LogicalProject<>(Lists.newArrayList(), filter); + + Memo memo = rewrite(root); + Group rootGroup = memo.getRoot(); + + Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getPlan(); + Plan leftFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + Plan rightFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(1).getLogicalExpression().getPlan(); + + Assertions.assertTrue(shouldJoin instanceof LogicalJoin); + Assertions.assertTrue(leftFilter instanceof LogicalFilter); + Assertions.assertTrue(rightFilter instanceof LogicalFilter); + LogicalJoin<Plan, Plan> actualJoin = (LogicalJoin<Plan, Plan>) shouldJoin; + LogicalFilter<Plan> actualLeft = (LogicalFilter<Plan>) leftFilter; + LogicalFilter<Plan> actualRight = (LogicalFilter<Plan>) rightFilter; + Assertions.assertEquals(bothSideEqualTo, actualJoin.getOtherJoinCondition().get()); + Assertions.assertEquals(leftSide, actualLeft.getPredicates()); + Assertions.assertEquals(rightSide, actualRight.getPredicates()); + } + + @Test + public void bothSideToOneSide() { + bothSideToOneSide(JoinType.LEFT_OUTER_JOIN, false); + bothSideToOneSide(JoinType.LEFT_ANTI_JOIN, false); + bothSideToOneSide(JoinType.LEFT_SEMI_JOIN, false); + bothSideToOneSide(JoinType.RIGHT_OUTER_JOIN, true); + bothSideToOneSide(JoinType.RIGHT_ANTI_JOIN, true); + bothSideToOneSide(JoinType.RIGHT_SEMI_JOIN, true); + } + + private void bothSideToOneSide(JoinType joinType, boolean testRight) { + + Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); + Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.and(pushSide, reserveSide); + + Plan left = rStudent; + Plan right = rScore; + if (testRight) { + left = rScore; + right = rStudent; + } + + Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.empty(), left, right); + Plan filter = new LogicalFilter<>(whereCondition, join); + Plan root = new LogicalProject<>(Lists.newArrayList(), filter); + + Memo memo = rewrite(root); + Group rootGroup = memo.getRoot(); + + Plan shouldJoin = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().getPlan(); + Plan shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); + Plan shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); + if (testRight) { + shouldFilter = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().child(1).getLogicalExpression().getPlan(); + shouldScan = rootGroup.getLogicalExpression().child(0).getLogicalExpression() + .child(0).getLogicalExpression().child(0).getLogicalExpression().getPlan(); + } + + Assertions.assertTrue(shouldJoin instanceof LogicalJoin); + Assertions.assertTrue(shouldFilter instanceof LogicalFilter); + Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); + LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter; + Assertions.assertEquals(pushSide, actualFilter.getPredicates()); + } + + private Memo rewrite(Plan plan) { + return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicatesThroughJoin()); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java index 135aaaf419..c03d24445b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanConstructor.java @@ -37,7 +37,7 @@ public class PlanConstructor { public static OlapTable student; public static OlapTable score; public static OlapTable course; - private static final IdGenerator<RelationId> GENERATOR = RelationId.createGenerator(); + private static final IdGenerator<RelationId> RELATION_ID_GENERATOR = RelationId.createGenerator(); static { student = new OlapTable(0L, "student", @@ -102,14 +102,14 @@ public class PlanConstructor { // With OlapTable. // Warning: equals() of Table depends on tableId. public static LogicalOlapScan newLogicalOlapScan(long tableId, String tableName, int hashColumn) { - return new LogicalOlapScan(GENERATOR.getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db")); + return new LogicalOlapScan(RELATION_ID_GENERATOR.getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db")); } public static LogicalOlapScan newLogicalOlapScanWithSameId(long tableId, String tableName, int hashColumn) { return new LogicalOlapScan(RelationId.createGenerator().getNextId(), newOlapTable(tableId, tableName, hashColumn), ImmutableList.of("db")); } - public static RelationId getNextId() { - return GENERATOR.getNextId(); + public static RelationId getNextRelationId() { + return RELATION_ID_GENERATOR.getNextId(); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org