This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 53c624ffa0924545ae870e190910a6e83840ca7b Author: 谢健 <jianx...@gmail.com> AuthorDate: Tue Jan 30 11:53:07 2024 +0800 [feat](Nereids): support alias when eliminate join for partially mv rewritting #30498 --- .../jobs/joinorder/hypergraph/GraphSimplifier.java | 3 +- .../jobs/joinorder/hypergraph/HyperGraph.java | 39 ++++++- .../jobs/joinorder/hypergraph/edge/Edge.java | 4 + .../jobs/joinorder/hypergraph/edge/JoinEdge.java | 19 +++- .../rules/exploration/mv/HyperGraphComparator.java | 116 +++++++++++++-------- .../rules/exploration/mv/EliminateJoinTest.java | 25 ++++- 6 files changed, 154 insertions(+), 52 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index 5539e894cc7..61738f2fca6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import java.util.ArrayDeque; import java.util.ArrayList; @@ -417,7 +418,7 @@ public class GraphSimplifier { JoinEdge newEdge = new JoinEdge(join, edge.getIndex(), edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes(), - edge.getLeftRequiredNodes(), edge.getRightRequiredNodes()); + edge.getLeftRequiredNodes(), edge.getRightRequiredNodes(), ImmutableSet.of(), ImmutableSet.of()); newEdge.addLeftExtendNode(leftNodes); newEdge.addRightExtendNode(rightNodes); return newEdge; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 8472837d79e..0099ab30afc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -46,14 +46,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.BitSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import javax.annotation.Nullable; /** * The graph is a join graph, whose node is the leaf plan and edge is a join operator. @@ -288,6 +291,36 @@ public class HyperGraph { return new HyperGraph.Builder().buildHyperGraphForMv(plan); } + /** + * map output to requires output and construct named expressions + */ + public @Nullable List<NamedExpression> getNamedExpressions( + long nodeMap, Set<Slot> outputSet, Set<Slot> requireOutputs) { + List<NamedExpression> output = new ArrayList<>(); + List<NamedExpression> projects = getComplexProject().get(nodeMap); + if (projects == null) { + return null; + } + for (Slot slot : requireOutputs) { + if (outputSet.contains(slot)) { + output.add(slot); + } else { + Optional<NamedExpression> expr = projects.stream() + .filter(p -> p.toSlot().equals(slot)) + .findFirst(); + if (!expr.isPresent()) { + return null; + } + // TODO: consider cascades alias + if (!outputSet.containsAll(expr.get().getInputSlots())) { + return null; + } + output.add(expr.get()); + } + } + return output; + } + /** * Builder of HyperGraph */ @@ -509,6 +542,10 @@ public class HyperGraph { } BitSet curJoinEdges = new BitSet(); + Set<Slot> leftInputSlots = ImmutableSet.copyOf( + Sets.intersection(join.getInputSlots(), join.left().getOutputSet())); + Set<Slot> rightInputSlots = ImmutableSet.copyOf( + Sets.intersection(join.getInputSlots(), join.right().getOutputSet())); for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts .entrySet()) { LogicalJoin<?, ?> singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, @@ -518,7 +555,7 @@ public class HyperGraph { Pair<Long, Long> ends = entry.getKey(); JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), leftEdgeNodes.first, rightEdgeNodes.first, LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second), - ends.first, ends.second); + ends.first, ends.second, leftInputSlots, rightInputSlots); for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { nodes.get(nodeIndex).attachEdge(edge); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java index 2530074931e..c41d9e270ac 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java @@ -79,6 +79,10 @@ public abstract class Edge { return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1; } + public boolean isRightSimple() { + return LongBitmap.getCardinality(rightExtendedNodes) == 1; + } + public void addLeftRejectEdge(JoinEdge edge) { leftRejectEdges.add(edge); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java index c23be5f16eb..94f4b30e8d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java @@ -37,12 +37,16 @@ import javax.annotation.Nullable; public class JoinEdge extends Edge { private final LogicalJoin<? extends Plan, ? extends Plan> join; + private final Set<Slot> leftInputSlots; + private final Set<Slot> rightInputSlots; public JoinEdge(LogicalJoin<? extends Plan, ? extends Plan> join, int index, BitSet leftChildEdges, BitSet rightChildEdges, long subTreeNodes, - long leftRequireNodes, long rightRequireNodes) { + long leftRequireNodes, long rightRequireNodes, Set<Slot> leftInputSlots, Set<Slot> rightInputSlots) { super(index, leftChildEdges, rightChildEdges, subTreeNodes, leftRequireNodes, rightRequireNodes); this.join = join; + this.leftInputSlots = leftInputSlots; + this.rightInputSlots = rightInputSlots; } /** @@ -51,7 +55,8 @@ public class JoinEdge extends Edge { public JoinEdge swap() { JoinEdge swapEdge = new JoinEdge(join.swap(), getIndex(), getRightChildEdges(), - getLeftChildEdges(), getSubTreeNodes(), getRightRequiredNodes(), getLeftRequiredNodes()); + getLeftChildEdges(), getSubTreeNodes(), getRightRequiredNodes(), getLeftRequiredNodes(), + this.rightInputSlots, this.leftInputSlots); swapEdge.addLeftRejectEdges(getLeftRejectEdge()); swapEdge.addRightRejectEdges(getRightRejectEdge()); return swapEdge; @@ -63,7 +68,7 @@ public class JoinEdge extends Edge { public JoinEdge withJoinTypeAndCleanCR(JoinType joinType) { return new JoinEdge(join.withJoinType(joinType), getIndex(), getLeftChildEdges(), getRightChildEdges(), - getSubTreeNodes(), getLeftRequiredNodes(), getRightRequiredNodes()); + getSubTreeNodes(), getLeftRequiredNodes(), getRightRequiredNodes(), leftInputSlots, rightInputSlots); } public LogicalJoin<? extends Plan, ? extends Plan> getJoin() { @@ -112,4 +117,12 @@ public class JoinEdge extends Edge { join.getExpressions().forEach(expression -> slots.addAll(expression.getInputSlots())); return slots; } + + public Set<Slot> getLeftInputSlots() { + return leftInputSlots; + } + + public Set<Slot> getRightInputSlots() { + return rightInputSlots; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java index 3339d009c79..341caa88094 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java @@ -27,9 +27,11 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; @@ -46,6 +48,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.stream.Collectors; +import javax.annotation.Nullable; /** * HyperGraphComparator @@ -144,60 +147,81 @@ public class HyperGraphComparator { return buildComparisonRes(); } - private boolean tryEliminateNodesAndEdge() { - boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream() - .filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) == 1) - .anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), eliminateViewNodesMap)); - if (hasFilterEdgeAbove) { - // If there is some filter edge above the eliminated node, we should rebuild a plan - // Right now, just refuse it. + private @Nullable Plan constructViewPlan(long nodeBitmap, Set<Slot> requireOutputs) { + if (LongBitmap.getCardinality(nodeBitmap) != 1) { + return null; + } + Plan basePlan = viewHyperGraph.getNode(LongBitmap.lowestOneIndex(nodeBitmap)).getPlan(); + if (basePlan.getOutputSet().containsAll(requireOutputs)) { + return basePlan; + } + List<NamedExpression> projects = viewHyperGraph + .getNamedExpressions(nodeBitmap, basePlan.getOutputSet(), requireOutputs); + if (projects == null) { + return null; + } + return new LogicalProject<>(projects, basePlan); + } + + private boolean canEliminatePrimaryByForeign(long primaryNodes, long foreignNodes, + Set<Slot> primarySlots, Set<Slot> foreignSlots, JoinEdge joinEdge) { + Plan foreign = constructViewPlan(foreignNodes, foreignSlots); + Plan primary = constructViewPlan(primaryNodes, primarySlots); + if (foreign == null || primary == null) { return false; } - for (JoinEdge joinEdge : viewHyperGraph.getJoinEdges()) { - if (!LongBitmap.isOverlap(joinEdge.getReferenceNodes(), eliminateViewNodesMap)) { - continue; + return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign); + } + + private boolean canEliminateViewEdge(JoinEdge joinEdge) { + // eliminate by unique + if (joinEdge.getJoinType().isLeftOuterJoin() && joinEdge.isRightSimple()) { + long eliminatedRight = + LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap); + if (LongBitmap.getCardinality(eliminatedRight) != 1) { + return false; } - // eliminate by unique - if (joinEdge.getJoinType().isLeftOuterJoin()) { - long eliminatedRight = - LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap); - if (LongBitmap.getCardinality(eliminatedRight) != 1) { - return false; - } - Plan rigthPlan = viewHyperGraph - .getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan(); - return JoinUtils.canEliminateByLeft(joinEdge.getJoin(), - rigthPlan.getLogicalProperties().getFunctionalDependencies()); + Plan rigthPlan = constructViewPlan(joinEdge.getRightExtendedNodes(), joinEdge.getRightInputSlots()); + if (rigthPlan == null) { + return false; } - // eliminate by pk fk - if (joinEdge.getJoinType().isInnerJoin()) { - if (!joinEdge.isSimple()) { - return false; - } - long eliminatedLeft = - LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), eliminateViewNodesMap); - long eliminatedRight = - LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap); - if (LongBitmap.getCardinality(eliminatedLeft) == 0 - && LongBitmap.getCardinality(eliminatedRight) == 1) { - Plan foreign = viewHyperGraph - .getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan(); - Plan primary = viewHyperGraph - .getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan(); - return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign); - } else if (LongBitmap.getCardinality(eliminatedLeft) == 1 - && LongBitmap.getCardinality(eliminatedRight) == 0) { - Plan foreign = viewHyperGraph - .getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan(); - Plan primary = viewHyperGraph - .getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan(); - return JoinUtils.canEliminateByFk(joinEdge.getJoin(), primary, foreign); - } + return JoinUtils.canEliminateByLeft(joinEdge.getJoin(), + rigthPlan.getLogicalProperties().getFunctionalDependencies()); + } + // eliminate by pk fk + if (joinEdge.getJoinType().isInnerJoin()) { + if (!joinEdge.isSimple()) { return false; } + long eliminatedLeft = + LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), eliminateViewNodesMap); + long eliminatedRight = + LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), eliminateViewNodesMap); + if (LongBitmap.getCardinality(eliminatedLeft) == 0 + && LongBitmap.getCardinality(eliminatedRight) == 1) { + return canEliminatePrimaryByForeign(joinEdge.getRightExtendedNodes(), joinEdge.getLeftExtendedNodes(), + joinEdge.getRightInputSlots(), joinEdge.getLeftInputSlots(), joinEdge); + } else if (LongBitmap.getCardinality(eliminatedLeft) == 1 + && LongBitmap.getCardinality(eliminatedRight) == 0) { + return canEliminatePrimaryByForeign(joinEdge.getLeftExtendedNodes(), joinEdge.getRightExtendedNodes(), + joinEdge.getLeftInputSlots(), joinEdge.getRightInputSlots(), joinEdge); + } + } + return false; + } + private boolean tryEliminateNodesAndEdge() { + boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream() + .filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) == 1) + .anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), eliminateViewNodesMap)); + if (hasFilterEdgeAbove) { + // If there is some filter edge above the eliminated node, we should rebuild a plan + // Right now, just reject it. + return false; } - return true; + return viewHyperGraph.getJoinEdges().stream() + .filter(joinEdge -> LongBitmap.isOverlap(joinEdge.getReferenceNodes(), eliminateViewNodesMap)) + .allMatch(this::canEliminateViewEdge); } private boolean compareNodeWithExpr(StructInfoNode query, StructInfoNode view) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java index 3e245741178..cc8b7147423 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java @@ -52,10 +52,22 @@ class EliminateJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); + CascadesContext c3 = createCascadesContext( + "select * from T1 left outer join (select id as id2 from T2 group by id) T2 " + + "on T1.id = T2.id2 ", + connectContext + ); + Plan p3 = PlanChecker.from(c3) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); + HyperGraph h3 = HyperGraph.builderForMv(p3).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); + Assertions.assertTrue(!HyperGraphComparator.isLogicCompatible(h1, h3, constructContext(p1, p2)).isInvalid()); Assertions.assertTrue(res.getViewExpressions().isEmpty()); } @@ -112,11 +124,23 @@ class EliminateJoinTest extends SqlTestBase { .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) .getAllPlan().get(0).child(0); + CascadesContext c3 = createCascadesContext( + "select * from T1 inner join (select id as id2 from T2) T2 " + + "on T1.id = T2.id2 ", + connectContext + ); + Plan p3 = PlanChecker.from(c3) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0); HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0); + HyperGraph h3 = HyperGraph.builderForMv(p3).buildAll().get(0); ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, constructContext(p1, p2)); Assertions.assertTrue(!res.isInvalid()); Assertions.assertTrue(res.getViewExpressions().isEmpty()); + Assertions.assertTrue(!HyperGraphComparator.isLogicCompatible(h1, h3, constructContext(p1, p2)).isInvalid()); dropConstraint("alter table T2 drop constraint pk"); } @@ -154,7 +178,6 @@ class EliminateJoinTest extends SqlTestBase { dropConstraint("alter table T3 drop constraint uk"); } - @Disabled @Test void testLOJWithPKFKAndUK2() throws Exception { connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org