This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 0de8d7de476ff93004c5f50f1f407d7737bd65ff
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Fri Aug 25 11:59:28 2023 +0800

    [fix](Nereids) infer predicates generate wrong result (#23456)
    
    We use two facilities to do predicate infer: PredicatePropagation and
    PullUpPredicates. In the prvious implementation, we use a set to save
    the intermediate result of PredicatePropagation. The purpose is infer
    new predicate though two equal relation. However, it is the wrong way.
    Because it could infer wrong predicate through outer join. For example
    
    ```sql
    select a.c1
       from a
       left join b on a.c2 = b.c2 and a.c1 = '1'
       left join c on a.c2 = c.c2 and a.c1 = '2'
       inner join d on a.c3=d.c3
    ```
    
    the predicates `a.c1 = '1'` and `a.c1 = '2'` should not be inferred as
    filter to relation `a`.
    
    This PR:
    1. revert the change from PR #22145, commit 3c58e9ba
    2. Remove the unreasonable restrict in PullupPredicate.
    3. Use new Filter node rather than new otherCondition on join node to
       save infer predicates
---
 .../nereids/rules/rewrite/InferPredicates.java     | 29 ++++++++++++----------
 .../rules/rewrite/PredicatePropagation.java        |  9 -------
 .../nereids/rules/rewrite/PullUpPredicates.java    | 28 ++-------------------
 .../nereids/rules/rewrite/InferPredicatesTest.java | 25 ++++++++++++++++++-
 .../infer_predicate/infer_predicate.groovy         |  1 +
 5 files changed, 43 insertions(+), 49 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
index 9736db8482..3c4593df54 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
@@ -25,12 +25,12 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 
-import java.util.List;
 import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -65,28 +65,31 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
         Plan left = join.left();
         Plan right = join.right();
         Set<Expression> expressions = getAllExpressions(left, right, 
join.getOnClauseCondition());
-        List<Expression> otherJoinConjuncts = 
Lists.newArrayList(join.getOtherJoinConjuncts());
         switch (join.getJoinType()) {
             case INNER_JOIN:
             case CROSS_JOIN:
             case LEFT_SEMI_JOIN:
             case RIGHT_SEMI_JOIN:
-                otherJoinConjuncts.addAll(inferNewPredicate(left, 
expressions));
-                otherJoinConjuncts.addAll(inferNewPredicate(right, 
expressions));
+                left = inferNewPredicate(left, expressions);
+                right = inferNewPredicate(right, expressions);
                 break;
             case LEFT_OUTER_JOIN:
             case LEFT_ANTI_JOIN:
             case NULL_AWARE_LEFT_ANTI_JOIN:
-                otherJoinConjuncts.addAll(inferNewPredicate(right, 
expressions));
+                right = inferNewPredicate(right, expressions);
                 break;
             case RIGHT_OUTER_JOIN:
             case RIGHT_ANTI_JOIN:
-                otherJoinConjuncts.addAll(inferNewPredicate(left, 
expressions));
+                left = inferNewPredicate(left, expressions);
                 break;
             default:
-                return join;
+                break;
+        }
+        if (left != join.left() || right != join.right()) {
+            return join.withChildren(ImmutableList.of(left, right));
+        } else {
+            return join;
         }
-        return join.withOtherJoinConjuncts(otherJoinConjuncts);
     }
 
     @Override
@@ -114,12 +117,12 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
         return Sets.newHashSet(plan.accept(pollUpPredicates, null));
     }
 
-    private List<Expression> inferNewPredicate(Plan plan, Set<Expression> 
expressions) {
-        List<Expression> predicates = expressions.stream()
+    private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
+        Set<Expression> predicates = expressions.stream()
                 .filter(c -> !c.getInputSlots().isEmpty() && 
plan.getOutputSet().containsAll(
-                        c.getInputSlots())).collect(Collectors.toList());
+                        c.getInputSlots())).collect(Collectors.toSet());
         predicates.removeAll(plan.accept(pollUpPredicates, null));
