This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new f14c62b274 [enhance](Nereids): polish code. (#16309) f14c62b274 is described below commit f14c62b274579b31d96398fcc9602c845a0990bb Author: jakevin <jakevin...@gmail.com> AuthorDate: Wed Feb 1 19:41:10 2023 +0800 [enhance](Nereids): polish code. (#16309) --- .../doris/nereids/jobs/joinorder/JoinOrderJob.java | 2 +- .../jobs/joinorder/hypergraph/GraphSimplifier.java | 42 ++++--- .../rules/exploration/join/OuterJoinLAsscom.java | 2 +- .../joinorder/hypergraph/GraphSimplifierTest.java | 11 +- .../rules/rewrite/logical/MergeProjectsTest.java | 126 ++++++++++----------- 5 files changed, 87 insertions(+), 96 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java index 2ac21483a0..e5c9fa440e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/JoinOrderJob.java @@ -97,7 +97,7 @@ public class JoinOrderJob extends Job { // For other projects, such as project constant or project nullable, we construct a new project above root if (otherProject.size() != 0) { otherProject.addAll(optimized.getLogicalExpression().getPlan().getOutput()); - LogicalProject logicalProject = new LogicalProject(new ArrayList<>(otherProject), + LogicalProject logicalProject = new LogicalProject<>(new ArrayList<>(otherProject), optimized.getLogicalExpression().getPlan()); GroupExpression groupExpression = new GroupExpression(logicalProject, Lists.newArrayList(group)); optimized = context.getCascadesContext().getMemo().copyInGroupExpression(groupExpression); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index 80c89185f0..9f71e217cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -38,27 +38,31 @@ import java.util.Stack; import javax.annotation.Nullable; /** - * GraphSimplifier is used to simplify HyperGraph {@link HyperGraph} + * GraphSimplifier is used to simplify HyperGraph {@link HyperGraph}. + * <p> + * Related paper: + * - [Neu09] Neumann: “Query Simplification: Graceful Degradation for Join-Order Optimization”. + * - [Rad19] Radke and Neumann: “LinDP++: Generalizing Linearized DP to Crossproducts and Non-Inner Joins”. */ public class GraphSimplifier { // Note that each index in the graph simplifier is the half of the actual index private final int edgeSize; // Detect the circle when order join - private CircleDetector circleDetector; + private final CircleDetector circleDetector; // This is used for cache the intermediate results when calculate the benefit // Note that we store it for the after. Because if we put B after A (t1 Join_A t2 Join_B t3), // B is changed. Therefore, Any step that involves B need to be recalculated. - private List<BestSimplification> simplifications = new ArrayList<>(); - private PriorityQueue<BestSimplification> priorityQueue = new PriorityQueue<>(); + private final List<BestSimplification> simplifications = new ArrayList<>(); + private final PriorityQueue<BestSimplification> priorityQueue = new PriorityQueue<>(); // The graph we are simplifying - private HyperGraph graph; + private final HyperGraph graph; // It cached the plan in simplification. we don't store it in hyper graph, // because it's just used for simulating join. In fact, the graph simplifier // just generate the partial order of join operator. - private HashMap<Long, Plan> cachePlan = new HashMap<>(); + private final HashMap<Long, Plan> cachePlan = new HashMap<>(); - private Stack<SimplificationStep> appliedSteps = new Stack<SimplificationStep>(); - private Stack<SimplificationStep> unAppliedSteps = new Stack<SimplificationStep>(); + private final Stack<SimplificationStep> appliedSteps = new Stack<>(); + private final Stack<SimplificationStep> unAppliedSteps = new Stack<>(); /** * Create a graph simplifier @@ -76,12 +80,12 @@ public class GraphSimplifier { cachePlan.put(node.getNodeMap(), node.getPlan()); } circleDetector = new CircleDetector(edgeSize); + + // init first simplification step + initFirstStep(); } - /** - * This function init the first simplification step - */ - public void initFirstStep() { + private void initFirstStep() { extractJoinDependencies(); for (int i = 0; i < edgeSize; i += 1) { processNeighbors(i, i + 1, edgeSize); @@ -122,7 +126,6 @@ public class GraphSimplifier { */ public boolean simplifyGraph(int limit) { Preconditions.checkArgument(limit >= 1); - initFirstStep(); int lowerBound = 0; int upperBound = 1; @@ -265,7 +268,7 @@ public class GraphSimplifier { private boolean trySetSimplificationStep(SimplificationStep step, BestSimplification bestSimplification, int index, int neighborIndex) { - if (bestSimplification.bestNeighbor == -1 || bestSimplification.isInQueue == false + if (bestSimplification.bestNeighbor == -1 || !bestSimplification.isInQueue || bestSimplification.getBenefit() <= step.getBenefit()) { bestSimplification.bestNeighbor = neighborIndex; bestSimplification.setStep(step); @@ -441,13 +444,16 @@ public class GraphSimplifier { inputs.add(PhysicalProperties.ANY); groupExpression.updateLowestCostTable(PhysicalProperties.ANY, inputs, cost); - return (LogicalJoin) newJoin.withGroupExpression(Optional.of(groupExpression)); + return newJoin.withGroupExpression(Optional.of(groupExpression)); } + /** + * Put join dependencies into circle detector. + */ private void extractJoinDependencies() { for (int i = 0; i < edgeSize; i++) { + Edge edge1 = graph.getEdge(i); for (int j = i + 1; j < edgeSize; j++) { - Edge edge1 = graph.getEdge(i); Edge edge2 = graph.getEdge(j); if (edge1.isSub(edge2)) { Preconditions.checkArgument(circleDetector.tryAddDirectedEdge(i, j), @@ -460,7 +466,7 @@ public class GraphSimplifier { } } - class SimplificationStep { + static class SimplificationStep { double benefit; int beforeIndex; int afterIndex; @@ -496,7 +502,7 @@ public class GraphSimplifier { } } - class BestSimplification implements Comparable<BestSimplification> { + static class BestSimplification implements Comparable<BestSimplification> { int bestNeighbor = -1; Optional<SimplificationStep> step = Optional.empty(); // This data whether to be added to the queue diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java index ff9f6b1e64..b1124965e8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java @@ -88,7 +88,7 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory { } /** - * topHashConjunct possiblity: (A B) (A C) (B C) (A B C). + * topHashConjunct possibility: (A B) (A C) (B C) (A B C). * (A B) is forbidden, because it should be in bottom join. * (B C) (A B C) check failed, because it contains B. * So, just allow: top (A C), bottom (A B), we can exchange HashConjunct directly. diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java index 3b616a7569..3cf0b5f895 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java @@ -40,7 +40,6 @@ public class GraphSimplifierTest { .addEdge(JoinType.INNER_JOIN, 0, 4) .build(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } Counter counter = new Counter(); @@ -68,7 +67,6 @@ public class GraphSimplifierTest { .addEdge(JoinType.INNER_JOIN, 2, 3) .build(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } Counter counter = new Counter(); @@ -97,7 +95,6 @@ public class GraphSimplifierTest { .addEdge(JoinType.INNER_JOIN, 2, 3) .build(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } Counter counter = new Counter(); @@ -131,7 +128,6 @@ public class GraphSimplifierTest { .addEdge(JoinType.INNER_JOIN, 0, 11) .build(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } Counter counter = new Counter(); @@ -153,13 +149,12 @@ public class GraphSimplifierTest { HyperGraph hyperGraph = new HyperGraphBuilder().randomBuildWith(tableNum, edgeNum); double now = System.currentTimeMillis(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } totalTime += System.currentTimeMillis() - now; } - System.out.println(String.format("Simplify graph with %d nodes %d edges cost %f ms", tableNum, edgeNum, - totalTime / times)); + System.out.printf("Simplify graph with %d nodes %d edges cost %f ms%n", tableNum, edgeNum, + totalTime / times); } @Test @@ -176,7 +171,6 @@ public class GraphSimplifierTest { .addEdge(JoinType.INNER_JOIN, 0, 2) .build(); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } Counter counter = new Counter(); @@ -193,7 +187,6 @@ public class GraphSimplifierTest { for (int i = 0; i < 10; i++) { HyperGraph hyperGraph = new HyperGraphBuilder().randomBuildWith(6, 6); GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); - graphSimplifier.initFirstStep(); while (graphSimplifier.applySimplificationStep()) { } Counter counter = new Counter(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java index b515beb0a6..024890753f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/MergeProjectsTest.java @@ -17,97 +17,89 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.analyzer.UnboundRelation; -import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.RelationUtil; -import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; +import java.util.Objects; /** * MergeConsecutiveProjects ut */ -public class MergeProjectsTest { +public class MergeProjectsTest implements PatternMatchSupported { + LogicalOlapScan score = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.score); + @Test public void testMergeConsecutiveProjects() { - UnboundRelation relation = new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table")); - NamedExpression colA = new SlotReference("a", IntegerType.INSTANCE, true, Lists.newArrayList("a")); - NamedExpression colB = new SlotReference("b", IntegerType.INSTANCE, true, Lists.newArrayList("b")); - NamedExpression colC = new SlotReference("c", IntegerType.INSTANCE, true, Lists.newArrayList("c")); - LogicalProject project1 = new LogicalProject<>(Lists.newArrayList(colA, colB, colC), relation); - LogicalProject project2 = new LogicalProject<>(Lists.newArrayList(colA, colB), project1); - LogicalProject project3 = new LogicalProject<>(Lists.newArrayList(colA), project2); - - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project3); - List<Rule> rules = Lists.newArrayList(new MergeProjects().build()); - cascadesContext.bottomUpRewrite(rules); - Plan plan = cascadesContext.getMemo().copyOut(); - System.out.println(plan.treeString()); - Assertions.assertTrue(plan instanceof LogicalProject); - Assertions.assertEquals(((LogicalProject<?>) plan).getProjects(), Lists.newArrayList(colA)); - Assertions.assertTrue(plan.child(0) instanceof UnboundRelation); + LogicalPlan plan = new LogicalPlanBuilder(score) + .project(ImmutableList.of(0, 1, 2)) + .project(ImmutableList.of(0, 1)) + .project(ImmutableList.of(0)) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new MergeProjects()) + .matches( + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 1) + ); } /** - * project2(X + 2) - * | - * project1(B, C, A+1 as X) - * | - * relation + * project2(X + 2) -> project1(B, C, A+1 as X) * transform to : - * project2((A + 1) + 2) - * | - * relation + * project2((A + 1) + 2) */ @Test public void testMergeConsecutiveProjectsWithAlias() { - UnboundRelation relation = new UnboundRelation(RelationUtil.newRelationId(), Lists.newArrayList("db", "table")); - NamedExpression colA = new SlotReference("a", IntegerType.INSTANCE, true, Lists.newArrayList("a")); - NamedExpression colB = new SlotReference("b", IntegerType.INSTANCE, true, Lists.newArrayList("b")); - NamedExpression colC = new SlotReference("c", IntegerType.INSTANCE, true, Lists.newArrayList("c")); - Alias alias = new Alias(new Add(colA, new IntegerLiteral(1)), "X"); - Slot aliasRef = alias.toSlot(); + Alias alias = new Alias(new Add(score.getOutput().get(0), Literal.of(1)), "X"); + LogicalProject<LogicalOlapScan> bottomProject = new LogicalProject<>( + Lists.newArrayList(score.getOutput().get(1), score.getOutput().get(2), alias), + score); - LogicalProject project1 = new LogicalProject<>( - Lists.newArrayList( - colB, - colC, - alias), - relation); - LogicalProject project2 = new LogicalProject<>( + LogicalProject<LogicalProject<LogicalOlapScan>> topProject = new LogicalProject<>( Lists.newArrayList( - new Alias(new Add(aliasRef, new IntegerLiteral(2)), "Y"), - aliasRef - ), - project1); + new Alias(new Add(bottomProject.getOutput().get(2), Literal.of(2)), "Y")), + bottomProject); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(project2); - List<Rule> rules = Lists.newArrayList(new MergeProjects().build()); - cascadesContext.bottomUpRewrite(rules); - Plan plan = cascadesContext.getMemo().copyOut(); - System.out.println(plan.treeString()); - Assertions.assertTrue(plan instanceof LogicalProject); - LogicalProject finalProject = (LogicalProject) plan; - Add aPlus1Plus2 = new Add( - new Add(colA, new IntegerLiteral(1)), - new IntegerLiteral(2) - ); - Assertions.assertEquals(2, finalProject.getProjects().size()); - Assertions.assertEquals(aPlus1Plus2, ((Alias) finalProject.getProjects().get(0)).child()); - Assertions.assertEquals(alias, finalProject.getProjects().get(1)); + PlanChecker.from(MemoTestUtils.createConnectContext(), topProject) + .applyTopDown(new MergeProjects()) + .matches( + logicalProject( + logicalOlapScan() + ).when(project -> Objects.equals(project.getProjects().toString(), + "[((sid#0 + 1) + 2) AS `Y`#4]")) + ); + } + + @Test + void testAlias() { + // project(a+1 as b) -> project(b+1 as c) + LogicalProject<LogicalOlapScan> bottomProject = new LogicalProject<>( + ImmutableList.of(new Alias(new Add(score.getOutput().get(0), Literal.of(1)), "b")), score); + LogicalProject<LogicalProject<LogicalOlapScan>> topProject = new LogicalProject<>( + ImmutableList.of(new Alias(new Add(bottomProject.getOutput().get(0), Literal.of(1)), "b")), + bottomProject); + PlanChecker.from(MemoTestUtils.createConnectContext(), topProject) + .applyTopDown(new MergeProjects()) + .matches( + logicalProject( + logicalOlapScan() + ).when(project -> Objects.equals(project.getProjects().toString(), + "[((sid#0 + 1) + 1) AS `b`#4]")) + ); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org