This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 3891083474 [fix](Nereids): fix some bugs in DpHyper (#16282) 3891083474 is described below commit 3891083474743b54ebf6dd1334e4373800029adb Author: 谢健 <jianx...@gmail.com> AuthorDate: Fri Feb 3 18:19:48 2023 +0800 [fix](Nereids): fix some bugs in DpHyper (#16282) --- .../org/apache/doris/nereids/NereidsPlanner.java | 9 +- .../doris/nereids/jobs/joinorder/JoinOrderJob.java | 16 +--- .../jobs/joinorder/hypergraph/HyperGraph.java | 105 ++++++++++++++++----- .../nereids/jobs/joinorder/hypergraph/Node.java | 4 + .../joinorder/hypergraph/SubgraphEnumerator.java | 4 +- .../hypergraph/receiver/PlanReceiver.java | 100 +++++++++++++------- .../processor/post/MergeProjectPostProcessor.java | 12 ++- .../nereids/processor/post/PlanPostProcessors.java | 1 + .../java/org/apache/doris/qe/SessionVariable.java | 2 +- 9 files changed, 166 insertions(+), 87 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index d24cde8e8c..e80a0018f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -219,7 +219,6 @@ public class NereidsPlanner extends Planner { private void dpHypOptimize() { Group root = getRoot(); - boolean changeRoot = false; if (root.isJoinGroup()) { // If the root group is join group, DPHyp can change the root group. // To keep the root group is not changed, we add a project operator above join @@ -227,16 +226,10 @@ public class NereidsPlanner extends Planner { LogicalPlan plan = new LogicalProject(outputs, root.getLogicalExpression().getPlan()); CopyInResult copyInResult = cascadesContext.getMemo().copyIn(plan, null, false); root = copyInResult.correspondingExpression.getOwnerGroup(); - changeRoot = true; } cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext())); cascadesContext.getJobScheduler().executeJobPool(cascadesContext); - if (changeRoot) { - cascadesContext.getMemo().setRoot(root.getLogicalExpression().child(0)); - } else { - // if the root is not join, we need to optimize again. - optimize(); - } + optimize(); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java index e5c9fa440e..6e82a0faa9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java @@ -81,7 +81,7 @@ public class JoinOrderJob extends Job { HyperGraph hyperGraph = new HyperGraph(); buildGraph(group, hyperGraph); // TODO: Right now, we just hardcode the limit with 10000, maybe we need a better way to set it - int limit = 10000; + int limit = 1000; PlanReceiver planReceiver = new PlanReceiver(this.context, limit, hyperGraph, group.getLogicalProperties().getOutputSet()); SubgraphEnumerator subgraphEnumerator = new SubgraphEnumerator(planReceiver, hyperGraph); @@ -113,13 +113,8 @@ public class JoinOrderJob extends Job { */ public void buildGraph(Group group, HyperGraph hyperGraph) { if (group.isProjectGroup()) { - Group childGroup = group.getLogicalExpression().child(0); - if (childGroup.isJoinGroup()) { - buildGraph(group.getLogicalExpression().child(0), hyperGraph); - processProjectPlan(hyperGraph, group); - } else { - hyperGraph.addNode(optimizePlan(group)); - } + buildGraph(group.getLogicalExpression().child(0), hyperGraph); + processProjectPlan(hyperGraph, group); return; } if (!group.isJoinGroup()) { @@ -136,7 +131,6 @@ public class JoinOrderJob extends Job { * 1. If it's a simple expression for column pruning, we just ignore it * 2. If it's an alias that may be used in the join operator, we need to add it to graph * 3. If it's other expression, we can ignore them and add it after optimizing - * 4. If it's a project only associate with one table, it's seen as an endNode just like a table */ private void processProjectPlan(HyperGraph hyperGraph, Group group) { LogicalProject<? extends Plan> logicalProject @@ -145,9 +139,7 @@ public class JoinOrderJob extends Job { for (NamedExpression expr : logicalProject.getProjects()) { if (expr.isAlias()) { - if (!hyperGraph.addAlias((Alias) expr, group)) { - break; - } + hyperGraph.addAlias((Alias) expr); } else if (!expr.isSlot()) { otherProject.add(expr); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 917938fc8e..2144f8bee6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.trees.expressions.Alias; @@ -49,7 +50,7 @@ public class HyperGraph { // Record the complex project expression for some subgraph // e.g. project (a + b) // |-- join(t1.a = t2.b) - private final HashMap<Long, NamedExpression> complexProject = new HashMap<>(); + private final HashMap<Long, List<NamedExpression>> complexProject = new HashMap<>(); public List<Edge> getEdges() { return edges; @@ -73,32 +74,29 @@ public class HyperGraph { /** * Store the relation between Alias Slot and Original Slot and its expression - * e.g. a = b - * |--- project((c + d) as b) - * Note if the alias only associated with one endNode, - * e.g. a = b - * |--- project((c + 1) as b) - * we need to replace the group of that node with this project group. + * e.g., + * a = b + * |--- project((c + d) as b) + * <p> + * a = b + * |--- project((c + 1) as b) * * @param alias The alias Expression in project Operator */ - public boolean addAlias(Alias alias, Group group) { + public boolean addAlias(Alias alias) { + Slot aliasSlot = alias.toSlot(); + if (slotToNodeMap.containsKey(aliasSlot)) { + return true; + } long bitmap = LongBitmap.newBitmap(); for (Slot slot : alias.getInputSlots()) { bitmap = LongBitmap.or(bitmap, slotToNodeMap.get(slot)); } - Slot aliasSlot = alias.toSlot(); - Preconditions.checkArgument(!slotToNodeMap.containsKey(aliasSlot)); slotToNodeMap.put(aliasSlot, bitmap); - if (LongBitmap.getCardinality(bitmap) == 1) { - // This means the alias only associate with one endNode - int index = LongBitmap.lowestOneIndex(bitmap); - nodeSet.remove(nodes.get(index).getGroup()); - nodeSet.add(group); - nodes.get(index).replaceGroupWith(group); - return false; + if (!complexProject.containsKey(bitmap)) { + complexProject.put(bitmap, new ArrayList<>()); } - complexProject.put(bitmap, alias); + complexProject.get(bitmap).add(alias); return true; } @@ -121,7 +119,7 @@ public class HyperGraph { return nodeSet.contains(group); } - public HashMap<Long, NamedExpression> getComplexProject() { + public HashMap<Long, List<NamedExpression>> getComplexProject() { return complexProject; } @@ -137,12 +135,9 @@ public class HyperGraph { LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), ImmutableList.of(expression), join.left(), join.right()); Edge edge = new Edge(singleJoin, edges.size()); - Preconditions.checkArgument(expression.children().size() == 2); - // TODO: use connected property to calculate edge - long left = calNodeMap(expression.child(0).getInputSlots()); - edge.setLeft(left); - long right = calNodeMap(expression.child(1).getInputSlots()); - edge.setRight(right); + Pair<Long, Long> ends = findEnds(expression); + edge.setLeft(ends.first); + edge.setRight(ends.second); for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { nodes.get(nodeIndex).attachEdge(edge); } @@ -152,6 +147,66 @@ public class HyperGraph { // We don't implement this trick now. } + private int findRoot(List<Integer> parent, int idx) { + int root = parent.get(idx); + if (root != idx) { + root = findRoot(parent, root); + } + parent.set(idx, root); + return root; + } + + private boolean isConnected(long bitmap, long excludeBitmap) { + if (LongBitmap.getCardinality(bitmap) == 1) { + return true; + } + + // use unionSet to check whether the bitmap is connected + List<Integer> parent = new ArrayList<>(); + for (int i = 0; i < nodes.size(); i++) { + parent.add(i, i); + } + for (Edge edge : edges) { + if (LongBitmap.isOverlap(edge.getLeft(), excludeBitmap) + || LongBitmap.isOverlap(edge.getRight(), excludeBitmap)) { + continue; + } + + int root = findRoot(parent, LongBitmap.nextSetBit(edge.getLeft(), 0)); + for (int idx : LongBitmap.getIterator(edge.getLeft())) { + parent.set(idx, root); + } + for (int idx : LongBitmap.getIterator(edge.getRight())) { + parent.set(idx, root); + } + } + + int root = findRoot(parent, LongBitmap.nextSetBit(bitmap, 0)); + for (int idx : LongBitmap.getIterator(bitmap)) { + if (root != findRoot(parent, idx)) { + return false; + } + } + return true; + } + + private Pair<Long, Long> findEnds(Expression expression) { + long bitmap = calNodeMap(expression.getInputSlots()); + int cardinality = LongBitmap.getCardinality(bitmap); + Preconditions.checkArgument(cardinality > 1); + for (long subset : LongBitmap.getSubsetIterator(bitmap)) { + long left = subset; + long right = LongBitmap.newBitmapDiff(bitmap, left); + // when the graph without right node has a connected-sub-graph contains left nodes + // and the graph without left node has a connected-sub-graph contains right nodes. + // we can generate an edge for this expression + if (isConnected(left, right) && isConnected(right, left)) { + return Pair.of(left, right); + } + } + throw new RuntimeException("DPhyper meets unconnected subgraph"); + } + private long calNodeMap(Set<Slot> slots) { Preconditions.checkArgument(slots.size() != 0); long bitmap = LongBitmap.newBitmap(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java index ee26e31629..7a3854431c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Node.java @@ -37,6 +37,10 @@ public class Node { this.index = index; } + public List<Edge> getEdges() { + return edges; + } + public void replaceGroupWith(Group group) { this.group = group; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java index 9670732216..342f1bd1d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java @@ -89,8 +89,8 @@ public class SubgraphEnumerator { LongBitmapSubsetIterator subsetIterator = LongBitmap.getSubsetIterator(neighborhood); for (long subset : subsetIterator) { long newCsg = LongBitmap.newBitmapUnion(csg, subset); + edgeCalculator.unionEdges(csg, subset); if (receiver.contain(newCsg)) { - edgeCalculator.unionEdges(csg, subset); if (!emitCsg(newCsg)) { return false; } @@ -113,8 +113,8 @@ public class SubgraphEnumerator { for (long subset : subsetIterator) { long newCmp = LongBitmap.newBitmapUnion(cmp, subset); // We need to check whether Cmp is connected and then try to find hyper edge + edgeCalculator.unionEdges(cmp, subset); if (receiver.contain(newCmp)) { - edgeCalculator.unionEdges(cmp, subset); // We check all edges for finding an edge. List<Edge> edges = edgeCalculator.connectCsgCmp(csg, newCmp); if (edges.isEmpty()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index 102bb17299..09ff637436 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -67,6 +67,7 @@ public class PlanReceiver implements AbstractReceiver { // limit define the max number of csg-cmp pair in this Receiver HashMap<Long, Group> planTable = new HashMap<>(); HashMap<Long, BitSet> usdEdges = new HashMap<>(); + HashMap<Long, List<NamedExpression>> projectsOnSubgraph = new HashMap<>(); int limit; int emitCount = 0; @@ -140,19 +141,24 @@ public class PlanReceiver implements AbstractReceiver { private Set<Slot> calculateRequiredSlots(long left, long right, List<Edge> edges) { Set<Slot> outputSlots = new HashSet<>(this.finalOutputs); - BitSet bitSet = new BitSet(); - bitSet.or(usdEdges.get(left)); - bitSet.or(usdEdges.get(right)); + BitSet usedEdgesBitmap = new BitSet(); + usedEdgesBitmap.or(usdEdges.get(left)); + usedEdgesBitmap.or(usdEdges.get(right)); for (Edge edge : edges) { - bitSet.set(edge.getIndex()); + usedEdgesBitmap.set(edge.getIndex()); } // required output slots = final outputs + slot of unused edges - usdEdges.put(LongBitmap.newBitmapUnion(left, right), bitSet); + usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap); for (Edge edge : hyperGraph.getEdges()) { - if (!bitSet.get(edge.getIndex())) { + if (!usedEdgesBitmap.get(edge.getIndex())) { outputSlots.addAll(edge.getExpression().getInputSlots()); } } + hyperGraph.getComplexProject() + .values() + .stream() + .flatMap(l -> l.stream()) + .forEach(expr -> outputSlots.addAll(expr.getInputSlots())); return outputSlots; } @@ -206,6 +212,13 @@ public class PlanReceiver implements AbstractReceiver { @Override public void addGroup(long bitmap, Group group) { + Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1); + usdEdges.put(bitmap, new BitSet()); + Plan plan = proposeProject(Lists.newArrayList(new GroupPlan(group)), new ArrayList<>(), bitmap, bitmap).get(0); + if (!(plan instanceof GroupPlan)) { + CopyInResult copyInResult = jobContext.getCascadesContext().getMemo().copyIn(plan, null, false); + group = copyInResult.correspondingExpression.getOwnerGroup(); + } planTable.put(bitmap, group); usdEdges.put(bitmap, new BitSet()); } @@ -275,40 +288,59 @@ public class PlanReceiver implements AbstractReceiver { } } - private List<Plan> proposeProject(List<Plan> allChild, List<Edge> usedEdges, long left, long right) { - List<Plan> res = new ArrayList<>(); - - // calculate required columns - Set<Slot> requireSlots = calculateRequiredSlots(left, right, usedEdges); + private List<Plan> proposeProject(List<Plan> allChild, List<Edge> edges, long left, long right) { + long fullKey = LongBitmap.newBitmapUnion(left, right); List<Slot> outputs = allChild.get(0).getOutput(); - List<NamedExpression> projects = outputs.stream().filter(e -> requireSlots.contains(e)).collect( - Collectors.toList()); + Set<Slot> outputSet = allChild.get(0).getOutputSet(); + if (!projectsOnSubgraph.containsKey(fullKey)) { + List<NamedExpression> projects = new ArrayList<>(); + // Calculate complex expression + Map<Long, List<NamedExpression>> complexExpressionMap = hyperGraph.getComplexProject(); + List<Long> bitmaps = complexExpressionMap.keySet().stream() + .filter(bitmap -> LongBitmap.isSubset(bitmap, fullKey)).collect(Collectors.toList()); + + for (long bitmap : bitmaps) { + projects.addAll(complexExpressionMap.get(bitmap)); + complexExpressionMap.remove(bitmap); + } - // Calculate complex expression - long fullKey = LongBitmap.newBitmapUnion(left, right); - Map<Long, NamedExpression> complexExpressionMap = hyperGraph.getComplexProject(); - List<Long> bitmaps = complexExpressionMap.keySet().stream() - .filter(bitmap -> LongBitmap.isSubset(bitmap, fullKey)).collect(Collectors.toList()); - - boolean addComplexProject = false; - for (long bitmap : bitmaps) { - projects.add(complexExpressionMap.get(bitmap)); - complexExpressionMap.remove(bitmap); - addComplexProject = true; - } + // calculate required columns + Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges); + outputs.stream() + .filter(e -> requireSlots.contains(e)) + .forEach(e -> projects.add(e)); - // propose physical project - if (projects.isEmpty()) { - projects.add(ExpressionUtils.selectMinimumColumn(outputs)); - } else if (projects.size() == outputs.size() && !addComplexProject) { + // propose physical project + if (projects.isEmpty()) { + projects.add(ExpressionUtils.selectMinimumColumn(outputs)); + } + projectsOnSubgraph.put(fullKey, projects); + } + List<NamedExpression> allProjects = projectsOnSubgraph.get(fullKey); + if (outputSet.equals(new HashSet<>(allProjects))) { return allChild; } - LogicalProperties projectProperties = new LogicalProperties( - () -> projects.stream().map(p -> p.toSlot()).collect(Collectors.toList())); - for (Plan child : allChild) { - res.add(new PhysicalProject<>(projects, projectProperties, child)); + while (true) { + Set<Slot> childOutputSet = allChild.get(0).getOutputSet(); + List<NamedExpression> projects = allProjects.stream() + .filter(expr -> + childOutputSet.containsAll(expr.getInputSlots()) || childOutputSet.contains(expr.toSlot())) + .collect(Collectors.toList()); + if (!outputSet.equals(new HashSet<>(projects))) { + LogicalProperties projectProperties = new LogicalProperties( + () -> projects.stream().map(p -> p.toSlot()).collect(Collectors.toList())); + allChild = allChild.stream() + .map(c -> new PhysicalProject<>(projects, projectProperties, c)) + .collect(Collectors.toList()); + } + if (projects.size() == 0) { + throw new RuntimeException("dphyer fail process project"); + } + if (projects.size() == allProjects.size()) { + break; + } } - return res; + return allChild; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/MergeProjectPostProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/MergeProjectPostProcessor.java index d5cbbf79ae..4a5ad49243 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/MergeProjectPostProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/MergeProjectPostProcessor.java @@ -22,6 +22,8 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; +import com.google.common.collect.Lists; + import java.util.List; /** @@ -32,11 +34,11 @@ public class MergeProjectPostProcessor extends PlanPostProcessor { @Override public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) { Plan child = project.child(); - child = child.accept(this, ctx); - if (child instanceof PhysicalProject) { - List<NamedExpression> projections = project.mergeProjections((PhysicalProject) child); - return project.withProjectionsAndChild(projections, child.child(0)); + Plan newChild = child.accept(this, ctx); + if (newChild instanceof PhysicalProject) { + List<NamedExpression> projections = project.mergeProjections((PhysicalProject) newChild); + return project.withProjectionsAndChild(projections, newChild.child(0)); } - return project; + return child != newChild ? project.withChildren(Lists.newArrayList(newChild)) : project; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java index 090f523b18..507d572d70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java @@ -58,6 +58,7 @@ public class PlanPostProcessors { public List<PlanPostProcessor> getProcessors() { // add processor if we need Builder<PlanPostProcessor> builder = ImmutableList.builder(); + builder.add(new MergeProjectPostProcessor()); if (cascadesContext.getConnectContext().getSessionVariable().isEnableNereidsRuntimeFilter() && !cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode() .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index dcb4054673..3085fcb1b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -552,7 +552,7 @@ public class SessionVariable implements Serializable, Writable { private boolean checkOverflowForDecimal = false; @VariableMgr.VarAttr(name = ENABLE_DPHYP_OPTIMIZER) - private boolean enableDPHypOptimizer = false; + private boolean enableDPHypOptimizer = true; /** * as the new optimizer is not mature yet, use this var * to control whether to use new optimizer, remove it when --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org