This is an automated email from the ASF dual-hosted git repository. kxiao pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 0de8d7de476ff93004c5f50f1f407d7737bd65ff Author: morrySnow <101034200+morrys...@users.noreply.github.com> AuthorDate: Fri Aug 25 11:59:28 2023 +0800 [fix](Nereids) infer predicates generate wrong result (#23456) We use two facilities to do predicate infer: PredicatePropagation and PullUpPredicates. In the prvious implementation, we use a set to save the intermediate result of PredicatePropagation. The purpose is infer new predicate though two equal relation. However, it is the wrong way. Because it could infer wrong predicate through outer join. For example ```sql select a.c1 from a left join b on a.c2 = b.c2 and a.c1 = '1' left join c on a.c2 = c.c2 and a.c1 = '2' inner join d on a.c3=d.c3 ``` the predicates `a.c1 = '1'` and `a.c1 = '2'` should not be inferred as filter to relation `a`. This PR: 1. revert the change from PR #22145, commit 3c58e9ba 2. Remove the unreasonable restrict in PullupPredicate. 3. Use new Filter node rather than new otherCondition on join node to save infer predicates --- .../nereids/rules/rewrite/InferPredicates.java | 29 ++++++++++++---------- .../rules/rewrite/PredicatePropagation.java | 9 ------- .../nereids/rules/rewrite/PullUpPredicates.java | 28 ++------------------- .../nereids/rules/rewrite/InferPredicatesTest.java | 25 ++++++++++++++++++- .../infer_predicate/infer_predicate.groovy | 1 + 5 files changed, 43 insertions(+), 49 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 9736db8482..3c4593df54 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -25,12 +25,12 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import java.util.List; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -65,28 +65,31 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements Plan left = join.left(); Plan right = join.right(); Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition()); - List<Expression> otherJoinConjuncts = Lists.newArrayList(join.getOtherJoinConjuncts()); switch (join.getJoinType()) { case INNER_JOIN: case CROSS_JOIN: case LEFT_SEMI_JOIN: case RIGHT_SEMI_JOIN: - otherJoinConjuncts.addAll(inferNewPredicate(left, expressions)); - otherJoinConjuncts.addAll(inferNewPredicate(right, expressions)); + left = inferNewPredicate(left, expressions); + right = inferNewPredicate(right, expressions); break; case LEFT_OUTER_JOIN: case LEFT_ANTI_JOIN: case NULL_AWARE_LEFT_ANTI_JOIN: - otherJoinConjuncts.addAll(inferNewPredicate(right, expressions)); + right = inferNewPredicate(right, expressions); break; case RIGHT_OUTER_JOIN: case RIGHT_ANTI_JOIN: - otherJoinConjuncts.addAll(inferNewPredicate(left, expressions)); + left = inferNewPredicate(left, expressions); break; default: - return join; + break; + } + if (left != join.left() || right != join.right()) { + return join.withChildren(ImmutableList.of(left, right)); + } else { + return join; } - return join.withOtherJoinConjuncts(otherJoinConjuncts); } @Override @@ -114,12 +117,12 @@ public class InferPredicates extends DefaultPlanRewriter<JobContext> implements return Sets.newHashSet(plan.accept(pollUpPredicates, null)); } - private List<Expression> inferNewPredicate(Plan plan, Set<Expression> expressions) { - List<Expression> predicates = expressions.stream() + private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) { + Set<Expression> predicates = expressions.stream() .filter(c -> !c.getInputSlots().isEmpty() && plan.getOutputSet().containsAll( - c.getInputSlots())).collect(Collectors.toList()); + c.getInputSlots())).collect(Collectors.toSet()); predicates.removeAll(plan.accept(pollUpPredicates, null)); - return predicates; + return PlanUtils.filterOrSelf(predicates, plan); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index f6f04e899b..cc45952817 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -40,17 +40,11 @@ import java.util.stream.Collectors; */ public class PredicatePropagation { - /** - * equal predicate with literal in one side would be chosen to be source predicates and used to infer all predicates - */ - private Set<Expression> sourcePredicates = Sets.newHashSet(); - /** * infer additional predicates. */ public Set<Expression> infer(Set<Expression> predicates) { Set<Expression> inferred = Sets.newHashSet(); - predicates.addAll(sourcePredicates); for (Expression predicate : predicates) { if (canEquivalentInfer(predicate)) { List<Expression> newInferred = predicates.stream() @@ -61,7 +55,6 @@ public class PredicatePropagation { } } inferred.removeAll(predicates); - sourcePredicates.addAll(inferred); return inferred; } @@ -83,10 +76,8 @@ public class PredicatePropagation { public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { // we need to get expression covered by cast, because we want to infer different datatype if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { - sourcePredicates.add(cp); return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { - sourcePredicates.add(cp); return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); } return super.visit(cp, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 781d056422..1a198c76ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -73,32 +73,8 @@ public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void Set<Expression> predicates = Sets.newHashSet(); ImmutableSet<Expression> leftPredicates = join.left().accept(this, context); ImmutableSet<Expression> rightPredicates = join.right().accept(this, context); - switch (join.getJoinType()) { - case INNER_JOIN: - case CROSS_JOIN: - predicates.addAll(leftPredicates); - predicates.addAll(rightPredicates); - join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on))); - break; - case LEFT_SEMI_JOIN: - predicates.addAll(leftPredicates); - join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on))); - break; - case RIGHT_SEMI_JOIN: - predicates.addAll(rightPredicates); - join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on))); - break; - case LEFT_OUTER_JOIN: - case LEFT_ANTI_JOIN: - case NULL_AWARE_LEFT_ANTI_JOIN: - predicates.addAll(leftPredicates); - break; - case RIGHT_OUTER_JOIN: - case RIGHT_ANTI_JOIN: - predicates.addAll(rightPredicates); - break; - default: - } + predicates.addAll(leftPredicates); + predicates.addAll(rightPredicates); return getAvailableExpressions(predicates, join); }); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java index 04613f7e75..adc67ca835 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; @@ -604,5 +605,27 @@ public class InferPredicatesTest extends TestWithFeService implements MemoPatter ) ); } -} + /** + * in this case, filter on relation s1 should not contain s1.id = 1. + */ + @Test + public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() { + String sql = "select * from student s1" + + " left join (select sid as id1, sid as id2, grade from score) s2 on s1.id = s2.id1 and s1.id = 1" + + " join (select sid as id1, sid as id2, grade from score) s3 on s1.id = s3.id1 where s1.id = 2"; + PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree(); + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalJoin( + logicalFilter( + logicalOlapScan() + ).when(filter -> filter.getConjuncts().size() == 1 + && filter.getPredicate().toSql().contains("id = 2")), + any() + ).when(join -> join.getJoinType() == JoinType.LEFT_OUTER_JOIN) + ); + } +} diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index a1b0ad3b0f..a1621f1c23 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -53,5 +53,6 @@ suite("test_infer_predicate") { sql "select * from infer_tb1 left join infer_tb2 on infer_tb1.k1 = infer_tb2.k3 left join infer_tb3 on " + "infer_tb2.k3 = infer_tb3.k2 where infer_tb1.k1 = 1;" contains "PREDICATES: k3" + contains "PREDICATES: k2" } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org