morrySnow commented on code in PR #12274:
URL: https://github.com/apache/doris/pull/12274#discussion_r961829420


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java:
##########
@@ -54,15 +59,63 @@ public class PruneAggChildColumns extends 
OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return 
RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
+            Slot slot = handleAggregateConstant(agg);
+            List<Slot> childOutput = agg.child().getOutput();
+            if (slot != null) {
+                if (childOutput.size() == 1 && 
childOutput.get(0).equals(slot)) {
+                    return agg;
+                }
+                return agg.withChildren(ImmutableList.of(new 
LogicalProject<>(ImmutableList.of(slot), agg.child())));

Review Comment:
   it is wrong that use new slot to replace child's output, to consider this 
situation:
   ```sql
   select count(1) from t group by a
   ```
   the child's output is just `a`. we cannot replace it with new slot.



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java:
##########
@@ -54,15 +59,63 @@ public class PruneAggChildColumns extends 
OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return 
RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
+            Slot slot = handleAggregateConstant(agg);
+            List<Slot> childOutput = agg.child().getOutput();
+            if (slot != null) {
+                if (childOutput.size() == 1 && 
childOutput.get(0).equals(slot)) {
+                    return agg;
+                }
+                return agg.withChildren(ImmutableList.of(new 
LogicalProject<>(ImmutableList.of(slot), agg.child())));
+            }
             List<Expression> slots = Lists.newArrayList();
             slots.addAll(agg.getExpressions());
             Set<Slot> outputs = SlotExtractor.extractSlot(slots);
-            List<NamedExpression> prunedOutputs = 
agg.child().getOutput().stream().filter(outputs::contains)
+            List<NamedExpression> prunedOutputs = 
childOutput.stream().filter(outputs::contains)
                     .collect(Collectors.toList());
             if (prunedOutputs.size() == agg.child().getOutput().size()) {
                 return agg;
             }
             return agg.withChildren(ImmutableList.of(new 
LogicalProject<>(prunedOutputs, agg.child())));
         }));
     }
+
+    /**
+     * For these aggregate function with constant param. Such as:
+     *  count(*), count(1), sum(1)..etc.
+     */
+    private Slot handleAggregateConstant(LogicalAggregate<GroupPlan> agg) {
+        List<NamedExpression> outputExpressions = agg.getOutputExpressions();
+        for (NamedExpression namedExpression : outputExpressions) {
+            if (!(namedExpression instanceof Alias)) {
+                return null;
+            }
+            Expression childOfAlias = ((Alias) namedExpression).child();

Review Comment:
   consider `COUNT(*) + 1`, u need use `collect` to get all aggregate function 
from current expression



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java:
##########
@@ -54,15 +59,63 @@ public class PruneAggChildColumns extends 
OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return 
RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
+            Slot slot = handleAggregateConstant(agg);
+            List<Slot> childOutput = agg.child().getOutput();
+            if (slot != null) {
+                if (childOutput.size() == 1 && 
childOutput.get(0).equals(slot)) {
+                    return agg;
+                }
+                return agg.withChildren(ImmutableList.of(new 
LogicalProject<>(ImmutableList.of(slot), agg.child())));
+            }
             List<Expression> slots = Lists.newArrayList();
             slots.addAll(agg.getExpressions());
             Set<Slot> outputs = SlotExtractor.extractSlot(slots);
-            List<NamedExpression> prunedOutputs = 
agg.child().getOutput().stream().filter(outputs::contains)
+            List<NamedExpression> prunedOutputs = 
childOutput.stream().filter(outputs::contains)
                     .collect(Collectors.toList());
             if (prunedOutputs.size() == agg.child().getOutput().size()) {
                 return agg;
             }
             return agg.withChildren(ImmutableList.of(new 
LogicalProject<>(prunedOutputs, agg.child())));
         }));
     }
+
+    /**
+     * For these aggregate function with constant param. Such as:
+     *  count(*), count(1), sum(1)..etc.
+     */
+    private Slot handleAggregateConstant(LogicalAggregate<GroupPlan> agg) {
+        List<NamedExpression> outputExpressions = agg.getOutputExpressions();
+        for (NamedExpression namedExpression : outputExpressions) {
+            if (!(namedExpression instanceof Alias)) {
+                return null;
+            }
+            Expression childOfAlias = ((Alias) namedExpression).child();
+            if (!(childOfAlias instanceof AggregateFunction)) {
+                return null;
+            }
+            if (childOfAlias instanceof Count) {

Review Comment:
   not just count need to handle. u need to check whether agg function's 
children contain slot reference.



##########
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java:
##########
@@ -177,6 +178,20 @@ public void testPruneColumns4() {
                 );
     }
 
+    @Test
+    public void pruneCountStarStmt() {

Review Comment:
   please add more test such as
   ```sql
   SELECT COUNT(*), SUM(1) + SUM(2) FROM t GROUP BY a, b;
   SELECT COUNT(*) + 1, COUNT(c) FROM t GROUP BY a, b;
   ```
   



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java:
##########
@@ -38,8 +38,16 @@
  */
 public class PhysicalProject<CHILD_TYPE extends Plan> extends 
PhysicalUnary<CHILD_TYPE> implements Project {
 
+    private static int sId = 0;

Review Comment:
   what's this?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to