This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b6133f4d0 [feature](Nereids): pushdown complex project through 
inner/outer Join. (#17365)
2b6133f4d0 is described below

commit 2b6133f4d0742ffa2446a46eee278cb909cd6de4
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Wed Mar 8 12:00:56 2023 +0800

    [feature](Nereids): pushdown complex project through inner/outer Join. 
(#17365)
---
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../rules/exploration/join/JoinReorderUtils.java   |  23 ++--
 .../join/PushdownProjectThroughInnerJoin.java      | 104 ++++++++++++++
 .../join/PushdownProjectThroughSemiJoin.java       |  36 +++--
 .../join/PushdownProjectThroughInnerJoinTest.java  | 151 +++++++++++++++++++++
 .../join/PushdownProjectThroughSemiJoinTest.java   |  27 ++++
 6 files changed, 313 insertions(+), 29 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 13c47df6a4..c439a13ce5 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -235,6 +235,7 @@ public enum RuleType {
     LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
     LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
     PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
+    PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
 
     // implementation rules
     
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
index b723e2e4a1..36c71cf01f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java
@@ -38,15 +38,6 @@ import java.util.stream.Stream;
  * Common
  */
 class JoinReorderUtils {
-    /**
-     * check project Expression Input Slot just contains one slot, like:
-     * - one SlotReference like a.id
-     * - Input Slot size == 1, like abs(a.id) + 1
-     */
-    static boolean isOneSlotProject(LogicalProject<LogicalJoin<GroupPlan, 
GroupPlan>> project) {
-        return project.getProjects().stream().allMatch(expr -> 
expr.getInputSlotExprIds().size() == 1);
-    }
-
     static boolean isAllSlotProject(LogicalProject<LogicalJoin<GroupPlan, 
GroupPlan>> project) {
         return project.getProjects().stream().allMatch(expr -> expr instanceof 
Slot);
     }
@@ -78,6 +69,13 @@ class JoinReorderUtils {
         return new LogicalProject<>(projectExprs, plan);
     }
 
+    public static Plan projectOrSelfInOrder(List<NamedExpression> 
projectExprs, Plan plan) {
+        if (projectExprs.isEmpty() || projectExprs.equals(plan.getOutput())) {
+            return plan;
+        }
+        return new LogicalProject<>(projectExprs, plan);
+    }
+
     /**
      * replace JoinConjuncts by using slots map.
      */
@@ -111,4 +109,11 @@ class JoinReorderUtils {
             }
         });
     }
+
+    public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends 
Plan, ? extends Plan> join, boolean left) {
+        Set<Slot> childSlots = left ? join.left().getOutputSet() : 
join.right().getOutputSet();
+        return join.getConditionSlot().stream()
+                .filter(childSlots::contains)
+                .collect(Collectors.toSet());
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java
new file mode 100644
index 0000000000..c46f2252a2
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * rule for pushdown project through inner/outer join
+ */
+public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory 
{
+    public static final PushdownProjectThroughInnerJoin INSTANCE = new 
PushdownProjectThroughInnerJoin();
+
+    /*
+     *    Project                   Join
+     *      |            ──►       /    \
+     *     Join               Project  Project
+     *    /   \                  |       |
+     *   A     B                 A       B
+     */
+    @Override
+    public Rule build() {
+        return logicalProject(logicalJoin())
+            .when(project -> project.child().getJoinType().isInnerJoin())
+            .whenNot(project -> project.child().hasJoinHint())
+            .then(project -> {
+                LogicalJoin<GroupPlan, GroupPlan> join = project.child();
+                Set<ExprId> aOutputExprIdSet = 
join.left().getOutputExprIdSet();
+                Set<ExprId> bOutputExprIdSet = 
join.right().getOutputExprIdSet();
+
+                // reject hyper edge in Project.
+                if (!project.getProjects().stream().allMatch(expr -> {
+                    Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
+                    return aOutputExprIdSet.containsAll(inputSlotExprIds)
+                            || bOutputExprIdSet.containsAll(inputSlotExprIds);
+                })) {
+                    return null;
+                }
+
+                Map<Boolean, List<NamedExpression>> map = 
JoinReorderUtils.splitProjection(project.getProjects(),
+                        join.left());
+                List<NamedExpression> aProjects = map.get(true);
+                List<NamedExpression> bProjects = map.get(false);
+
+                boolean leftContains = aProjects.stream().anyMatch(e -> !(e 
instanceof Slot));
+                boolean rightContains = bProjects.stream().anyMatch(e -> !(e 
instanceof Slot));
+                // due to JoinCommute, we don't need to consider just right 
contains.
+                if (!leftContains) {
+                    return null;
+                }
+
+                Builder<NamedExpression> newAProject = 
ImmutableList.<NamedExpression>builder().addAll(aProjects);
+                Set<Slot> aConditionSlots = 
JoinReorderUtils.joinChildConditionSlots(join, true);
+                Set<Slot> aProjectSlots = 
aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
+                aConditionSlots.stream().filter(slot -> 
!aProjectSlots.contains(slot)).forEach(newAProject::add);
+                Plan newLeft = 
JoinReorderUtils.projectOrSelf(newAProject.build(), join.left());
+
+                if (!rightContains) {
+                    Plan newJoin = join.withChildren(newLeft, join.right());
+                    return JoinReorderUtils.projectOrSelf(new 
ArrayList<>(project.getOutput()), newJoin);
+                }
+
+                Builder<NamedExpression> newBProject = 
ImmutableList.<NamedExpression>builder().addAll(bProjects);
+                Set<Slot> bConditionSlots = 
JoinReorderUtils.joinChildConditionSlots(join, false);
+                Set<Slot> bProjectSlots = 
bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
+                bConditionSlots.stream().filter(slot -> 
!bProjectSlots.contains(slot)).forEach(newBProject::add);
+                Plan newRight = 
JoinReorderUtils.projectOrSelf(newBProject.build(), join.right());
+
+                Plan newJoin = join.withChildren(newLeft, newRight);
+                return JoinReorderUtils.projectOrSelfInOrder(new 
ArrayList<>(project.getOutput()), newJoin);
+            }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
index 57ea9df15d..172de009a7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java
@@ -49,27 +49,23 @@ public class PushdownProjectThroughSemiJoin extends 
OneExplorationRuleFactory {
     @Override
     public Rule build() {
         return logicalProject(logicalJoin())
-                .when(project -> 
project.child().getJoinType().isLeftSemiOrAntiJoin())
-                .when(JoinReorderUtils::isOneSlotProject)
-                // Just pushdown project with non-column expr like (t.id + 1)
-                .whenNot(JoinReorderUtils::isAllSlotProject)
-                .whenNot(project -> project.child().hasJoinHint())
-                .then(project -> {
-                    LogicalJoin<GroupPlan, GroupPlan> join = project.child();
-                    Set<Slot> aOutputExprIdSet = join.left().getOutputSet();
-                    Set<Slot> conditionLeftSlots = 
join.getConditionSlot().stream()
-                            .filter(aOutputExprIdSet::contains)
-                            .collect(Collectors.toSet());
+            .when(project -> 
project.child().getJoinType().isLeftSemiOrAntiJoin())
+            // Just pushdown project with non-column expr like (t.id + 1)
+            .whenNot(JoinReorderUtils::isAllSlotProject)
+            .whenNot(project -> project.child().hasJoinHint())
+            .then(project -> {
+                LogicalJoin<GroupPlan, GroupPlan> join = project.child();
+                Set<Slot> conditionLeftSlots = 
JoinReorderUtils.joinChildConditionSlots(join, true);
 
-                    List<NamedExpression> newProject = new 
ArrayList<>(project.getProjects());
-                    Set<Slot> projectUsedSlots = project.getProjects().stream()
-                            
.map(NamedExpression::toSlot).collect(Collectors.toSet());
-                    conditionLeftSlots.stream().filter(slot -> 
!projectUsedSlots.contains(slot))
-                            .forEach(newProject::add);
-                    Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, 
join.left());
-                    Plan newJoin = join.withChildren(newLeft, join.right());
-                    return JoinReorderUtils.projectOrSelf(new 
ArrayList<>(project.getOutput()), newJoin);
-                }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
+                List<NamedExpression> newProject = new 
ArrayList<>(project.getProjects());
+                Set<Slot> projectUsedSlots = 
project.getProjects().stream().map(NamedExpression::toSlot)
+                        .collect(Collectors.toSet());
+                conditionLeftSlots.stream().filter(slot -> 
!projectUsedSlots.contains(slot)).forEach(newProject::add);
+                Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, 
join.left());
+
+                Plan newJoin = join.withChildren(newLeft, join.right());
+                return JoinReorderUtils.projectOrSelf(new 
ArrayList<>(project.getOutput()), newJoin);
+            }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
     }
 
     List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java
new file mode 100644
index 0000000000..7d94fa876c
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.join;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+
+class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported 
{
+    private final LogicalOlapScan scan1 = 
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+    private final LogicalOlapScan scan2 = 
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+
+    @Test
+    public void pushBothSide() {
+        // project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
+        List<NamedExpression> projectExprs = ImmutableList.of(
+                new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), 
"alias"),
+                scan1.getOutput().get(1),
+                new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), 
"alias"),
+                scan2.getOutput().get(1)
+        );
+        // complex projection contain ti.id, which isn't in Join Condition
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(1, 1))
+                .projectExprs(projectExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                        logicalJoin(
+                                logicalProject().when(project -> 
project.getProjects().size() == 2),
+                                logicalProject().when(project -> 
project.getProjects().size() == 2)
+                        )
+                );
+    }
+
+    @Test
+    public void pushdownProjectInCondition() {
+        // project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
+        List<NamedExpression> projectExprs = ImmutableList.of(
+                new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), 
"alias"),
+                scan1.getOutput().get(1),
+                new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), 
"alias"),
+                scan2.getOutput().get(1)
+        );
+        // complex projection contain ti.id, which is in Join Condition
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .projectExprs(projectExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                        logicalProject(
+                                logicalJoin(
+                                        logicalProject().when(project -> 
project.getProjects().size() == 3),
+                                        logicalProject().when(project -> 
project.getProjects().size() == 3)
+                                )
+                        )
+                );
+    }
+
+    @Test
+    void pushComplexProject() {
+        // project (t1.id + t1.name) as complex1, (t2.id + t2.name) as complex2
+        List<NamedExpression> projectExprs = ImmutableList.of(
+                new Alias(new Add(scan1.getOutput().get(0), 
scan1.getOutput().get(1)), "complex1"),
+                new Alias(new Add(scan2.getOutput().get(0), 
scan2.getOutput().get(1)), "complex2")
+        );
+        // complex projection contain ti.id, which is in Join Condition
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .projectExprs(projectExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                    logicalProject(
+                        logicalJoin(
+                            logicalProject()
+                                .when(project ->
+                                        
project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`")
+                                        && 
project.getProjects().get(1).toSql().equals("id")),
+                            logicalProject()
+                                .when(project ->
+                                        
project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`")
+                                        && 
project.getProjects().get(1).toSql().equals("id"))
+                        )
+                    ).when(project -> 
project.getProjects().get(0).toSql().equals("complex1")
+                            && 
project.getProjects().get(1).toSql().equals("complex2")
+                    )
+                );
+    }
+
+    @Test
+    void rejectHyperEdgeProject() {
+        // project (t1.id + t2.id) as alias
+        List<NamedExpression> projectExprs = ImmutableList.of(
+                new Alias(new Add(scan1.getOutput().get(0), 
scan2.getOutput().get(0)), "alias")
+        );
+        // complex projection contain ti.id, which is in Join Condition
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .projectExprs(projectExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
+                .checkMemo(memo -> Assertions.assertEquals(1, 
memo.getRoot().getLogicalExpressions().size()));
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
index 0a3b7a04ba..b47910f748 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java
@@ -95,4 +95,31 @@ class PushdownProjectThroughSemiJoinTest implements 
MemoPatternMatchSupported {
                         ).when(project -> project.getProjects().size() == 2)
                 );
     }
+
+    @Test
+    void pushComplexProject() {
+        // project (t1.id + t1.name) as complex
+        List<NamedExpression> projectExprs = ImmutableList.of(
+                new Alias(new Add(scan1.getOutput().get(0), 
scan1.getOutput().get(1)), "complex"));
+        // complex projection contain ti.id, which is in Join Condition
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
+                .projectExprs(projectExprs)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                    logicalProject(
+                        leftSemiLogicalJoin(
+                            logicalProject()
+                                .when(project -> 
project.getProjects().get(0).toSql().equals("(id + name) AS `complex`")
+                                    && 
project.getProjects().get(1).toSql().equals("id")),
+                            logicalOlapScan()
+                        )
+                    ).when(project -> 
project.getProjects().get(0).toSql().equals("complex"))
+                );
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to