This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 5ac4b6a1373cfc274d8efe30ff7435b8442d7b05
Author: xzj7019 <131111794+xzj7...@users.noreply.github.com>
AuthorDate: Mon Feb 19 14:13:32 2024 +0800

    [opt](Nereids) refine group by elimination column prune (#30953)
---
 .../doris/nereids/rules/rewrite/ColumnPruning.java | 54 ++++++++++++++++++----
 .../nereids/rules/rewrite/EliminateGroupByKey.java |  5 +-
 2 files changed, 45 insertions(+), 14 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
index dcb330cd28e..284ac52e14d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java
@@ -18,7 +18,9 @@
 package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
@@ -38,6 +40,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
 import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
@@ -173,16 +176,18 @@ public class ColumnPruning extends 
DefaultPlanRewriter<PruneContext> implements
     private Plan pruneAggregate(Aggregate agg, PruneContext context) {
         // first try to prune group by and aggregate functions
         Aggregate prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), 
agg::pruneOutputs, context);
+        Set<Integer> enableNereidsRules = 
ConnectContext.get().getSessionVariable().getEnableNereidsRules();
+        Aggregate fillUpAggr;
 
-        List<Expression> groupByExpressions = 
prunedOutputAgg.getGroupByExpressions();
-        List<NamedExpression> outputExpressions = 
prunedOutputAgg.getOutputExpressions();
-
-        // then fill up group by
-        Aggregate fillUpOutputRepeat = 
fillUpGroupByToOutput(groupByExpressions, outputExpressions)
-                .map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput))
-                .orElse(prunedOutputAgg);
+        if 
(!enableNereidsRules.contains(RuleType.ELIMINATE_GROUP_BY_KEY.type())) {
+            fillUpAggr = fillUpGroupByToOutput(prunedOutputAgg)
+                    .map(fullOutput -> 
prunedOutputAgg.withAggOutput(fullOutput))
+                    .orElse(prunedOutputAgg);
+        } else {
+            fillUpAggr = fillUpGroupByAndOutput(prunedOutputAgg);
+        }
 
-        return pruneChildren(fillUpOutputRepeat);
+        return pruneChildren(fillUpAggr);
     }
 
     private Plan skipPruneThisAndFirstLevelChildren(Plan plan) {
@@ -193,8 +198,9 @@ public class ColumnPruning extends 
DefaultPlanRewriter<PruneContext> implements
         return pruneChildren(plan, requireAllOutputOfChildren);
     }
 
-    private static Optional<List<NamedExpression>> fillUpGroupByToOutput(
-            List<Expression> groupBy, List<NamedExpression> output) {
+    private static Optional<List<NamedExpression>> 
fillUpGroupByToOutput(Aggregate prunedOutputAgg) {
+        List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
+        List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
 
         if (output.containsAll(groupBy)) {
             return Optional.empty();
@@ -209,6 +215,34 @@ public class ColumnPruning extends 
DefaultPlanRewriter<PruneContext> implements
                 .build());
     }
 
+    private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) 
{
+        List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
+        List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
+
+        if (!(prunedOutputAgg instanceof LogicalAggregate)) {
+            return prunedOutputAgg;
+        }
+        // add back group by keys which eliminated by rule 
ELIMINATE_GROUP_BY_KEY
+        // if related output expressions are not in pruned output list.
+        List<NamedExpression> remainedOutputExprs = Lists.newArrayList(output);
+        remainedOutputExprs.removeAll(groupBy);
+
+        List<NamedExpression> newOutputList = Lists.newArrayList();
+        newOutputList.addAll((List) groupBy);
+        newOutputList.addAll(remainedOutputExprs);
+
+        if (!(prunedOutputAgg instanceof LogicalAggregate)) {
+            return prunedOutputAgg.withAggOutput(newOutputList);
+        } else {
+            List<Expression> newGroupByExprList = 
newOutputList.stream().filter(e ->
+                    !(prunedOutputAgg.getAggregateFunctions().contains(e)
+                            || e instanceof Alias && 
prunedOutputAgg.getAggregateFunctions()
+                                .contains(((Alias) e).child()))
+            ).collect(Collectors.toList());
+            return ((LogicalAggregate) 
prunedOutputAgg).withGroupByAndOutput(newGroupByExprList, newOutputList);
+        }
+    }
+
     /** prune output */
     public static <P extends Plan> P pruneOutput(P plan, List<NamedExpression> 
originOutput,
             Function<List<NamedExpression>, P> withPrunedOutput, PruneContext 
context) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
index d922252bebc..69a34a680ec 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java
@@ -50,10 +50,7 @@ public class EliminateGroupByKey extends 
OneRewriteRuleFactory {
             List<FdItem> uniqueFdItems = new ArrayList<>();
             List<FdItem> nonUniqueFdItems = new ArrayList<>();
             if (agg.getGroupByExpressions().isEmpty()
-                    || 
agg.getGroupByExpressions().equals(agg.getOutputExpressions())
-                    || !agg.getGroupByExpressions().stream().allMatch(e -> e 
instanceof SlotReference)
-                    || agg.getGroupByExpressions().stream().allMatch(e ->
-                        ((SlotReference) e).getColumn().isPresent() && 
((SlotReference) e).getTable().isPresent())) {
+                    || !agg.getGroupByExpressions().stream().allMatch(e -> e 
instanceof SlotReference)) {
                 return null;
             }
             ImmutableSet<FdItem> fdItems = 
childPlan.getLogicalProperties().getFunctionalDependencies().getFdItems();


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to