morrySnow commented on code in PR #59116:
URL: https://github.com/apache/doris/pull/59116#discussion_r2645385227


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java:
##########
@@ -100,6 +100,7 @@ private Plan 
mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate
         Map<ExprId, Expression> innerAggExprIdToAggFunc = 
getInnerAggExprIdToAggFuncMap(innerAgg);
         // rewrite agg function. e.g. max(max)
         List<NamedExpression> replacedAggFunc = 
replacedOutputExpressions.stream()
+                .filter(e -> e.containsType(AggregateFunction.class))

Review Comment:
   is this a another bug? or it is a optimization?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);

Review Comment:
   maybe we could support more functions, just like what we do in MV rollup 
matching



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) child1.getOutput());
+        childrenOutputs.add((List) child2.getOutput());
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(child1, child2));
+    }
+
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {
+        Plan aggChild = aggregate.child();
+        if (!(aggChild instanceof LogicalRepeat)) {
+            return -1;
+        }
+        // check agg func
+        Set<AggregateFunction> aggFunctions = 
aggregate.getAggregateFunctions();
+        for (AggregateFunction aggFunction : aggFunctions) {
+            if (!SUPPORT_AGG_FUNCTIONS.contains(aggFunction.getClass())) {
+                return -1;
+            }
+            if (aggFunction.isDistinct()) {
+                return -1;
+            }
+        }
+        // find max group
+        LogicalRepeat<Plan> repeat = (LogicalRepeat) aggChild;
+        for (NamedExpression expr : repeat.getOutputExpressions()) {
+            if (expr.containsType(GroupingScalarFunction.class)) {
+                return -1;
+            }

Review Comment:
   why?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) child1.getOutput());
+        childrenOutputs.add((List) child2.getOutput());
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(child1, child2));
+    }
+
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {
+        Plan aggChild = aggregate.child();
+        if (!(aggChild instanceof LogicalRepeat)) {
+            return -1;
+        }
+        // check agg func
+        Set<AggregateFunction> aggFunctions = 
aggregate.getAggregateFunctions();
+        for (AggregateFunction aggFunction : aggFunctions) {
+            if (!SUPPORT_AGG_FUNCTIONS.contains(aggFunction.getClass())) {
+                return -1;
+            }
+            if (aggFunction.isDistinct()) {
+                return -1;
+            }
+        }
+        // find max group
+        LogicalRepeat<Plan> repeat = (LogicalRepeat) aggChild;
+        for (NamedExpression expr : repeat.getOutputExpressions()) {
+            if (expr.containsType(GroupingScalarFunction.class)) {
+                return -1;
+            }
+        }
+        List<List<Expression>> groupingSets = repeat.getGroupingSets();
+        if (groupingSets.size() <= 3) {
+            return -1;
+        }
+        ImmutableSet<Expression> maxGroup = 
ImmutableSet.copyOf(groupingSets.get(0));
+        int maxGroupIndex = 0;
+        for (int i = 1; i < groupingSets.size(); ++i) {
+            List<Expression> groupingSet = groupingSets.get(i);
+            if (maxGroup.containsAll(groupingSet)) {
+                continue;
+            }
+            if (groupingSet.size() <= maxGroup.size()) {
+                maxGroupIndex = -1;
+                break;

Review Comment:
   i don't think it is right. think about `[1,2], [3,4], [1,2,3,4]`, we only 
need to find a group that contains all other group. but this function require 
more than that



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,

Review Comment:
   add java doc style comment to all private functions



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));

Review Comment:
   add a regression test for rollup + session var guard



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),

