This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 2b6133f4d0 [feature](Nereids): pushdown complex project through inner/outer Join. (#17365) 2b6133f4d0 is described below commit 2b6133f4d0742ffa2446a46eee278cb909cd6de4 Author: jakevin <jakevin...@gmail.com> AuthorDate: Wed Mar 8 12:00:56 2023 +0800 [feature](Nereids): pushdown complex project through inner/outer Join. (#17365) --- .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../rules/exploration/join/JoinReorderUtils.java | 23 ++-- .../join/PushdownProjectThroughInnerJoin.java | 104 ++++++++++++++ .../join/PushdownProjectThroughSemiJoin.java | 36 +++-- .../join/PushdownProjectThroughInnerJoinTest.java | 151 +++++++++++++++++++++ .../join/PushdownProjectThroughSemiJoinTest.java | 27 ++++ 6 files changed, 313 insertions(+), 29 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 13c47df6a4..c439a13ce5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -235,6 +235,7 @@ public enum RuleType { LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION), + PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION), // implementation rules LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java index b723e2e4a1..36c71cf01f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java @@ -38,15 +38,6 @@ import java.util.stream.Stream; * Common */ class JoinReorderUtils { - /** - * check project Expression Input Slot just contains one slot, like: - * - one SlotReference like a.id - * - Input Slot size == 1, like abs(a.id) + 1 - */ - static boolean isOneSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) { - return project.getProjects().stream().allMatch(expr -> expr.getInputSlotExprIds().size() == 1); - } - static boolean isAllSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) { return project.getProjects().stream().allMatch(expr -> expr instanceof Slot); } @@ -78,6 +69,13 @@ class JoinReorderUtils { return new LogicalProject<>(projectExprs, plan); } + public static Plan projectOrSelfInOrder(List<NamedExpression> projectExprs, Plan plan) { + if (projectExprs.isEmpty() || projectExprs.equals(plan.getOutput())) { + return plan; + } + return new LogicalProject<>(projectExprs, plan); + } + /** * replace JoinConjuncts by using slots map. */ @@ -111,4 +109,11 @@ class JoinReorderUtils { } }); } + + public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends Plan, ? extends Plan> join, boolean left) { + Set<Slot> childSlots = left ? join.left().getOutputSet() : join.right().getOutputSet(); + return join.getConditionSlot().stream() + .filter(childSlots::contains) + .collect(Collectors.toSet()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java new file mode 100644 index 0000000000..c46f2252a2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * rule for pushdown project through inner/outer join + */ +public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory { + public static final PushdownProjectThroughInnerJoin INSTANCE = new PushdownProjectThroughInnerJoin(); + + /* + * Project Join + * | ──► / \ + * Join Project Project + * / \ | | + * A B A B + */ + @Override + public Rule build() { + return logicalProject(logicalJoin()) + .when(project -> project.child().getJoinType().isInnerJoin()) + .whenNot(project -> project.child().hasJoinHint()) + .then(project -> { + LogicalJoin<GroupPlan, GroupPlan> join = project.child(); + Set<ExprId> aOutputExprIdSet = join.left().getOutputExprIdSet(); + Set<ExprId> bOutputExprIdSet = join.right().getOutputExprIdSet(); + + // reject hyper edge in Project. + if (!project.getProjects().stream().allMatch(expr -> { + Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds(); + return aOutputExprIdSet.containsAll(inputSlotExprIds) + || bOutputExprIdSet.containsAll(inputSlotExprIds); + })) { + return null; + } + + Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(project.getProjects(), + join.left()); + List<NamedExpression> aProjects = map.get(true); + List<NamedExpression> bProjects = map.get(false); + + boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot)); + boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot)); + // due to JoinCommute, we don't need to consider just right contains. + if (!leftContains) { + return null; + } + + Builder<NamedExpression> newAProject = ImmutableList.<NamedExpression>builder().addAll(aProjects); + Set<Slot> aConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, true); + Set<Slot> aProjectSlots = aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet()); + aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add); + Plan newLeft = JoinReorderUtils.projectOrSelf(newAProject.build(), join.left()); + + if (!rightContains) { + Plan newJoin = join.withChildren(newLeft, join.right()); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); + } + + Builder<NamedExpression> newBProject = ImmutableList.<NamedExpression>builder().addAll(bProjects); + Set<Slot> bConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, false); + Set<Slot> bProjectSlots = bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet()); + bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add); + Plan newRight = JoinReorderUtils.projectOrSelf(newBProject.build(), join.right()); + + Plan newJoin = join.withChildren(newLeft, newRight); + return JoinReorderUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java index 57ea9df15d..172de009a7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java @@ -49,27 +49,23 @@ public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory { @Override public Rule build() { return logicalProject(logicalJoin()) - .when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin()) - .when(JoinReorderUtils::isOneSlotProject) - // Just pushdown project with non-column expr like (t.id + 1) - .whenNot(JoinReorderUtils::isAllSlotProject) - .whenNot(project -> project.child().hasJoinHint()) - .then(project -> { - LogicalJoin<GroupPlan, GroupPlan> join = project.child(); - Set<Slot> aOutputExprIdSet = join.left().getOutputSet(); - Set<Slot> conditionLeftSlots = join.getConditionSlot().stream() - .filter(aOutputExprIdSet::contains) - .collect(Collectors.toSet()); + .when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin()) + // Just pushdown project with non-column expr like (t.id + 1) + .whenNot(JoinReorderUtils::isAllSlotProject) + .whenNot(project -> project.child().hasJoinHint()) + .then(project -> { + LogicalJoin<GroupPlan, GroupPlan> join = project.child(); + Set<Slot> conditionLeftSlots = JoinReorderUtils.joinChildConditionSlots(join, true); - List<NamedExpression> newProject = new ArrayList<>(project.getProjects()); - Set<Slot> projectUsedSlots = project.getProjects().stream() - .map(NamedExpression::toSlot).collect(Collectors.toSet()); - conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)) - .forEach(newProject::add); - Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left()); - Plan newJoin = join.withChildren(newLeft, join.right()); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN); + List<NamedExpression> newProject = new ArrayList<>(project.getProjects()); + Set<Slot> projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot) + .collect(Collectors.toSet()); + conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)).forEach(newProject::add); + Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left()); + + Plan newJoin = join.withChildren(newLeft, join.right()); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN); } List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java new file mode 100644 index 0000000000..7d94fa876c --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + @Test + public void pushBothSide() { + // project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name + List<NamedExpression> projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"), + scan1.getOutput().get(1), + new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"), + scan2.getOutput().get(1) + ); + // complex projection contain ti.id, which isn't in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(1, 1)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalJoin( + logicalProject().when(project -> project.getProjects().size() == 2), + logicalProject().when(project -> project.getProjects().size() == 2) + ) + ); + } + + @Test + public void pushdownProjectInCondition() { + // project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name + List<NamedExpression> projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"), + scan1.getOutput().get(1), + new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"), + scan2.getOutput().get(1) + ); + // complex projection contain ti.id, which is in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalProject( + logicalJoin( + logicalProject().when(project -> project.getProjects().size() == 3), + logicalProject().when(project -> project.getProjects().size() == 3) + ) + ) + ); + } + + @Test + void pushComplexProject() { + // project (t1.id + t1.name) as complex1, (t2.id + t2.name) as complex2 + List<NamedExpression> projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), scan1.getOutput().get(1)), "complex1"), + new Alias(new Add(scan2.getOutput().get(0), scan2.getOutput().get(1)), "complex2") + ); + // complex projection contain ti.id, which is in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalProject( + logicalJoin( + logicalProject() + .when(project -> + project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`") + && project.getProjects().get(1).toSql().equals("id")), + logicalProject() + .when(project -> + project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`") + && project.getProjects().get(1).toSql().equals("id")) + ) + ).when(project -> project.getProjects().get(0).toSql().equals("complex1") + && project.getProjects().get(1).toSql().equals("complex2") + ) + ); + } + + @Test + void rejectHyperEdgeProject() { + // project (t1.id + t2.id) as alias + List<NamedExpression> projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), scan2.getOutput().get(0)), "alias") + ); + // complex projection contain ti.id, which is in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size())); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java index 0a3b7a04ba..b47910f748 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java @@ -95,4 +95,31 @@ class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported { ).when(project -> project.getProjects().size() == 2) ); } + + @Test + void pushComplexProject() { + // project (t1.id + t1.name) as complex + List<NamedExpression> projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), scan1.getOutput().get(1)), "complex")); + // complex projection contain ti.id, which is in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalProject( + leftSemiLogicalJoin( + logicalProject() + .when(project -> project.getProjects().get(0).toSql().equals("(id + name) AS `complex`") + && project.getProjects().get(1).toSql().equals("id")), + logicalOlapScan() + ) + ).when(project -> project.getProjects().get(0).toSql().equals("complex")) + ); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org