This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 409bd76999 [improve](Nereids): ReorderJoin eliminate this recursion (#13505) 409bd76999 is described below commit 409bd76999236e9e624e6f2b1f3428191e86ea57 Author: jakevin <jakevin...@gmail.com> AuthorDate: Mon Oct 24 17:11:43 2022 +0800 [improve](Nereids): ReorderJoin eliminate this recursion (#13505) --- .../nereids/rules/rewrite/logical/ReorderJoin.java | 82 ++++++++++++---------- .../rules/rewrite/logical/ReorderJoinTest.java | 33 ++++++++- .../doris/nereids/sqltest/MultiJoinTest.java | 39 +++++++++- .../org/apache/doris/nereids/util/PlanChecker.java | 1 + 4 files changed, 112 insertions(+), 43 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java index 1cbdc370e2..c0c8622348 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java @@ -38,6 +38,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -79,8 +80,7 @@ public class ReorderJoin extends OneRewriteRuleFactory { /** * Recursively convert to - * {@link LogicalJoin} or - * {@link LogicalFilter}--{@link LogicalJoin} + * {@link LogicalJoin} or {@link LogicalFilter}--{@link LogicalJoin} * --> {@link MultiJoin} */ public Plan joinToMultiJoin(Plan plan) { @@ -182,20 +182,20 @@ public class ReorderJoin extends OneRewriteRuleFactory { * <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D)) * </ul> * </p> - * <p> * Graphic presentation: + * <pre> * A JOIN B JOIN C LEFT JOIN D JOIN F * left left│ * A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF) - * <p> + * * A JOIN B RIGHT JOIN C JOIN D JOIN F * right │right * A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F) - * <p> + * * (A JOIN B JOIN C) FULL JOIN (D JOIN F) * full │ * A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F)) - * </p> + * </pre> */ public Plan multiJoinToJoin(MultiJoin multiJoin) { if (multiJoin.arity() == 1) { @@ -272,24 +272,22 @@ public class ReorderJoin extends OneRewriteRuleFactory { } // following this multiJoin just contain INNER/CROSS. - List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter(); + Set<Expression> joinFilter = new HashSet<>(multiJoinHandleChildren.getJoinFilter()); Plan left = multiJoinHandleChildren.child(0); - List<Plan> candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()); - - LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter); - List<Plan> newInputs = Lists.newArrayList(); - newInputs.add(join); - newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList())); - - joinFilter.removeAll(join.getHashJoinConjuncts()); - joinFilter.removeAll(join.getOtherJoinConjuncts()); - // TODO(wj): eliminate this recursion. - return multiJoinToJoin(new MultiJoin( - newInputs, - joinFilter, - JoinType.INNER_JOIN, - ExpressionUtils.EMPTY_CONDITION)); + Set<Integer> usedPlansIndex = new HashSet<>(); + usedPlansIndex.add(0); + + while (usedPlansIndex.size() != multiJoinHandleChildren.children().size()) { + LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, multiJoinHandleChildren.children(), + joinFilter, usedPlansIndex); + join.getHashJoinConjuncts().forEach(joinFilter::remove); + join.getOtherJoinConjuncts().forEach(joinFilter::remove); + + left = join; + } + + return PlanUtils.filterOrSelf(new ArrayList<>(joinFilter), left); } /** @@ -319,9 +317,14 @@ public class ReorderJoin extends OneRewriteRuleFactory { * @return InnerJoin or CrossJoin{left, last of [candidates]} */ private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates, - List<Expression> joinFilter) { + Set<Expression> joinFilter, Set<Integer> usedPlansIndex) { + List<Expression> otherJoinConditions = Lists.newArrayList(); Set<Slot> leftOutputSet = left.getOutputSet(); for (int i = 0; i < candidates.size(); i++) { + if (usedPlansIndex.contains(i)) { + continue; + } + Plan candidate = candidates.get(i); Set<Slot> rightOutputSet = candidate.getOutputSet(); @@ -330,34 +333,35 @@ public class ReorderJoin extends OneRewriteRuleFactory { List<Expression> currentJoinFilter = joinFilter.stream() .filter(expr -> { Set<Slot> exprInputSlots = expr.getInputSlots(); - Preconditions.checkState(exprInputSlots.size() > 1, - "Predicate like table.col > 1 must have pushdown."); - if (leftOutputSet.containsAll(exprInputSlots)) { - return false; - } - if (rightOutputSet.containsAll(exprInputSlots)) { - return false; - } - - return joinOutput.containsAll(exprInputSlots); + return !leftOutputSet.containsAll(exprInputSlots) + && !rightOutputSet.containsAll(exprInputSlots) + && joinOutput.containsAll(exprInputSlots); }).collect(Collectors.toList()); Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable( left.getOutput(), candidate.getOutput(), currentJoinFilter); List<Expression> hashJoinConditions = pair.first; - List<Expression> otherJoinConditions = pair.second; + otherJoinConditions = pair.second; if (!hashJoinConditions.isEmpty()) { + usedPlansIndex.add(i); return new LogicalJoin<>(JoinType.INNER_JOIN, hashJoinConditions, otherJoinConditions, left, candidate); } - - if (i == candidates.size() - 1) { - return new LogicalJoin<>(JoinType.CROSS_JOIN, - hashJoinConditions, otherJoinConditions, - left, candidate); + } + // All { left -> one in [candidates] } is CrossJoin + // Generate a CrossJoin + for (int j = candidates.size() - 1; j >= 0; j--) { + if (usedPlansIndex.contains(j)) { + continue; } + usedPlansIndex.add(j); + return new LogicalJoin<>(JoinType.CROSS_JOIN, + ExpressionUtils.EMPTY_CONDITION, + otherJoinConditions, + left, candidates.get(j)); } + throw new RuntimeException("findInnerJoin: can't reach here"); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java index ffb7e16510..da386f8911 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java @@ -130,17 +130,44 @@ class ReorderJoinTest implements PatternMatchSupported { check(plans); } - public void check(List<LogicalPlan> plans) { + @Test + public void testCrossJoin() { + ImmutableList<LogicalPlan> plans = ImmutableList.of( + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0))) + .build(), + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan2, JoinType.CROSS_JOIN) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0))) + .build() + ); + for (LogicalPlan plan : plans) { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyBottomUp(new ReorderJoin()) + .matchesFromRoot( + logicalJoin( + logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()), + leafPlan() + ).when(join -> join.getJoinType().isCrossJoin()) + ); + } + } + + public void check(List<LogicalPlan> plans) { + for (LogicalPlan plan : plans) { + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .rewrite() + .printlnTree() .matchesFromRoot( logicalJoin( logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()), leafPlan() ).whenNot(join -> join.getJoinType().isCrossJoin()) - ) - .printlnTree(); + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java index 230c9cc245..5beb12445c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; import org.apache.doris.nereids.util.PlanChecker; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.List; @@ -29,8 +30,9 @@ public class MultiJoinTest extends SqlTestBase { @Test void testMultiJoinEliminateCross() { List<String> sqls = ImmutableList.<String>builder() - .add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id") .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id") + .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0") + .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id AND T1.score > 0 AND T1.id + T2.id + T3.id > 0") .build(); for (String sql : sqls) { @@ -47,6 +49,41 @@ public class MultiJoinTest extends SqlTestBase { } } + @Test + @Disabled + // TODO: MultiJoin And EliminateOuter + void testEliminateBelowOuter() { + String sql = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id"; + PlanChecker.from(connectContext) + .analyze(sql) + .applyBottomUp(new ReorderJoin()) + .printlnTree(); + } + + @Test + void testPushdownAndEliminateOuter() { + String sql = "SELECT * FROM T1 LEFT JOIN T2 ON T1.id = T2.id WHERE T2.score > 0"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .printlnTree() + .matches( + logicalJoin().when(join -> join.getJoinType().isInnerJoin()) + ); + + String sql1 = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id AND T3.score > 0"; + PlanChecker.from(connectContext) + .analyze(sql1) + .rewrite() + .printlnTree() + .matches( + logicalJoin( + logicalJoin().when(join -> join.getJoinType().isInnerJoin()), + any() + ).when(join -> join.getJoinType().isInnerJoin()) + ); + } + @Test void testMultiJoinExistCross() { List<String> sqls = ImmutableList.<String>builder() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index ee364c7854..3ef68ead77 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -363,6 +363,7 @@ public class PlanChecker { public PlanChecker printlnTree() { System.out.println(cascadesContext.getMemo().copyOut().treeString()); + System.out.println("-----------------------------"); return this; } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org