This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch 2.1-tmp
in repository https://gitbox.apache.org/repos/asf/doris.git

commit c3ffbdb799b1a1f62e8a03f1657782e827441acd
Author: minghong <engle...@gmail.com>
AuthorDate: Tue Apr 2 08:52:03 2024 +0800

    [feature](nereids) support common sub expression by multi-layer projections 
(fe part) (#33087)
    
    * cse fe part
---
 .../glue/translator/PhysicalPlanTranslator.java    |  50 ++++++--
 .../post/CommonSubExpressionCollector.java         |  59 ++++++++++
 .../processor/post/CommonSubExpressionOpt.java     | 125 ++++++++++++++++++++
 .../nereids/processor/post/PlanPostProcessors.java |   3 +-
 .../trees/plans/physical/PhysicalProject.java      |  81 ++++++++++++-
 .../java/org/apache/doris/planner/PlanNode.java    |  38 +++++-
 .../apache/doris/catalog/CreateFunctionTest.java   |  41 ++++---
 .../postprocess/CommonSubExpressionTest.java       | 131 +++++++++++++++++++++
 regression-test/data/tpch_sf0.1_p1/sql/cse.out     |  30 +++++
 .../doris/regression/action/ExplainAction.groovy   |  15 +++
 .../suites/tpch_sf0.1_p1/sql/cse.groovy            |  49 ++++++++
 11 files changed, 591 insertions(+), 31 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 15e149cdd4d..ab72d995573 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -1833,15 +1833,38 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             registerRewrittenSlot(project, (OlapScanNode) 
inputFragment.getPlanRoot());
         }
 
-        List<Expr> projectionExprs = project.getProjects()
-                .stream()
-                .map(e -> ExpressionTranslator.translate(e, context))
-                .collect(Collectors.toList());
-        List<Slot> slots = project.getProjects()
-                .stream()
-                .map(NamedExpression::toSlot)
-                .collect(Collectors.toList());
-
+        PlanNode inputPlanNode = inputFragment.getPlanRoot();
+        List<Expr> projectionExprs = null;
+        List<Expr> allProjectionExprs = Lists.newArrayList();
+        List<Slot> slots = null;
+        if (project.hasMultiLayerProjection()) {
+            int layerCount = project.getMultiLayerProjects().size();
+            for (int i = 0; i < layerCount; i++) {
+                List<NamedExpression> layer = 
project.getMultiLayerProjects().get(i);
+                projectionExprs = layer.stream()
+                        .map(e -> ExpressionTranslator.translate(e, context))
+                        .collect(Collectors.toList());
+                slots = layer.stream()
+                        .map(NamedExpression::toSlot)
+                        .collect(Collectors.toList());
+                if (i < layerCount - 1) {
+                    inputPlanNode.addIntermediateProjectList(projectionExprs);
+                    TupleDescriptor projectionTuple = generateTupleDesc(slots, 
null, context);
+                    
inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple);
+                }
+                allProjectionExprs.addAll(projectionExprs);
+            }
+        } else {
+            projectionExprs = project.getProjects()
+                    .stream()
+                    .map(e -> ExpressionTranslator.translate(e, context))
+                    .collect(Collectors.toList());
+            slots = project.getProjects()
+                    .stream()
+                    .map(NamedExpression::toSlot)
+                    .collect(Collectors.toList());
+            allProjectionExprs.addAll(projectionExprs);
+        }
         // process multicast sink
         if (inputFragment instanceof MultiCastPlanFragment) {
             MultiCastDataSink multiCastDataSink = (MultiCastDataSink) 
inputFragment.getSink();
@@ -1853,10 +1876,9 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             return inputFragment;
         }
 
-        PlanNode inputPlanNode = inputFragment.getPlanRoot();
         List<Expr> conjuncts = inputPlanNode.getConjuncts();
         Set<SlotId> requiredSlotIdSet = Sets.newHashSet();
