This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 8ccd8b4337 [fix](Nereids) fix ends calculation when there are constant project (#22265) 8ccd8b4337 is described below commit 8ccd8b43379482d7a268b969884d3b7409387ff6 Author: 谢健 <jianx...@gmail.com> AuthorDate: Mon Jul 31 14:10:44 2023 +0800 [fix](Nereids) fix ends calculation when there are constant project (#22265) --- .../doris/nereids/jobs/joinorder/JoinOrderJob.java | 26 +-- .../nereids/jobs/joinorder/hypergraph/Edge.java | 124 +++++++++---- .../jobs/joinorder/hypergraph/GraphSimplifier.java | 62 ++++--- .../jobs/joinorder/hypergraph/HyperGraph.java | 200 ++++++++++----------- .../joinorder/hypergraph/SubgraphEnumerator.java | 12 +- .../hypergraph/receiver/PlanReceiver.java | 55 +++--- .../jobs/joinorder/hypergraph/OtherJoinTest.java | 31 +++- .../hypergraph/SubgraphEnumeratorTest.java | 4 +- .../doris/nereids/sqltest/JoinOrderJobTest.java | 14 ++ .../org/apache/doris/nereids/sqltest/JoinTest.java | 20 +-- .../doris/nereids/util/HyperGraphBuilder.java | 27 ++- .../org/apache/doris/nereids/util/PlanChecker.java | 21 +-- 12 files changed, 339 insertions(+), 257 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java index acc1ebb96d..a48604dc02 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.jobs.joinorder; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.Job; @@ -26,6 +27,7 @@ import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob; import org.apache.doris.nereids.jobs.joinorder.hypergraph.GraphSimplifier; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.SubgraphEnumerator; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.PlanReceiver; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; @@ -111,20 +113,22 @@ public class JoinOrderJob extends Job { * * @param group root group, should be join type * @param hyperGraph build hyperGraph + * + * @return return edges of group's child and subTreeNodes of this group */ - public BitSet buildGraph(Group group, HyperGraph hyperGraph) { + public Pair<BitSet, Long> buildGraph(Group group, HyperGraph hyperGraph) { if (group.isProjectGroup()) { - BitSet edgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph); - processProjectPlan(hyperGraph, group); - return edgeMap; + Pair<BitSet, Long> res = buildGraph(group.getLogicalExpression().child(0), hyperGraph); + processProjectPlan(hyperGraph, group, res.second); + return res; } if (!group.isValidJoinGroup()) { - hyperGraph.addNode(optimizePlan(group)); - return new BitSet(); + int idx = hyperGraph.addNode(optimizePlan(group)); + return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); } - BitSet leftEdgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph); - BitSet rightEdgeMap = buildGraph(group.getLogicalExpression().child(1), hyperGraph); - return hyperGraph.addEdge(group, leftEdgeMap, rightEdgeMap); + Pair<BitSet, Long> left = buildGraph(group.getLogicalExpression().child(0), hyperGraph); + Pair<BitSet, Long> right = buildGraph(group.getLogicalExpression().child(1), hyperGraph); + return Pair.of(hyperGraph.addEdge(group, left, right), LongBitmap.or(left.second, right.second)); } /** @@ -133,14 +137,14 @@ public class JoinOrderJob extends Job { * 2. If it's an alias that may be used in the join operator, we need to add it to graph * 3. If it's other expression, we can ignore them and add it after optimizing */ - private void processProjectPlan(HyperGraph hyperGraph, Group group) { + private void processProjectPlan(HyperGraph hyperGraph, Group group, long subTreeNodes) { LogicalProject<? extends Plan> logicalProject = (LogicalProject<? extends Plan>) group.getLogicalExpression() .getPlan(); for (NamedExpression expr : logicalProject.getProjects()) { if (expr instanceof Alias) { - hyperGraph.addAlias((Alias) expr); + hyperGraph.addAlias((Alias) expr, subTreeNodes); } else if (!expr.isSlot()) { otherProject.add(expr); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java index b2bdd4583a..9ebc1ed3fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -26,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import com.google.common.base.Preconditions; +import java.util.BitSet; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -38,23 +40,36 @@ public class Edge { final LogicalJoin<? extends Plan, ? extends Plan> join; final double selectivity; - // The endpoints (hyperNodes) of this hyperEdge. - // left and right may not overlap, and both must have at least one bit set. - private long left = LongBitmap.newBitmap(); - private long right = LongBitmap.newBitmap(); - - private long originalLeft = LongBitmap.newBitmap(); - private long originalRight = LongBitmap.newBitmap(); + // "RequiredNodes" refers to the nodes that can activate this edge based on + // specific requirements. These requirements are established during the building process. + // "ExtendNodes" encompasses both the "RequiredNodes" and any additional nodes + // added by the graph simplifier. + private long leftRequiredNodes = LongBitmap.newBitmap(); + private long rightRequiredNodes = LongBitmap.newBitmap(); + private long leftExtendedNodes = LongBitmap.newBitmap(); + private long rightExtendedNodes = LongBitmap.newBitmap(); private long referenceNodes = LongBitmap.newBitmap(); + // record the left child edges and right child edges in origin plan tree + private BitSet leftChildEdges; + private BitSet rightChildEdges; + + // record the edges in the same operator + private BitSet curJoinEdges = new BitSet(); + // record all sub nodes behind in this operator. It's T function in paper + private Long subTreeNodes; + /** * Create simple edge. */ - public Edge(LogicalJoin join, int index) { + public Edge(LogicalJoin join, int index, BitSet leftChildEdges, BitSet rightChildEdges, Long subTreeNodes) { this.index = index; this.join = join; this.selectivity = 1.0; + this.leftChildEdges = leftChildEdges; + this.rightChildEdges = rightChildEdges; + this.subTreeNodes = subTreeNodes; } public LogicalJoin getJoin() { @@ -66,65 +81,107 @@ public class Edge { } public boolean isSimple() { - return LongBitmap.getCardinality(left) == 1 && LongBitmap.getCardinality(right) == 1; + return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1; } public void addLeftNode(long left) { - this.left = LongBitmap.or(this.left, left); + this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left); referenceNodes = LongBitmap.or(referenceNodes, left); } public void addLeftNodes(long... bitmaps) { for (long bitmap : bitmaps) { - this.left = LongBitmap.or(this.left, bitmap); + this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, bitmap); referenceNodes = LongBitmap.or(referenceNodes, bitmap); } } public void addRightNode(long right) { - this.right = LongBitmap.or(this.right, right); + this.rightExtendedNodes = LongBitmap.or(this.rightExtendedNodes, right); referenceNodes = LongBitmap.or(referenceNodes, right); } public void addRightNodes(long... bitmaps) { for (long bitmap : bitmaps) { - LongBitmap.or(this.right, bitmap); + LongBitmap.or(this.rightExtendedNodes, bitmap); LongBitmap.or(referenceNodes, bitmap); } } - public long getLeft() { - return left; + public long getSubTreeNodes() { + return this.subTreeNodes; + } + + public long getLeftExtendedNodes() { + return leftExtendedNodes; + } + + public BitSet getLeftChildEdges() { + return leftChildEdges; } - public void setLeft(long left) { + public Pair<BitSet, Long> getLeftEdgeNodes(List<Edge> edges) { + return Pair.of(leftChildEdges, getLeftSubNodes(edges)); + } + + public Pair<BitSet, Long> getRightEdgeNodes(List<Edge> edges) { + return Pair.of(rightChildEdges, getRightSubNodes(edges)); + } + + public long getLeftSubNodes(List<Edge> edges) { + if (leftChildEdges.isEmpty()) { + return leftRequiredNodes; + } + return edges.get(leftChildEdges.nextSetBit(0)).getSubTreeNodes(); + } + + public long getRightSubNodes(List<Edge> edges) { + if (rightChildEdges.isEmpty()) { + return rightRequiredNodes; + } + return edges.get(rightChildEdges.nextSetBit(0)).getSubTreeNodes(); + } + + public void setLeftExtendedNodes(long leftExtendedNodes) { referenceNodes = LongBitmap.clear(referenceNodes); - this.left = left; + this.leftExtendedNodes = leftExtendedNodes; + } + + public long getRightExtendedNodes() { + return rightExtendedNodes; } - public long getRight() { - return right; + public BitSet getRightChildEdges() { + return rightChildEdges; } - public void setRight(long right) { + public void setRightExtendedNodes(long rightExtendedNodes) { referenceNodes = LongBitmap.clear(referenceNodes); - this.right = right; + this.rightExtendedNodes = rightExtendedNodes; } - public long getOriginalLeft() { - return originalLeft; + public long getLeftRequiredNodes() { + return leftRequiredNodes; } - public void setOriginalLeft(long left) { - this.originalLeft = left; + public void setLeftRequiredNodes(long left) { + this.leftRequiredNodes = left; } - public long getOriginalRight() { - return originalRight; + public long getRightRequiredNodes() { + return rightRequiredNodes; } - public void setOriginalRight(long right) { - this.originalRight = right; + public void setRightRequiredNodes(long right) { + this.rightRequiredNodes = right; + } + + public void addCurJoinEdges(BitSet edges) { + curJoinEdges.or(edges); + } + + public BitSet getCurJoinEdges() { + return curJoinEdges; } public boolean isSub(Edge edge) { @@ -135,11 +192,15 @@ public class Edge { public long getReferenceNodes() { if (LongBitmap.getCardinality(referenceNodes) == 0) { - referenceNodes = LongBitmap.newBitmapUnion(left, right); + referenceNodes = LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes); } return referenceNodes; } + public long getRequireNodes() { + return LongBitmap.newBitmapUnion(leftRequiredNodes, rightRequiredNodes); + } + public int getIndex() { return index; } @@ -165,7 +226,8 @@ public class Edge { @Override public String toString() { - return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right)); + return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString( + rightExtendedNodes)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index efe9896dd2..cf613ac174 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -108,10 +108,10 @@ public class GraphSimplifier { Edge edge1 = graph.getEdge(i); Edge edge2 = graph.getEdge(j); List<Long> superset = new ArrayList<>(); - tryGetSuperset(edge1.getLeft(), edge2.getLeft(), superset); - tryGetSuperset(edge1.getLeft(), edge2.getRight(), superset); - tryGetSuperset(edge1.getRight(), edge2.getLeft(), superset); - tryGetSuperset(edge1.getRight(), edge2.getRight(), superset); + tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes(), superset); + tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset); + tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getLeftExtendedNodes(), superset); + tryGetSuperset(edge1.getRightExtendedNodes(), edge2.getRightExtendedNodes(), superset); if (!circleDetector.checkCircleWithEdge(i, j) && !circleDetector.checkCircleWithEdge(j, i) && !edge2.isSub(edge1) && !edge1.isSub(edge2) && !superset.isEmpty()) { return false; @@ -213,8 +213,8 @@ public class GraphSimplifier { BestSimplification bestSimplification = priorityQueue.poll(); bestSimplification.isInQueue = false; SimplificationStep bestStep = bestSimplification.getStep(); - while (bestSimplification.bestNeighbor == -1 || !circleDetector.tryAddDirectedEdge(bestStep.beforeIndex, - bestStep.afterIndex)) { + while (bestSimplification.bestNeighbor == -1 + || !circleDetector.tryAddDirectedEdge(bestStep.beforeIndex, bestStep.afterIndex)) { processNeighbors(bestStep.afterIndex, 0, edgeSize); if (priorityQueue.isEmpty()) { return null; @@ -307,10 +307,14 @@ public class GraphSimplifier { || circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)) { return Optional.empty(); } - long left1 = edge1.getLeft(); - long right1 = edge1.getRight(); - long left2 = edge2.getLeft(); - long right2 = edge2.getRight(); + long left1 = edge1.getLeftExtendedNodes(); + long right1 = edge1.getRightExtendedNodes(); + long left2 = edge2.getLeftExtendedNodes(); + long right2 = edge2.getRightExtendedNodes(); + if (!cacheStats.containsKey(left1) || !cacheStats.containsKey(right1) + || !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) { + return Optional.empty(); + } Pair<Statistics, Edge> edge1Before2; Pair<Statistics, Edge> edge2Before1; List<Long> superBitset = new ArrayList<>(); @@ -351,13 +355,14 @@ public class GraphSimplifier { Statistics leftStats = JoinEstimation.estimate(cacheStats.get(bitmap1), cacheStats.get(bitmap2), edge1.getJoin()); Statistics joinStats = JoinEstimation.estimate(leftStats, cacheStats.get(bitmap3), edge2.getJoin()); - Edge edge = new Edge(edge2.getJoin(), -1); + Edge edge = new Edge( + edge2.getJoin(), -1, edge2.getLeftChildEdges(), edge2.getRightChildEdges(), edge2.getSubTreeNodes()); long newLeft = LongBitmap.newBitmapUnion(bitmap1, bitmap2); // To avoid overlapping the left and the right, the newLeft is calculated, Note the // newLeft is not totally include the bitset1 and bitset2, we use circle detector to trace the dependency newLeft = LongBitmap.andNot(newLeft, bitmap3); edge.addLeftNodes(newLeft); - edge.addRightNode(edge2.getRight()); + edge.addRightNode(edge2.getRightExtendedNodes()); cacheStats.put(newLeft, leftStats); cacheCost.put(newLeft, calCost(edge2, leftStats, cacheStats.get(bitmap1), cacheStats.get(bitmap2))); return Pair.of(joinStats, edge); @@ -370,11 +375,12 @@ public class GraphSimplifier { Statistics rightStats = JoinEstimation.estimate(cacheStats.get(bitmap2), cacheStats.get(bitmap3), edge2.getJoin()); Statistics joinStats = JoinEstimation.estimate(cacheStats.get(bitmap1), rightStats, edge1.getJoin()); - Edge edge = new Edge(edge1.getJoin(), -1); + Edge edge = new Edge( + edge1.getJoin(), -1, edge1.getLeftChildEdges(), edge1.getRightChildEdges(), edge1.getSubTreeNodes()); long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3); newRight = LongBitmap.andNot(newRight, bitmap1); - edge.addLeftNode(edge1.getLeft()); + edge.addLeftNode(edge1.getLeftExtendedNodes()); edge.addRightNode(newRight); cacheStats.put(newRight, rightStats); cacheCost.put(newRight, calCost(edge2, rightStats, cacheStats.get(bitmap2), cacheStats.get(bitmap3))); @@ -384,11 +390,11 @@ public class GraphSimplifier { private SimplificationStep orderJoin(Pair<Statistics, Edge> edge1Before2, Pair<Statistics, Edge> edge2Before1, int edgeIndex1, int edgeIndex2) { Cost cost1Before2 = calCost(edge1Before2.second, edge1Before2.first, - cacheStats.get(edge1Before2.second.getLeft()), - cacheStats.get(edge1Before2.second.getRight())); + cacheStats.get(edge1Before2.second.getLeftExtendedNodes()), + cacheStats.get(edge1Before2.second.getRightExtendedNodes())); Cost cost2Before1 = calCost(edge2Before1.second, edge1Before2.first, - cacheStats.get(edge1Before2.second.getLeft()), - cacheStats.get(edge1Before2.second.getRight())); + cacheStats.get(edge1Before2.second.getLeftExtendedNodes()), + cacheStats.get(edge1Before2.second.getRightExtendedNodes())); double benefit = Double.MAX_VALUE; SimplificationStep step; // Choose the plan with smaller cost and make the simplification step to replace the old edge by it. @@ -397,17 +403,17 @@ public class GraphSimplifier { benefit = cost2Before1.getValue() / cost1Before2.getValue(); } // choose edge1Before2 - step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeft(), - edge1Before2.second.getRight(), graph.getEdge(edgeIndex2).getLeft(), - graph.getEdge(edgeIndex2).getRight()); + step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.second.getLeftExtendedNodes(), + edge1Before2.second.getRightExtendedNodes(), graph.getEdge(edgeIndex2).getLeftExtendedNodes(), + graph.getEdge(edgeIndex2).getRightExtendedNodes()); } else { if (cost2Before1.getValue() != 0) { benefit = cost1Before2.getValue() / cost2Before1.getValue(); } // choose edge2Before1 - step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeft(), - edge2Before1.second.getRight(), graph.getEdge(edgeIndex1).getLeft(), - graph.getEdge(edgeIndex1).getRight()); + step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.second.getLeftExtendedNodes(), + edge2Before1.second.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(), + graph.getEdge(edgeIndex1).getRightExtendedNodes()); } return step; } @@ -438,8 +444,8 @@ public class GraphSimplifier { join.left(), join.right()); cost = CostCalculator.calculateCost(nestedLoopJoin, planContext); - cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeft()), 0); - cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRight()), 1); + cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0); + cost = CostCalculator.addChildCost(nestedLoopJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1); } else { PhysicalHashJoin hashJoin = new PhysicalHashJoin<>( join.getJoinType(), @@ -451,8 +457,8 @@ public class GraphSimplifier { join.left(), join.right()); cost = CostCalculator.calculateCost(hashJoin, planContext); - cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeft()), 0); - cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRight()), 1); + cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getLeftExtendedNodes()), 0); + cost = CostCalculator.addChildCost(hashJoin, cost, cacheCost.get(edge.getRightExtendedNodes()), 1); } return cost; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 2bc55d8ed2..67f246cd80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; @@ -88,7 +87,7 @@ public class HyperGraph { * * @param alias The alias Expression in project Operator */ - public boolean addAlias(Alias alias) { + public boolean addAlias(Alias alias, long subTreeNodes) { Slot aliasSlot = alias.toSlot(); if (slotToNodeMap.containsKey(aliasSlot)) { return true; @@ -97,12 +96,21 @@ public class HyperGraph { for (Slot slot : alias.getInputSlots()) { bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); } + // The case hit when there are some constant aliases such as: + // select * from t1 join ( + // select *, 1 as b1 from t2) + // on t1.b = b1 + // just reference them all for this slot + if (bitmap == 0) { + bitmap = subTreeNodes; + } + Preconditions.checkArgument(bitmap > 0, "slot must belong to some table"); slotToNodeMap.put(aliasSlot, bitmap); if (!complexProject.containsKey(bitmap)) { complexProject.put(bitmap, new ArrayList<>()); - } else if (!(alias.child() instanceof SlotReference)) { - alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0); } + alias = (Alias) PlanUtils.mergeProjections(complexProject.get(bitmap), Lists.newArrayList(alias)).get(0); + complexProject.get(bitmap).add(alias); return true; } @@ -111,8 +119,9 @@ public class HyperGraph { * add end node to HyperGraph * * @param group The group that is the end node in graph + * @return return the node index */ - public void addNode(Group group) { + public int addNode(Group group) { Preconditions.checkArgument(!group.isValidJoinGroup()); for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) { Preconditions.checkArgument(!slotToNodeMap.containsKey(slot)); @@ -120,6 +129,7 @@ public class HyperGraph { } nodeSet.add(group); nodes.add(new Node(nodes.size(), group)); + return nodes.size() - 1; } public boolean isNodeGroup(Group group) { @@ -135,138 +145,126 @@ public class HyperGraph { * * @param group The join group */ - public BitSet addEdge(Group group, BitSet leftEdgeMap, BitSet rightEdgeMap) { + public BitSet addEdge(Group group, Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) { Preconditions.checkArgument(group.isValidJoinGroup()); LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan(); HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>(); - for (Expression expression : join.getHashJoinConjuncts()) { - Pair<Long, Long> ends = findEnds(expression); + // TODO: avoid calling calculateEnds if calNodeMap's results are same + Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, + rightEdgeNodes); if (!conjuncts.containsKey(ends)) { conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); } conjuncts.get(ends).first.add(expression); } for (Expression expression : join.getOtherJoinConjuncts()) { - Pair<Long, Long> ends = findEnds(expression); + Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes, + rightEdgeNodes); if (!conjuncts.containsKey(ends)) { conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); } conjuncts.get(ends).second.add(expression); } - BitSet edgeMap = new BitSet(); - edgeMap.or(leftEdgeMap); - edgeMap.or(rightEdgeMap); - + BitSet curJoinEdges = new BitSet(); for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts .entrySet()) { LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(), Lists.newArrayList(join.left(), join.right())); - Edge edge = new Edge(singleJoin, edges.size()); + Edge edge = new Edge(singleJoin, edges.size(), leftEdgeNodes.first, rightEdgeNodes.first, + LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second)); Pair<Long, Long> ends = entry.getKey(); - initEdgeEnds(ends, edge, leftEdgeMap, rightEdgeMap); + edge.setLeftRequiredNodes(ends.first); + edge.setLeftExtendedNodes(ends.first); + edge.setRightRequiredNodes(ends.second); + edge.setRightExtendedNodes(ends.second); for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { nodes.get(nodeIndex).attachEdge(edge); } - edgeMap.set(edge.getIndex()); + curJoinEdges.set(edge.getIndex()); edges.add(edge); } - - return edgeMap; + curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges)); + curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges)); + curJoinEdges.stream().forEach(i -> makeConflictRules(edges.get(i))); + return curJoinEdges; // In MySQL, each edge is reversed and store in edges again for reducing the branch miss // We don't implement this trick now. } - // Make edge with CD-A algorithm in + // Make edge with CD-C algorithm in // On the correct and complete enumeration of the core search - private void initEdgeEnds(Pair<Long, Long> ends, Edge edge, BitSet leftEdges, BitSet rightEdges) { - long left = ends.first; - long right = ends.second; - for (int i = leftEdges.nextSetBit(0); i >= 0; i = leftEdges.nextSetBit(i + 1)) { - Edge lEdge = edges.get(i); - if (!JoinType.isAssoc(lEdge.getJoinType(), edge.getJoinType())) { - left = LongBitmap.or(left, lEdge.getLeft()); + private void makeConflictRules(Edge edgeB) { + BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges()); + BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges()); + long leftRequired = edgeB.getLeftRequiredNodes(); + long rightRequired = edgeB.getRightRequiredNodes(); + + for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) { + Edge childA = edges.get(i); + if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) { + leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(edges)); } - if (!JoinType.isLAssoc(lEdge.getJoinType(), edge.getJoinType())) { - left = LongBitmap.or(left, lEdge.getRight()); + if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) { + leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(edges)); } } - for (int i = rightEdges.nextSetBit(0); i >= 0; i = rightEdges.nextSetBit(i + 1)) { - Edge rEdge = edges.get(i); - if (!JoinType.isAssoc(rEdge.getJoinType(), edge.getJoinType())) { - right = LongBitmap.or(right, rEdge.getRight()); + + for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) { + Edge childA = edges.get(i); + if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) { + rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(edges)); } - if (!JoinType.isRAssoc(rEdge.getJoinType(), edge.getJoinType())) { - right = LongBitmap.or(right, rEdge.getLeft()); + if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) { + rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(edges)); } } - - edge.setOriginalLeft(left); - edge.setOriginalRight(right); - edge.setLeft(left); - edge.setRight(right); + edgeB.setLeftRequiredNodes(leftRequired); + edgeB.setRightRequiredNodes(rightRequired); + edgeB.setLeftExtendedNodes(leftRequired); + edgeB.setRightExtendedNodes(rightRequired); } - private int findRoot(List<Integer> parent, int idx) { - int root = parent.get(idx); - if (root != idx) { - root = findRoot(parent, root); - } - parent.set(idx, root); - return root; + private BitSet subTreeEdges(Edge edge) { + BitSet bitSet = new BitSet(); + bitSet.or(subTreeEdges(edge.getLeftChildEdges())); + bitSet.or(subTreeEdges(edge.getRightChildEdges())); + bitSet.set(edge.getIndex()); + return bitSet; } - private boolean isConnected(long bitmap, long excludeBitmap) { - if (LongBitmap.getCardinality(bitmap) == 1) { - return true; - } - - // use unionSet to check whether the bitmap is connected - List<Integer> parent = new ArrayList<>(); - for (int i = 0; i < nodes.size(); i++) { - parent.add(i, i); - } - for (Edge edge : edges) { - if (LongBitmap.isOverlap(edge.getLeft(), excludeBitmap) - || LongBitmap.isOverlap(edge.getRight(), excludeBitmap)) { - continue; - } - - int root = findRoot(parent, LongBitmap.nextSetBit(edge.getLeft(), 0)); - for (int idx : LongBitmap.getIterator(edge.getLeft())) { - parent.set(idx, root); - } - for (int idx : LongBitmap.getIterator(edge.getRight())) { - parent.set(idx, root); - } - } - - int root = findRoot(parent, LongBitmap.nextSetBit(bitmap, 0)); - for (int idx : LongBitmap.getIterator(bitmap)) { - if (root != findRoot(parent, idx)) { - return false; - } - } - return true; + private BitSet subTreeEdges(BitSet edgeSet) { + BitSet bitSet = new BitSet(); + edgeSet.stream() + .mapToObj(i -> subTreeEdges(edges.get(i))) + .forEach(b -> bitSet.or(b)); + return bitSet; } - private Pair<Long, Long> findEnds(Expression expression) { - long bitmap = calNodeMap(expression.getInputSlots()); - int cardinality = LongBitmap.getCardinality(bitmap); - Preconditions.checkArgument(cardinality > 1); - for (long subset : LongBitmap.getSubsetIterator(bitmap)) { - long left = subset; - long right = LongBitmap.newBitmapDiff(bitmap, left); - // when the graph without right node has a connected-sub-graph contains left nodes - // and the graph without left node has a connected-sub-graph contains right nodes. - // we can generate an edge for this expression - if (isConnected(left, right) && isConnected(right, left)) { - return Pair.of(left, right); - } + // Try to calculate the ends of an expression. + // left = ref_nodes \cap left_tree , right = ref_nodes \cap right_tree + // if left = 0, recursively calculate it in left tree + private Pair<Long, Long> calculateEnds(long allNodes, Pair<BitSet, Long> leftEdgeNodes, + Pair<BitSet, Long> rightEdgeNodes) { + long left = LongBitmap.newBitmapIntersect(allNodes, leftEdgeNodes.second); + long right = LongBitmap.newBitmapIntersect(allNodes, rightEdgeNodes.second); + if (left == 0) { + Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0, + "the number of the table which expression reference is less 2"); + Pair<BitSet, Long> llEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges); + Pair<BitSet, Long> lrEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges); + return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes); + } + if (right == 0) { + Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0, + "the number of the table which expression reference is less 2"); + Pair<BitSet, Long> rlEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges); + Pair<BitSet, Long> rrEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges); + return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes); } - throw new RuntimeException("DPhyper meets unconnected subgraph"); + return Pair.of(left, right); } private long calNodeMap(Set<Slot> slots) { @@ -291,10 +289,10 @@ public class HyperGraph { // For these nodes that are only in the old edge, we need remove the edge from them // For these nodes that are only in the new edge, we need to add the edge to them Edge edge = edges.get(edgeIndex); - updateEdges(edge, edge.getLeft(), newLeft); - updateEdges(edge, edge.getRight(), newRight); - edges.get(edgeIndex).setLeft(newLeft); - edges.get(edgeIndex).setRight(newRight); + updateEdges(edge, edge.getLeftExtendedNodes(), newLeft); + updateEdges(edge, edge.getRightExtendedNodes(), newRight); + edges.get(edgeIndex).setLeftExtendedNodes(newLeft); + edges.get(edgeIndex).setRightExtendedNodes(newRight); } private void updateEdges(Edge edge, long oldNodes, long newNodes) { @@ -339,8 +337,8 @@ public class HyperGraph { arrowHead = ",arrowhead=none"; } - int leftIndex = LongBitmap.lowestOneIndex(edge.getLeft()); - int rightIndex = LongBitmap.lowestOneIndex(edge.getRight()); + int leftIndex = LongBitmap.lowestOneIndex(edge.getLeftExtendedNodes()); + int rightIndex = LongBitmap.lowestOneIndex(edge.getRightExtendedNodes()); builder.append(String.format("%s -> %s [label=\"%s\"%s]\n", graphvisNodes.get(leftIndex), graphvisNodes.get(rightIndex), label, arrowHead)); } else { @@ -349,7 +347,7 @@ public class HyperGraph { String leftLabel = ""; String rightLabel = ""; - if (LongBitmap.getCardinality(edge.getLeft()) == 1) { + if (LongBitmap.getCardinality(edge.getLeftExtendedNodes()) == 1) { rightLabel = label; } else { leftLabel = label; @@ -357,13 +355,13 @@ public class HyperGraph { int finalI = i; String finalLeftLabel = leftLabel; - for (int nodeIndex : LongBitmap.getIterator(edge.getLeft())) { + for (int nodeIndex : LongBitmap.getIterator(edge.getLeftExtendedNodes())) { builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n", graphvisNodes.get(nodeIndex), finalI, finalLeftLabel)); } String finalRightLabel = rightLabel; - for (int nodeIndex : LongBitmap.getIterator(edge.getRight())) { + for (int nodeIndex : LongBitmap.getIterator(edge.getRightExtendedNodes())) { builder.append(String.format("%s -> e%d [arrowhead=none, label=\"%s\"]\n", graphvisNodes.get(nodeIndex), finalI, finalRightLabel)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java index cc5a9cea96..cbfa62d1f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java @@ -225,8 +225,8 @@ public class SubgraphEnumerator { neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes); forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods); for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) { - long left = edge.getLeft(); - long right = edge.getRight(); + long left = edge.getLeftExtendedNodes(); + long right = edge.getRightExtendedNodes(); if (LongBitmap.isSubset(left, subgraph) && !LongBitmap.isOverlap(right, forbiddenNodes)) { neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(right)); } else if (LongBitmap.isSubset(right, subgraph) && !LongBitmap.isOverlap(left, forbiddenNodes)) { @@ -362,14 +362,14 @@ public class SubgraphEnumerator { } private boolean isContainEdge(long subgraph, Edge edge) { - int containLeft = LongBitmap.isSubset(edge.getLeft(), subgraph) ? 0 : 1; - int containRight = LongBitmap.isSubset(edge.getRight(), subgraph) ? 0 : 1; + int containLeft = LongBitmap.isSubset(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1; + int containRight = LongBitmap.isSubset(edge.getRightExtendedNodes(), subgraph) ? 0 : 1; return containLeft + containRight == 1; } private boolean isOverlapEdge(long subgraph, Edge edge) { - int overlapLeft = LongBitmap.isOverlap(edge.getLeft(), subgraph) ? 0 : 1; - int overlapRight = LongBitmap.isOverlap(edge.getRight(), subgraph) ? 0 : 1; + int overlapLeft = LongBitmap.isOverlap(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1; + int overlapRight = LongBitmap.isOverlap(edge.getRightExtendedNodes(), subgraph) ? 0 : 1; return overlapLeft + overlapRight == 1; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index 2c1b7ebed5..51c16e24f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -60,6 +60,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nullable; /** @@ -102,15 +103,14 @@ public class PlanReceiver implements AbstractReceiver { public boolean emitCsgCmp(long left, long right, List<Edge> edges) { Preconditions.checkArgument(planTable.containsKey(left)); Preconditions.checkArgument(planTable.containsKey(right)); - processMissedEdges(left, right, edges); - Memo memo = jobContext.getCascadesContext().getMemo(); emitCount += 1; if (emitCount > limit) { return false; } + Memo memo = jobContext.getCascadesContext().getMemo(); GroupPlan leftPlan = new GroupPlan(planTable.get(left)); GroupPlan rightPlan = new GroupPlan(planTable.get(right)); @@ -118,6 +118,7 @@ public class PlanReceiver implements AbstractReceiver { // In this step, we don't generate logical expression because they are useless in DPhyp. List<Expression> hashConjuncts = new ArrayList<>(); List<Expression> otherConjuncts = new ArrayList<>(); + JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); if (joinType == null) { return true; @@ -126,6 +127,7 @@ public class PlanReceiver implements AbstractReceiver { List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts, otherConjuncts); + List<Plan> physicalPlans = proposeProject(physicalJoins, edges, left, right); // Second, we copy all physical plan to Group and generate properties and calculate cost @@ -188,7 +190,7 @@ public class PlanReceiver implements AbstractReceiver { // find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes for (Edge edge : hyperGraph.getEdges()) { long referenceNodes = - LongBitmap.newBitmapUnion(edge.getOriginalLeft(), edge.getOriginalRight()); + LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes()); if (LongBitmap.isSubset(referenceNodes, allReferenceNodes) && !usedEdgesBitmap.get(edge.getIndex())) { // add the missed edge to edges @@ -344,7 +346,6 @@ public class PlanReceiver implements AbstractReceiver { long fullKey = LongBitmap.newBitmapUnion(left, right); List<Slot> outputs = allChild.get(0).getOutput(); Set<Slot> outputSet = allChild.get(0).getOutputSet(); - List<NamedExpression> allProjects = Lists.newArrayList(); List<NamedExpression> complexProjects = new ArrayList<>(); // Calculate complex expression should be done by current(fullKey) node @@ -358,43 +359,23 @@ public class PlanReceiver implements AbstractReceiver { // complexProjectMap is created by a bottom up traverse of join tree, so child node is put before parent node // in the bitmaps + bitmaps.sort(Long::compare); for (long bitmap : bitmaps) { if (complexProjects.isEmpty()) { - complexProjects = complexProjectMap.get(bitmap); + complexProjects.addAll(complexProjectMap.get(bitmap)); } else { - // The top project of (T1, T2, T3) is different after reorder - // we need merge Project1 and Project2 as Project4 after reorder - // T1 join T2 join T3: - // Project1(a, e + f) - // join(a = e) - // Project2(a, b + d as e) - // join(a = c) - // T1(a, b) - // T2(c, d) - // T3(e, f) - // - // after reorder: - // T1 join T3 join T2: - // Project4(a, b + d + f) - // join(a = c) - // Project3(a, b, f) - // join(a = e) - // T1(a, b) - // T3(e, f) - // T2(c, d) - // - complexProjects = - PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap)); + // Rewrite project expression by its children + complexProjects.addAll( + PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap))); } } - allProjects.addAll(complexProjects); // calculate required columns by all parents Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges); - - // add output slots belong to required slots to project list - allProjects.addAll(outputs.stream().filter(e -> requireSlots.contains(e)) - .collect(Collectors.toList())); + List<NamedExpression> allProjects = Stream.concat( + outputs.stream().filter(e -> requireSlots.contains(e)), + complexProjects.stream().filter(e -> requireSlots.contains(e.toSlot())) + ).collect(Collectors.toList()); // propose physical project if (allProjects.isEmpty()) { @@ -416,7 +397,13 @@ public class PlanReceiver implements AbstractReceiver { .map(c -> new PhysicalProject<>(projects, projectProperties, c)) .collect(Collectors.toList()); } - Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size()); + if (!(!projects.isEmpty() && projects.size() == allProjects.size())) { + Set<NamedExpression> s1 = projects.stream().collect(Collectors.toSet()); + List<NamedExpression> s2 = allProjects.stream().filter(e -> !s1.contains(e)).collect(Collectors.toList()); + System.out.println(s2); + } + Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size(), + " there are some projects left " + projects + allProjects); return allChild; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java index feeb971b15..a4062d2edf 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.datasets.tpch.TPCHTestBase; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.HyperGraphBuilder; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -32,23 +33,37 @@ import java.util.Set; public class OtherJoinTest extends TPCHTestBase { @Test - public void randomTest() { + public void test() { + for (int t = 3; t < 10; t++) { + for (int e = t - 1; e <= (t * (t - 1)) / 2; e++) { + for (int i = 0; i < 10; i++) { + System.out.println(String.valueOf(t) + " " + e + ": " + i); + randomTest(t, e); + } + } + } + } + + private void randomTest(int tableNum, int edgeNum) { HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(); Plan plan = hyperGraphBuilder - .randomBuildPlanWith(10, 20); - Set<List<Integer>> res1 = hyperGraphBuilder.evaluate(plan); + .randomBuildPlanWith(tableNum, edgeNum); + plan = new LogicalProject(plan.getOutput(), plan); + Set<List<String>> res1 = hyperGraphBuilder.evaluate(plan); CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan); hyperGraphBuilder.initStats(cascadesContext); Plan optimizedPlan = PlanChecker.from(cascadesContext) - .dpHypOptimize() - .getBestPlanTree(); + .dpHypOptimize() + .getBestPlanTree(); - Set<List<Integer>> res2 = hyperGraphBuilder.evaluate(optimizedPlan); + Set<List<String>> res2 = hyperGraphBuilder.evaluate(optimizedPlan); if (!res1.equals(res2)) { - System.out.println(res1); - System.out.println(res2); System.out.println(plan.treeString()); System.out.println(optimizedPlan.treeString()); + cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan); + PlanChecker.from(cascadesContext).dpHypOptimize().getBestPlanTree(); + System.out.println(res1); + System.out.println(res2); } Assertions.assertTrue(res1.equals(res2)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java index 6e559efbbe..47dc68c9e0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java @@ -130,8 +130,8 @@ public class SubgraphEnumeratorTest { visited.add(left); visited.add(right); for (Edge edge : hyperGraph.getEdges()) { - if ((LongBitmap.isSubset(edge.getLeft(), left) && LongBitmap.isSubset(edge.getRight(), right)) || ( - LongBitmap.isSubset(edge.getLeft(), right) && LongBitmap.isSubset(edge.getRight(), left))) { + if ((LongBitmap.isSubset(edge.getLeftExtendedNodes(), left) && LongBitmap.isSubset(edge.getRightExtendedNodes(), right)) || ( + LongBitmap.isSubset(edge.getLeftExtendedNodes(), right) && LongBitmap.isSubset(edge.getRightExtendedNodes(), left))) { count += countAndCheck(left, hyperGraph, counter, cache) * countAndCheck(right, hyperGraph, counter, cache); break; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java index af272f3d5d..66f747aae8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinOrderJobTest.java @@ -87,6 +87,20 @@ public class JoinOrderJobTest extends SqlTestBase { .dpHypOptimize(); } + @Test + protected void testConstantJoin() { + String sql = "select count(*) \n" + + "from \n" + + "T1 \n" + + " join (\n" + + "select * , now() as t from T2 \n" + + ") subTable on T1.id = t; \n"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .dpHypOptimize(); + } + @Test protected void testCountJoin() { String sql = "select count(*) \n" diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinTest.java index 72f8ec0879..21283fcf98 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/JoinTest.java @@ -17,9 +17,9 @@ package org.apache.doris.nereids.sqltest; -import org.apache.doris.nereids.properties.DistributionSpecGather; import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.rewrite.ReorderJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; @@ -50,12 +50,8 @@ public class JoinTest extends SqlTestBase { .getBestPlanTree(); // generate colocate join plan without physicalDistribute System.out.println(plan.treeString()); - Assertions.assertFalse(plan.anyMatch(p -> { - if (p instanceof PhysicalDistribute) { - return !(((PhysicalDistribute<?>) p).getDistributionSpec() instanceof DistributionSpecGather); - } - return false; - })); + Assertions.assertFalse(plan.anyMatch(p -> p instanceof PhysicalDistribute + && ((PhysicalDistribute) p).getDistributionSpec() instanceof DistributionSpecHash)); sql = "select * from T1 join T0 on T1.score = T0.score and T1.id = T0.id;"; plan = PlanChecker.from(connectContext) .analyze(sql) @@ -63,12 +59,8 @@ public class JoinTest extends SqlTestBase { .optimize() .getBestPlanTree(); // generate colocate join plan without physicalDistribute - Assertions.assertFalse(plan.anyMatch(p -> { - if (p instanceof PhysicalDistribute) { - return !(((PhysicalDistribute<?>) p).getDistributionSpec() instanceof DistributionSpecGather); - } - return false; - })); + Assertions.assertFalse(plan.anyMatch(p -> p instanceof PhysicalDistribute + && ((PhysicalDistribute) p).getDistributionSpec() instanceof DistributionSpecHash)); } @Test @@ -100,7 +92,7 @@ public class JoinTest extends SqlTestBase { .analyze(sql) .rewrite() .optimize() - .getBestPlanTree(); + .getBestPlanTree(PhysicalProperties.ANY); Assertions.assertEquals( ShuffleType.NATURAL, ((DistributionSpecHash) ((PhysicalPlan) (plan.child(0).child(0))) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java index 7fe8b97977..5a1d88fe51 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.util; +import org.apache.doris.catalog.Env; import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; @@ -26,6 +27,7 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -35,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.Statistics; +import org.apache.doris.statistics.StatisticsCacheKey; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -90,6 +93,12 @@ public class HyperGraphBuilder { return buildHyperGraph(plan); } + public Plan buildPlan() { + assert plans.size() == 1 : "there are cross join"; + Plan plan = plans.values().iterator().next(); + return plan; + } + public Plan buildJoinPlan() { assert plans.size() == 1 : "there are cross join"; Plan plan = plans.values().iterator().next(); @@ -166,9 +175,14 @@ public class HyperGraphBuilder { for (Group group : context.getMemo().getGroups()) { GroupExpression groupExpression = group.getLogicalExpression(); if (groupExpression.getPlan() instanceof LogicalOlapScan) { + LogicalOlapScan scan = (LogicalOlapScan) groupExpression.getPlan(); Statistics stats = injectRowcount((LogicalOlapScan) groupExpression.getPlan()); - groupExpression.setStatDerived(true); - group.setStatistics(stats); + for (Expression expr : stats.columnStatistics().keySet()) { + SlotReference slot = (SlotReference) expr; + Env.getCurrentEnv().getStatisticsCache().putCache( + new StatisticsCacheKey(scan.getTable().getId(), -1, slot.getName()), + stats.columnStatistics().get(expr)); + } } } } @@ -364,7 +378,7 @@ public class HyperGraphBuilder { return hashConjunts; } - public Set<List<Integer>> evaluate(Plan plan) { + public Set<List<String>> evaluate(Plan plan) { JoinEvaluator evaluator = new JoinEvaluator(rowCounts); Map<Slot, List<Integer>> res = evaluator.evaluate(plan); int rowCount = 0; @@ -376,11 +390,12 @@ public class HyperGraphBuilder { (slot1, slot2) -> String.CASE_INSENSITIVE_ORDER.compare(slot1.toString(), slot2.toString())) .collect(Collectors.toList()); - Set<List<Integer>> tuples = new HashSet<>(); + Set<List<String>> tuples = new HashSet<>(); + tuples.add(keySet.stream().map(s -> s.toString()).collect(Collectors.toList())); for (int i = 0; i < rowCount; i++) { - List<Integer> tuple = new ArrayList<>(); + List<String> tuple = new ArrayList<>(); for (Slot key : keySet) { - tuple.add(res.get(key).get(i)); + tuple.add(String.valueOf(res.get(key).get(i))); } tuples.add(tuple); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 6c88f66056..77d0db9195 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -33,7 +33,6 @@ import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob; import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob; import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob; -import org.apache.doris.nereids.memo.CopyInResult; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.memo.Memo; @@ -243,23 +242,13 @@ public class PlanChecker { public PlanChecker dpHypOptimize() { double now = System.currentTimeMillis(); + cascadesContext.getStatementContext().setDpHyp(true); + cascadesContext.getConnectContext().getSessionVariable().enableDPHypOptimizer = true; Group root = cascadesContext.getMemo().getRoot(); - boolean changeRoot = false; - if (root.isValidJoinGroup()) { - // If the root group is join group, DPHyp can change the root group. - // To keep the root group is not changed, we add a dummy project operator above join - List<Slot> outputs = root.getLogicalExpression().getPlan().getOutput(); - LogicalPlan plan = new LogicalProject(outputs, root.getLogicalExpression().getPlan()); - CopyInResult copyInResult = cascadesContext.getMemo().copyIn(plan, null, false); - root = copyInResult.correspondingExpression.getOwnerGroup(); - changeRoot = true; - } cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext())); + cascadesContext.pushJob(new DeriveStatsJob(root.getLogicalExpression(), + cascadesContext.getCurrentJobContext())); cascadesContext.getJobScheduler().executeJobPool(cascadesContext); - if (changeRoot) { - cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0)); - } - // if the root is not join, we need to optimize again. optimize(); System.out.println("DPhyp:" + (System.currentTimeMillis() - now)); return this; @@ -602,7 +591,7 @@ public class PlanChecker { } public PhysicalPlan getBestPlanTree() { - return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.ANY); + return chooseBestPlan(cascadesContext.getMemo().getRoot(), PhysicalProperties.GATHER); } public PhysicalPlan getBestPlanTree(PhysicalProperties properties) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org