This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 59bc048ddb6 [feature](Nereids): refresh view hypergraph after inferring join (#29469) 59bc048ddb6 is described below commit 59bc048ddb6424d4319c1e23a35d78f1567c2958 Author: 谢健 <jianx...@gmail.com> AuthorDate: Mon Jan 8 15:30:14 2024 +0800 [feature](Nereids): refresh view hypergraph after inferring join (#29469) --- .../joinorder/hypergraph/ConflictRulesMaker.java | 108 +++++++++++++ .../jobs/joinorder/hypergraph/HyperGraph.java | 73 +-------- .../jobs/joinorder/hypergraph/edge/FilterEdge.java | 4 + .../jobs/joinorder/hypergraph/edge/JoinEdge.java | 5 + .../rules/exploration/mv/HyperGraphComparator.java | 140 +++++++++++++---- .../exploration/mv/HyperGraphComparatorTest.java | 172 +++++++++++++++++++++ 6 files changed, 399 insertions(+), 103 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/ConflictRulesMaker.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/ConflictRulesMaker.java new file mode 100644 index 00000000000..11db05843fe --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/ConflictRulesMaker.java @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.joinorder.hypergraph; + +import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; +import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin; +import org.apache.doris.nereids.trees.plans.JoinType; + +import java.util.BitSet; +import java.util.List; + +/** + * This is a conflict rule maker to + */ +public class ConflictRulesMaker { + private ConflictRulesMaker() {} + + /** + * make conflict rules for filter edges + */ + public static void makeFilterConflictRules( + JoinEdge joinEdge, List<JoinEdge> joinEdges, List<FilterEdge> filterEdges) { + long leftSubNodes = joinEdge.getLeftSubNodes(joinEdges); + long rightSubNodes = joinEdge.getRightSubNodes(joinEdges); + filterEdges.forEach(e -> { + if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes) + && !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) { + e.addLeftRejectEdge(joinEdge); + } + if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes) + && !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) { + e.addRightRejectEdge(joinEdge); + } + }); + } + + /** + * Make edge with CD-C algorithm in + * On the correct and complete enumeration of the core search + */ + public static void makeJoinConflictRules(JoinEdge edgeB, List<JoinEdge> joinEdges) { + BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges(), joinEdges); + BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges(), joinEdges); + long leftRequired = edgeB.getLeftRequiredNodes(); + long rightRequired = edgeB.getRightRequiredNodes(); + + for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) { + JoinEdge childA = joinEdges.get(i); + if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) { + leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges)); + childA.addLeftRejectEdge(edgeB); + } + if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) { + leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges)); + childA.addLeftRejectEdge(edgeB); + } + } + + for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) { + JoinEdge childA = joinEdges.get(i); + if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) { + rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges)); + childA.addRightRejectEdge(edgeB); + } + if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) { + rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges)); + childA.addRightRejectEdge(edgeB); + } + } + edgeB.setLeftExtendedNodes(leftRequired); + edgeB.setRightExtendedNodes(rightRequired); + } + + private static BitSet subTreeEdge(Edge edge, List<JoinEdge> joinEdges) { + long subTreeNodes = edge.getSubTreeNodes(); + BitSet subEdges = new BitSet(); + joinEdges.stream() + .filter(e -> LongBitmap.isSubset(subTreeNodes, e.getReferenceNodes())) + .forEach(e -> subEdges.set(e.getIndex())); + return subEdges; + } + + private static BitSet subTreeEdges(BitSet edgeSet, List<JoinEdge> joinEdges) { + BitSet bitSet = new BitSet(); + edgeSet.stream() + .mapToObj(i -> subTreeEdge(joinEdges.get(i), joinEdges)) + .forEach(bitSet::or); + return bitSet; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index d78374d1347..8a0bd8daaab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; -import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -244,8 +243,9 @@ public class HyperGraph { joinEdges.add(edge); } curJoinEdges.stream().forEach(i -> joinEdges.get(i).addCurJoinEdges(curJoinEdges)); - curJoinEdges.stream().forEach(i -> makeJoinConflictRules(joinEdges.get(i))); - curJoinEdges.stream().forEach(i -> makeFilterConflictRules(joinEdges.get(i))); + curJoinEdges.stream().forEach(i -> ConflictRulesMaker.makeJoinConflictRules(joinEdges.get(i), joinEdges)); + curJoinEdges.stream().forEach(i -> + ConflictRulesMaker.makeFilterConflictRules(joinEdges.get(i), joinEdges, filterEdges)); return curJoinEdges; // In MySQL, each edge is reversed and store in edges again for reducing the branch miss // We don't implement this trick now. @@ -260,73 +260,6 @@ public class HyperGraph { return bitSet; } - private void makeFilterConflictRules(JoinEdge joinEdge) { - long leftSubNodes = joinEdge.getLeftSubNodes(joinEdges); - long rightSubNodes = joinEdge.getRightSubNodes(joinEdges); - filterEdges.forEach(e -> { - if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes) - && !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) { - e.addLeftRejectEdge(joinEdge); - } - if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes) - && !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) { - e.addRightRejectEdge(joinEdge); - } - }); - } - - // Make edge with CD-C algorithm in - // On the correct and complete enumeration of the core search - private void makeJoinConflictRules(JoinEdge edgeB) { - BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges()); - BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges()); - long leftRequired = edgeB.getLeftRequiredNodes(); - long rightRequired = edgeB.getRightRequiredNodes(); - - for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) { - JoinEdge childA = joinEdges.get(i); - if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) { - leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges)); - childA.addLeftRejectEdge(edgeB); - } - if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) { - leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges)); - childA.addLeftRejectEdge(edgeB); - } - } - - for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) { - JoinEdge childA = joinEdges.get(i); - if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) { - rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges)); - childA.addRightRejectEdge(edgeB); - } - if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) { - rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges)); - childA.addRightRejectEdge(edgeB); - } - } - edgeB.setLeftExtendedNodes(leftRequired); - edgeB.setRightExtendedNodes(rightRequired); - } - - private BitSet subTreeEdge(Edge edge) { - long subTreeNodes = edge.getSubTreeNodes(); - BitSet subEdges = new BitSet(); - joinEdges.stream() - .filter(e -> LongBitmap.isSubset(subTreeNodes, e.getReferenceNodes())) - .forEach(e -> subEdges.set(e.getIndex())); - return subEdges; - } - - private BitSet subTreeEdges(BitSet edgeSet) { - BitSet bitSet = new BitSet(); - edgeSet.stream() - .mapToObj(i -> subTreeEdge(joinEdges.get(i))) - .forEach(bitSet::or); - return bitSet; - } - // Try to calculate the ends of an expression. // left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree // if left = 0, recursively calculate it in left tree diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java index 57c6d9660d0..52c94abe3a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java @@ -47,4 +47,8 @@ public class FilterEdge extends Edge { public List<? extends Expression> getExpressions() { return filter.getExpressions(); } + + public FilterEdge clear() { + return new FilterEdge(filter, getIndex(), getLeftChildEdges(), getSubTreeNodes(), getLeftRequiredNodes()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java index 635bf91f364..c23be5f16eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java @@ -61,6 +61,11 @@ public class JoinEdge extends Edge { return join.getJoinType(); } + public JoinEdge withJoinTypeAndCleanCR(JoinType joinType) { + return new JoinEdge(join.withJoinType(joinType), getIndex(), getLeftChildEdges(), getRightChildEdges(), + getSubTreeNodes(), getLeftRequiredNodes(), getRightRequiredNodes()); + } + public LogicalJoin<? extends Plan, ? extends Plan> getJoin() { return join; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java index b92e352f5ce..7817475c7b8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.exploration.mv; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.ConflictRulesMaker; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; @@ -40,7 +41,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -64,10 +64,12 @@ public class HyperGraphComparator { // record inferred edges when comparing mv private final HyperGraph queryHyperGraph; private final HyperGraph viewHyperGraph; - private final Map<JoinEdge, Pair<JoinType, Set<Slot>>> inferredViewEdgeMap = new HashMap<>(); private final Map<Edge, List<? extends Expression>> pullUpQueryExprWithEdge = new HashMap<>(); private final Map<Edge, List<? extends Expression>> pullUpViewExprWithEdge = new HashMap<>(); private final LogicalCompatibilityContext logicalCompatibilityContext; + private final Map<JoinEdge, Pair<JoinType, Set<Slot>>> inferredViewEdgeWithCond = new HashMap<>(); + private List<JoinEdge> viewJoinEdgesAfterInferring; + private List<FilterEdge> viewFilterEdgesAfterInferring; public HyperGraphComparator(HyperGraph queryHyperGraph, HyperGraph viewHyperGraph, LogicalCompatibilityContext logicalCompatibilityContext) { @@ -89,9 +91,18 @@ public class HyperGraphComparator { private ComparisonResult isLogicCompatible() { // 1 try to construct a map which can be mapped from edge to edge - Map<Edge, Edge> queryToView = constructMapWithNode(); + Map<Edge, Edge> queryToView = constructQueryToViewMapWithExpr(); + if (!makeViewJoinCompatible(queryToView)) { + return ComparisonResult.newInvalidResWithErrorMessage("Join types are not compatible\n"); + } + refreshViewEdges(); - // 2. compare them by expression and extract residual expr + // 2. compare them by expression and nodes. Note compare edges after inferring for nodes + boolean matchNodes = queryToView.entrySet().stream() + .allMatch(e -> compareEdgeWithNode(e.getKey(), e.getValue())); + if (!matchNodes) { + return ComparisonResult.newInvalidResWithErrorMessage("Join nodes are not compatible\n"); + } queryToView.forEach(this::compareEdgeWithExpr); // 3. process residual edges @@ -122,12 +133,12 @@ public class HyperGraphComparator { List<? extends Expression> rawFilter = e.getValue().stream() .filter(expr -> !ExpressionUtils.isInferred(expr)) .collect(Collectors.toList()); - if (!rawFilter.isEmpty() && !canPullUp(e.getKey())) { + if (!rawFilter.isEmpty() && !canPullUp(getViewEdgeAfterInferring(e.getKey()))) { return ComparisonResult.newInvalidResWithErrorMessage(getErrorMessage() + "with error edge\n" + e); } builder.addViewExpressions(rawFilter); } - for (Pair<JoinType, Set<Slot>> inferredCond : inferredViewEdgeMap.values()) { + for (Pair<JoinType, Set<Slot>> inferredCond : inferredViewEdgeWithCond.values()) { builder.addViewNoNullableSlot(inferredCond.second); } return builder.build(); @@ -145,7 +156,15 @@ public class HyperGraphComparator { getViewJoinEdges(), getQueryFilterEdges(), getViewFilterEdges(), - inferredViewEdgeMap); + inferredViewEdgeWithCond); + } + + private Edge getViewEdgeAfterInferring(Edge edge) { + if (edge instanceof JoinEdge) { + return viewJoinEdgesAfterInferring.get(edge.getIndex()); + } else { + return viewFilterEdgesAfterInferring.get(edge.getIndex()); + } } private boolean canPullUp(Edge edge) { @@ -154,10 +173,10 @@ public class HyperGraphComparator { return false; } boolean pullFromLeft = edge.getLeftRejectEdge().stream() - .map(e -> inferredViewEdgeMap.getOrDefault(e, Pair.of(e.getJoinType(), null))) + .map(e -> inferredViewEdgeWithCond.getOrDefault(e, Pair.of(e.getJoinType(), null))) .allMatch(e -> canPullFromLeft(edge, e.first)); boolean pullFromRight = edge.getRightRejectEdge().stream() - .map(e -> inferredViewEdgeMap.getOrDefault(e, Pair.of(e.getJoinType(), null))) + .map(e -> inferredViewEdgeWithCond.getOrDefault(e, Pair.of(e.getJoinType(), null))) .allMatch(e -> canPullFromRight(edge, e.first)); return pullFromLeft && pullFromRight; } @@ -198,6 +217,24 @@ public class HyperGraphComparator { return ImmutableSet.copyOf(queryHyperGraph.getFilterEdges()); } + private List<Edge> getQueryEdges() { + return ImmutableList.<Edge>builder() + .addAll(getQueryJoinEdges()) + .addAll(getQueryFilterEdges()).build(); + } + + private boolean makeViewJoinCompatible(Map<Edge, Edge> queryToView) { + for (Entry<Edge, Edge> entry : queryToView.entrySet()) { + if (entry.getKey() instanceof JoinEdge && entry.getValue() instanceof JoinEdge) { + boolean res = compareJoinEdgeOrInfer((JoinEdge) entry.getKey(), (JoinEdge) entry.getValue()); + if (!res) { + return false; + } + } + } + return true; + } + private Set<FilterEdge> getViewFilterEdgeSet() { return ImmutableSet.copyOf(viewHyperGraph.getFilterEdges()); } @@ -214,6 +251,12 @@ public class HyperGraphComparator { return viewHyperGraph.getFilterEdges(); } + private List<Edge> getViewEdges() { + return ImmutableList.<Edge>builder() + .addAll(getViewJoinEdges()) + .addAll(getViewFilterEdges()).build(); + } + private Map<Expression, Expression> getQueryToViewExprMap() { return logicalCompatibilityContext.getQueryToViewEdgeExpressionMapping(); } @@ -222,43 +265,74 @@ public class HyperGraphComparator { return logicalCompatibilityContext.getQueryToViewNodeIDMapping(); } - private Map<Edge, Edge> constructMapWithNode() { - // TODO use hash map to reduce loop - Map<Edge, Edge> joinEdgeMap = getQueryJoinEdges().stream().map(qe -> { - Optional<JoinEdge> viewEdge = getViewJoinEdges().stream() - .filter(ve -> compareEdgeWithNode(qe, ve)).findFirst(); - return Pair.of(qe, viewEdge); - }).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get())); - Map<Edge, Edge> filterEdgeMap = getQueryFilterEdges().stream().map(qe -> { - Optional<FilterEdge> viewEdge = getViewFilterEdges().stream() - .filter(ve -> compareEdgeWithNode(qe, ve)).findFirst(); - return Pair.of(qe, viewEdge); - }).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get())); - return ImmutableMap.<Edge, Edge>builder().putAll(joinEdgeMap).putAll(filterEdgeMap).build(); + private Map<Edge, Edge> constructQueryToViewMapWithExpr() { + Map<Expression, Edge> viewExprToEdge = getViewEdges().stream() + .flatMap(e -> e.getExpressions().stream().map(expr -> Pair.of(expr, e))) + .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second)); + Map<Expression, Edge> queryExprToEdge = getQueryEdges().stream() + .flatMap(e -> e.getExpressions().stream().map(expr -> Pair.of(expr, e))) + .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second)); + return queryExprToEdge.entrySet().stream() + .filter(entry -> viewExprToEdge.containsKey(getViewExprFromQueryExpr(entry.getKey()))) + .map(entry -> Pair.of(entry.getValue(), + viewExprToEdge.get(getViewExprFromQueryExpr(entry.getKey())))) + .distinct() + .collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second)); + } + + private Expression getViewExprFromQueryExpr(Expression query) { + return logicalCompatibilityContext.getQueryToViewEdgeExpressionMapping().get(query); + } + + private void refreshViewEdges() { + List<FilterEdge> newFilterEdges = getViewFilterEdges().stream() + .map(FilterEdge::clear) + .collect(ImmutableList.toImmutableList()); + + List<JoinEdge> newJoinEdges = new ArrayList<>(); + for (JoinEdge joinEdge : getViewJoinEdges()) { + JoinType newJoinType = inferredViewEdgeWithCond + .getOrDefault(joinEdge, Pair.of(joinEdge.getJoinType(), null)).first; + JoinEdge newJoinEdge = joinEdge.withJoinTypeAndCleanCR(newJoinType); + newJoinEdges.add(newJoinEdge); + ConflictRulesMaker.makeJoinConflictRules(newJoinEdge, newJoinEdges); + ConflictRulesMaker.makeFilterConflictRules(newJoinEdge, newJoinEdges, newFilterEdges); + } + + viewJoinEdgesAfterInferring = ImmutableList.copyOf(newJoinEdges); + viewFilterEdgesAfterInferring = ImmutableList.copyOf(newFilterEdges); } private boolean compareEdgeWithNode(Edge query, Edge view) { if (query instanceof FilterEdge && view instanceof FilterEdge) { - return compareEdgeWithFilter((FilterEdge) query, (FilterEdge) view); + return compareFilterEdgeWithNode((FilterEdge) query, viewFilterEdgesAfterInferring.get(view.getIndex())); } else if (query instanceof JoinEdge && view instanceof JoinEdge) { - return compareJoinEdge((JoinEdge) query, (JoinEdge) view); + return compareJoinEdgeWithNode((JoinEdge) query, viewJoinEdgesAfterInferring.get(view.getIndex())); } return false; } - private boolean compareEdgeWithFilter(FilterEdge query, FilterEdge view) { - long qChild = query.getReferenceNodes(); - long vChild = view.getReferenceNodes(); - return rewriteQueryNodeMap(qChild) == vChild; + private boolean compareFilterEdgeWithNode(FilterEdge query, FilterEdge view) { + return rewriteQueryNodeMap(query.getReferenceNodes()) == view.getReferenceNodes(); + } + + private boolean compareJoinEdgeWithNode(JoinEdge query, JoinEdge view) { + boolean res = false; + if (query.getJoinType().swap() == view.getJoinType()) { + res |= rewriteQueryNodeMap(query.getLeftExtendedNodes()) == view.getRightExtendedNodes() + && rewriteQueryNodeMap(query.getRightExtendedNodes()) == view.getLeftExtendedNodes(); + } + res |= rewriteQueryNodeMap(query.getLeftExtendedNodes()) == view.getLeftExtendedNodes() + && rewriteQueryNodeMap(query.getRightExtendedNodes()) == view.getRightExtendedNodes(); + return res; } - private boolean compareJoinEdge(JoinEdge query, JoinEdge view) { + private boolean compareJoinEdgeOrInfer(JoinEdge query, JoinEdge view) { if (query.getJoinType().equals(view.getJoinType()) || canInferredJoinTypeMap.containsKey(Pair.of(query.getJoinType(), view.getJoinType()))) { if (tryInferEdge(query, view)) { return true; } - } if (query.getJoinType().swap().equals(view.getJoinType()) @@ -272,8 +346,8 @@ public class HyperGraphComparator { } private boolean tryInferEdge(JoinEdge query, JoinEdge view) { - if (rewriteQueryNodeMap(query.getLeftExtendedNodes()) != view.getLeftExtendedNodes() - || rewriteQueryNodeMap(query.getRightExtendedNodes()) != view.getRightExtendedNodes()) { + if (rewriteQueryNodeMap(query.getLeftRequiredNodes()) != view.getLeftRequiredNodes() + || rewriteQueryNodeMap(query.getRightRequiredNodes()) != view.getRightRequiredNodes()) { return false; } if (!query.getJoinType().equals(view.getJoinType())) { @@ -286,7 +360,7 @@ public class HyperGraphComparator { noNullableChild.first ? view.getJoin().left().getOutputSet() : ImmutableSet.of(), noNullableChild.second ? view.getJoin().right().getOutputSet() : ImmutableSet.of() ); - inferredViewEdgeMap.put(view, Pair.of(query.getJoinType(), noNullableSlot)); + inferredViewEdgeWithCond.put(view, Pair.of(query.getJoinType(), noNullableSlot)); } return true; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java new file mode 100644 index 00000000000..77b7fd67294 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparatorTest.java @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.mv; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; +import org.apache.doris.nereids.rules.RuleSet; +import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping; +import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; +import org.apache.doris.nereids.sqltest.SqlTestBase; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +class HyperGraphComparatorTest extends SqlTestBase { + @Test + void testInnerJoinAndLOJ() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 inner join T2 " + + "on T1.id = T2.id " + + "inner join T3 on T1.id = T3.id", + connectContext + ); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 left outer join T2 " + + "on T1.id = T2.id " + + "left outer join T3 on T1.id = T3.id", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); + Assertions.assertTrue(!res.isInvalid()); + Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); + } + + @Test + void testIJAndLojAssoc() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 inner join T3 " + + "on T1.id = T3.id " + + "inner join T2 on T1.id = T2.id", + connectContext + ); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 left outer join T2 " + + "on T1.id = T2.id " + + "left outer join T3 on T1.id = T3.id", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); + Assertions.assertTrue(!res.isInvalid()); + Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); + } + + @Test + void testIJAndLojAssocWithFilter() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 inner join T3 " + + "on T1.id = T3.id " + + "inner join T2 on T1.id = T2.id", + connectContext + ); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 left outer join " + + "(select * from T2 where T2.id = 1) T2 " + + "on T1.id = T2.id " + + "left outer join T3 on T1.id = T3.id", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); + Assertions.assertTrue(!res.isInvalid()); + Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); + } + + @Disabled + @Test + void testIJAndLojAssocWithJoinCond() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 inner join T3 " + + "on T1.id = T3.id " + + "inner join T2 on T1.id = T2.id", + connectContext + ); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 left outer join " + + "(" + + "select T1.* from T1 left outer join T3 " + + "on T1.id = T3.id and T1.score = T3.score " + + ") T2 " + + "on T1.id = T2.id ", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); + Assertions.assertTrue(!res.isInvalid()); + Assertions.assertEquals(2, res.getViewNoNullableSlot().size()); + } + + LogicalCompatibilityContext constructContext(Plan p1, Plan p2) { + StructInfo st1 = AbstractMaterializedViewRule.extractStructInfo(p1, + null).get(0); + StructInfo st2 = AbstractMaterializedViewRule.extractStructInfo(p2, + null).get(0); + RelationMapping rm = RelationMapping.generate(st1.getRelations(), st2.getRelations()).get(0); + SlotMapping sm = SlotMapping.generate(rm); + return LogicalCompatibilityContext.from(rm, sm, st1, st2); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org