-        return predicates;
+        return PlanUtils.filterOrSelf(predicates, plan);
     }
 }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index f6f04e899b..cc45952817 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -40,17 +40,11 @@ import java.util.stream.Collectors;
  */
 public class PredicatePropagation {
 
-    /**
-     * equal predicate with literal in one side would be chosen to be source 
predicates and used to infer all predicates
-     */
-    private Set<Expression> sourcePredicates = Sets.newHashSet();
-
     /**
      * infer additional predicates.
      */
     public Set<Expression> infer(Set<Expression> predicates) {
         Set<Expression> inferred = Sets.newHashSet();
-        predicates.addAll(sourcePredicates);
         for (Expression predicate : predicates) {
             if (canEquivalentInfer(predicate)) {
                 List<Expression> newInferred = predicates.stream()
@@ -61,7 +55,6 @@ public class PredicatePropagation {
             }
         }
         inferred.removeAll(predicates);
-        sourcePredicates.addAll(inferred);
         return inferred;
     }
 
@@ -83,10 +76,8 @@ public class PredicatePropagation {
             public Expression visitComparisonPredicate(ComparisonPredicate cp, 
Void context) {
                 // we need to get expression covered by cast, because we want 
to infer different datatype
                 if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) 
&& (cp.right().isConstant())) {
-                    sourcePredicates.add(cp);
                     return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
                 } else if 
(ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && 
cp.left().isConstant()) {
-                    sourcePredicates.add(cp);
                     return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
                 }
                 return super.visit(cp, context);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 781d056422..1a198c76ea 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -73,32 +73,8 @@ public class PullUpPredicates extends 
PlanVisitor<ImmutableSet<Expression>, Void
             Set<Expression> predicates = Sets.newHashSet();
             ImmutableSet<Expression> leftPredicates = join.left().accept(this, 
context);
             ImmutableSet<Expression> rightPredicates = 
join.right().accept(this, context);
-            switch (join.getJoinType()) {
-                case INNER_JOIN:
-                case CROSS_JOIN:
-                    predicates.addAll(leftPredicates);
-                    predicates.addAll(rightPredicates);
-                    join.getOnClauseCondition().map(on -> 
predicates.addAll(ExpressionUtils.extractConjunction(on)));
-                    break;
-                case LEFT_SEMI_JOIN:
-                    predicates.addAll(leftPredicates);
-                    join.getOnClauseCondition().map(on -> 
predicates.addAll(ExpressionUtils.extractConjunction(on)));
-                    break;
-                case RIGHT_SEMI_JOIN:
-                    predicates.addAll(rightPredicates);
-                    join.getOnClauseCondition().map(on -> 
predicates.addAll(ExpressionUtils.extractConjunction(on)));
-                    break;
-                case LEFT_OUTER_JOIN:
-                case LEFT_ANTI_JOIN:
-                case NULL_AWARE_LEFT_ANTI_JOIN:
-                    predicates.addAll(leftPredicates);
-                    break;
-                case RIGHT_OUTER_JOIN:
-                case RIGHT_ANTI_JOIN:
-                    predicates.addAll(rightPredicates);
-                    break;
-                default:
-            }
+            predicates.addAll(leftPredicates);
+            predicates.addAll(rightPredicates);
             return getAvailableExpressions(predicates, join);
         });
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index 04613f7e75..adc67ca835 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.utframe.TestWithFeService;
@@ -604,5 +605,27 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
                     )
                 );
     }
-}
 
+    /**
+     * in this case, filter on relation s1 should not contain s1.id = 1.
+     */
+    @Test
+    public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
+        String sql = "select * from student s1"
+                + " left join (select sid as id1, sid as id2, grade from 
score) s2 on s1.id = s2.id1 and s1.id = 1"
+                + " join (select sid as id1, sid as id2, grade from score) s3 
on s1.id = s3.id1 where s1.id = 2";
+        PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalJoin(
+                                logicalFilter(
+                                        logicalOlapScan()
+                                ).when(filter -> filter.getConjuncts().size() 
== 1
+                                        && 
filter.getPredicate().toSql().contains("id = 2")),
+                                any()
+                        ).when(join -> join.getJoinType() == 
JoinType.LEFT_OUTER_JOIN)
+                );
+    }
+}
diff --git 
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy 
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index a1b0ad3b0f..a1621f1c23 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -53,5 +53,6 @@ suite("test_infer_predicate") {
         sql "select * from infer_tb1 left join infer_tb2 on infer_tb1.k1 = 
infer_tb2.k3 left join infer_tb3 on " +
                 "infer_tb2.k3 = infer_tb3.k2 where infer_tb1.k1 = 1;"
         contains "PREDICATES: k3"
+        contains "PREDICATES: k2"
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to