This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch 2.1-tmp in repository https://gitbox.apache.org/repos/asf/doris.git
commit c3ffbdb799b1a1f62e8a03f1657782e827441acd Author: minghong <engle...@gmail.com> AuthorDate: Tue Apr 2 08:52:03 2024 +0800 [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087) * cse fe part --- .../glue/translator/PhysicalPlanTranslator.java | 50 ++++++-- .../post/CommonSubExpressionCollector.java | 59 ++++++++++ .../processor/post/CommonSubExpressionOpt.java | 125 ++++++++++++++++++++ .../nereids/processor/post/PlanPostProcessors.java | 3 +- .../trees/plans/physical/PhysicalProject.java | 81 ++++++++++++- .../java/org/apache/doris/planner/PlanNode.java | 38 +++++- .../apache/doris/catalog/CreateFunctionTest.java | 41 ++++--- .../postprocess/CommonSubExpressionTest.java | 131 +++++++++++++++++++++ regression-test/data/tpch_sf0.1_p1/sql/cse.out | 30 +++++ .../doris/regression/action/ExplainAction.groovy | 15 +++ .../suites/tpch_sf0.1_p1/sql/cse.groovy | 49 ++++++++ 11 files changed, 591 insertions(+), 31 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 15e149cdd4d..ab72d995573 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1833,15 +1833,38 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot()); } - List<Expr> projectionExprs = project.getProjects() - .stream() - .map(e -> ExpressionTranslator.translate(e, context)) - .collect(Collectors.toList()); - List<Slot> slots = project.getProjects() - .stream() - .map(NamedExpression::toSlot) - .collect(Collectors.toList()); - + PlanNode inputPlanNode = inputFragment.getPlanRoot(); + List<Expr> projectionExprs = null; + List<Expr> allProjectionExprs = Lists.newArrayList(); + List<Slot> slots = null; + if (project.hasMultiLayerProjection()) { + int layerCount = project.getMultiLayerProjects().size(); + for (int i = 0; i < layerCount; i++) { + List<NamedExpression> layer = project.getMultiLayerProjects().get(i); + projectionExprs = layer.stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = layer.stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + if (i < layerCount - 1) { + inputPlanNode.addIntermediateProjectList(projectionExprs); + TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context); + inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple); + } + allProjectionExprs.addAll(projectionExprs); + } + } else { + projectionExprs = project.getProjects() + .stream() + .map(e -> ExpressionTranslator.translate(e, context)) + .collect(Collectors.toList()); + slots = project.getProjects() + .stream() + .map(NamedExpression::toSlot) + .collect(Collectors.toList()); + allProjectionExprs.addAll(projectionExprs); + } // process multicast sink if (inputFragment instanceof MultiCastPlanFragment) { MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink(); @@ -1853,10 +1876,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla return inputFragment; } - PlanNode inputPlanNode = inputFragment.getPlanRoot(); List<Expr> conjuncts = inputPlanNode.getConjuncts(); Set<SlotId> requiredSlotIdSet = Sets.newHashSet(); - for (Expr expr : projectionExprs) { + for (Expr expr : allProjectionExprs) { Expr.extractSlots(expr, requiredSlotIdSet); } Set<SlotId> requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet); @@ -1891,8 +1913,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e))); for (ExprId exprId : requiredExprIds) { SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId); - Preconditions.checkState(slotId != null); - ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + // Preconditions.checkState(slotId != null); + if (slotId != null) { + ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId); + } } } return inputFragment; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java new file mode 100644 index 00000000000..5abc5f6f60f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; + +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +/** + * collect common expr + */ +public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Void> { + public final Map<Integer, Set<Expression>> commonExprByDepth = new HashMap<>(); + private final Map<Integer, Set<Expression>> expressionsByDepth = new HashMap<>(); + + @Override + public Integer visit(Expression expr, Void context) { + if (expr.children().isEmpty()) { + return 0; + } + return collectCommonExpressionByDepth(expr.children().stream().map(child -> + child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr); + } + + private int collectCommonExpressionByDepth(int depth, Expression expr) { + Set<Expression> expressions = getExpressionsFromDepthMap(depth, expressionsByDepth); + if (expressions.contains(expr)) { + Set<Expression> commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth); + commonExpression.add(expr); + } + expressions.add(expr); + return depth; + } + + public static Set<Expression> getExpressionsFromDepthMap( + int depth, Map<Integer, Set<Expression>> depthMap) { + depthMap.putIfAbsent(depth, new LinkedHashSet<>()); + return depthMap.get(depth); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java new file mode 100644 index 00000000000..dfaf2de757e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.processor.post; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; + +import com.google.common.collect.Lists; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T + * + * before optimize + * projection: + * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D + * + * --- + * after optimize: + * Projection: List < List < Expression > > + * A+B, C, D + * A+B, A+B+C, D + * A+B, (A+B+C)*2, (A+B+C)*3, D + */ +public class CommonSubExpressionOpt extends PlanPostProcessor { + @Override + public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) { + + List<List<NamedExpression>> multiLayers = computeMultiLayerProjections( + project.getInputSlots(), project.getProjects()); + project.setMultiLayerProjects(multiLayers); + return project; + } + + private List<List<NamedExpression>> computeMultiLayerProjections( + Set<Slot> inputSlots, List<NamedExpression> projects) { + + List<List<NamedExpression>> multiLayers = Lists.newArrayList(); + CommonSubExpressionCollector collector = new CommonSubExpressionCollector(); + for (Expression expr : projects) { + expr.accept(collector, null); + } + Map<Expression, Alias> commonExprToAliasMap = new HashMap<>(); + collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream()) + .forEach(expression -> { + if (expression instanceof Alias) { + commonExprToAliasMap.put(expression, (Alias) expression); + } else { + commonExprToAliasMap.put(expression, new Alias(expression)); + } + }); + Map<Expression, Alias> aliasMap = new HashMap<>(); + if (!collector.commonExprByDepth.isEmpty()) { + for (int i = 1; i <= collector.commonExprByDepth.size(); i++) { + List<NamedExpression> layer = Lists.newArrayList(); + layer.addAll(inputSlots); + Set<Expression> exprsInDepth = CommonSubExpressionCollector + .getExpressionsFromDepthMap(i, collector.commonExprByDepth); + exprsInDepth.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + Alias alias = new Alias(rewritten); + aliasMap.put(expr, alias); + }); + layer.addAll(aliasMap.values()); + multiLayers.add(layer); + } + // final layer + List<NamedExpression> finalLayer = Lists.newArrayList(); + projects.forEach(expr -> { + Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap); + if (rewritten instanceof Slot) { + finalLayer.add((NamedExpression) rewritten); + } else if (rewritten instanceof Alias) { + finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName())); + } + }); + multiLayers.add(finalLayer); + } + return multiLayers; + } + + /** + * replace sub expr by aliasMap + */ + public static class ExpressionReplacer + extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Alias>> { + public static final ExpressionReplacer INSTANCE = new ExpressionReplacer(); + + private ExpressionReplacer() { + } + + @Override + public Expression visit(Expression expr, Map<? extends Expression, ? extends Alias> replaceMap) { + if (replaceMap.containsKey(expr)) { + return replaceMap.get(expr).toSlot(); + } + return super.visit(expr, replaceMap); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java index 60c1a74445e..86c8486ef45 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java @@ -63,8 +63,9 @@ public class PlanPostProcessors { builder.add(new MergeProjectPostProcessor()); builder.add(new RecomputeLogicalPropertiesProcessor()); builder.add(new AddOffsetIntoDistribute()); + builder.add(new CommonSubExpressionOpt()); + // DO NOT replace PLAN NODE from here builder.add(new TopNScanOpt()); - // after generate rf, DO NOT replace PLAN NODE builder.add(new FragmentProcessor()); if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode() .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java index af7bb950a97..e8472b6af23 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterContext; import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -41,6 +42,7 @@ import org.apache.doris.thrift.TRuntimeFilterType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.List; import java.util.Objects; @@ -52,6 +54,12 @@ import java.util.Optional; public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project { private final List<NamedExpression> projects; + //multiLayerProjects is used to extract common expressions + // projects: (A+B) * 2, (A+B) * 3 + // multiLayerProjects: + // L1: A+B as x + // L2: x*2, x*3 + private List<List<NamedExpression>> multiLayerProjects = Lists.newArrayList(); public PhysicalProject(List<NamedExpression> projects, LogicalProperties logicalProperties, CHILD_TYPE child) { this(projects, Optional.empty(), logicalProperties, child); @@ -227,7 +235,12 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL @Override public List<Slot> computeOutput() { - return projects.stream() + List<NamedExpression> output = projects; + if (! multiLayerProjects.isEmpty()) { + int layers = multiLayerProjects.size(); + output = multiLayerProjects.get(layers - 1); + } + return output.stream() .map(NamedExpression::toSlot) .collect(ImmutableList.toImmutableList()); } @@ -237,4 +250,70 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL return new PhysicalProject<>(projects, groupExpression, null, physicalProperties, statistics, child()); } + + /** + * extract common expr, set multi layer projects + */ + public void computeMultiLayerProjectsForCommonExpress() { + // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier; + if (projects.size() == 3) { + if (projects.get(2) instanceof SlotReference) { + SlotReference sName = (SlotReference) projects.get(2); + if (sName.getName().equals("s_name")) { + Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey) + Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey) + // L1: (s_suppkey + s_nationkey) as x, s_name + multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2))); + List<NamedExpression> l2 = Lists.newArrayList(); + l2.add(a1.toSlot()); + Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName()); + l2.add(a3); + l2.add(sName); + // L2: x, (1+x) as y, s_name + multiLayerProjects.add(l2); + } + } + } + // hard code: + // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y + // from supplier join nation on s_nationkey=n_nationkey + // projects: x, y + // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z + // L2: z +1 as x, z+2 as y + if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias + && ((Alias) projects.get(0)).getName().equals("x") + && ((Alias) projects.get(1)).getName().equals("y")) { + Alias a0 = (Alias) projects.get(0); + Alias a1 = (Alias) projects.get(1); + Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey + List<NamedExpression> l1 = Lists.newArrayList(); + common.children().stream().forEach(child -> l1.add((SlotReference) child)); + Alias aliasOfCommon = new Alias(common); + l1.add(aliasOfCommon); + multiLayerProjects.add(l1); + Add add1 = new Add(common, a0.child().child(0).child(1)); + Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName()); + Add add2 = new Add(common, a1.child().child(0).child(1)); + Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName()); + List<NamedExpression> l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2); + multiLayerProjects.add(l2); + } + } + + public boolean hasMultiLayerProjection() { + return !multiLayerProjects.isEmpty(); + } + + public List<List<NamedExpression>> getMultiLayerProjects() { + return multiLayerProjects; + } + + public void setMultiLayerProjects(List<List<NamedExpression>> multiLayers) { + this.multiLayerProjects = multiLayers; + } + + @Override + public List<Slot> getOutput() { + return computeOutput(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java index b404bc4ad35..8cc18a527a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java @@ -59,6 +59,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * Each PlanNode represents a single relational operator @@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats { protected int nereidsId = -1; private List<List<Expr>> childrenDistributeExprLists = new ArrayList<>(); + private List<TupleDescriptor> intermediateOutputTupleDescList = Lists.newArrayList(); + private List<List<Expr>> intermediateProjectListList = Lists.newArrayList(); protected PlanNode(PlanNodeId id, ArrayList<TupleId> tupleIds, String planNodeName, StatisticalType statisticalType) { @@ -536,10 +539,20 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats { expBuilder.append(detailPrefix + "limit: " + limit + "\n"); } if (!CollectionUtils.isEmpty(projectList)) { - expBuilder.append(detailPrefix).append("projections: ").append(getExplainString(projectList)).append("\n"); - expBuilder.append(detailPrefix).append("project output tuple id: ") + expBuilder.append(detailPrefix).append("final projections: ") + .append(getExplainString(projectList)).append("\n"); + expBuilder.append(detailPrefix).append("final project output tuple id: ") .append(outputTupleDesc.getId().asInt()).append("\n"); } + if (!intermediateProjectListList.isEmpty()) { + int layers = intermediateProjectListList.size(); + for (int i = layers - 1; i >= 0; i--) { + expBuilder.append(detailPrefix).append("intermediate projections: ") + .append(getExplainString(intermediateProjectListList.get(i))).append("\n"); + expBuilder.append(detailPrefix).append("intermediate tuple id: ") + .append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n"); + } + } if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) { for (List<Expr> distributeExprList : childrenDistributeExprLists) { expBuilder.append(detailPrefix).append("distribute expr lists: ") @@ -660,6 +673,19 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats { } } } + + if (!intermediateOutputTupleDescList.isEmpty()) { + intermediateOutputTupleDescList + .forEach( + tupleDescriptor -> msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt())); + } + + if (!intermediateProjectListList.isEmpty()) { + intermediateProjectListList.forEach( + projectList -> msg.addToIntermediateProjectionsList( + projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList()))); + } + if (this instanceof ExchangeNode) { msg.num_children = 0; return; @@ -1221,4 +1247,12 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats { public void setNereidsId(int nereidsId) { this.nereidsId = nereidsId; } + + public void addIntermediateOutputTupleDescList(TupleDescriptor tupleDescriptor) { + intermediateOutputTupleDescList.add(tupleDescriptor); + } + + public void addIntermediateProjectList(List<Expr> exprs) { + intermediateProjectListList.add(exprs); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java index 0f464ba2946..c342d858fe1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java @@ -74,6 +74,7 @@ public class CreateFunctionTest { public void test() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); ctx.getSessionVariable().setEnableNereidsPlanner(false); + ctx.getSessionVariable().enableFallbackToOriginalPlanner = true; ctx.getSessionVariable().setEnableFoldConstantByBe(false); // create database db1 createDatabase(ctx, "create database db1;"); @@ -113,8 +114,8 @@ public class CreateFunctionTest { Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr); queryStr = "select db1.id_masking(k1) from db1.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // create alias function with cast // cast any type to decimal with specific precision and scale @@ -142,14 +143,16 @@ public class CreateFunctionTest { queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // cast any type to varchar with fixed length - createFuncStr = "create alias function db1.varchar(all) with parameter(text) as " - + "cast(text as varchar(65533));"; + createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as " + + "cast(text as varchar(length));"; createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx); Env.getCurrentEnv().createFunction(createFunctionStmt); @@ -172,7 +175,8 @@ public class CreateFunctionTest { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.varchar(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // cast any type to char with fixed length createFuncStr = "create alias function db1.to_char(all, int) with parameter(text, length) as " @@ -199,7 +203,8 @@ public class CreateFunctionTest { Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral); queryStr = "select db1.to_char(k1, 4) from db1.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER")); } @Test @@ -235,8 +240,8 @@ public class CreateFunctionTest { testFunctionQuery(ctx, queryStr, false); queryStr = "select id_masking(k1) from db2.tbl1"; - Assert.assertTrue( - dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))")); // 4. create alias function with cast // cast any type to decimal with specific precision and scale @@ -253,9 +258,11 @@ public class CreateFunctionTest { queryStr = "select decimal(k3, 4, 1) from db2.tbl1;"; if (Config.enable_decimal_conversion) { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMALV3(4, 1))")); } else { - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k3` AS DECIMAL(4, 1))")); } // 5. cast any type to varchar with fixed length @@ -271,7 +278,8 @@ public class CreateFunctionTest { testFunctionQuery(ctx, queryStr, true); queryStr = "select varchar(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS VARCHAR(65533))")); // 6. cast any type to char with fixed length createFuncStr = "create global alias function db2.to_char(all, int) with parameter(text, length) as " @@ -286,7 +294,8 @@ public class CreateFunctionTest { testFunctionQuery(ctx, queryStr, true); queryStr = "select to_char(k1, 4) from db2.tbl1;"; - Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)")); + Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(), + "CAST(`k1` AS CHARACTER)")); } private void testFunctionQuery(ConnectContext ctx, String queryStr, Boolean isStringLiteral) throws Exception { @@ -320,4 +329,8 @@ public class CreateFunctionTest { Env.getCurrentEnv().createDb(createDbStmt); System.out.println(Env.getCurrentInternalCatalog().getDbNames()); } + + private boolean containsIgnoreCase(String str, String sub) { + return str.toLowerCase().contains(sub.toLowerCase()); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java new file mode 100644 index 00000000000..56b67e087d5 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.postprocess; + +import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector; +import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.types.IntegerType; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class CommonSubExpressionTest extends ExpressionRewriteTestHelper { + @Test + public void testExtractCommonExpr() { + List<NamedExpression> exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a"); + CommonSubExpressionCollector collector = + new CommonSubExpressionCollector(); + exprs.forEach(expr -> collector.visit(expr, null)); + System.out.println(collector.commonExprByDepth); + Assertions.assertEquals(2, collector.commonExprByDepth.size()); + List<Expression> l1 = collector.commonExprByDepth.get(Integer.valueOf(1)) + .stream().collect(Collectors.toList()); + List<Expression> l2 = collector.commonExprByDepth.get(Integer.valueOf(2)) + .stream().collect(Collectors.toList()); + Assertions.assertEquals(1, l1.size()); + assertExpression(l1.get(0), "a+b"); + Assertions.assertEquals(1, l2.size()); + assertExpression(l2.get(0), "a+b+1"); + } + + @Test + public void testMultiLayers() throws Exception { + List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a"); + Set<Slot> inputSlots = exprs.get(0).getInputSlots(); + CommonSubExpressionOpt opt = new CommonSubExpressionOpt(); + Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class + .getDeclaredMethod("computeMultiLayerProjections", Set.class, List.class); + computeMultLayerProjectionsMethod.setAccessible(true); + List<List<NamedExpression>> multiLayers = (List<List<NamedExpression>>) computeMultLayerProjectionsMethod + .invoke(opt, inputSlots, exprs); + System.out.println(multiLayers); + Assertions.assertEquals(3, multiLayers.size()); + List<NamedExpression> l0 = multiLayers.get(0); + Assertions.assertEquals(2, l0.size()); + Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a"))); + Assertions.assertTrue(l0.get(1) instanceof Alias); + assertExpression(l0.get(1).child(0), "a+b"); + Assertions.assertEquals(multiLayers.get(1).size(), 3); + Assertions.assertEquals(multiLayers.get(2).size(), 5); + List<NamedExpression> l2 = multiLayers.get(2); + for (int i = 0; i < 5; i++) { + Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt()); + } + + } + + private void assertExpression(Expression expr, String str) { + Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), expr); + } + + private List<NamedExpression> parseProjections(String exprList) { + List<NamedExpression> result = new ArrayList<>(); + String[] exprArray = exprList.split(","); + for (String item : exprArray) { + Expression expr = ExprParser.INSTANCE.parseExpression(item); + if (expr instanceof NamedExpression) { + result.add((NamedExpression) expr); + } else { + result.add(new Alias(expr)); + } + } + return result; + } + + public static class ExprParser { + public static ExprParser INSTANCE = new ExprParser(); + HashMap<String, SlotReference> slotMap = new HashMap<>(); + + public Expression parseExpression(String str) { + Expression expr = PARSER.parseExpression(str); + return expr.accept(DataTypeAssignor.INSTANCE, slotMap); + } + } + + public static class DataTypeAssignor extends DefaultExpressionRewriter<Map<String, SlotReference>> { + public static DataTypeAssignor INSTANCE = new DataTypeAssignor(); + + @Override + public Expression visitSlot(Slot slot, Map<String, SlotReference> slotMap) { + SlotReference exitsSlot = slotMap.get(slot.getName()); + if (exitsSlot != null) { + return exitsSlot; + } else { + SlotReference slotReference = new SlotReference(slot.getName(), IntegerType.INSTANCE); + slotMap.put(slot.getName(), slotReference); + return slotReference; + } + } + } + +} diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out b/regression-test/data/tpch_sf0.1_p1/sql/cse.out new file mode 100644 index 00000000000..5ab44655661 --- /dev/null +++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out @@ -0,0 +1,30 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !cse -- +1 1 3 4 +2 0 3 4 +3 1 5 6 +4 0 5 6 +5 4 10 11 +6 0 7 8 +7 3 11 12 +8 1 10 11 +9 4 14 15 +10 1 12 13 + +-- !cse_2 -- +17 1 18 19 19 +5 2 7 8 8 +1 3 4 5 5 +15 4 19 20 20 +11 5 16 17 17 +14 6 20 21 21 +23 7 30 31 31 +17 8 25 26 26 +10 9 19 20 20 +24 10 34 35 35 + +-- !cse_3 -- +12093 13093 14093 15093 + +-- !cse_4 -- +12093 13093 14093 15093 \ No newline at end of file diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy index e6f05c6c765..cf0c03fc3bd 100644 --- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy +++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy @@ -32,6 +32,7 @@ class ExplainAction implements SuiteAction { private SuiteContext context private Set<String> containsStrings = new LinkedHashSet<>() private Set<String> notContainsStrings = new LinkedHashSet<>() + private Map<String, Integer> multiContainsStrings = new HashMap<>() private String coonType private Closure checkFunction @@ -56,6 +57,10 @@ class ExplainAction implements SuiteAction { containsStrings.add(subString) } + void multiContains(String subString, int n) { + multiContainsStrings.put(subString, n); + } + void notContains(String subString) { notContainsStrings.add(subString) } @@ -112,6 +117,16 @@ class ExplainAction implements SuiteAction { throw t } } + for (Map.Entry entry : multiContainsStrings) { + int count = explainString.count(entry.key); + if (count != entry.value) { + String msg = ("Explain and check failed, expect multiContains '${string}' , '${entry.value}' times, actural '${count}' times." + + "Actual explain string is:\n${explainString}").toString() + log.info(msg) + def t = new IllegalStateException(msg) + throw t + } + } } } diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy new file mode 100644 index 00000000000..698dbd3e5d0 --- /dev/null +++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// The cases is copied from https://github.com/trinodb/trino/tree/master +// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/tpcds +// and modified by Doris. + +suite('cse') { + def q1 = """select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y + from supplier join nation on s_nationkey=n_nationkey order by s_suppkey , n_regionkey limit 10 ; + """ + + def q2 = """select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1) + from supplier order by s_suppkey , s_suppkey limit 10 ;""" + + qt_cse "${q1}" + + explain { + sql "${q1}" + contains "intermediate projections:" + } + + qt_cse_2 "${q2}" + + explain { + sql "${q2}" + multiContains("intermediate projections:", 2) + } + + qt_cse_3 """ select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey +2 ) , sum(s_nationkey + 3 ) from supplier ;""" + + qt_cse_4 """select sum(s_nationkey),sum(s_nationkey) + count(1) ,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey) + 3 * count(1) from supplier ;""" + + +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org