This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 59bc048ddb6 [feature](Nereids): refresh view hypergraph after 
inferring join (#29469)
59bc048ddb6 is described below

commit 59bc048ddb6424d4319c1e23a35d78f1567c2958
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Mon Jan 8 15:30:14 2024 +0800

    [feature](Nereids): refresh view hypergraph after inferring join (#29469)
---
 .../joinorder/hypergraph/ConflictRulesMaker.java   | 108 +++++++++++++
 .../jobs/joinorder/hypergraph/HyperGraph.java      |  73 +--------
 .../jobs/joinorder/hypergraph/edge/FilterEdge.java |   4 +
 .../jobs/joinorder/hypergraph/edge/JoinEdge.java   |   5 +
 .../rules/exploration/mv/HyperGraphComparator.java | 140 +++++++++++++----
 .../exploration/mv/HyperGraphComparatorTest.java   | 172 +++++++++++++++++++++
 6 files changed, 399 insertions(+), 103 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/ConflictRulesMaker.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/ConflictRulesMaker.java
new file mode 100644
index 00000000000..11db05843fe
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/ConflictRulesMaker.java
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.jobs.joinorder.hypergraph;
+
+import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
+import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
+import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge;
+import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
+import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
+import org.apache.doris.nereids.trees.plans.JoinType;
+
+import java.util.BitSet;
+import java.util.List;
+
+/**
+ * This is a conflict rule maker to
+ */
+public class ConflictRulesMaker {
+    private ConflictRulesMaker() {}
+
+    /**
+     * make conflict rules for filter edges
+     */
+    public static void makeFilterConflictRules(
+            JoinEdge joinEdge, List<JoinEdge> joinEdges, List<FilterEdge> 
filterEdges) {
+        long leftSubNodes = joinEdge.getLeftSubNodes(joinEdges);
+        long rightSubNodes = joinEdge.getRightSubNodes(joinEdges);
+        filterEdges.forEach(e -> {
+            if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
+                    && 
!PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType()))
 {
+                e.addLeftRejectEdge(joinEdge);
+            }
+            if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
+                    && 
!PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType()))
 {
+                e.addRightRejectEdge(joinEdge);
+            }
+        });
+    }
+
+    /**
+     * Make edge with CD-C algorithm in
+     * On the correct and complete enumeration of the core search
+     */
+    public static void makeJoinConflictRules(JoinEdge edgeB, List<JoinEdge> 
joinEdges) {
+        BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges(), 
joinEdges);
+        BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges(), 
joinEdges);
+        long leftRequired = edgeB.getLeftRequiredNodes();
+        long rightRequired = edgeB.getRightRequiredNodes();
+
+        for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = 
leftSubTreeEdges.nextSetBit(i + 1)) {
+            JoinEdge childA = joinEdges.get(i);
+            if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
+                leftRequired = LongBitmap.newBitmapUnion(leftRequired, 
childA.getLeftSubNodes(joinEdges));
+                childA.addLeftRejectEdge(edgeB);
+            }
+            if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) 
{
+                leftRequired = LongBitmap.newBitmapUnion(leftRequired, 
childA.getRightSubNodes(joinEdges));
+                childA.addLeftRejectEdge(edgeB);
+            }
+        }
+
+        for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = 
rightSubTreeEdges.nextSetBit(i + 1)) {
+            JoinEdge childA = joinEdges.get(i);
+            if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
+                rightRequired = LongBitmap.newBitmapUnion(rightRequired, 
childA.getRightSubNodes(joinEdges));
+                childA.addRightRejectEdge(edgeB);
+            }
+            if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) 
{
+                rightRequired = LongBitmap.newBitmapUnion(rightRequired, 
childA.getLeftSubNodes(joinEdges));
+                childA.addRightRejectEdge(edgeB);
+            }
+        }
+        edgeB.setLeftExtendedNodes(leftRequired);
+        edgeB.setRightExtendedNodes(rightRequired);
+    }
+
+    private static BitSet subTreeEdge(Edge edge, List<JoinEdge> joinEdges) {
+        long subTreeNodes = edge.getSubTreeNodes();
+        BitSet subEdges = new BitSet();
+        joinEdges.stream()
+                .filter(e -> LongBitmap.isSubset(subTreeNodes, 
e.getReferenceNodes()))
+                .forEach(e -> subEdges.set(e.getIndex()));
+        return subEdges;
+    }
+
+    private static BitSet subTreeEdges(BitSet edgeSet, List<JoinEdge> 
joinEdges) {
+        BitSet bitSet = new BitSet();
+        edgeSet.stream()
+                .mapToObj(i -> subTreeEdge(joinEdges.get(i), joinEdges))
+                .forEach(bitSet::or);
+        return bitSet;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
index d78374d1347..8a0bd8daaab 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java
@@ -28,7 +28,6 @@ import 
org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
 import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
 import org.apache.doris.nereids.memo.Group;
 import org.apache.doris.nereids.memo.GroupExpression;
-import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -244,8 +243,9 @@ public class HyperGraph {
             joinEdges.add(edge);
         }
         curJoinEdges.stream().forEach(i -> 
joinEdges.get(i).addCurJoinEdges(curJoinEdges));
-        curJoinEdges.stream().forEach(i -> 
makeJoinConflictRules(joinEdges.get(i)));
-        curJoinEdges.stream().forEach(i -> 
makeFilterConflictRules(joinEdges.get(i)));
+        curJoinEdges.stream().forEach(i -> 
ConflictRulesMaker.makeJoinConflictRules(joinEdges.get(i), joinEdges));
+        curJoinEdges.stream().forEach(i ->
+                ConflictRulesMaker.makeFilterConflictRules(joinEdges.get(i), 
joinEdges, filterEdges));
         return curJoinEdges;
         // In MySQL, each edge is reversed and store in edges again for 
reducing the branch miss
         // We don't implement this trick now.
@@ -260,73 +260,6 @@ public class HyperGraph {
         return bitSet;
     }
 
-    private void makeFilterConflictRules(JoinEdge joinEdge) {
-        long leftSubNodes = joinEdge.getLeftSubNodes(joinEdges);
-        long rightSubNodes = joinEdge.getRightSubNodes(joinEdges);
-        filterEdges.forEach(e -> {
-            if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
-                    && 
!PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType()))
 {
-                e.addLeftRejectEdge(joinEdge);
-            }
-            if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
-                    && 
!PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType()))
 {
-                e.addRightRejectEdge(joinEdge);
-            }
-        });
-    }
-
-    // Make edge with CD-C algorithm in
-    // On the correct and complete enumeration of the core search
-    private void makeJoinConflictRules(JoinEdge edgeB) {
-        BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges());
-        BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges());
-        long leftRequired = edgeB.getLeftRequiredNodes();
-        long rightRequired = edgeB.getRightRequiredNodes();
-
-        for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = 
leftSubTreeEdges.nextSetBit(i + 1)) {
-            JoinEdge childA = joinEdges.get(i);
-            if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
-                leftRequired = LongBitmap.newBitmapUnion(leftRequired, 
childA.getLeftSubNodes(joinEdges));
-                childA.addLeftRejectEdge(edgeB);
-            }
-            if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) 
{
-                leftRequired = LongBitmap.newBitmapUnion(leftRequired, 
childA.getRightSubNodes(joinEdges));
-                childA.addLeftRejectEdge(edgeB);
-            }
-        }
-
-        for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = 
rightSubTreeEdges.nextSetBit(i + 1)) {
-            JoinEdge childA = joinEdges.get(i);
-            if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
-                rightRequired = LongBitmap.newBitmapUnion(rightRequired, 
childA.getRightSubNodes(joinEdges));
-                childA.addRightRejectEdge(edgeB);
-            }
-            if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) 
{
-                rightRequired = LongBitmap.newBitmapUnion(rightRequired, 
childA.getLeftSubNodes(joinEdges));
-                childA.addRightRejectEdge(edgeB);
-            }
-        }
-        edgeB.setLeftExtendedNodes(leftRequired);
-        edgeB.setRightExtendedNodes(rightRequired);
-    }
-
-    private BitSet subTreeEdge(Edge edge) {
-        long subTreeNodes = edge.getSubTreeNodes();
-        BitSet subEdges = new BitSet();
-        joinEdges.stream()
-                .filter(e -> LongBitmap.isSubset(subTreeNodes, 
e.getReferenceNodes()))
-                .forEach(e -> subEdges.set(e.getIndex()));
-        return subEdges;
-    }
-
-    private BitSet subTreeEdges(BitSet edgeSet) {
-        BitSet bitSet = new BitSet();
-        edgeSet.stream()
-                .mapToObj(i -> subTreeEdge(joinEdges.get(i)))
-                .forEach(bitSet::or);
-        return bitSet;
-    }
-
     // Try to calculate the ends of an expression.
     // left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree
     // if left = 0, recursively calculate it in left tree
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java
index 57c6d9660d0..52c94abe3a1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java
@@ -47,4 +47,8 @@ public class FilterEdge extends Edge {
     public List<? extends Expression> getExpressions() {
         return filter.getExpressions();
     }
+
+    public FilterEdge clear() {
+        return new FilterEdge(filter, getIndex(), getLeftChildEdges(), 
getSubTreeNodes(), getLeftRequiredNodes());
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
index 635bf91f364..c23be5f16eb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java
@@ -61,6 +61,11 @@ public class JoinEdge extends Edge {
         return join.getJoinType();
     }
 
+    public JoinEdge withJoinTypeAndCleanCR(JoinType joinType) {
+        return new JoinEdge(join.withJoinType(joinType), getIndex(), 
getLeftChildEdges(), getRightChildEdges(),
+                getSubTreeNodes(), getLeftRequiredNodes(), 
getRightRequiredNodes());
+    }
+
     public LogicalJoin<? extends Plan, ? extends Plan> getJoin() {
         return join;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
index b92e352f5ce..7817475c7b8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.rules.exploration.mv;
 
 import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.jobs.joinorder.hypergraph.ConflictRulesMaker;
 import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
 import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
 import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
@@ -40,7 +41,6 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
-import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -64,10 +64,12 @@ public class HyperGraphComparator {
     // record inferred edges when comparing mv
     private final HyperGraph queryHyperGraph;
     private final HyperGraph viewHyperGraph;
-    private final Map<JoinEdge, Pair<JoinType, Set<Slot>>> inferredViewEdgeMap 
= new HashMap<>();
     private final Map<Edge, List<? extends Expression>> 
pullUpQueryExprWithEdge = new HashMap<>();
     private final Map<Edge, List<? extends Expression>> pullUpViewExprWithEdge 
= new HashMap<>();
     private final LogicalCompatibilityContext logicalCompatibilityContext;
+    private final Map<JoinEdge, Pair<JoinType, Set<Slot>>> 
inferredViewEdgeWithCond = new HashMap<>();
+    private List<JoinEdge> viewJoinEdgesAfterInferring;
+    private List<FilterEdge> viewFilterEdgesAfterInferring;
 
     public HyperGraphComparator(HyperGraph queryHyperGraph, HyperGraph 
viewHyperGraph,
             LogicalCompatibilityContext logicalCompatibilityContext) {
@@ -89,9 +91,18 @@ public class HyperGraphComparator {
 
     private ComparisonResult isLogicCompatible() {
         // 1 try to construct a map which can be mapped from edge to edge
-        Map<Edge, Edge> queryToView = constructMapWithNode();
+        Map<Edge, Edge> queryToView = constructQueryToViewMapWithExpr();
+        if (!makeViewJoinCompatible(queryToView)) {
+            return ComparisonResult.newInvalidResWithErrorMessage("Join types 
are not compatible\n");
+        }
+        refreshViewEdges();
 
-        // 2. compare them by expression and extract residual expr
+        // 2. compare them by expression and nodes. Note compare edges after 
inferring for nodes
+        boolean matchNodes = queryToView.entrySet().stream()
+                .allMatch(e -> compareEdgeWithNode(e.getKey(), e.getValue()));
+        if (!matchNodes) {
+            return ComparisonResult.newInvalidResWithErrorMessage("Join nodes 
are not compatible\n");
+        }
         queryToView.forEach(this::compareEdgeWithExpr);
 
         // 3. process residual edges
@@ -122,12 +133,12 @@ public class HyperGraphComparator {
             List<? extends Expression> rawFilter = e.getValue().stream()
                     .filter(expr -> !ExpressionUtils.isInferred(expr))
                     .collect(Collectors.toList());
-            if (!rawFilter.isEmpty() && !canPullUp(e.getKey())) {
+            if (!rawFilter.isEmpty() && 
!canPullUp(getViewEdgeAfterInferring(e.getKey()))) {
                 return 
ComparisonResult.newInvalidResWithErrorMessage(getErrorMessage() + "with error 
edge\n" + e);
             }
             builder.addViewExpressions(rawFilter);
         }
-        for (Pair<JoinType, Set<Slot>> inferredCond : 
inferredViewEdgeMap.values()) {
+        for (Pair<JoinType, Set<Slot>> inferredCond : 
inferredViewEdgeWithCond.values()) {
             builder.addViewNoNullableSlot(inferredCond.second);
         }
         return builder.build();
@@ -145,7 +156,15 @@ public class HyperGraphComparator {
                 getViewJoinEdges(),
                 getQueryFilterEdges(),
                 getViewFilterEdges(),
-                inferredViewEdgeMap);
+                inferredViewEdgeWithCond);
+    }
+
+    private Edge getViewEdgeAfterInferring(Edge edge) {
+        if (edge instanceof JoinEdge) {
+            return viewJoinEdgesAfterInferring.get(edge.getIndex());
+        } else {
+            return viewFilterEdgesAfterInferring.get(edge.getIndex());
+        }
     }
 
     private boolean canPullUp(Edge edge) {
@@ -154,10 +173,10 @@ public class HyperGraphComparator {
             return false;
         }
         boolean pullFromLeft = edge.getLeftRejectEdge().stream()
-                .map(e -> inferredViewEdgeMap.getOrDefault(e, 
Pair.of(e.getJoinType(), null)))
+                .map(e -> inferredViewEdgeWithCond.getOrDefault(e, 
Pair.of(e.getJoinType(), null)))
                 .allMatch(e -> canPullFromLeft(edge, e.first));
         boolean pullFromRight = edge.getRightRejectEdge().stream()
-                .map(e -> inferredViewEdgeMap.getOrDefault(e, 
Pair.of(e.getJoinType(), null)))
+                .map(e -> inferredViewEdgeWithCond.getOrDefault(e, 
Pair.of(e.getJoinType(), null)))
                 .allMatch(e -> canPullFromRight(edge, e.first));
         return pullFromLeft && pullFromRight;
     }
@@ -198,6 +217,24 @@ public class HyperGraphComparator {
         return ImmutableSet.copyOf(queryHyperGraph.getFilterEdges());
     }
 
+    private List<Edge> getQueryEdges() {
+        return ImmutableList.<Edge>builder()
+                .addAll(getQueryJoinEdges())
+                .addAll(getQueryFilterEdges()).build();
+    }
+
+    private boolean makeViewJoinCompatible(Map<Edge, Edge> queryToView) {
+        for (Entry<Edge, Edge> entry : queryToView.entrySet()) {
+            if (entry.getKey() instanceof JoinEdge && entry.getValue() 
instanceof JoinEdge) {
+                boolean res = compareJoinEdgeOrInfer((JoinEdge) 
entry.getKey(), (JoinEdge) entry.getValue());
+                if (!res) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
     private Set<FilterEdge> getViewFilterEdgeSet() {
         return ImmutableSet.copyOf(viewHyperGraph.getFilterEdges());
     }
@@ -214,6 +251,12 @@ public class HyperGraphComparator {
         return viewHyperGraph.getFilterEdges();
     }
 
+    private List<Edge> getViewEdges() {
+        return ImmutableList.<Edge>builder()
+                .addAll(getViewJoinEdges())
+                .addAll(getViewFilterEdges()).build();
+    }
+
     private Map<Expression, Expression> getQueryToViewExprMap() {
         return 
logicalCompatibilityContext.getQueryToViewEdgeExpressionMapping();
     }
@@ -222,43 +265,74 @@ public class HyperGraphComparator {
         return logicalCompatibilityContext.getQueryToViewNodeIDMapping();
     }
 
-    private Map<Edge, Edge> constructMapWithNode() {
-        // TODO use hash map to reduce loop
-        Map<Edge, Edge> joinEdgeMap = getQueryJoinEdges().stream().map(qe -> {
-            Optional<JoinEdge> viewEdge = getViewJoinEdges().stream()
-                    .filter(ve -> compareEdgeWithNode(qe, ve)).findFirst();
-            return Pair.of(qe, viewEdge);
-        }).filter(e -> 
e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> 
p.second.get()));
-        Map<Edge, Edge> filterEdgeMap = getQueryFilterEdges().stream().map(qe 
-> {
-            Optional<FilterEdge> viewEdge = getViewFilterEdges().stream()
-                    .filter(ve -> compareEdgeWithNode(qe, ve)).findFirst();
-            return Pair.of(qe, viewEdge);
-        }).filter(e -> 
e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> 
p.second.get()));
-        return ImmutableMap.<Edge, 
Edge>builder().putAll(joinEdgeMap).putAll(filterEdgeMap).build();
+    private Map<Edge, Edge> constructQueryToViewMapWithExpr() {
+        Map<Expression, Edge> viewExprToEdge = getViewEdges().stream()
+                .flatMap(e -> e.getExpressions().stream().map(expr -> 
Pair.of(expr, e)))
+                .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> 
p.second));
+        Map<Expression, Edge> queryExprToEdge = getQueryEdges().stream()
+                .flatMap(e -> e.getExpressions().stream().map(expr -> 
Pair.of(expr, e)))
+                .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> 
p.second));
+        return queryExprToEdge.entrySet().stream()
+                .filter(entry -> 
viewExprToEdge.containsKey(getViewExprFromQueryExpr(entry.getKey())))
+                .map(entry -> Pair.of(entry.getValue(),
+                        
viewExprToEdge.get(getViewExprFromQueryExpr(entry.getKey()))))
+                .distinct()
+                .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> 
p.second));
+    }
+
+    private Expression getViewExprFromQueryExpr(Expression query) {
+        return 
logicalCompatibilityContext.getQueryToViewEdgeExpressionMapping().get(query);
+    }
+
+    private void refreshViewEdges() {
+        List<FilterEdge> newFilterEdges = getViewFilterEdges().stream()
+                .map(FilterEdge::clear)
+                .collect(ImmutableList.toImmutableList());
+
+        List<JoinEdge> newJoinEdges = new ArrayList<>();
+        for (JoinEdge joinEdge : getViewJoinEdges()) {
+            JoinType newJoinType = inferredViewEdgeWithCond
+                    .getOrDefault(joinEdge, Pair.of(joinEdge.getJoinType(), 
null)).first;
+            JoinEdge newJoinEdge = 
joinEdge.withJoinTypeAndCleanCR(newJoinType);
+            newJoinEdges.add(newJoinEdge);
+            ConflictRulesMaker.makeJoinConflictRules(newJoinEdge, 
newJoinEdges);
+            ConflictRulesMaker.makeFilterConflictRules(newJoinEdge, 
newJoinEdges, newFilterEdges);
+        }
+
+        viewJoinEdgesAfterInferring = ImmutableList.copyOf(newJoinEdges);
+        viewFilterEdgesAfterInferring = ImmutableList.copyOf(newFilterEdges);
     }
 
     private boolean compareEdgeWithNode(Edge query, Edge view) {
         if (query instanceof FilterEdge && view instanceof FilterEdge) {
-            return compareEdgeWithFilter((FilterEdge) query, (FilterEdge) 
view);
+            return compareFilterEdgeWithNode((FilterEdge) query, 
viewFilterEdgesAfterInferring.get(view.getIndex()));
         } else if (query instanceof JoinEdge && view instanceof JoinEdge) {
-            return compareJoinEdge((JoinEdge) query, (JoinEdge) view);
+            return compareJoinEdgeWithNode((JoinEdge) query, 
viewJoinEdgesAfterInferring.get(view.getIndex()));
         }
         return false;
     }
 
-    private boolean compareEdgeWithFilter(FilterEdge query, FilterEdge view) {
-        long qChild = query.getReferenceNodes();
-        long vChild = view.getReferenceNodes();
-        return rewriteQueryNodeMap(qChild) == vChild;
+    private boolean compareFilterEdgeWithNode(FilterEdge query, FilterEdge 
view) {
+        return rewriteQueryNodeMap(query.getReferenceNodes()) == 
view.getReferenceNodes();
+    }
+
+    private boolean compareJoinEdgeWithNode(JoinEdge query, JoinEdge view) {
+        boolean res = false;
+        if (query.getJoinType().swap() == view.getJoinType()) {
+            res |= rewriteQueryNodeMap(query.getLeftExtendedNodes()) == 
view.getRightExtendedNodes()
+                    && rewriteQueryNodeMap(query.getRightExtendedNodes()) == 
view.getLeftExtendedNodes();
+        }
+        res |= rewriteQueryNodeMap(query.getLeftExtendedNodes()) == 
view.getLeftExtendedNodes()
+                && rewriteQueryNodeMap(query.getRightExtendedNodes()) == 
view.getRightExtendedNodes();
+        return res;
     }
 
-    private boolean compareJoinEdge(JoinEdge query, JoinEdge view) {
+    private boolean compareJoinEdgeOrInfer(JoinEdge query, JoinEdge view) {
         if (query.getJoinType().equals(view.getJoinType())
                 || 
canInferredJoinTypeMap.containsKey(Pair.of(query.getJoinType(), 
view.getJoinType()))) {
             if (tryInferEdge(query, view)) {
                 return true;
             }
-
         }
 
         if (query.getJoinType().swap().equals(view.getJoinType())
@@ -272,8 +346,8 @@ public class HyperGraphComparator {
     }
 
     private boolean tryInferEdge(JoinEdge query, JoinEdge view) {
-        if (rewriteQueryNodeMap(query.getLeftExtendedNodes()) != 
view.getLeftExtendedNodes()
-                || rewriteQueryNodeMap(query.getRightExtendedNodes()) != 
view.getRightExtendedNodes()) {
+        if (rewriteQueryNodeMap(query.getLeftRequiredNodes()) != 
view.getLeftRequiredNodes()
+                || rewriteQueryNodeMap(query.getRightRequiredNodes()) != 
view.getRightRequiredNodes()) {
             return false;
         }
         if (!query.getJoinType().equals(view.getJoinType())) {
@@ -286,7 +360,7 @@ public class HyperGraphComparator {
                     noNullableChild.first ? 
view.getJoin().left().getOutputSet() : ImmutableSet.of(),
                     noNullableChild.second ? 
view.getJoin().right().getOutputSet() : ImmutableSet.of()
             );
-            inferredViewEdgeMap.put(view, Pair.of(query.getJoinType(), 
noNullableSlot));
+            inferredViewEdgeWithCond.put(view, Pair.of(query.getJoinType(), 
noNullableSlot));
         }
         return true;
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java
new file mode 100644
index 00000000000..77b7fd67294
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration.mv;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
+import org.apache.doris.nereids.rules.RuleSet;
+import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
+import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
+import org.apache.doris.nereids.sqltest.SqlTestBase;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.util.PlanChecker;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+class HyperGraphComparatorTest extends SqlTestBase {
+    @Test
+    void testInnerJoinAndLOJ() {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1 inner join T2 "
+                        + "on T1.id = T2.id "
+                        + "inner join T3 on T1.id = T3.id",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        CascadesContext c2 = createCascadesContext(
+                "select * from T1 left outer join T2 "
+                        + "on T1.id = T2.id "
+                        + "left outer join T3 on T1.id = T3.id",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0);
+        HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertEquals(2, res.getViewNoNullableSlot().size());
+    }
+
+    @Test
+    void testIJAndLojAssoc() {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1 inner join T3 "
+                        + "on T1.id = T3.id "
+                        + "inner join T2 on T1.id = T2.id",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        CascadesContext c2 = createCascadesContext(
+                "select * from T1 left outer join T2 "
+                        + "on T1.id = T2.id "
+                        + "left outer join T3 on T1.id = T3.id",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0);
+        HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertEquals(2, res.getViewNoNullableSlot().size());
+    }
+
+    @Test
+    void testIJAndLojAssocWithFilter() {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1 inner join T3 "
+                        + "on T1.id = T3.id "
+                        + "inner join T2 on T1.id = T2.id",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        CascadesContext c2 = createCascadesContext(
+                "select * from T1 left outer join "
+                        + "(select * from T2 where T2.id = 1) T2 "
+                        + "on T1.id = T2.id "
+                        + "left outer join T3 on T1.id = T3.id",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0);
+        HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertEquals(2, res.getViewNoNullableSlot().size());
+    }
+
+    @Disabled
+    @Test
+    void testIJAndLojAssocWithJoinCond() {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1 inner join T3 "
+                        + "on T1.id = T3.id "
+                        + "inner join T2 on T1.id = T2.id",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        CascadesContext c2 = createCascadesContext(
+                "select * from T1 left outer join "
+                        + "("
+                        + "select T1.* from T1 left outer join T3 "
+                        + "on T1.id = T3.id and T1.score = T3.score "
+                        + ") T2 "
+                        + "on T1.id = T2.id ",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0);
+        HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertEquals(2, res.getViewNoNullableSlot().size());
+    }
+
+    LogicalCompatibilityContext constructContext(Plan p1, Plan p2) {
+        StructInfo st1 = AbstractMaterializedViewRule.extractStructInfo(p1,
+                null).get(0);
+        StructInfo st2 = AbstractMaterializedViewRule.extractStructInfo(p2,
+                null).get(0);
+        RelationMapping rm = RelationMapping.generate(st1.getRelations(), 
st2.getRelations()).get(0);
+        SlotMapping sm = SlotMapping.generate(rm);
+        return LogicalCompatibilityContext.from(rm, sm, st1, st2);
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to