-        for (Expr expr : projectionExprs) {
+        for (Expr expr : allProjectionExprs) {
             Expr.extractSlots(expr, requiredSlotIdSet);
         }
         Set<SlotId> requiredByProjectSlotIdSet = 
Sets.newHashSet(requiredSlotIdSet);
@@ -1891,8 +1913,10 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 requiredSlotIdSet.forEach(e -> 
requiredExprIds.add(context.findExprId(e)));
                 for (ExprId exprId : requiredExprIds) {
                     SlotId slotId = ((HashJoinNode) 
joinNode).getHashOutputExprSlotIdMap().get(exprId);
-                    Preconditions.checkState(slotId != null);
-                    ((HashJoinNode) 
joinNode).addSlotIdToHashOutputSlotIds(slotId);
+                    // Preconditions.checkState(slotId != null);
+                    if (slotId != null) {
+                        ((HashJoinNode) 
joinNode).addSlotIdToHashOutputSlotIds(slotId);
+                    }
                 }
             }
             return inputFragment;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
new file mode 100644
index 00000000000..5abc5f6f60f
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * collect common expr
+ */
+public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, 
Void> {
+    public final Map<Integer, Set<Expression>> commonExprByDepth = new 
HashMap<>();
+    private final Map<Integer, Set<Expression>> expressionsByDepth = new 
HashMap<>();
+
+    @Override
+    public Integer visit(Expression expr, Void context) {
+        if (expr.children().isEmpty()) {
+            return 0;
+        }
+        return 
collectCommonExpressionByDepth(expr.children().stream().map(child ->
+                child.accept(this, context)).reduce(Math::max).map(m -> m + 
1).orElse(1), expr);
+    }
+
+    private int collectCommonExpressionByDepth(int depth, Expression expr) {
+        Set<Expression> expressions = getExpressionsFromDepthMap(depth, 
expressionsByDepth);
+        if (expressions.contains(expr)) {
+            Set<Expression> commonExpression = 
getExpressionsFromDepthMap(depth, commonExprByDepth);
+            commonExpression.add(expr);
+        }
+        expressions.add(expr);
+        return depth;
+    }
+
+    public static Set<Expression> getExpressionsFromDepthMap(
+            int depth, Map<Integer, Set<Expression>> depthMap) {
+        depthMap.putIfAbsent(depth, new LinkedHashSet<>());
+        return depthMap.get(depth);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
new file mode 100644
index 00000000000..dfaf2de757e
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T
+ *
+ * before optimize
+ * projection:
+ * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D
+ *
+ * ---
+ * after optimize:
+ * Projection: List < List < Expression > >
+ * A+B, C, D
+ * A+B, A+B+C, D
+ * A+B, (A+B+C)*2, (A+B+C)*3, D
+ */
+public class CommonSubExpressionOpt extends PlanPostProcessor {
+    @Override
+    public PhysicalProject visitPhysicalProject(PhysicalProject<? extends 
Plan> project, CascadesContext ctx) {
+
+        List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
+                project.getInputSlots(), project.getProjects());
+        project.setMultiLayerProjects(multiLayers);
+        return project;
+    }
+
+    private List<List<NamedExpression>> computeMultiLayerProjections(
+            Set<Slot> inputSlots, List<NamedExpression> projects) {
+
+        List<List<NamedExpression>> multiLayers = Lists.newArrayList();
+        CommonSubExpressionCollector collector = new 
CommonSubExpressionCollector();
+        for (Expression expr : projects) {
+            expr.accept(collector, null);
+        }
+        Map<Expression, Alias> commonExprToAliasMap = new HashMap<>();
+        collector.commonExprByDepth.values().stream().flatMap(expressions -> 
expressions.stream())
+                .forEach(expression -> {
+                    if (expression instanceof Alias) {
+                        commonExprToAliasMap.put(expression, (Alias) 
expression);
+                    } else {
+                        commonExprToAliasMap.put(expression, new 
Alias(expression));
+                    }
+                });
+        Map<Expression, Alias> aliasMap = new HashMap<>();
+        if (!collector.commonExprByDepth.isEmpty()) {
+            for (int i = 1; i <= collector.commonExprByDepth.size(); i++) {
+                List<NamedExpression> layer = Lists.newArrayList();
+                layer.addAll(inputSlots);
+                Set<Expression> exprsInDepth = CommonSubExpressionCollector
+                        .getExpressionsFromDepthMap(i, 
collector.commonExprByDepth);
+                exprsInDepth.forEach(expr -> {
+                    Expression rewritten = 
expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
+                    Alias alias = new Alias(rewritten);
+                    aliasMap.put(expr, alias);
+                });
+                layer.addAll(aliasMap.values());
+                multiLayers.add(layer);
+            }
+            // final layer
+            List<NamedExpression> finalLayer = Lists.newArrayList();
+            projects.forEach(expr -> {
+                Expression rewritten = 
expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
+                if (rewritten instanceof Slot) {
+                    finalLayer.add((NamedExpression) rewritten);
+                } else if (rewritten instanceof Alias) {
+                    finalLayer.add(new Alias(expr.getExprId(), ((Alias) 
rewritten).child(), expr.getName()));
+                }
+            });
+            multiLayers.add(finalLayer);
+        }
+        return multiLayers;
+    }
+
+    /**
+     * replace sub expr by aliasMap
+     */
+    public static class ExpressionReplacer
+            extends DefaultExpressionRewriter<Map<? extends Expression, ? 
extends Alias>> {
+        public static final ExpressionReplacer INSTANCE = new 
ExpressionReplacer();
+
+        private ExpressionReplacer() {
+        }
+
+        @Override
+        public Expression visit(Expression expr, Map<? extends Expression, ? 
extends Alias> replaceMap) {
+            if (replaceMap.containsKey(expr)) {
+                return replaceMap.get(expr).toSlot();
+            }
+            return super.visit(expr, replaceMap);
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
index 60c1a74445e..86c8486ef45 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
@@ -63,8 +63,9 @@ public class PlanPostProcessors {
         builder.add(new MergeProjectPostProcessor());
         builder.add(new RecomputeLogicalPropertiesProcessor());
         builder.add(new AddOffsetIntoDistribute());
+        builder.add(new CommonSubExpressionOpt());
+        // DO NOT replace PLAN NODE from here
         builder.add(new TopNScanOpt());
-        // after generate rf, DO NOT replace PLAN NODE
         builder.add(new FragmentProcessor());
         if 
(!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
                         .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
index af7bb950a97..e8472b6af23 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
@@ -24,6 +24,7 @@ import 
org.apache.doris.nereids.processor.post.RuntimeFilterContext;
 import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -41,6 +42,7 @@ import org.apache.doris.thrift.TRuntimeFilterType;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 
 import java.util.List;
 import java.util.Objects;
@@ -52,6 +54,12 @@ import java.util.Optional;
 public class PhysicalProject<CHILD_TYPE extends Plan> extends 
PhysicalUnary<CHILD_TYPE> implements Project {
 
     private final List<NamedExpression> projects;
+    //multiLayerProjects is used to extract common expressions
+    // projects: (A+B) * 2, (A+B) * 3
+    // multiLayerProjects:
+    //            L1: A+B as x
+    //            L2: x*2, x*3
+    private List<List<NamedExpression>> multiLayerProjects = 
Lists.newArrayList();
 
     public PhysicalProject(List<NamedExpression> projects, LogicalProperties 
logicalProperties, CHILD_TYPE child) {
         this(projects, Optional.empty(), logicalProperties, child);
@@ -227,7 +235,12 @@ public class PhysicalProject<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CHIL
 
     @Override
     public List<Slot> computeOutput() {
-        return projects.stream()
+        List<NamedExpression> output = projects;
+        if (! multiLayerProjects.isEmpty()) {
+            int layers = multiLayerProjects.size();
+            output = multiLayerProjects.get(layers - 1);
+        }
+        return output.stream()
                 .map(NamedExpression::toSlot)
                 .collect(ImmutableList.toImmutableList());
     }
@@ -237,4 +250,70 @@ public class PhysicalProject<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CHIL
         return new PhysicalProject<>(projects, groupExpression, null, 
physicalProperties,
                 statistics, child());
     }
+
+    /**
+     * extract common expr, set multi layer projects
+     */
+    public void computeMultiLayerProjectsForCommonExpress() {
+        // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + 
s_nationkey), s_name from supplier;
+        if (projects.size() == 3) {
+            if (projects.get(2) instanceof SlotReference) {
+                SlotReference sName = (SlotReference) projects.get(2);
+                if (sName.getName().equals("s_name")) {
+                    Alias a1 = (Alias) projects.get(0); // (s_suppkey + 
s_nationkey)
+                    Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + 
s_nationkey)
+                    // L1: (s_suppkey + s_nationkey) as x, s_name
+                    multiLayerProjects.add(Lists.newArrayList(projects.get(0), 
projects.get(2)));
+                    List<NamedExpression> l2 = Lists.newArrayList();
+                    l2.add(a1.toSlot());
+                    Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), 
a2.child().child(1)), a2.getName());
+                    l2.add(a3);
+                    l2.add(sName);
+                    // L2: x, (1+x) as y, s_name
+                    multiLayerProjects.add(l2);
+                }
+            }
+        }
+        // hard code:
+        // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + 
n_regionkey) + 2 as y
+        // from supplier join nation on s_nationkey=n_nationkey
+        // projects: x, y
+        // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z
+        //       L2: z +1 as x, z+2 as y
+        if (projects.size() == 2 && projects.get(0) instanceof Alias && 
projects.get(1) instanceof Alias
+                && ((Alias) projects.get(0)).getName().equals("x")
+                && ((Alias) projects.get(1)).getName().equals("y")) {
+            Alias a0 = (Alias) projects.get(0);
+            Alias a1 = (Alias) projects.get(1);
+            Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey
+            List<NamedExpression> l1 = Lists.newArrayList();
+            common.children().stream().forEach(child -> l1.add((SlotReference) 
child));
+            Alias aliasOfCommon = new Alias(common);
+            l1.add(aliasOfCommon);
+            multiLayerProjects.add(l1);
+            Add add1 = new Add(common, a0.child().child(0).child(1));
+            Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName());
+            Add add2 = new Add(common, a1.child().child(0).child(1));
+            Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName());
+            List<NamedExpression> l2 = Lists.newArrayList(aliasOfAdd1, 
aliasOfAdd2);
+            multiLayerProjects.add(l2);
+        }
+    }
+
+    public boolean hasMultiLayerProjection() {
+        return !multiLayerProjects.isEmpty();
+    }
+
+    public List<List<NamedExpression>> getMultiLayerProjects() {
+        return multiLayerProjects;
+    }
+
+    public void setMultiLayerProjects(List<List<NamedExpression>> multiLayers) 
{
+        this.multiLayerProjects = multiLayers;
+    }
+
+    @Override
+    public List<Slot> getOutput() {
+        return computeOutput();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
index b404bc4ad35..8cc18a527a8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
@@ -59,6 +59,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * Each PlanNode represents a single relational operator
@@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
     protected int nereidsId = -1;
 
     private List<List<Expr>> childrenDistributeExprLists = new ArrayList<>();
+    private List<TupleDescriptor> intermediateOutputTupleDescList = 
Lists.newArrayList();
+    private List<List<Expr>> intermediateProjectListList = 
Lists.newArrayList();
 
     protected PlanNode(PlanNodeId id, ArrayList<TupleId> tupleIds, String 
planNodeName,
             StatisticalType statisticalType) {
@@ -536,10 +539,20 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
             expBuilder.append(detailPrefix + "limit: " + limit + "\n");
         }
         if (!CollectionUtils.isEmpty(projectList)) {
-            expBuilder.append(detailPrefix).append("projections: 
").append(getExplainString(projectList)).append("\n");
-            expBuilder.append(detailPrefix).append("project output tuple id: ")
+            expBuilder.append(detailPrefix).append("final projections: ")
+                .append(getExplainString(projectList)).append("\n");
+            expBuilder.append(detailPrefix).append("final project output tuple 
id: ")
                     .append(outputTupleDesc.getId().asInt()).append("\n");
         }
+        if (!intermediateProjectListList.isEmpty()) {
+            int layers = intermediateProjectListList.size();
+            for (int i = layers - 1; i >= 0; i--) {
+                expBuilder.append(detailPrefix).append("intermediate 
projections: ")
+                        
.append(getExplainString(intermediateProjectListList.get(i))).append("\n");
+                expBuilder.append(detailPrefix).append("intermediate tuple id: 
")
+                        
.append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n");
+            }
+        }
         if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) {
             for (List<Expr> distributeExprList : childrenDistributeExprLists) {
                 expBuilder.append(detailPrefix).append("distribute expr lists: 
")
@@ -660,6 +673,19 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
                 }
             }
         }
+
+        if (!intermediateOutputTupleDescList.isEmpty()) {
+            intermediateOutputTupleDescList
+                    .forEach(
+                            tupleDescriptor -> 
msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt()));
+        }
+
+        if (!intermediateProjectListList.isEmpty()) {
+            intermediateProjectListList.forEach(
+                    projectList -> msg.addToIntermediateProjectionsList(
+                            projectList.stream().map(expr -> 
expr.treeToThrift()).collect(Collectors.toList())));
+        }
+
         if (this instanceof ExchangeNode) {
             msg.num_children = 0;
             return;
@@ -1221,4 +1247,12 @@ public abstract class PlanNode extends 
TreeNode<PlanNode> implements PlanStats {
     public void setNereidsId(int nereidsId) {
         this.nereidsId = nereidsId;
     }
+
+    public void addIntermediateOutputTupleDescList(TupleDescriptor 
tupleDescriptor) {
+        intermediateOutputTupleDescList.add(tupleDescriptor);
+    }
+
+    public void addIntermediateProjectList(List<Expr> exprs) {
+        intermediateProjectListList.add(exprs);
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
index 0f464ba2946..c342d858fe1 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
@@ -74,6 +74,7 @@ public class CreateFunctionTest {
     public void test() throws Exception {
         ConnectContext ctx = UtFrameUtils.createDefaultCtx();
         ctx.getSessionVariable().setEnableNereidsPlanner(false);
+        ctx.getSessionVariable().enableFallbackToOriginalPlanner = true;
         ctx.getSessionVariable().setEnableFoldConstantByBe(false);
         // create database db1
         createDatabase(ctx, "create database db1;");
@@ -113,8 +114,8 @@ public class CreateFunctionTest {
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
FunctionCallExpr);
 
         queryStr = "select db1.id_masking(k1) from db1.tbl1";
-        Assert.assertTrue(
-                
dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 
AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS 
VARCHAR(65533)), 4))"));
+        
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), 
'****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
 
         // create alias function with cast
         // cast any type to decimal with specific precision and scale
@@ -142,14 +143,16 @@ public class CreateFunctionTest {
 
         queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
         if (Config.enable_decimal_conversion) {
-            
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3`
 AS DECIMALV3(4, 1))"));
+            
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMALV3(4, 1))"));
         } else {
-            
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3`
 AS DECIMAL(4, 1))"));
+            
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMAL(4, 1))"));
         }
 
         // cast any type to varchar with fixed length
-        createFuncStr = "create alias function db1.varchar(all) with 
parameter(text) as "
-                + "cast(text as varchar(65533));";
+        createFuncStr = "create alias function db1.varchar(all, int) with 
parameter(text, length) as "
+                + "cast(text as varchar(length));";
         createFunctionStmt = (CreateFunctionStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
         Env.getCurrentEnv().createFunction(createFunctionStmt);
 
@@ -172,7 +175,8 @@ public class CreateFunctionTest {
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
StringLiteral);
 
         queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
-        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
 AS VARCHAR(65533))"));
+        
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS VARCHAR(65533))"));
 
         // cast any type to char with fixed length
         createFuncStr = "create alias function db1.to_char(all, int) with 
parameter(text, length) as "
@@ -199,7 +203,8 @@ public class CreateFunctionTest {
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof 
StringLiteral);
 
         queryStr = "select db1.to_char(k1, 4) from db1.tbl1;";
-        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
 AS CHARACTER"));
+        
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS CHARACTER"));
     }
 
     @Test
@@ -235,8 +240,8 @@ public class CreateFunctionTest {
         testFunctionQuery(ctx, queryStr, false);
 
         queryStr = "select id_masking(k1) from db2.tbl1";
-        Assert.assertTrue(
-                
dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 
AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS 
VARCHAR(65533)), 4))"));
+        
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), 
'****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
 
         // 4. create alias function with cast
         // cast any type to decimal with specific precision and scale
