This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new db8bc80c36 [feature](Nereids): semi join transpose (#12590) db8bc80c36 is described below commit db8bc80c36629c0c5d71f6a7cd78aefeb6733e2a Author: jakevin <jakevin...@gmail.com> AuthorDate: Thu Sep 15 21:32:50 2022 +0800 [feature](Nereids): semi join transpose (#12590) * [feature](Nereids): semi join transpose and enable ZIG_ZAG join reorder. --- .../org/apache/doris/nereids/rules/RuleSet.java | 11 ++ .../org/apache/doris/nereids/rules/RuleType.java | 7 +- .../rules/exploration/join/JoinLAsscomProject.java | 2 +- .../join/SemiJoinLogicalJoinTranspose.java | 50 ++++++-- .../join/SemiJoinLogicalJoinTransposeProject.java | 111 ++++++++++------- .../join/SemiJoinSemiJoinTranspose.java | 3 +- .../logical/PushdownProjectThroughLimit.java | 2 +- .../SemiJoinLogicalJoinTransposeProjectTest.java | 134 +++++++++++++++++++++ .../join/SemiJoinLogicalJoinTransposeTest.java | 126 +++++++++++++++++++ ...est.java => SemiJoinSemiJoinTransposeTest.java} | 13 +- .../doris/nereids/util/LogicalPlanBuilder.java | 6 +- 11 files changed, 394 insertions(+), 71 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index cf2c0ce311..7f11734e1f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -21,6 +21,9 @@ import org.apache.doris.nereids.rules.exploration.join.JoinCommute; import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject; import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom; import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject; +import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose; +import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject; +import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose; import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows; import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation; @@ -55,6 +58,9 @@ public class RuleSet { .add(JoinCommuteProject.LEFT_DEEP) .add(JoinLAsscom.INNER) .add(JoinLAsscomProject.INNER) + .add(SemiJoinLogicalJoinTranspose.LEFT_DEEP) + .add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP) + .add(SemiJoinSemiJoinTranspose.INSTANCE) .add(new PushdownFilterThroughProject()) .add(new MergeConsecutiveProjects()) .build(); @@ -140,6 +146,11 @@ public class RuleSet { return this; } + public RuleFactories addAll(List<Rule> rules) { + this.rules.addAll(rules); + return this; + } + public List<Rule> build() { return rules.build(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index c9fe816970..c40419dfea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -109,8 +109,7 @@ public enum RuleType { OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE), // Pushdown filter PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE), - LOGICAL_LIMIT_TO_LOGICAL_EMPTY_RELATION_RULE(RuleTypeClass.REWRITE), - SWAP_LIMIT_PROJECT(RuleTypeClass.REWRITE), + PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE), REWRITE_SENTINEL(RuleTypeClass.REWRITE), // limit push down @@ -122,7 +121,11 @@ public enum RuleType { LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION), LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION), + LOGICAL_JOIN_L_ASSCOM_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION), + LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION), + LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), + LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION), // implementation rules LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java index 5bbd120b52..8c45afaa04 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java @@ -75,6 +75,6 @@ public class JoinLAsscomProject extends OneExplorationRuleFactory { return null; } return helper.newTopJoin(); - }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM_PROJECT); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java index f5c5580d8d..b9ebadfae1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -27,6 +28,7 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.base.Preconditions; +import java.util.List; import java.util.Set; /** @@ -42,9 +44,21 @@ import java.util.Set; * which operands actually participate in the semi-join. */ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { + + public static final SemiJoinLogicalJoinTranspose LEFT_DEEP = new SemiJoinLogicalJoinTranspose(true); + + public static final SemiJoinLogicalJoinTranspose ALL = new SemiJoinLogicalJoinTranspose(false); + + private final boolean leftDeep; + + public SemiJoinLogicalJoinTranspose(boolean leftDeep) { + this.leftDeep = leftDeep; + } + @Override public Rule build() { return leftSemiLogicalJoin(logicalJoin(), group()) + .whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin()) .when(this::conditionChecker) .then(topSemiJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topSemiJoin.left(); @@ -52,7 +66,14 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { GroupPlan b = bottomJoin.right(); GroupPlan c = topSemiJoin.right(); - boolean lasscom = bottomJoin.getOutputSet().containsAll(a.getOutput()); + List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts(); + Set<Slot> aOutputSet = a.getOutputSet(); + + boolean lasscom = false; + for (Expression hashJoinConjunct : hashJoinConjuncts) { + Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance); + lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom; + } if (lasscom) { /* @@ -81,20 +102,27 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin); } - }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE); } // bottomJoin just return A OR B, else return false. - private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) { - Set<Slot> bottomOutputSet = topJoin.left().getOutputSet(); - - Set<Slot> aOutputSet = topJoin.left().left().getOutputSet(); - Set<Slot> bOutputSet = topJoin.left().right().getOutputSet(); + private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topSemiJoin) { + List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts(); - boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet, aOutputSet); - boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet, bOutputSet); + List<Slot> aOutput = topSemiJoin.left().left().getOutput(); + List<Slot> bOutput = topSemiJoin.left().right().getOutput(); - Preconditions.checkState(isProjectA || isProjectB, "join output must contain child"); - return !(isProjectA && isProjectB); + boolean hashContainsA = false; + boolean hashContainsB = false; + for (Expression hashJoinConjunct : hashJoinConjuncts) { + Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance); + hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA; + hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB; + } + if (leftDeep && hashContainsB) { + return false; + } + Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child"); + return !(hashContainsA && hashContainsB); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index 45cdc7a19e..183a1218d1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -20,15 +20,17 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.base.Preconditions; +import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -45,9 +47,20 @@ import java.util.Set; * which operands actually participate in the semi-join. */ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFactory { + public static final SemiJoinLogicalJoinTransposeProject LEFT_DEEP = new SemiJoinLogicalJoinTransposeProject(true); + + public static final SemiJoinLogicalJoinTransposeProject ALL = new SemiJoinLogicalJoinTransposeProject(false); + + private final boolean leftDeep; + + public SemiJoinLogicalJoinTransposeProject(boolean leftDeep) { + this.leftDeep = leftDeep; + } + @Override public Rule build() { return leftSemiLogicalJoin(logicalProject(logicalJoin()), group()) + .whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin()) .when(this::conditionChecker) .then(topSemiJoin -> { LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left(); @@ -56,67 +69,77 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto GroupPlan b = bottomJoin.right(); GroupPlan c = topSemiJoin.right(); - boolean lasscom = a.getOutputSet().containsAll(project.getOutput()); + Set<Slot> aOutputSet = a.getOutputSet(); + + List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts(); + + boolean lasscom = false; + for (Expression hashJoinConjunct : hashJoinConjuncts) { + Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance); + lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom; + } if (lasscom) { /*- - * topSemiJoin newTopProject - * / \ | + * topSemiJoin project + * / \ | * project C newTopJoin * | -> / \ - * bottomJoin newBottomSemiJoin B + * bottomJoin newBottomSemiJoin B * / \ / \ - * A B aNewProject C - * | - * A + * A B A C */ - List<NamedExpression> projects = project.getProjects(); - LogicalProject<GroupPlan> aNewProject = new LogicalProject<>(projects, a); - LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( + LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), - topSemiJoin.getOtherJoinCondition(), aNewProject, c); - LogicalJoin<LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>, GroupPlan> newTopJoin - = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b); - return new LogicalProject<>(projects, newTopJoin); + topSemiJoin.getOtherJoinCondition(), a, c); + + LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), + bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(), + newBottomSemiJoin, b); + + return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); } else { /*- - * topSemiJoin newTopProject - * / \ | - * project C newTopJoin - * | / \ - * bottomJoin C --> A newBottomSemiJoin - * / \ / \ - * A B bNewProject C - * | - * B + * topSemiJoin project + * / \ | + * project C newTopJoin + * | / \ + * bottomJoin C --> A newBottomSemiJoin + * / \ / \ + * A B B C */ - List<NamedExpression> projects = project.getProjects(); - LogicalProject<GroupPlan> bNewProject = new LogicalProject<>(projects, b); - LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( + LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), - topSemiJoin.getOtherJoinCondition(), bNewProject, c); + topSemiJoin.getOtherJoinCondition(), b, c); + + LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), + bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(), + a, newBottomSemiJoin); - LogicalJoin<GroupPlan, LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>> newTopJoin - = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin); - return new LogicalProject<>(projects, newTopJoin); + return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); } - }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT); } - // bottomJoin just return A OR B, else return false. + // project of bottomJoin just return A OR B, else return false. private boolean conditionChecker( - LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) { - Set<Slot> projectOutputSet = topJoin.left().getOutputSet(); - - Set<Slot> aOutputSet = topJoin.left().child().left().getOutputSet(); - Set<Slot> bOutputSet = topJoin.left().child().right().getOutputSet(); + LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) { + List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts(); - boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet, aOutputSet); - boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet, bOutputSet); + List<Slot> aOutput = topSemiJoin.left().child().left().getOutput(); + List<Slot> bOutput = topSemiJoin.left().child().right().getOutput(); - Preconditions.checkState(isProjectA || isProjectB, "project must contain child"); - return !(isProjectA && isProjectB); + boolean hashContainsA = false; + boolean hashContainsB = false; + for (Expression hashJoinConjunct : hashJoinConjuncts) { + Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance); + hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA; + hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB; + } + if (leftDeep && hashContainsB) { + return false; + } + Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child"); + return !(hashContainsA && hashContainsB); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java index 31d326612d..beab255b89 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java @@ -37,6 +37,7 @@ import java.util.Set; * LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z)) */ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory { + public static final SemiJoinSemiJoinTranspose INSTANCE = new SemiJoinSemiJoinTranspose(); public static Set<Pair<JoinType, JoinType>> typeSet = ImmutableSet.of( Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN), @@ -69,7 +70,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory { newBottomJoin, b); return newTopJoin; - }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE); } private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java index 230e5f2e98..bd56224d26 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownProjectThroughLimit.java @@ -54,6 +54,6 @@ public class PushdownProjectThroughLimit extends OneRewriteRuleFactory { return new LogicalLimit<LogicalProject<GroupPlan>>(logicalLimit.getLimit(), logicalLimit.getOffset(), new LogicalProject<>(logicalProject.getProjects(), logicalLimit.child())); - }).toRule(RuleType.SWAP_LIMIT_PROJECT); + }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGHT_LIMIT); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java new file mode 100644 index 0000000000..8111f5b92e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProjectTest.java @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class SemiJoinLogicalJoinTransposeProjectTest { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + + @Test + public void testSemiJoinLogicalTransposeProjectLAsscom() { + /*- + * topSemiJoin project + * / \ | + * project C newTopJoin + * | -> / \ + * bottomJoin newBottomSemiJoin B + * / \ / \ + * A B A C + */ + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .project(ImmutableList.of(0)) + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); + + LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0); + LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left(); + Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType()); + Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType()); + + LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left(); + LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right(); + LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right(); + + Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName()); + Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName()); + Assertions.assertEquals("t2", newTopJoinRight.getTable().getName()); + }); + } + + @Test + public void testSemiJoinLogicalTransposeProjectLAsscomFail() { + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .project(ImmutableList.of(0, 2)) // t1.id, t2.id + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0)) // t2.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(1, root.getLogicalExpressions().size()); + }); + } + + @Test + public void testSemiJoinLogicalTransposeProjectAll() { + /*- + * topSemiJoin project + * / \ | + * project C newTopJoin + * | / \ + * bottomJoin C --> A newBottomSemiJoin + * / \ / \ + * A B B C + */ + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .project(ImmutableList.of(0, 2)) // t1.id, t2.id + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0)) // t2.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(SemiJoinLogicalJoinTransposeProject.ALL.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); + + LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0); + LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right(); + Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType()); + Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType()); + + LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left(); + LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right(); + LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left(); + + Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName()); + Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName()); + Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName()); + }); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeTest.java new file mode 100644 index 0000000000..29ba945dbb --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeTest.java @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class SemiJoinLogicalJoinTransposeTest { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + + @Test + public void testSemiJoinLogicalTransposeLAsscom() { + /* + * topSemiJoin newTopJoin + * / \ / \ + * bottomJoin C --> newBottomSemiJoin B + * / \ / \ + * A B A C + */ + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); + + LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan; + LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left(); + Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType()); + Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType()); + + LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left(); + LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right(); + LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right(); + + Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName()); + Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName()); + Assertions.assertEquals("t2", newTopJoinRight.getTable().getName()); + }); + } + + @Test + public void testSemiJoinLogicalTransposeLAsscomFail() { + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0)) // t2.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(1, root.getLogicalExpressions().size()); + }); + } + + @Test + public void testSemiJoinLogicalTransposeAll() { + /* + * topSemiJoin newTopJoin + * / \ / \ + * bottomJoin C --> A newBottomSemiJoin + * / \ / \ + * A B B C + */ + LogicalPlan topJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0)) // t2.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(SemiJoinLogicalJoinTranspose.ALL.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); + + LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan; + LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right(); + Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType()); + Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType()); + + LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left(); + LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left(); + LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right(); + + Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName()); + Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName()); + Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName()); + }); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeTest.java similarity index 83% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeTest.java index c7aa852449..3091d44613 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinTransposeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeTest.java @@ -29,11 +29,10 @@ import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; -import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -public class SemiJoinTransposeTest { +public class SemiJoinSemiJoinTransposeTest { public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); @@ -41,21 +40,19 @@ public class SemiJoinTransposeTest { @Test public void testSemiJoinLogicalTransposeCommute() { LogicalPlan topJoin = new LogicalPlanBuilder(scan1) - .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) - .project(ImmutableList.of(0)) + .hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0)) .hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) - .transform((new SemiJoinLogicalJoinTransposeProject()).build()) + .transform(SemiJoinSemiJoinTranspose.INSTANCE.build()) .checkMemo(memo -> { Group root = memo.getRoot(); Assertions.assertEquals(2, root.getLogicalExpressions().size()); - Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false); + Plan join = memo.copyOut(root.getLogicalExpressions().get(1), false); - Plan join = plan.child(0); Assertions.assertTrue(join instanceof LogicalJoin); - Assertions.assertEquals(JoinType.INNER_JOIN, ((LogicalJoin<?, ?>) join).getJoinType()); + Assertions.assertEquals(JoinType.LEFT_ANTI_JOIN, ((LogicalJoin<?, ?>) join).getJoinType()); Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, ((LogicalJoin<?, ?>) ((LogicalJoin<?, ?>) join).left()).getJoinType()); }); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index 950647a674..29ea98d708 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -58,10 +58,10 @@ public class LogicalPlanBuilder { return from(project); } - public LogicalPlanBuilder project(List<Integer> slots) { + public LogicalPlanBuilder project(List<Integer> slotsIndex) { List<NamedExpression> projectExprs = Lists.newArrayList(); - for (int i = 0; i < slots.size(); i++) { - projectExprs.add(this.plan.getOutput().get(i)); + for (Integer index : slotsIndex) { + projectExprs.add(this.plan.getOutput().get(index)); } LogicalProject<LogicalPlan> project = new LogicalProject<>(projectExprs, this.plan); return from(project); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org