This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 53c624ffa0924545ae870e190910a6e83840ca7b
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Tue Jan 30 11:53:07 2024 +0800

    [feat](Nereids): support alias when eliminate join for partially mv 
rewritting #30498
---
 .../jobs/joinorder/hypergraph/GraphSimplifier.java |   3 +-
 .../jobs/joinorder/hypergraph/HyperGraph.java      |  39 ++++++-
 .../jobs/joinorder/hypergraph/edge/Edge.java       |   4 +
 .../jobs/joinorder/hypergraph/edge/JoinEdge.java   |  19 +++-
 .../rules/exploration/mv/HyperGraphComparator.java | 116 +++++++++++++--------
 .../rules/exploration/mv/EliminateJoinTest.java    |  25 ++++-
 6 files changed, 154 insertions(+), 52 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
index 5539e894cc7..61738f2fca6 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java
@@ -33,6 +33,7 @@ import org.apache.doris.nereids.util.JoinUtils;
 import org.apache.doris.statistics.Statistics;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableSet;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
@@ -417,7 +418,7 @@ public class GraphSimplifier {
 
         JoinEdge newEdge = new JoinEdge(join, edge.getIndex(),
                 edge.getLeftChildEdges(), edge.getRightChildEdges(), 
edge.getSubTreeNodes(),
-                edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
+                edge.getLeftRequiredNodes(), edge.getRightRequiredNodes(), 
ImmutableSet.of(), ImmutableSet.of());
         newEdge.addLeftExtendNode(leftNodes);
         newEdge.addRightExtendNode(rightNodes);
         return newEdge;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
index 8472837d79e..0099ab30afc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
@@ -46,14 +46,17 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 
 import java.util.ArrayList;
 import java.util.BitSet;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 
 /**
  * The graph is a join graph, whose node is the leaf plan and edge is a join 
operator.
@@ -288,6 +291,36 @@ public class HyperGraph {
         return new HyperGraph.Builder().buildHyperGraphForMv(plan);
     }
 
+    /**
+     * map output to requires output and construct named expressions
+     */
+    public @Nullable List<NamedExpression> getNamedExpressions(
+            long nodeMap, Set<Slot> outputSet, Set<Slot> requireOutputs) {
+        List<NamedExpression> output = new ArrayList<>();
+        List<NamedExpression> projects = getComplexProject().get(nodeMap);
+        if (projects == null) {
+            return null;
+        }
+        for (Slot slot : requireOutputs) {
+            if (outputSet.contains(slot)) {
+                output.add(slot);
+            } else {
+                Optional<NamedExpression> expr = projects.stream()
+                        .filter(p -> p.toSlot().equals(slot))
+                        .findFirst();
+                if (!expr.isPresent()) {
+                    return null;
+                }
+                // TODO: consider cascades alias
+                if (!outputSet.containsAll(expr.get().getInputSlots())) {
+                    return null;
+                }
+                output.add(expr.get());
+            }
+        }
+        return output;
+    }
+
     /**
      * Builder of HyperGraph
      */
@@ -509,6 +542,10 @@ public class HyperGraph {
             }
 
             BitSet curJoinEdges = new BitSet();
+            Set<Slot> leftInputSlots = ImmutableSet.copyOf(
+                    Sets.intersection(join.getInputSlots(), 
join.left().getOutputSet()));
+            Set<Slot> rightInputSlots = ImmutableSet.copyOf(
+                    Sets.intersection(join.getInputSlots(), 
join.right().getOutputSet()));
             for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, 
List<Expression>>> entry : conjuncts
                     .entrySet()) {
                 LogicalJoin<?, ?> singleJoin = new 
LogicalJoin<>(join.getJoinType(), entry.getValue().first,
@@ -518,7 +555,7 @@ public class HyperGraph {
                 Pair<Long, Long> ends = entry.getKey();
                 JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), 
leftEdgeNodes.first, rightEdgeNodes.first,
                         LongBitmap.newBitmapUnion(leftEdgeNodes.second, 
rightEdgeNodes.second),
-                        ends.first, ends.second);
+                        ends.first, ends.second, leftInputSlots, 
rightInputSlots);
                 for (int nodeIndex : 
LongBitmap.getIterator(edge.getReferenceNodes())) {
                     nodes.get(nodeIndex).attachEdge(edge);
                 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java
index 2530074931e..c41d9e270ac 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java
@@ -79,6 +79,10 @@ public abstract class Edge {
         return LongBitmap.getCardinality(leftExtendedNodes) == 1 && 
LongBitmap.getCardinality(rightExtendedNodes) == 1;
     }
 
+    public boolean isRightSimple() {
+        return LongBitmap.getCardinality(rightExtendedNodes) == 1;
+    }
+
     public void addLeftRejectEdge(JoinEdge edge) {
         leftRejectEdges.add(edge);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
index c23be5f16eb..94f4b30e8d4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
@@ -37,12 +37,16 @@ import javax.annotation.Nullable;
 public class JoinEdge extends Edge {
 
     private final LogicalJoin<? extends Plan, ? extends Plan> join;
+    private final Set<Slot> leftInputSlots;
+    private final Set<Slot> rightInputSlots;
 
     public JoinEdge(LogicalJoin<? extends Plan, ? extends Plan> join, int 
index,
             BitSet leftChildEdges, BitSet rightChildEdges, long subTreeNodes,
-            long leftRequireNodes, long rightRequireNodes) {
+            long leftRequireNodes, long rightRequireNodes, Set<Slot> 
leftInputSlots, Set<Slot> rightInputSlots) {
         super(index, leftChildEdges, rightChildEdges, subTreeNodes, 
leftRequireNodes, rightRequireNodes);
         this.join = join;
+        this.leftInputSlots = leftInputSlots;
+        this.rightInputSlots = rightInputSlots;
     }
 
     /**
@@ -51,7 +55,8 @@ public class JoinEdge extends Edge {
     public JoinEdge swap() {
         JoinEdge swapEdge = new
                 JoinEdge(join.swap(), getIndex(), getRightChildEdges(),
-                getLeftChildEdges(), getSubTreeNodes(), 
getRightRequiredNodes(), getLeftRequiredNodes());
+                getLeftChildEdges(), getSubTreeNodes(), 
getRightRequiredNodes(), getLeftRequiredNodes(),
+                this.rightInputSlots, this.leftInputSlots);
         swapEdge.addLeftRejectEdges(getLeftRejectEdge());
         swapEdge.addRightRejectEdges(getRightRejectEdge());
         return swapEdge;
@@ -63,7 +68,7 @@ public class JoinEdge extends Edge {
 
     public JoinEdge withJoinTypeAndCleanCR(JoinType joinType) {
         return new JoinEdge(join.withJoinType(joinType), getIndex(), 
getLeftChildEdges(), getRightChildEdges(),
-                getSubTreeNodes(), getLeftRequiredNodes(), 
getRightRequiredNodes());
+                getSubTreeNodes(), getLeftRequiredNodes(), 
getRightRequiredNodes(), leftInputSlots, rightInputSlots);
     }
 
     public LogicalJoin<? extends Plan, ? extends Plan> getJoin() {
@@ -112,4 +117,12 @@ public class JoinEdge extends Edge {
         join.getExpressions().forEach(expression -> 
slots.addAll(expression.getInputSlots()));
         return slots;
     }
+
+    public Set<Slot> getLeftInputSlots() {
+        return leftInputSlots;
+    }
+
+    public Set<Slot> getRightInputSlots() {
+        return rightInputSlots;
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
index 3339d009c79..341caa88094 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
@@ -27,9 +27,11 @@ import 
org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
 import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
 import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.JoinUtils;
 
@@ -46,6 +48,7 @@ import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 
 /**
  * HyperGraphComparator
@@ -144,60 +147,81 @@ public class HyperGraphComparator {
         return buildComparisonRes();
     }
 
-    private boolean tryEliminateNodesAndEdge() {
-        boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream()
-                .filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) 
== 1)
-                .anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), 
eliminateViewNodesMap));
-        if (hasFilterEdgeAbove) {
-            // If there is some filter edge above the eliminated node, we 
should rebuild a plan
-            // Right now, just refuse it.
+    private @Nullable Plan constructViewPlan(long nodeBitmap, Set<Slot> 
requireOutputs) {
+        if (LongBitmap.getCardinality(nodeBitmap) != 1) {
+            return null;
+        }
+        Plan basePlan = 
viewHyperGraph.getNode(LongBitmap.lowestOneIndex(nodeBitmap)).getPlan();
+        if (basePlan.getOutputSet().containsAll(requireOutputs)) {
+            return basePlan;
+        }
+        List<NamedExpression> projects = viewHyperGraph
+                .getNamedExpressions(nodeBitmap, basePlan.getOutputSet(), 
requireOutputs);
+        if (projects == null) {
+            return null;
+        }
+        return new LogicalProject<>(projects, basePlan);
+    }
+
+    private boolean canEliminatePrimaryByForeign(long primaryNodes, long 
foreignNodes,
+            Set<Slot> primarySlots, Set<Slot> foreignSlots, JoinEdge joinEdge) 
{
+        Plan foreign = constructViewPlan(foreignNodes, foreignSlots);
+        Plan primary = constructViewPlan(primaryNodes, primarySlots);
+        if (foreign == null || primary == null) {
             return false;
         }
-        for (JoinEdge joinEdge : viewHyperGraph.getJoinEdges()) {
-            if (!LongBitmap.isOverlap(joinEdge.getReferenceNodes(), 
eliminateViewNodesMap)) {
-                continue;
+        return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, 
foreign);
+    }
+
+    private boolean canEliminateViewEdge(JoinEdge joinEdge) {
+        // eliminate by unique
+        if (joinEdge.getJoinType().isLeftOuterJoin() && 
joinEdge.isRightSimple()) {
+            long eliminatedRight =
+                    
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), 
eliminateViewNodesMap);
+            if (LongBitmap.getCardinality(eliminatedRight) != 1) {
+                return false;
             }
-            // eliminate by unique
-            if (joinEdge.getJoinType().isLeftOuterJoin()) {
-                long eliminatedRight =
-                        
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), 
eliminateViewNodesMap);
-                if (LongBitmap.getCardinality(eliminatedRight) != 1) {
-                    return false;
-                }
-                Plan rigthPlan = viewHyperGraph
-                        
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
-                return JoinUtils.canEliminateByLeft(joinEdge.getJoin(),
-                        
rigthPlan.getLogicalProperties().getFunctionalDependencies());
+            Plan rigthPlan = 
constructViewPlan(joinEdge.getRightExtendedNodes(), 
joinEdge.getRightInputSlots());
+            if (rigthPlan == null) {
+                return false;
             }
-            // eliminate by pk fk
-            if (joinEdge.getJoinType().isInnerJoin()) {
-                if (!joinEdge.isSimple()) {
-                    return false;
-                }
-                long eliminatedLeft =
-                        
LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), 
eliminateViewNodesMap);
-                long eliminatedRight =
-                        
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), 
eliminateViewNodesMap);
-                if (LongBitmap.getCardinality(eliminatedLeft) == 0
-                        && LongBitmap.getCardinality(eliminatedRight) == 1) {
-                    Plan foreign = viewHyperGraph
-                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
-                    Plan primary = viewHyperGraph
-                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
-                    return JoinUtils.canEliminateByFk(joinEdge.getJoin(), 
primary, foreign);
-                } else if (LongBitmap.getCardinality(eliminatedLeft) == 1
-                        && LongBitmap.getCardinality(eliminatedRight) == 0) {
-                    Plan foreign = viewHyperGraph
-                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
-                    Plan primary = viewHyperGraph
-                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
-                    return JoinUtils.canEliminateByFk(joinEdge.getJoin(), 
primary, foreign);
-                }
+            return JoinUtils.canEliminateByLeft(joinEdge.getJoin(),
+                    
rigthPlan.getLogicalProperties().getFunctionalDependencies());
+        }
+        // eliminate by pk fk
+        if (joinEdge.getJoinType().isInnerJoin()) {
+            if (!joinEdge.isSimple()) {
                 return false;
             }
+            long eliminatedLeft =
+                    
LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), 
eliminateViewNodesMap);
+            long eliminatedRight =
+                    
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), 
eliminateViewNodesMap);
+            if (LongBitmap.getCardinality(eliminatedLeft) == 0
+                    && LongBitmap.getCardinality(eliminatedRight) == 1) {
+                return 
canEliminatePrimaryByForeign(joinEdge.getRightExtendedNodes(), 
joinEdge.getLeftExtendedNodes(),
+                        joinEdge.getRightInputSlots(), 
joinEdge.getLeftInputSlots(), joinEdge);
+            } else if (LongBitmap.getCardinality(eliminatedLeft) == 1
+                    && LongBitmap.getCardinality(eliminatedRight) == 0) {
+                return 
canEliminatePrimaryByForeign(joinEdge.getLeftExtendedNodes(), 
joinEdge.getRightExtendedNodes(),
+                        joinEdge.getLeftInputSlots(), 
joinEdge.getRightInputSlots(), joinEdge);
+            }
+        }
+        return false;
+    }
 
+    private boolean tryEliminateNodesAndEdge() {
+        boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream()
+                .filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) 
== 1)
+                .anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), 
eliminateViewNodesMap));
+        if (hasFilterEdgeAbove) {
+            // If there is some filter edge above the eliminated node, we 
should rebuild a plan
+            // Right now, just reject it.
+            return false;
         }
-        return true;
+        return viewHyperGraph.getJoinEdges().stream()
+                .filter(joinEdge -> 
LongBitmap.isOverlap(joinEdge.getReferenceNodes(), eliminateViewNodesMap))
+                .allMatch(this::canEliminateViewEdge);
     }
 
     private boolean compareNodeWithExpr(StructInfoNode query, StructInfoNode 
view) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
index 3e245741178..cc8b7147423 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
@@ -52,10 +52,22 @@ class EliminateJoinTest extends SqlTestBase {
                 .rewrite()
                 .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
                 .getAllPlan().get(0).child(0);
+        CascadesContext c3 = createCascadesContext(
+                "select * from T1 left outer join (select id as id2 from T2 
group by id) T2 "
+                        + "on T1.id = T2.id2 ",
+                connectContext
+        );
+        Plan p3 = PlanChecker.from(c3)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
         HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
         HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
+        HyperGraph h3 = HyperGraph.builderForMv(p3).buildAll().get(0);
         ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
         Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertTrue(!HyperGraphComparator.isLogicCompatible(h1, h3, 
constructContext(p1, p2)).isInvalid());
         Assertions.assertTrue(res.getViewExpressions().isEmpty());
     }
 
@@ -112,11 +124,23 @@ class EliminateJoinTest extends SqlTestBase {
                 .rewrite()
                 .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
                 .getAllPlan().get(0).child(0);
+        CascadesContext c3 = createCascadesContext(
+                "select * from T1 inner join (select id as id2 from T2) T2 "
+                        + "on T1.id = T2.id2 ",
+                connectContext
+        );
+        Plan p3 = PlanChecker.from(c3)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
         HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
         HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
+        HyperGraph h3 = HyperGraph.builderForMv(p3).buildAll().get(0);
         ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
         Assertions.assertTrue(!res.isInvalid());
         Assertions.assertTrue(res.getViewExpressions().isEmpty());
+        Assertions.assertTrue(!HyperGraphComparator.isLogicCompatible(h1, h3, 
constructContext(p1, p2)).isInvalid());
         dropConstraint("alter table T2 drop constraint pk");
     }
 
@@ -154,7 +178,6 @@ class EliminateJoinTest extends SqlTestBase {
         dropConstraint("alter table T3 drop constraint uk");
     }
 
-    @Disabled
     @Test
     void testLOJWithPKFKAndUK2() throws Exception {
         
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to