This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 35c2a529fac [fix](Nereids): when predicate contains right output, 
don't convert outer to anti join (#30276)
35c2a529fac is described below

commit 35c2a529fac1b3aa805cb7fd365565c86f327844
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Thu Jan 25 14:01:27 2024 +0800

    [fix](Nereids): when predicate contains right output, don't convert outer 
to anti join (#30276)
---
 .../rules/rewrite/ConvertOuterJoinToAntiJoin.java  |  6 +-
 .../rewrite/ConvertOuterJoinToAntiJoinTest.java    | 82 +++++++++++++++++++++-
 2 files changed, 84 insertions(+), 4 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
index ebc4630d78e..74bd7e29142 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoin.java
@@ -81,7 +81,8 @@ public class ConvertOuterJoinToAntiJoin extends 
OneRewriteRuleFactory {
                             && 
rightAlwaysNullSlots.containsAll(p.getInputSlots())))
                     .collect(ImmutableSet.toImmutableSet());
             boolean containRightSlot = predicates.stream()
-                    .anyMatch(s -> 
join.right().getOutputSet().containsAll(s.getInputSlots()));
+                    .flatMap(p -> p.getInputSlots().stream())
+                    .anyMatch(join.right().getOutputSet()::contains);
             if (!containRightSlot) {
                 res = join.withJoinType(JoinType.LEFT_ANTI_JOIN);
                 res = predicates.isEmpty() ? res : 
filter.withConjuncts(predicates).withChildren(res);
@@ -94,7 +95,8 @@ public class ConvertOuterJoinToAntiJoin extends 
OneRewriteRuleFactory {
                             && 
leftAlwaysNullSlots.containsAll(p.getInputSlots())))
                     .collect(ImmutableSet.toImmutableSet());
             boolean containLeftSlot = predicates.stream()
-                    .anyMatch(s -> 
join.left().getOutputSet().containsAll(s.getInputSlots()));
+                    .flatMap(p -> p.getInputSlots().stream())
+                    .anyMatch(join.left().getOutputSet()::contains);
             if (!containLeftSlot) {
                 res = join.withJoinType(JoinType.RIGHT_ANTI_JOIN);
                 res = predicates.isEmpty() ? res : 
filter.withConjuncts(predicates).withChildren(res);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
index 49960c77ee4..20b36d3272e 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConvertOuterJoinToAntiJoinTest.java
@@ -18,8 +18,12 @@
 package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@@ -30,6 +34,7 @@ import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Test;
 
 class ConvertOuterJoinToAntiJoinTest implements MemoPatternMatchSupported {
@@ -48,7 +53,7 @@ class ConvertOuterJoinToAntiJoinTest implements 
MemoPatternMatchSupported {
         LogicalPlan plan = new LogicalPlanBuilder(scan1)
                 .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))  // 
t1.id = t2.id
                 .filter(new IsNull(scan2.getOutput().get(0)))
-                .project(ImmutableList.of(0, 1))
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
@@ -63,7 +68,7 @@ class ConvertOuterJoinToAntiJoinTest implements 
MemoPatternMatchSupported {
         LogicalPlan plan = new LogicalPlanBuilder(scan1)
                 .join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))  // 
t1.id = t2.id
                 .filter(new IsNull(scan1.getOutput().get(0)))
-                .project(ImmutableList.of(2, 3))
+                .projectExprs(ImmutableList.copyOf(scan2.getOutput()))
                 .build();
 
         PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
@@ -72,4 +77,77 @@ class ConvertOuterJoinToAntiJoinTest implements 
MemoPatternMatchSupported {
                 .printlnTree()
                 .matches(logicalJoin().when(join -> 
join.getJoinType().isRightAntiJoin()));
     }
+
+    @Test
+    void testEliminateLeftWithLeftPredicate() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))  // 
t1.id = t2.id
+                .filter(Sets.newHashSet(
+                        new IsNull(scan2.getOutput().get(0)),
+                        new EqualTo(scan1.getOutput().get(0), new 
IntegerLiteral(1)))
+                )
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferFilterNotNull())
+                .applyTopDown(new ConvertOuterJoinToAntiJoin())
+                .printlnTree()
+                .matches(logicalJoin().when(join -> 
join.getJoinType().isLeftAntiJoin()));
+    }
+
+    @Test
+    void testEliminateLeftWithRightPredicate() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))  // 
t1.id = t2.id
+                .filter(Sets.newHashSet(
+                        new IsNull(scan2.getOutput().get(0)),
+                        new EqualTo(scan2.getOutput().get(0), new 
IntegerLiteral(1)))
+                )
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferFilterNotNull())
+                .applyTopDown(new ConvertOuterJoinToAntiJoin())
+                .printlnTree()
+                .matches(logicalJoin().when(join -> 
join.getJoinType().isLeftOuterJoin()));
+    }
+
+    @Test
+    void testEliminateLeftWithOrPredicate() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))  // 
t1.id = t2.id
+                .filter(Sets.newHashSet(
+                        new IsNull(scan1.getOutput().get(0)),
+                        new Or(new IsNull(scan1.getOutput().get(0)), new 
IsNull(scan2.getOutput().get(0))))
+                )
+                .projectExprs(ImmutableList.copyOf(scan1.getOutput()))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferFilterNotNull())
+                .applyTopDown(new ConvertOuterJoinToAntiJoin())
+                .printlnTree()
+                .matches(logicalJoin().when(join -> 
join.getJoinType().isLeftOuterJoin()));
+    }
+
+    @Test
+    void testEliminateLeftWithAndPredicate() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))  // 
t1.id = t2.id
+                .filter(Sets.newHashSet(
+                        new IsNull(scan1.getOutput().get(0)),
+                        new And(new EqualTo(scan1.getOutput().get(0), new 
IntegerLiteral(1)),
+                                new EqualTo(scan1.getOutput().get(0), new 
IntegerLiteral(1))))
+                )
+                .project(ImmutableList.of(2, 3))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferFilterNotNull())
+                .applyTopDown(new ConvertOuterJoinToAntiJoin())
+                .printlnTree()
+                .matches(logicalJoin().when(join -> 
join.getJoinType().isLeftOuterJoin()));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to