morrySnow commented on code in PR #64849:
URL: https://github.com/apache/doris/pull/64849#discussion_r3489312106


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java:
##########
@@ -38,69 +42,162 @@
 import java.util.Map.Entry;
 import java.util.Set;
 
-
 /**
  * Eliminate group by key based on fd item information.
  * such as:
  *  for a -> b, we can get:
  *          group by a, b, c  => group by a, c
+ *
+ * When a group-by key is FD-redundant but still needed in the output,
+ * it is wrapped with any_value() and assigned a fresh ExprId.
+ * Upper plan references are rewritten via ExprIdRewriter so that
+ * all ancestor nodes see the new ExprIds.
  */
-@DependsRules({EliminateGroupBy.class, ColumnPruning.class})
-public class EliminateGroupByKey implements RewriteRuleFactory {
+public class EliminateGroupByKey extends DefaultPlanRewriter<Map<ExprId, 
ExprId>> implements CustomRewriter {
+    private ExprIdRewriter exprIdReplacer;
 
     @Override
-    public List<Rule> buildRules() {
-        return ImmutableList.of(
-                RuleType.ELIMINATE_GROUP_BY_KEY.build(
-                        logicalProject(logicalAggregate().when(agg -> 
!agg.getSourceRepeat().isPresent()))
-                                .then(proj -> {
-                                    LogicalAggregate<? extends Plan> agg = 
proj.child();
-                                    LogicalAggregate<Plan> newAgg = 
eliminateGroupByKey(agg, proj.getInputSlots());
-                                    if (newAgg == null) {
-                                        return null;
-                                    }
-                                    return proj.withChildren(newAgg);
-                                })),
-                RuleType.ELIMINATE_FILTER_GROUP_BY_KEY.build(
-                        logicalProject(logicalFilter(logicalAggregate()
-                                .when(agg -> 
!agg.getSourceRepeat().isPresent())))
-                                .then(proj -> {
-                                    LogicalAggregate<? extends Plan> agg = 
proj.child().child();
-                                    Set<Slot> requireSlots = new 
HashSet<>(proj.getInputSlots());
-                                    
requireSlots.addAll(proj.child(0).getInputSlots());
-                                    LogicalAggregate<Plan> newAgg = 
eliminateGroupByKey(agg, requireSlots);
-                                    if (newAgg == null) {
-                                        return null;
-                                    }
-                                    return 
proj.withChildren(proj.child().withChildren(newAgg));
-                                })
-                )
-        );
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        if (!plan.containsType(Aggregate.class)) {
+            return plan;
+        }
+        Map<ExprId, ExprId> replaceMap = new HashMap<>();
+        ExprIdRewriter.ReplaceRule replaceRule = new 
ExprIdRewriter.ReplaceRule(replaceMap, false);
+        exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
+        return plan.accept(this, replaceMap);
     }
 
-    LogicalAggregate<Plan> eliminateGroupByKey(LogicalAggregate<? extends 
Plan> agg, Set<Slot> requireOutput) {
-        Set<Expression> removeExpression = findCanBeRemovedExpressions(agg, 
requireOutput,
+    @Override
+    public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
+        plan = visitChildren(this, plan, replaceMap);
+        plan = exprIdReplacer.rewriteExpr(plan, replaceMap);
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalProject(LogicalProject<? extends Plan> proj, 
Map<ExprId, ExprId> replaceMap) {
+        proj = visitChildren(this, proj, replaceMap);
+
+        // Find the Aggregate child, possibly through a Filter
+        Plan child = proj.child(0);
+        LogicalAggregate<? extends Plan> agg;
+        boolean hasFilter = child instanceof LogicalFilter;
+        if (hasFilter && child.child(0) instanceof LogicalAggregate) {
+            agg = (LogicalAggregate<? extends Plan>) child.child(0);
+        } else if (child instanceof LogicalAggregate) {
+            agg = (LogicalAggregate<? extends Plan>) child;
+        } else {
+            return exprIdReplacer.rewriteExpr(proj, replaceMap);
+        }
+
+        // Don't transform if source repeat is present
+        if (agg.getSourceRepeat().isPresent()) {
+            return exprIdReplacer.rewriteExpr(proj, replaceMap);
+        }
+
+        // Compute requireOutput: slots needed by the Project (and Filter, if 
present)
+        Set<Slot> requireOutput = new HashSet<>(proj.getInputSlots());
+        if (hasFilter) {
+            requireOutput.addAll(child.getInputSlots());
+        }
+
+        // Transform the aggregate
+        EliminateResult result = eliminateGroupByKeyWithMap(agg, 
requireOutput);
+        if (!result.changed) {
+            return exprIdReplacer.rewriteExpr(proj, replaceMap);
+        }
+
+        // Merge into the global replaceMap so that all ancestor nodes get 
rewritten
+        replaceMap.putAll(result.replaceMap);
+
+        // Rebuild the child chain with the new aggregate,
+        // and rewrite the Filter (if present) and Project expressions
+        Plan newChild;
+        if (hasFilter) {
+            Plan updatedFilter = child.withChildren(result.newAgg);
+            newChild = exprIdReplacer.rewriteExpr(updatedFilter, replaceMap);
+        } else {
+            newChild = result.newAgg;
+        }
+        Plan newProj = exprIdReplacer.rewriteExpr(proj.withChildren(newChild), 
replaceMap);
+        return newProj;
+    }
+
+    /** Result of eliminateGroupByKey: the new aggregate and a map of old->new 
ExprIds. */
+    private static class EliminateResult {
+        final LogicalAggregate<Plan> newAgg;
+        final Map<ExprId, ExprId> replaceMap;
+        final boolean changed;
+
+        EliminateResult(LogicalAggregate<Plan> newAgg, Map<ExprId, ExprId> 
replaceMap, boolean changed) {
+            this.newAgg = newAgg;
+            this.replaceMap = replaceMap;
+            this.changed = changed;
+        }
+    }
+
+    EliminateResult eliminateGroupByKeyWithMap(LogicalAggregate<? extends 
Plan> agg, Set<Slot> requireOutput) {
+        FindResult result = findCanBeRemovedExpressionsInternal(agg, 
requireOutput,
                 agg.child().getLogicalProperties().getTrait());
+        Set<Expression> removeExpression = result.removeExpression;
+        Set<Expression> wrapWithAnyValue = result.wrapWithAnyValue;
+
         List<Expression> newGroupExpression = new ArrayList<>();
         for (Expression expression : agg.getGroupByExpressions()) {
-            if (!removeExpression.contains(expression)) {
+            if (!removeExpression.contains(expression)
+                    && !wrapWithAnyValue.contains(expression)) {
                 newGroupExpression.add(expression);

Review Comment:
   Nit: consider adding a defensive guard for when all `groupByExpressions` 
become redundant and `newGroupExpression` ends up empty. 
`EliminateGroupByKeyByUniform` retains one eliminated key in this scenario — 
without it, an aggregate with zero group-by keys can produce incorrect results 
on empty input tables. While current FD patterns may not trigger this, the 
guard would make the rewrite more resilient.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to