@@ -253,9 +258,11 @@ public class CreateFunctionTest {
 
         queryStr = "select decimal(k3, 4, 1) from db2.tbl1;";
         if (Config.enable_decimal_conversion) {
-            
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3`
 AS DECIMALV3(4, 1))"));
+            
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMALV3(4, 1))"));
         } else {
-            
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3`
 AS DECIMAL(4, 1))"));
+            
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMAL(4, 1))"));
         }
 
         // 5. cast any type to varchar with fixed length
@@ -271,7 +278,8 @@ public class CreateFunctionTest {
         testFunctionQuery(ctx, queryStr, true);
 
         queryStr = "select varchar(k1, 4) from db2.tbl1;";
-        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
 AS VARCHAR(65533))"));
+        
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS VARCHAR(65533))"));
 
         // 6. cast any type to char with fixed length
         createFuncStr = "create global alias function db2.to_char(all, int) 
with parameter(text, length) as "
@@ -286,7 +294,8 @@ public class CreateFunctionTest {
         testFunctionQuery(ctx, queryStr, true);
 
         queryStr = "select to_char(k1, 4) from db2.tbl1;";
-        
Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1`
 AS CHARACTER)"));
+        
Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS CHARACTER)"));
     }
 
     private void testFunctionQuery(ConnectContext ctx, String queryStr, 
Boolean isStringLiteral) throws Exception {
@@ -320,4 +329,8 @@ public class CreateFunctionTest {
         Env.getCurrentEnv().createDb(createDbStmt);
         System.out.println(Env.getCurrentInternalCatalog().getDbNames());
     }
+
+    private boolean containsIgnoreCase(String str, String sub) {
+        return str.toLowerCase().contains(sub.toLowerCase());
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
new file mode 100644
index 00000000000..56b67e087d5
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.postprocess;
+
+import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector;
+import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.types.IntegerType;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class CommonSubExpressionTest extends ExpressionRewriteTestHelper {
+    @Test
+    public void testExtractCommonExpr() {
+        List<NamedExpression> exprs = parseProjections("a+b, a+b+1, 
abs(a+b+1), a");
+        CommonSubExpressionCollector collector =
+                new CommonSubExpressionCollector();
+        exprs.forEach(expr -> collector.visit(expr, null));
+        System.out.println(collector.commonExprByDepth);
+        Assertions.assertEquals(2, collector.commonExprByDepth.size());
+        List<Expression> l1 = 
collector.commonExprByDepth.get(Integer.valueOf(1))
+                .stream().collect(Collectors.toList());
+        List<Expression> l2 = 
collector.commonExprByDepth.get(Integer.valueOf(2))
+                .stream().collect(Collectors.toList());
+        Assertions.assertEquals(1, l1.size());
+        assertExpression(l1.get(0), "a+b");
+        Assertions.assertEquals(1, l2.size());
+        assertExpression(l2.get(0), "a+b+1");
+    }
+
+    @Test
+    public void testMultiLayers() throws Exception {
+        List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, 
abs(a+b+1), a");
+        Set<Slot> inputSlots = exprs.get(0).getInputSlots();
+        CommonSubExpressionOpt opt = new CommonSubExpressionOpt();
+        Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class
+                .getDeclaredMethod("computeMultiLayerProjections", Set.class, 
List.class);
+        computeMultLayerProjectionsMethod.setAccessible(true);
+        List<List<NamedExpression>> multiLayers = 
(List<List<NamedExpression>>) computeMultLayerProjectionsMethod
+                .invoke(opt, inputSlots, exprs);
+        System.out.println(multiLayers);
+        Assertions.assertEquals(3, multiLayers.size());
+        List<NamedExpression> l0 = multiLayers.get(0);
+        Assertions.assertEquals(2, l0.size());
+        
Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a")));
+        Assertions.assertTrue(l0.get(1) instanceof Alias);
+        assertExpression(l0.get(1).child(0), "a+b");
+        Assertions.assertEquals(multiLayers.get(1).size(), 3);
+        Assertions.assertEquals(multiLayers.get(2).size(), 5);
+        List<NamedExpression> l2 = multiLayers.get(2);
+        for (int i = 0; i < 5; i++) {
+            Assertions.assertEquals(exprs.get(i).getExprId().asInt(), 
l2.get(i).getExprId().asInt());
+        }
+
+    }
+
+    private void assertExpression(Expression expr, String str) {
+        Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), 
expr);
+    }
+
+    private List<NamedExpression> parseProjections(String exprList) {
+        List<NamedExpression> result = new ArrayList<>();
+        String[] exprArray = exprList.split(",");
+        for (String item : exprArray) {
+            Expression expr = ExprParser.INSTANCE.parseExpression(item);
+            if (expr instanceof NamedExpression) {
+                result.add((NamedExpression) expr);
+            } else {
+                result.add(new Alias(expr));
+            }
+        }
+        return result;
+    }
+
+    public static class ExprParser {
+        public static ExprParser INSTANCE = new ExprParser();
+        HashMap<String, SlotReference> slotMap = new HashMap<>();
+
+        public Expression parseExpression(String str) {
+            Expression expr = PARSER.parseExpression(str);
+            return expr.accept(DataTypeAssignor.INSTANCE, slotMap);
+        }
+    }
+
+    public static class DataTypeAssignor extends 
DefaultExpressionRewriter<Map<String, SlotReference>> {
+        public static DataTypeAssignor INSTANCE = new DataTypeAssignor();
+
+        @Override
+        public Expression visitSlot(Slot slot, Map<String, SlotReference> 
slotMap) {
+            SlotReference exitsSlot = slotMap.get(slot.getName());
+            if (exitsSlot != null) {
+                return exitsSlot;
+            } else {
+                SlotReference slotReference = new 
SlotReference(slot.getName(), IntegerType.INSTANCE);
+                slotMap.put(slot.getName(), slotReference);
+                return slotReference;
+            }
+        }
+    }
+
+}
diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out 
b/regression-test/data/tpch_sf0.1_p1/sql/cse.out
new file mode 100644
index 00000000000..5ab44655661
--- /dev/null
+++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out
@@ -0,0 +1,30 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !cse --
+1      1       3       4
+2      0       3       4
+3      1       5       6
+4      0       5       6
+5      4       10      11
+6      0       7       8
+7      3       11      12
+8      1       10      11
+9      4       14      15
+10     1       12      13
+
+-- !cse_2 --
+17     1       18      19      19
+5      2       7       8       8
+1      3       4       5       5
+15     4       19      20      20
+11     5       16      17      17
+14     6       20      21      21
+23     7       30      31      31
+17     8       25      26      26
+10     9       19      20      20
+24     10      34      35      35
+
+-- !cse_3 --
+12093  13093   14093   15093
+
+-- !cse_4 --
+12093  13093   14093   15093
\ No newline at end of file
diff --git 
a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
 
b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
index e6f05c6c765..cf0c03fc3bd 100644
--- 
a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
+++ 
b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
@@ -32,6 +32,7 @@ class ExplainAction implements SuiteAction {
     private SuiteContext context
     private Set<String> containsStrings = new LinkedHashSet<>()
     private Set<String> notContainsStrings = new LinkedHashSet<>()
+    private Map<String, Integer> multiContainsStrings = new HashMap<>()
     private String coonType
     private Closure checkFunction
 
@@ -56,6 +57,10 @@ class ExplainAction implements SuiteAction {
         containsStrings.add(subString)
     }
 
+    void multiContains(String subString, int n) {
+        multiContainsStrings.put(subString, n);
+    }
+
     void notContains(String subString) {
         notContainsStrings.add(subString)
     }
@@ -112,6 +117,16 @@ class ExplainAction implements SuiteAction {
                     throw t
                 }
             }
+            for (Map.Entry entry : multiContainsStrings) {
+                int count = explainString.count(entry.key);
+                if (count != entry.value) {
+                    String msg = ("Explain and check failed, expect 
multiContains '${string}' , '${entry.value}' times, actural '${count}' times."
+                            + "Actual explain string 
is:\n${explainString}").toString()
+                    log.info(msg)
+                    def t = new IllegalStateException(msg)
+                    throw t
+                }
+            }
         }
     }
 
diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy 
b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy
new file mode 100644
index 00000000000..698dbd3e5d0
--- /dev/null
+++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// The cases is copied from https://github.com/trinodb/trino/tree/master
+// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/tpcds
+// and modified by Doris.
+
+suite('cse') {
+    def q1 = """select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as 
x, (s_suppkey + n_regionkey) + 2 as y 
+            from supplier join nation on s_nationkey=n_nationkey order by 
s_suppkey , n_regionkey limit 10 ;
+            """
+
+    def q2 = """select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), 
(s_nationkey + s_suppkey) + 1,  abs((s_nationkey + s_suppkey) + 1) 
+    from supplier order by s_suppkey , s_suppkey limit 10 ;"""
+
+    qt_cse "${q1}"
+
+    explain {
+        sql "${q1}"
+        contains "intermediate projections:"
+    }
+
+    qt_cse_2 "${q2}"
+
+    explain {
+        sql "${q2}"
+        multiContains("intermediate projections:", 2)
+    }
+
+    qt_cse_3 """ select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey 
+2 )  , sum(s_nationkey + 3 ) from supplier ;"""
+
+    qt_cse_4 """select sum(s_nationkey),sum(s_nationkey) + count(1) 
,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey)  + 3 * count(1) from 
supplier ;"""
+
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to