Review Comment:
   use meaningful name



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) child1.getOutput());
+        childrenOutputs.add((List) child2.getOutput());
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(child1, child2));
+    }
+
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {

Review Comment:
   add comment to explain what the return value mean



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,

Review Comment:
   add function ut for all private functions



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) child1.getOutput());
+        childrenOutputs.add((List) child2.getOutput());
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(child1, child2));
+    }
+
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {

Review Comment:
   add a function ut to it



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) child1.getOutput());
+        childrenOutputs.add((List) child2.getOutput());
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(child1, child2));
+    }
+
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {
+        Plan aggChild = aggregate.child();
+        if (!(aggChild instanceof LogicalRepeat)) {
+            return -1;
+        }
+        // check agg func
+        Set<AggregateFunction> aggFunctions = 
aggregate.getAggregateFunctions();
+        for (AggregateFunction aggFunction : aggFunctions) {
+            if (!SUPPORT_AGG_FUNCTIONS.contains(aggFunction.getClass())) {
+                return -1;
+            }
+            if (aggFunction.isDistinct()) {
+                return -1;
+            }
+        }
+        // find max group
+        LogicalRepeat<Plan> repeat = (LogicalRepeat) aggChild;
+        for (NamedExpression expr : repeat.getOutputExpressions()) {
+            if (expr.containsType(GroupingScalarFunction.class)) {
+                return -1;
+            }
+        }
+        List<List<Expression>> groupingSets = repeat.getGroupingSets();
+        if (groupingSets.size() <= 3) {

Review Comment:
   explain why use `3` here



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,
+            LogicalAggregate<? extends Plan> aggregate) {
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<NamedExpression> unionOutputs = new ArrayList<>();
+        List<List<SlotReference>> childrenOutputs = new ArrayList<>();
+        childrenOutputs.add((List) child1.getOutput());
+        childrenOutputs.add((List) child2.getOutput());
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                continue;
+            }
+            unionOutputs.add(expr.toSlot());
+        }
+        return new LogicalUnion(Qualifier.ALL, unionOutputs, childrenOutputs, 
ImmutableList.of(),
+                false, ImmutableList.of(child1, child2));
+    }
+
+    private int canOptimize(LogicalAggregate<? extends Plan> aggregate) {
+        Plan aggChild = aggregate.child();
+        if (!(aggChild instanceof LogicalRepeat)) {
+            return -1;
+        }
+        // check agg func
+        Set<AggregateFunction> aggFunctions = 
aggregate.getAggregateFunctions();
+        for (AggregateFunction aggFunction : aggFunctions) {
+            if (!SUPPORT_AGG_FUNCTIONS.contains(aggFunction.getClass())) {
+                return -1;
+            }
+            if (aggFunction.isDistinct()) {
+                return -1;
+            }
+        }
+        // find max group
+        LogicalRepeat<Plan> repeat = (LogicalRepeat) aggChild;
+        for (NamedExpression expr : repeat.getOutputExpressions()) {
+            if (expr.containsType(GroupingScalarFunction.class)) {
+                return -1;
+            }
+        }
+        List<List<Expression>> groupingSets = repeat.getGroupingSets();
+        if (groupingSets.size() <= 3) {
+            return -1;
+        }
+        ImmutableSet<Expression> maxGroup = 
ImmutableSet.copyOf(groupingSets.get(0));
+        int maxGroupIndex = 0;
+        for (int i = 1; i < groupingSets.size(); ++i) {
+            List<Expression> groupingSet = groupingSets.get(i);
+            if (maxGroup.containsAll(groupingSet)) {
+                continue;
+            }
+            if (groupingSet.size() <= maxGroup.size()) {
+                maxGroupIndex = -1;
+                break;
+            }
+            ImmutableSet<Expression> currentSet = 
ImmutableSet.copyOf(groupingSet);
+            if (currentSet.containsAll(maxGroup)) {
+                maxGroup = currentSet;
+                maxGroupIndex = i;
+            } else {
+                maxGroupIndex = -1;
+                break;
+            }
+        }
+        return maxGroupIndex;
+    }
+
+    private LogicalCTEProducer<LogicalAggregate<Plan>> 
constructProducer(LogicalAggregate<? extends Plan> aggregate,
+            int maxGroupIndex, DistinctSelectorContext ctx, Map<Slot, Slot> 
preToCloneSlotMap) {
+        LogicalRepeat<? extends Plan> repeat = (LogicalRepeat<? extends Plan>) 
aggregate.child();
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        List<NamedExpression> originAggOutputs = 
aggregate.getOutputExpressions();
+        Set<NamedExpression> preAggOutputSet = new 
HashSet<NamedExpression>((List) maxGroupByList);
+        for (NamedExpression aggOutput : originAggOutputs) {
+            if (aggOutput.containsType(AggregateFunction.class)) {
+                preAggOutputSet.add(aggOutput);
+            }
+        }
+        List<NamedExpression> orderedAggOutputs = new ArrayList<>();
+        // keep order
+        for (NamedExpression aggOutput : originAggOutputs) {
+            if (preAggOutputSet.contains(aggOutput)) {
+                orderedAggOutputs.add(aggOutput);
+            }
+        }
+
+        LogicalAggregate<Plan> preAgg = new 
LogicalAggregate<>(Utils.fastToImmutableList(maxGroupByList),
+                orderedAggOutputs, repeat.child());
+        LogicalAggregate<Plan> preAggClone = (LogicalAggregate<Plan>) 
LogicalPlanDeepCopier.INSTANCE
+                .deepCopy(preAgg, new DeepCopierContext());
+        for (int i = 0; i < preAgg.getOutputExpressions().size(); ++i) {
+            preToCloneSlotMap.put(preAgg.getOutput().get(i), 
preAggClone.getOutput().get(i));
+        }
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer =
+                new LogicalCTEProducer<>(ctx.statementContext.getNextCTEId(), 
preAggClone);
+        ctx.cteProducerList.add(producer);
+        return producer;
+    }
+
+    private LogicalRepeat<Plan> constructRepeat(LogicalRepeat<Plan> repeat, 
LogicalPlan child,
+            List<List<Expression>> newGroupingSets, Map<Slot, Slot> 
producerToConsumer2SlotMap) {

Review Comment:
   2 -> To ?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java:
##########
@@ -0,0 +1,354 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SessionVarGuardExpr;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * RewriteRepeatByCte will rewrite grouping sets. eg:
+ *      select a, b, c, d, e sum(f) from t group by rollup(a, b, c, d, e);
+ * rewrite to:
+ *    with cte1 as (select a, b, c, d, e, sum(f) x from t group by rollup(a, 
b, c, d, e))
+ *      select * fom cte1
+ *      union all
+ *      select a, b, c, d, null, sum(x) x from t group by rollup(a, b, c, d)
+ *
+ * LogicalAggregate(gby: a,b,c,d,e,grouping_id 
output:a,b,c,d,e,grouping_id,sum(f))
+ *   +--LogicalRepeat(grouping sets: 
(a,b,c,d,e),(a,b,c,d),(a,b,c),(a,b),(a),())
+ * ->
+ * LogicalCTEAnchor
+ *   +--LogicalCTEProducer(cte)
+ *     +--LogicalAggregate(gby: a,b,c,d,e; aggFunc: sum(f) as x)
+ *   +--LogicalUnionAll
+ *     +--LogicalProject(a,b,c,d, null as e, sum(x))
+ *       +--LogicalAggregate(gby:a,b,c,d,grouping_id; aggFunc: sum(x))
+ *         +--LogicalRepeat(grouping sets: (a,b,c,d),(a,b,c),(a,b),(a),())
+ *           +--LogicalCTEConsumer(cteConsumer1)
+ *     +--LogicalCTEConsumer(cteConsumer2)
+ */
+public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<DistinctSelectorContext>
+        implements CustomRewriter {
+    public static final DecomposeRepeatWithPreAggregation INSTANCE = new 
DecomposeRepeatWithPreAggregation();
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORT_AGG_FUNCTIONS =
+            ImmutableSet.of(Sum.class, Min.class, Max.class);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        DistinctSelectorContext ctx = new 
DistinctSelectorContext(jobContext.getCascadesContext().getStatementContext(),
+                jobContext.getCascadesContext());
+        plan = plan.accept(this, ctx);
+        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
ctx.cteProducerList.get(i);
+            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
+        }
+        return plan;
+    }
+
+    @Override
+    public Plan visitLogicalCTEAnchor(
+            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, 
DistinctSelectorContext ctx) {
+        Plan child1 = anchor.child(0).accept(this, ctx);
+        DistinctSelectorContext consumerContext =
+                new DistinctSelectorContext(ctx.statementContext, 
ctx.cascadesContext);
+        Plan child2 = anchor.child(1).accept(this, consumerContext);
+        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
+            LogicalCTEProducer<? extends Plan> producer = 
consumerContext.cteProducerList.get(i);
+            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, 
child2);
+        }
+        return anchor.withChildren(ImmutableList.of(child1, child2));
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, DistinctSelectorContext ctx) {
+        aggregate = visitChildren(this, aggregate, ctx);
+        int maxGroupIndex = canOptimize(aggregate);
+        if (maxGroupIndex < 0) {
+            return aggregate;
+        }
+        Map<Slot, Slot> preToProducerSlotMap = new HashMap<>();
+        LogicalCTEProducer<LogicalAggregate<Plan>> producer = 
constructProducer(aggregate, maxGroupIndex, ctx,
+                preToProducerSlotMap);
+        LogicalCTEConsumer consumer1 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+        LogicalCTEConsumer consumer2 = new 
LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
+                producer.getCteId(), "", producer);
+
+        // build map : origin slot to consumer slot
+        Map<Slot, Slot> producerToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : 
consumer1.getProducerToConsumerOutputMap().entries()) {
+            producerToConsumerMap.put(entry.getKey(), entry.getValue());
+        }
+        Map<Slot, Slot> originToConsumerMap = new HashMap<>();
+        for (Map.Entry<Slot, Slot> entry : preToProducerSlotMap.entrySet()) {
+            originToConsumerMap.put(entry.getKey(), 
producerToConsumerMap.get(entry.getValue()));
+        }
+
+        LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
+        List<List<Expression>> newGroupingSets = new ArrayList<>();
+        for (int i = 0; i < repeat.getGroupingSets().size(); ++i) {
+            if (i == maxGroupIndex) {
+                continue;
+            }
+            newGroupingSets.add(repeat.getGroupingSets().get(i));
+        }
+        LogicalRepeat<Plan> newRepeat = constructRepeat(repeat, consumer1, 
newGroupingSets, originToConsumerMap);
+        Set<Expression> needRemovedExprSet = getNeedAddNullExpressions(repeat, 
newGroupingSets, maxGroupIndex);
+        LogicalProject<Plan> project = constructProjectAggregate(aggregate,
+                originToConsumerMap, repeat, newRepeat, needRemovedExprSet);
+        return constructUnion(project, consumer2, aggregate);
+    }
+
+    private Map<AggregateFunction, Slot> 
getAggFuncSlotMap(List<NamedExpression> outputExpressions,
+            Map<Slot, Slot> pToc) {
+        // build map : aggFunc to Slot
+        Map<AggregateFunction, Slot> aggFuncSlotMap = new HashMap<>();
+        for (NamedExpression expr : outputExpressions) {
+            if (expr instanceof Alias) {
+                Expression aggFunc = 
SessionVarGuardExpr.getSessionVarGuardChild(expr.child(0));
+                if (!(aggFunc instanceof AggregateFunction)) {
+                    continue;
+                }
+                aggFuncSlotMap.put((AggregateFunction) aggFunc, 
pToc.get(expr.toSlot()));
+            }
+        }
+        return aggFuncSlotMap;
+    }
+
+    private Set<Expression> getNeedAddNullExpressions(LogicalRepeat<Plan> 
repeat,
+            List<List<Expression>> newGroupingSets, int maxGroupIndex) {
+        Set<Expression> otherGroupExprSet = new HashSet<>();
+        for (List<Expression> groupingSet : newGroupingSets) {
+            otherGroupExprSet.addAll(groupingSet);
+        }
+        List<Expression> maxGroupByList = 
repeat.getGroupingSets().get(maxGroupIndex);
+        Set<Expression> needRemovedExprSet = new HashSet<>(maxGroupByList);
+        needRemovedExprSet.removeAll(otherGroupExprSet);
+        return needRemovedExprSet;
+    }
+
+    private LogicalProject<Plan> constructProjectAggregate(LogicalAggregate<? 
extends Plan> aggregate,
+            Map<Slot, Slot> originToConsumerMap, LogicalRepeat<Plan> repeat,
+            LogicalRepeat<Plan> newRepeat, Set<Expression> needRemovedExprSet) 
{
+        Map<AggregateFunction, Slot> aggFuncSlotMap = 
getAggFuncSlotMap(aggregate.getOutputExpressions(),
+                originToConsumerMap);
+        List<NamedExpression> topAggOutput = new ArrayList<>();
+        List<NamedExpression> projects = new ArrayList<>();
+        for (NamedExpression expr : aggregate.getOutputExpressions()) {
+            if (needRemovedExprSet.contains(expr)) {
+                projects.add(new Alias(new NullLiteral(expr.getDataType()), 
expr.getName()));
+            } else {
+                if (expr instanceof Alias && 
expr.containsType(AggregateFunction.class)) {
+                    NamedExpression aggFuncAfterRewrite = (NamedExpression) 
expr.rewriteDownShortCircuit(e -> {
+                        if (e instanceof AggregateFunction) {
+                            return e.withChildren(aggFuncSlotMap.get(e));
+                        } else {
+                            return e;
+                        }
+                    });
+                    aggFuncAfterRewrite = ((Alias) aggFuncAfterRewrite)
+                            .withExprId(StatementScopeIdGenerator.newExprId());
+                    NamedExpression replacedExpr = (NamedExpression) 
aggFuncAfterRewrite.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                } else {
+                    if 
(expr.getExprId().equals(repeat.getGroupingId().get().getExprId())) {
+                        topAggOutput.add(expr);
+                        continue;
+                    }
+                    NamedExpression replacedExpr = (NamedExpression) 
expr.rewriteDownShortCircuit(
+                            e -> {
+                                if (originToConsumerMap.containsKey(e)) {
+                                    return originToConsumerMap.get(e);
+                                } else {
+                                    return e;
+                                }
+                            }
+                    );
+                    topAggOutput.add(replacedExpr);
+                    projects.add(replacedExpr.toSlot());
+                }
+            }
+        }
+        Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf(
+                ExpressionUtils.flatExpressions((List) 
newRepeat.getGroupingSets()));
+        List<Expression> topAggGby2 = new ArrayList<>(groupingSetsUsedSlot);
+        topAggGby2.add(newRepeat.getGroupingId().get());
+        List<Expression> replacedAggGby = ExpressionUtils.replace(topAggGby2, 
originToConsumerMap);
+        LogicalAggregate<Plan> topAgg = new LogicalAggregate<>(replacedAggGby, 
topAggOutput,
+                Optional.of(newRepeat), newRepeat);
+        return new LogicalProject<>(projects, topAgg);
+    }
+
+    private LogicalUnion constructUnion(LogicalPlan child1, LogicalPlan child2,

Review Comment:
   parameters should use meaningful name



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to