This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 35c2a529fac [fix](Nereids): when predicate contains right output, don't convert outer to anti join (#30276) 35c2a529fac is described below commit 35c2a529fac1b3aa805cb7fd365565c86f327844 Author: 谢健 <jianx...@gmail.com> AuthorDate: Thu Jan 25 14:01:27 2024 +0800 [fix](Nereids): when predicate contains right output, don't convert outer to anti join (#30276) --- .../rules/rewrite/ConvertOuterJoinToAntiJoin.java | 6 +- .../rewrite/ConvertOuterJoinToAntiJoinTest.java | 82 +++++++++++++++++++++- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java index ebc4630d78e..74bd7e29142 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java @@ -81,7 +81,8 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory { && rightAlwaysNullSlots.containsAll(p.getInputSlots()))) .collect(ImmutableSet.toImmutableSet()); boolean containRightSlot = predicates.stream() - .anyMatch(s -> join.right().getOutputSet().containsAll(s.getInputSlots())); + .flatMap(p -> p.getInputSlots().stream()) + .anyMatch(join.right().getOutputSet()::contains); if (!containRightSlot) { res = join.withJoinType(JoinType.LEFT_ANTI_JOIN); res = predicates.isEmpty() ? res : filter.withConjuncts(predicates).withChildren(res); @@ -94,7 +95,8 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory { && leftAlwaysNullSlots.containsAll(p.getInputSlots()))) .collect(ImmutableSet.toImmutableSet()); boolean containLeftSlot = predicates.stream() - .anyMatch(s -> join.left().getOutputSet().containsAll(s.getInputSlots())); + .flatMap(p -> p.getInputSlots().stream()) + .anyMatch(join.left().getOutputSet()::contains); if (!containLeftSlot) { res = join.withJoinType(JoinType.RIGHT_ANTI_JOIN); res = predicates.isEmpty() ? res : filter.withConjuncts(predicates).withChildren(res); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java index 49960c77ee4..20b36d3272e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java @@ -18,8 +18,12 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -30,6 +34,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; import org.junit.jupiter.api.Test; class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { @@ -48,7 +53,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .filter(new IsNull(scan2.getOutput().get(0))) - .project(ImmutableList.of(0, 1)) + .projectExprs(ImmutableList.copyOf(scan1.getOutput())) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) @@ -63,7 +68,7 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .filter(new IsNull(scan1.getOutput().get(0))) - .project(ImmutableList.of(2, 3)) + .projectExprs(ImmutableList.copyOf(scan2.getOutput())) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) @@ -72,4 +77,77 @@ class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported { .printlnTree() .matches(logicalJoin().when(join -> join.getJoinType().isRightAntiJoin())); } + + @Test + void testEliminateLeftWithLeftPredicate() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .filter(Sets.newHashSet( + new IsNull(scan2.getOutput().get(0)), + new EqualTo(scan1.getOutput().get(0), new IntegerLiteral(1))) + ) + .projectExprs(ImmutableList.copyOf(scan1.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferFilterNotNull()) + .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftAntiJoin())); + } + + @Test + void testEliminateLeftWithRightPredicate() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .filter(Sets.newHashSet( + new IsNull(scan2.getOutput().get(0)), + new EqualTo(scan2.getOutput().get(0), new IntegerLiteral(1))) + ) + .projectExprs(ImmutableList.copyOf(scan1.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferFilterNotNull()) + .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); + } + + @Test + void testEliminateLeftWithOrPredicate() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .filter(Sets.newHashSet( + new IsNull(scan1.getOutput().get(0)), + new Or(new IsNull(scan1.getOutput().get(0)), new IsNull(scan2.getOutput().get(0)))) + ) + .projectExprs(ImmutableList.copyOf(scan1.getOutput())) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferFilterNotNull()) + .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); + } + + @Test + void testEliminateLeftWithAndPredicate() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .filter(Sets.newHashSet( + new IsNull(scan1.getOutput().get(0)), + new And(new EqualTo(scan1.getOutput().get(0), new IntegerLiteral(1)), + new EqualTo(scan1.getOutput().get(0), new IntegerLiteral(1)))) + ) + .project(ImmutableList.of(2, 3)) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new InferFilterNotNull()) + .applyTopDown(new ConvertOuterJoinToAntiJoin()) + .printlnTree() + .matches(logicalJoin().when(join -> join.getJoinType().isLeftOuterJoin())); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org