morrySnow commented on code in PR #54079:
URL: https://github.com/apache/doris/pull/54079#discussion_r2241856840


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostModel.java:
##########
@@ -103,6 +105,7 @@ public CostModel(ConnectContext connectContext) {
         }
         this.hboPlanStatisticsProvider = 
Objects.requireNonNull(Env.getCurrentEnv().getHboPlanStatisticsManager()
                 .getHboPlanStatisticsProvider(), "HboPlanStatisticsProvider is 
null");
+        this.requestChildrenProperties = childrenProperties;

Review Comment:
   why need this?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java:
##########
@@ -216,7 +216,8 @@ public void execute() {
             // if break when running the loop above, the condition must be 
false.
             if (curChildIndex == groupExpression.arity()) {
                 if (!calculateEnforce(requestChildrenProperties, 
outputChildrenProperties)) {
-                    return; // if error exists, return
+                    clear();

Review Comment:
   is a bug in base code too? we should fix it in a seperate PR and add some UT 
and regression case for it



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -1111,6 +1112,8 @@ public PlanFragment visitPhysicalHashAggregate(
         // 2. collect agg expressions and generate agg function to slot 
reference map
         List<Slot> aggFunctionOutput = Lists.newArrayList();
         ArrayList<FunctionCallExpr> execAggregateFunctions = 
Lists.newArrayListWithCapacity(outputExpressions.size());
+        boolean isPartial = 
aggregate.getAggregateParam().aggMode.productAggregateBuffer;

Review Comment:
   is this a bug in base code? maybe we should change the logic that generate 
aggMode



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);
+    private static final Map<Class<? extends AggregateFunction>,
+            Pair<Class<? extends AggregateFunction>, Class<? extends 
AggregateFunction>>> aggFunctionMap =
+            ImmutableMap.of(
+                    Sum.class, Pair.of(Sum.class, Sum.class),
+                    Min.class, Pair.of(Min.class, Min.class),
+                    Max.class, Pair.of(Max.class, Max.class),
+                    Count.class, Pair.of(Count.class, Sum0.class),
+                    Sum0.class, Pair.of(Sum0.class, Sum0.class),
+                    AnyValue.class, Pair.of(AnyValue.class, AnyValue.class)
+                    );
+
+    @Override
+    public Rule build() {
+        return logicalAggregate()
+                .whenNot(agg -> agg.getGroupByExpressions().isEmpty())
+                .then(this::apply).toRule(RuleType.DISTINCT_AGGREGATE_SPLIT);
+    }
+
+    private boolean checkByStatistics(LogicalAggregate<? extends Plan> 
aggregate) {
+        // 带group by的场景, group by key ndv低, distinct key ndv高, 
则转为multi_distinct
+        //      其他情况都拆分
+        // 不带group by的场景, distinct key的ndv低,使用multi_distinct, ndv高,使用cte拆分

Review Comment:
   use english, add TODO



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java:
##########
@@ -639,8 +641,11 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
                     rewriteJobs.addAll(jobs(topic("or expansion",
                             custom(RuleType.OR_EXPANSION, () -> 
OrExpansion.INSTANCE))));
                 }
+                // rewriteJobs.addAll(jobs(topic("split multi distinct",
+                //         custom(RuleType.SPLIT_MULTI_DISTINCT, () -> 
SplitMultiDistinct.INSTANCE))));
+

Review Comment:
   redundant code?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java:
##########
@@ -639,8 +641,11 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
                     rewriteJobs.addAll(jobs(topic("or expansion",
                             custom(RuleType.OR_EXPANSION, () -> 
OrExpansion.INSTANCE))));
                 }
+                // rewriteJobs.addAll(jobs(topic("split multi distinct",
+                //         custom(RuleType.SPLIT_MULTI_DISTINCT, () -> 
SplitMultiDistinct.INSTANCE))));
+
                 rewriteJobs.addAll(jobs(topic("split multi distinct",
-                        custom(RuleType.SPLIT_MULTI_DISTINCT, () -> 
SplitMultiDistinct.INSTANCE))));

Review Comment:
   `SplitMultiDistinct` this file is useless now?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);
+    private static final Map<Class<? extends AggregateFunction>,
+            Pair<Class<? extends AggregateFunction>, Class<? extends 
AggregateFunction>>> aggFunctionMap =
+            ImmutableMap.of(
+                    Sum.class, Pair.of(Sum.class, Sum.class),
+                    Min.class, Pair.of(Min.class, Min.class),
+                    Max.class, Pair.of(Max.class, Max.class),
+                    Count.class, Pair.of(Count.class, Sum0.class),
+                    Sum0.class, Pair.of(Sum0.class, Sum0.class),
+                    AnyValue.class, Pair.of(AnyValue.class, AnyValue.class)
+                    );
+
+    @Override
+    public Rule build() {
+        return logicalAggregate()
+                .whenNot(agg -> agg.getGroupByExpressions().isEmpty())
+                .then(this::apply).toRule(RuleType.DISTINCT_AGGREGATE_SPLIT);
+    }
+
+    private boolean checkByStatistics(LogicalAggregate<? extends Plan> 
aggregate) {
+        // 带group by的场景, group by key ndv低, distinct key ndv高, 
则转为multi_distinct
+        //      其他情况都拆分
+        // 不带group by的场景, distinct key的ndv低,使用multi_distinct, ndv高,使用cte拆分
+        return true;
+    }
+
+    private Plan apply(LogicalAggregate<? extends Plan> aggregate) {
+        Set<AggregateFunction> aggFuncs = aggregate.getAggregateFunctions();
+        //这个函数同时也要处理count(distinct a,b)这种
+        // 需要保证1.只有一个count(distinct)函数
+        // 如果是multi_distinct不处理
+        Set<AggregateFunction> distinctAggFuncs = new HashSet<>();
+        Set<AggregateFunction> otherFunctions = new HashSet<>();
+        for (AggregateFunction aggFunc : aggFuncs) {
+            if (aggFunc.isDistinct()) {
+                distinctAggFuncs.add(aggFunc);
+            } else {
+                otherFunctions.add(aggFunc);
+            }
+        }
+        if (distinctAggFuncs.size() != 1) {
+            return null;
+        }
+
+        // 并不是所有的都能拆分, other function里面如果有一些不能拆分的函数,就不拆
+        // 这里先实现一下,然后后续再考虑一下什么函数可以拆分.
+        for (AggregateFunction aggFunc : otherFunctions) {
+            if (!supportSplitOtherFunctions.contains(aggFunc.getClass())) {
+                return null;
+            }
+        }
+
+        // 满足条件了,开始写拆分的代码
+        // 先构造下面进行去重的AGG
+        // group by key为group by key + distinct key
+        List<Expression> groupByKeys = ImmutableList.<Expression>builder()
+                .addAll(aggregate.getGroupByExpressions())
+                
.addAll(distinctAggFuncs.iterator().next().getDistinctArguments())
+                .build();

Review Comment:
   need a set to remove duplicate keys



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java:
##########
@@ -38,22 +40,21 @@ public AggregateParam(AggPhase aggPhase, AggMode aggMode) {
         this(aggPhase, aggMode, true);
     }
 
-    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
canBeBanned) {
+    /** AggregateParam */
+    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
needSplit) {
+        this(aggPhase, aggMode, true, needSplit);
+    }
+
+    /** AggregateParam */
+    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
canBeBanned, boolean needSplit) {
         this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be 
null");
         this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be 
null");
         this.canBeBanned = canBeBanned;
-    }
-
-    public AggregateParam withAggPhase(AggPhase aggPhase) {
-        return new AggregateParam(aggPhase, aggMode, canBeBanned);
-    }
-
-    public AggregateParam withAggPhase(AggMode aggMode) {
-        return new AggregateParam(aggPhase, aggMode, canBeBanned);
-    }
-
-    public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode 
aggMode) {
-        return new AggregateParam(aggPhase, aggMode, canBeBanned);
+        // 从三阶段拆分出来的顶层一阶段的agg needSplit设置为false,
+        // 如果指定了false,那么不拆分, 如果指定了true,不代表需要拆分,
+        // 需要同时满足(!aggMode.productAggregateBuffer && 
!aggMode.consumeAggregateBuffer)
+        // needSplit && (!aggMode.productAggregateBuffer && 
!aggMode.consumeAggregateBuffer) 才需要拆分
+        this.needSplit = needSplit && (!aggMode.productAggregateBuffer && 
!aggMode.consumeAggregateBuffer);

Review Comment:
   should not update attribute, return a new one instead



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java:
##########
@@ -413,6 +415,52 @@ public Void visitPhysicalWindow(PhysicalWindow<? extends 
Plan> window, PlanConte
         return null;
     }
 
+    @Override
+    public Void visitPhysicalHashAggregate(PhysicalHashAggregate<? extends 
Plan> agg, PlanContext context) {
+        // 先在这里实现一下
+        // group by a,b
+        // 如果agg收到的请求是a,agg发出的请求是a,b,a是a,b的子集, 
那么agg发送a请求给孩子(这里判断一下a的ndv,如果很小的话就还是发a,b)
+        // 如果agg没有收到请求,那还是发出a,b
+        // 如果agg收到了请求,但是没有交集,那么agg仍然发出a,b
+        // 如果是local agg,那么发出any, 如果是global agg,才有要求
+        DistributionSpec parentDist = 
requestPropertyFromParent.getDistributionSpec();
+        if (agg.getAggPhase().isLocal()) {
+            addRequestPropertyToChildren(PhysicalProperties.ANY);
+            return null;
+        } else if (agg.getAggPhase().isGlobal()) {
+            if (agg.getPartitionExpressions().isPresent() && 
!agg.getPartitionExpressions().get().isEmpty()) {
+                addRequestPropertyToChildren(
+                        
PhysicalProperties.createHash(agg.getPartitionExpressions().get(), 
ShuffleType.REQUIRE));
+                return null;
+            }
+            if (agg.getGroupByExpressions().isEmpty()) {
+                addRequestPropertyToChildren(PhysicalProperties.GATHER);
+                return null;
+            }
+            //获得当前的group by key的expr id
+            List<ExprId> groupByExprIds = agg.getGroupByExpressions().stream()
+                    .filter(SlotReference.class::isInstance)
+                    .map(SlotReference.class::cast)
+                    .map(SlotReference::getExprId)
+                    .collect(Collectors.toList());
+            if (parentDist instanceof DistributionSpecHash) {
+                DistributionSpecHash distributionRequestFromParent = 
(DistributionSpecHash) parentDist;
+                List<ExprId> hashExprIds = 
distributionRequestFromParent.getOrderedShuffledColumns();
+                // 还需加上ndv的判断
+                if (new HashSet<>(groupByExprIds).containsAll(hashExprIds)) {

Review Comment:
   why must contains all? why not intersect is not empty?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java:
##########
@@ -413,6 +415,52 @@ public Void visitPhysicalWindow(PhysicalWindow<? extends 
Plan> window, PlanConte
         return null;
     }
 
+    @Override
+    public Void visitPhysicalHashAggregate(PhysicalHashAggregate<? extends 
Plan> agg, PlanContext context) {
+        // 先在这里实现一下
+        // group by a,b
+        // 如果agg收到的请求是a,agg发出的请求是a,b,a是a,b的子集, 
那么agg发送a请求给孩子(这里判断一下a的ndv,如果很小的话就还是发a,b)
+        // 如果agg没有收到请求,那还是发出a,b
+        // 如果agg收到了请求,但是没有交集,那么agg仍然发出a,b
+        // 如果是local agg,那么发出any, 如果是global agg,才有要求
+        DistributionSpec parentDist = 
requestPropertyFromParent.getDistributionSpec();
+        if (agg.getAggPhase().isLocal()) {
+            addRequestPropertyToChildren(PhysicalProperties.ANY);
+            return null;
+        } else if (agg.getAggPhase().isGlobal()) {
+            if (agg.getPartitionExpressions().isPresent() && 
!agg.getPartitionExpressions().get().isEmpty()) {
+                addRequestPropertyToChildren(
+                        
PhysicalProperties.createHash(agg.getPartitionExpressions().get(), 
ShuffleType.REQUIRE));
+                return null;
+            }
+            if (agg.getGroupByExpressions().isEmpty()) {
+                addRequestPropertyToChildren(PhysicalProperties.GATHER);
+                return null;
+            }
+            //获得当前的group by key的expr id
+            List<ExprId> groupByExprIds = agg.getGroupByExpressions().stream()
+                    .filter(SlotReference.class::isInstance)
+                    .map(SlotReference.class::cast)
+                    .map(SlotReference::getExprId)
+                    .collect(Collectors.toList());
+            if (parentDist instanceof DistributionSpecHash) {
+                DistributionSpecHash distributionRequestFromParent = 
(DistributionSpecHash) parentDist;
+                List<ExprId> hashExprIds = 
distributionRequestFromParent.getOrderedShuffledColumns();
+                // 还需加上ndv的判断

Review Comment:
   add TODO



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java:
##########
@@ -110,76 +109,85 @@ public List<List<PhysicalProperties>> 
visitPhysicalHashAggregate(
         if (agg.getGroupByExpressions().isEmpty() && 
agg.getOutputExpressions().isEmpty()) {
             return ImmutableList.of();
         }
-        if (!agg.getAggregateParam().canBeBanned) {
-            return visit(agg, context);
-        }
-        // forbid one phase agg on distribute
-        if (agg.getAggMode() == AggMode.INPUT_TO_RESULT && 
children.get(0).getPlan() instanceof PhysicalDistribute) {
-            // this means one stage gather agg, usually bad pattern
-            return ImmutableList.of();
-        }
-
-        // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
-        // TODO: this is forbid good plan after cte reuse by mistake
-        if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER
-                && requiredProperties.get(0).getDistributionSpec() instanceof 
DistributionSpecHash
-                && children.get(0).getPlan() instanceof PhysicalDistribute) {
-            return ImmutableList.of();
+        // 如果origin property 满足group by key, 但是不满足required, 那么禁用这个计划
+        PhysicalProperties originChildProperty = 
originChildrenProperties.get(0);

Review Comment:
   add ut



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggStrategySelector.java:
##########
@@ -0,0 +1,156 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.StatementContext;
+import org.apache.doris.nereids.jobs.JobContext;
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Chooses the optimal execution strategy for queries with multiple DISTINCT 
aggregations.
+ *
+ * Handles queries like "SELECT COUNT(DISTINCT c1), COUNT(DISTINCT c2) FROM t" 
by selecting between:
+ * - CTE decomposition: Splits into multiple CTEs, each computing one DISTINCT 
aggregate
+ * - Multi-DISTINCT function: Processes all distinct function use multi 
distinct function
+ *
+ * Selection criteria includes:
+ * - Number of distinct aggregates
+ * - Estimated cardinality of distinct values
+ * - Available memory resources
+ * - Query complexity
+ */
+public class DistinctAggStrategySelector extends 
DefaultPlanRewriter<DistinctSelectorContext>

Review Comment:
   add ut



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);
+    private static final Map<Class<? extends AggregateFunction>,
+            Pair<Class<? extends AggregateFunction>, Class<? extends 
AggregateFunction>>> aggFunctionMap =
+            ImmutableMap.of(
+                    Sum.class, Pair.of(Sum.class, Sum.class),
+                    Min.class, Pair.of(Min.class, Min.class),
+                    Max.class, Pair.of(Max.class, Max.class),
+                    Count.class, Pair.of(Count.class, Sum0.class),
+                    Sum0.class, Pair.of(Sum0.class, Sum0.class),
+                    AnyValue.class, Pair.of(AnyValue.class, AnyValue.class)
+                    );
+
+    @Override
+    public Rule build() {
+        return logicalAggregate()
+                .whenNot(agg -> agg.getGroupByExpressions().isEmpty())
+                .then(this::apply).toRule(RuleType.DISTINCT_AGGREGATE_SPLIT);
+    }
+
+    private boolean checkByStatistics(LogicalAggregate<? extends Plan> 
aggregate) {
+        // 带group by的场景, group by key ndv低, distinct key ndv高, 
则转为multi_distinct
+        //      其他情况都拆分
+        // 不带group by的场景, distinct key的ndv低,使用multi_distinct, ndv高,使用cte拆分
+        return true;
+    }
+
+    private Plan apply(LogicalAggregate<? extends Plan> aggregate) {
+        Set<AggregateFunction> aggFuncs = aggregate.getAggregateFunctions();
+        //这个函数同时也要处理count(distinct a,b)这种
+        // 需要保证1.只有一个count(distinct)函数
+        // 如果是multi_distinct不处理
+        Set<AggregateFunction> distinctAggFuncs = new HashSet<>();
+        Set<AggregateFunction> otherFunctions = new HashSet<>();
+        for (AggregateFunction aggFunc : aggFuncs) {
+            if (aggFunc.isDistinct()) {
+                distinctAggFuncs.add(aggFunc);
+            } else {
+                otherFunctions.add(aggFunc);
+            }
+        }
+        if (distinctAggFuncs.size() != 1) {
+            return null;
+        }
+
+        // 并不是所有的都能拆分, other function里面如果有一些不能拆分的函数,就不拆
+        // 这里先实现一下,然后后续再考虑一下什么函数可以拆分.
+        for (AggregateFunction aggFunc : otherFunctions) {
+            if (!supportSplitOtherFunctions.contains(aggFunc.getClass())) {
+                return null;
+            }
+        }
+
+        // 满足条件了,开始写拆分的代码
+        // 先构造下面进行去重的AGG
+        // group by key为group by key + distinct key
+        List<Expression> groupByKeys = ImmutableList.<Expression>builder()
+                .addAll(aggregate.getGroupByExpressions())
+                
.addAll(distinctAggFuncs.iterator().next().getDistinctArguments())

Review Comment:
   should check all distinctArguments is slot



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java:
##########
@@ -110,8 +110,10 @@ public PhysicalHashAggregate(List<Expression> 
groupByExpressions, List<NamedExpr
         this.partitionExpressions = Objects.requireNonNull(
                 partitionExpressions, "partitionExpressions cannot be null");
         this.aggregateParam = Objects.requireNonNull(aggregateParam, 
"aggregate param cannot be null");
+        // this.aggregateParam = aggregateParam;
         this.maybeUsingStream = maybeUsingStream;
-        this.requireProperties = Objects.requireNonNull(requireProperties, 
"requireProperties cannot be null");
+        // this.requireProperties = Objects.requireNonNull(requireProperties, 
"requireProperties cannot be null");
+        this.requireProperties = requireProperties;

Review Comment:
   if could be null, use optional



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);
+    private static final Map<Class<? extends AggregateFunction>,
+            Pair<Class<? extends AggregateFunction>, Class<? extends 
AggregateFunction>>> aggFunctionMap =
+            ImmutableMap.of(
+                    Sum.class, Pair.of(Sum.class, Sum.class),
+                    Min.class, Pair.of(Min.class, Min.class),
+                    Max.class, Pair.of(Max.class, Max.class),
+                    Count.class, Pair.of(Count.class, Sum0.class),
+                    Sum0.class, Pair.of(Sum0.class, Sum0.class),
+                    AnyValue.class, Pair.of(AnyValue.class, AnyValue.class)
+                    );
+
+    @Override
+    public Rule build() {
+        return logicalAggregate()
+                .whenNot(agg -> agg.getGroupByExpressions().isEmpty())
+                .then(this::apply).toRule(RuleType.DISTINCT_AGGREGATE_SPLIT);
+    }
+
+    private boolean checkByStatistics(LogicalAggregate<? extends Plan> 
aggregate) {

Review Comment:
   add to when, change name to `shouldSplit`



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);

Review Comment:
   group_concat is ok?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java:
##########
@@ -76,7 +77,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
     private final boolean ordinalIsResolved;
     private final boolean generated;
     private final boolean hasPushed;
-
+    private final AggregateParam aggregateParam;
+    private final Optional<List<Expression>> partitionExpressions;

Review Comment:
   partition expression is useless now



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java:
##########
@@ -38,22 +40,21 @@ public AggregateParam(AggPhase aggPhase, AggMode aggMode) {
         this(aggPhase, aggMode, true);
     }
 
-    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
canBeBanned) {
+    /** AggregateParam */
+    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
needSplit) {
+        this(aggPhase, aggMode, true, needSplit);
+    }
+
+    /** AggregateParam */
+    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
canBeBanned, boolean needSplit) {
         this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be 
null");
         this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be 
null");
         this.canBeBanned = canBeBanned;
-    }
-
-    public AggregateParam withAggPhase(AggPhase aggPhase) {
-        return new AggregateParam(aggPhase, aggMode, canBeBanned);
-    }
-
-    public AggregateParam withAggPhase(AggMode aggMode) {
-        return new AggregateParam(aggPhase, aggMode, canBeBanned);
-    }
-
-    public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode 
aggMode) {
-        return new AggregateParam(aggPhase, aggMode, canBeBanned);
+        // 从三阶段拆分出来的顶层一阶段的agg needSplit设置为false,
+        // 如果指定了false,那么不拆分, 如果指定了true,不代表需要拆分,
+        // 需要同时满足(!aggMode.productAggregateBuffer && 
!aggMode.consumeAggregateBuffer)
+        // needSplit && (!aggMode.productAggregateBuffer && 
!aggMode.consumeAggregateBuffer) 才需要拆分

Review Comment:
   why not compute needSplit in constructor?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);
+    private static final Map<Class<? extends AggregateFunction>,
+            Pair<Class<? extends AggregateFunction>, Class<? extends 
AggregateFunction>>> aggFunctionMap =
+            ImmutableMap.of(
+                    Sum.class, Pair.of(Sum.class, Sum.class),
+                    Min.class, Pair.of(Min.class, Min.class),
+                    Max.class, Pair.of(Max.class, Max.class),
+                    Count.class, Pair.of(Count.class, Sum0.class),
+                    Sum0.class, Pair.of(Sum0.class, Sum0.class),
+                    AnyValue.class, Pair.of(AnyValue.class, AnyValue.class)
+                    );
+
+    @Override
+    public Rule build() {
+        return logicalAggregate()
+                .whenNot(agg -> agg.getGroupByExpressions().isEmpty())
+                .then(this::apply).toRule(RuleType.DISTINCT_AGGREGATE_SPLIT);
+    }
+
+    private boolean checkByStatistics(LogicalAggregate<? extends Plan> 
aggregate) {
+        // 带group by的场景, group by key ndv低, distinct key ndv高, 
则转为multi_distinct
+        //      其他情况都拆分
+        // 不带group by的场景, distinct key的ndv低,使用multi_distinct, ndv高,使用cte拆分
+        return true;
+    }
+
+    private Plan apply(LogicalAggregate<? extends Plan> aggregate) {
+        Set<AggregateFunction> aggFuncs = aggregate.getAggregateFunctions();
+        //这个函数同时也要处理count(distinct a,b)这种
+        // 需要保证1.只有一个count(distinct)函数
+        // 如果是multi_distinct不处理
+        Set<AggregateFunction> distinctAggFuncs = new HashSet<>();
+        Set<AggregateFunction> otherFunctions = new HashSet<>();
+        for (AggregateFunction aggFunc : aggFuncs) {
+            if (aggFunc.isDistinct()) {
+                distinctAggFuncs.add(aggFunc);
+            } else {
+                otherFunctions.add(aggFunc);
+            }
+        }
+        if (distinctAggFuncs.size() != 1) {
+            return null;
+        }
+
+        // 并不是所有的都能拆分, other function里面如果有一些不能拆分的函数,就不拆
+        // 这里先实现一下,然后后续再考虑一下什么函数可以拆分.
+        for (AggregateFunction aggFunc : otherFunctions) {
+            if (!supportSplitOtherFunctions.contains(aggFunc.getClass())) {
+                return null;
+            }
+        }
+
+        // 满足条件了,开始写拆分的代码
+        // 先构造下面进行去重的AGG
+        // group by key为group by key + distinct key
+        List<Expression> groupByKeys = ImmutableList.<Expression>builder()
+                .addAll(aggregate.getGroupByExpressions())
+                
.addAll(distinctAggFuncs.iterator().next().getDistinctArguments())
+                .build();
+        // 需要把otherFunction输出一下
+        ImmutableList.Builder<NamedExpression> bottomAggOtherFunctions = 
ImmutableList.builder();
+        Map<AggregateFunction, NamedExpression> aggFuncToSlot = new 
HashMap<>();
+        for (AggregateFunction aggFunc : otherFunctions) {
+            Alias bottomAggFuncAlias = new Alias(aggFunc);
+            bottomAggOtherFunctions.add(bottomAggFuncAlias);
+            aggFuncToSlot.put(aggFunc, bottomAggFuncAlias.toSlot());
+        }
+
+        List<NamedExpression> aggOutput = 
ImmutableList.<NamedExpression>builder()
+                .addAll((List) groupByKeys)
+                .addAll(bottomAggOtherFunctions.build())
+                .build();
+
+        LogicalAggregate<Plan> bottomAgg = new LogicalAggregate<>(groupByKeys, 
aggOutput, aggregate.child());
+
+        // 然后构造上面的AGG
+        List<NamedExpression> topAggOutput = 
ExpressionUtils.rewriteDownShortCircuit(aggregate.getOutputExpressions(),
+                expr -> {
+                    if (expr instanceof AggregateFunction) {
+                        AggregateFunction aggFunc = (AggregateFunction) expr;
+                        // 如果是distinct function,那么直接把distinct去掉
+                        if (aggFunc.isDistinct()) {
+                            if (aggFunc.arity() == 1) {
+                                return aggFunc.withDistinctAndChildren(false, 
aggFunc.children());
+                            } else if (aggFunc instanceof Count && 
aggFunc.arity() > 1) {
+                                return 
AggregateUtils.countDistinctMultiExprToCountIf((Count) aggFunc);
+                            }
+                        } else {
+                            if (aggFuncToSlot.get(aggFunc) != null) {
+                                if (aggFunc instanceof Count) {
+                                    return new 
Count(aggFuncToSlot.get(aggFunc));

Review Comment:
   should be `Sum0` ?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateSplitter.java:
##########
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Rewrites queries containing DISTINCT aggregate functions by splitting them 
into two processing layers:
+ * 1. Lower layer: Performs deduplication on both grouping columns and 
DISTINCT columns
+ * 2. Upper layer: Applies simple aggregation on the deduplicated results
+ *
+ * For example, transforms:
+ *   SELECT COUNT(DISTINCT a), count(c) FROM t GROUP BY b
+ * Into:
+ *   SELECT COUNT(a), sum0(cnt) FROM (
+ *     SELECT a, b, count(c) cnt FROM t GROUP BY a, b
+ *   ) GROUP BY b
+ */
+public class DistinctAggregateSplitter extends OneRewriteRuleFactory {
+    public static final DistinctAggregateSplitter INSTANCE = new 
DistinctAggregateSplitter();
+    private static final Set<Class<? extends AggregateFunction>> 
supportSplitOtherFunctions = ImmutableSet.of(
+            Sum.class, Min.class, Max.class, Count.class, Sum0.class, 
AnyValue.class);
+    private static final Map<Class<? extends AggregateFunction>,
+            Pair<Class<? extends AggregateFunction>, Class<? extends 
AggregateFunction>>> aggFunctionMap =

Review Comment:
   not used



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/SplitAggMultiPhase.java:
##########
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.stats.ExpressionEstimation;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import org.apache.doris.nereids.trees.plans.AggMode;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**SplitAggMultiPhase
+ * only process agg with distinct function, split Agg into multi phase
+ * */
+public class SplitAggMultiPhase extends SplitAggRule implements 
ExplorationRuleFactory {
+    public static final SplitAggMultiPhase INSTANCE = new SplitAggMultiPhase();
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                logicalAggregate()
+                        .when(agg -> agg.getAggregateParam().needSplit)
+                        .when(Aggregate::hasDistinctFunc)
+                        .when(agg -> !agg.getGroupByExpressions().isEmpty())
+                        .thenApplyMulti(ctx -> rewrite(ctx.root))
+                        .toRule(RuleType.SPLIT_AGG_MULTI_PHASE)
+        );
+    }
+
+    private List<Plan> rewrite(LogicalAggregate<? extends Plan> aggregate) {
+        if (shouldUseThreePhase(aggregate)) {
+            return ImmutableList.<Plan>builder()
+                    .add(splitToTwoPlusOnePhase(aggregate))
+                    .add(splitToOnePlusOnePhase(aggregate))
+                    .addAll(splitToOnePlusTwoPhase(aggregate))
+                    .build();
+        } else {
+            return ImmutableList.<Plan>builder()
+                    .add(splitToOnePlusOnePhase(aggregate))
+                    .add(splitToTwoPlusTwoPhase(aggregate))
+                    .addAll(splitToOnePlusTwoPhase(aggregate))
+                    .build();
+        }
+    }
+
+    private Plan splitToTwoPlusOnePhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        Map<AggregateFunction, Alias> middleAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan middleAgg = splitDeduplicateTwoPhase(aggregate, 
middleAggFunctionToAlias,
+                aggregate.getGroupByExpressions(), localAggGroupBySet);
+
+        // third phase
+        AggregateParam inputToResultParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.INPUT_TO_RESULT,
+                false);
+        return splitDistinctOnePhase(aggregate, inputToResultParam, 
middleAggFunctionToAlias, middleAgg);
+    }
+
+    private Plan splitToOnePlusOnePhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        // first phase
+        AggregateParam inputToResultParamFirst = new 
AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT, false);
+        AggregateParam paramForAggFunc = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_BUFFER);
+        Map<AggregateFunction, Alias> localAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan localAgg = splitDeduplicateOnePhase(aggregate, 
localAggGroupBySet, inputToResultParamFirst,
+                paramForAggFunc, localAggFunctionToAlias, aggregate.child(),
+                Utils.fastToImmutableList(aggregate.getGroupByExpressions()));
+
+        // second phase
+        AggregateParam inputToResultParamSecond = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL,
+                AggMode.INPUT_TO_RESULT, false);
+        return splitDistinctOnePhase(aggregate, inputToResultParamSecond, 
localAggFunctionToAlias, localAgg);
+    }
+
+    private Plan splitToTwoPlusTwoPhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        Map<AggregateFunction, Alias> middleAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan middleAgg = splitDeduplicateTwoPhase(aggregate, 
middleAggFunctionToAlias,
+                Utils.fastToImmutableList(localAggGroupBySet), 
localAggGroupBySet);
+
+        return splitDistinctTwoPhase(aggregate, middleAggFunctionToAlias, 
middleAgg);
+    }
+
+    private List<Plan> splitToOnePlusTwoPhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        // first phase
+        AggregateParam paramForAgg = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, false);
+        AggregateParam paramForAggFunc = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_BUFFER, false);
+
+        Map<AggregateFunction, Alias> localAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan localAgg = splitDeduplicateOnePhase(aggregate, 
localAggGroupBySet, paramForAgg, paramForAggFunc,
+                localAggFunctionToAlias, aggregate.child(),
+                Utils.fastToImmutableList(aggregate.getDistinctArguments()));
+        AggregateParam param = new AggregateParam(AggPhase.DISTINCT_GLOBAL, 
AggMode.INPUT_TO_RESULT, false);
+        return 
ImmutableList.<Plan>builder().add(splitDistinctTwoPhase(aggregate, 
localAggFunctionToAlias, localAgg))
+                .add(splitDistinctOnePhase(aggregate, param, 
localAggFunctionToAlias, localAgg))
+                .build();
+    }
+
+    private LogicalAggregate<? extends Plan> 
splitDistinctOnePhase(LogicalAggregate<? extends Plan> aggregate,
+            AggregateParam inputToResultParamSecond, Map<AggregateFunction, 
Alias> childAggFuncMap, Plan child) {
+        List<NamedExpression> globalOutput = 
ExpressionUtils.rewriteDownShortCircuit(
+                aggregate.getOutputExpressions(), expr -> {
+                    if (expr instanceof AggregateFunction) {
+                        AggregateFunction aggFunc = (AggregateFunction) expr;
+                        if (aggFunc.isDistinct()) {
+                            // 测试一下为什么需要checkArgument here
+                            return new AggregateExpression(
+                                    aggFunc.withDistinctAndChildren(false, 
aggFunc.children()),
+                                    inputToResultParamSecond);
+                        } else {
+                            return new AggregateExpression(aggFunc,
+                                    new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT),
+                                    childAggFuncMap.get(aggFunc).toSlot());
+                        }
+                    }
+                    return expr;
+                });
+        return aggregate.withAggParam(globalOutput, 
aggregate.getGroupByExpressions(),
+                inputToResultParamSecond, aggregate.getLogicalProperties(),
+                aggregate.getGroupByExpressions(), child);
+    }
+
+    private boolean shouldUseThreePhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Statistics aggStats = 
aggregate.getGroupExpression().get().getOwnerGroup().getStatistics();
+        Statistics aggChildStats = 
aggregate.getGroupExpression().get().childStatistics(0);
+        for (Expression groupByExpr : aggregate.getGroupByExpressions()) {
+            ColumnStatistic columnStat = 
aggChildStats.findColumnStatistics(groupByExpr);
+            if (columnStat == null) {
+                columnStat = ExpressionEstimation.estimate(groupByExpr, 
aggChildStats);
+            }
+            if (columnStat.isUnKnown) {
+                return true;
+            }
+        }
+        double ndv = aggStats.getRowCount();
+        // 当ndv非常低的情况下,不能使用三阶段AGG,会有倾斜
+        if (ndv < 1000) {

Review Comment:
   why 1000?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java:
##########
@@ -310,4 +311,8 @@ default DataTrait computeDataTrait() {
     void computeEqualSet(DataTrait.Builder builder);
 
     void computeFd(DataTrait.Builder builder);
+
+    default Statistics getStats() {
+        return null;

Review Comment:
   return null is dengerous



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java:
##########
@@ -413,6 +415,52 @@ public Void visitPhysicalWindow(PhysicalWindow<? extends 
Plan> window, PlanConte
         return null;
     }
 
+    @Override
+    public Void visitPhysicalHashAggregate(PhysicalHashAggregate<? extends 
Plan> agg, PlanContext context) {
+        // 先在这里实现一下
+        // group by a,b
+        // 如果agg收到的请求是a,agg发出的请求是a,b,a是a,b的子集, 
那么agg发送a请求给孩子(这里判断一下a的ndv,如果很小的话就还是发a,b)
+        // 如果agg没有收到请求,那还是发出a,b
+        // 如果agg收到了请求,但是没有交集,那么agg仍然发出a,b
+        // 如果是local agg,那么发出any, 如果是global agg,才有要求
+        DistributionSpec parentDist = 
requestPropertyFromParent.getDistributionSpec();
+        if (agg.getAggPhase().isLocal()) {
+            addRequestPropertyToChildren(PhysicalProperties.ANY);
+            return null;
+        } else if (agg.getAggPhase().isGlobal()) {
+            if (agg.getPartitionExpressions().isPresent() && 
!agg.getPartitionExpressions().get().isEmpty()) {

Review Comment:
   what is PartitionExpressions' mean?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/SplitAggMultiPhase.java:
##########
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.stats.ExpressionEstimation;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import org.apache.doris.nereids.trees.plans.AggMode;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**SplitAggMultiPhase
+ * only process agg with distinct function, split Agg into multi phase
+ * */
+public class SplitAggMultiPhase extends SplitAggRule implements 
ExplorationRuleFactory {
+    public static final SplitAggMultiPhase INSTANCE = new SplitAggMultiPhase();
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                logicalAggregate()
+                        .when(agg -> agg.getAggregateParam().needSplit)
+                        .when(Aggregate::hasDistinctFunc)
+                        .when(agg -> !agg.getGroupByExpressions().isEmpty())
+                        .thenApplyMulti(ctx -> rewrite(ctx.root))
+                        .toRule(RuleType.SPLIT_AGG_MULTI_PHASE)
+        );
+    }
+
+    private List<Plan> rewrite(LogicalAggregate<? extends Plan> aggregate) {
+        if (shouldUseThreePhase(aggregate)) {

Review Comment:
   it is better change name to twoPlusOneBetterThanTwoPlusTwo ?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/AggregateUtils.java:
##########
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.util;
+
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Utils for aggregate
+ */
+public class AggregateUtils {
+    public static AggregateFunction 
tryConvertToMultiDistinct(AggregateFunction function) {
+        if (function instanceof SupportMultiDistinct && function.isDistinct()) 
{
+            return ((SupportMultiDistinct) function).convertToMultiDistinct();
+        }
+        return function;
+    }
+
+    public static Expression countDistinctMultiExprToCountIf(Count count) {
+        Set<Expression> arguments = ImmutableSet.copyOf(count.getArguments());
+        Expression countExpr = count.getArgument(arguments.size() - 1);
+        for (int i = arguments.size() - 2; i >= 0; --i) {
+            Expression argument = count.getArgument(i);
+            If ifNull = new If(new IsNull(argument), NullLiteral.INSTANCE, 
countExpr);
+            countExpr = assignNullType(ifNull);

Review Comment:
   why not change null literal's type in here?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/SplitAggMultiPhaseWithoutGbyKey.java:
##########
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctGroupConcat;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum0;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.plans.AggMode;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.AggregateUtils;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Supplier;
+
+/**SplitAggMultiPhaseWithoutGbyKey*/
+public class SplitAggMultiPhaseWithoutGbyKey extends SplitAggRule implements 
ExplorationRuleFactory {
+    public static final SplitAggMultiPhaseWithoutGbyKey INSTANCE = new 
SplitAggMultiPhaseWithoutGbyKey();
+    public static final List<Class<? extends AggregateFunction>> 
finalMultiDistinctSupportFunc =
+            ImmutableList.of(Count.class, Sum.class, Sum0.class);
+    public static final List<Class<? extends AggregateFunction>> 
finalMultiDistinctSupportOtherFunc =
+            ImmutableList.of(Count.class, Sum.class, Min.class, Max.class);

Review Comment:
   sum0?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctStrategy.java:
##########
@@ -0,0 +1,229 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import 
org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext;
+import org.apache.doris.nereids.trees.copier.DeepCopierContext;
+import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.OrderExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Split multi distinct strategy
+ * */
+public class SplitMultiDistinctStrategy {

Review Comment:
   add ut



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/SplitAggMultiPhase.java:
##########
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.stats.ExpressionEstimation;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import org.apache.doris.nereids.trees.plans.AggMode;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**SplitAggMultiPhase
+ * only process agg with distinct function, split Agg into multi phase
+ * */
+public class SplitAggMultiPhase extends SplitAggRule implements 
ExplorationRuleFactory {

Review Comment:
   add ut



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/SplitAggMultiPhase.java:
##########
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.stats.ExpressionEstimation;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import org.apache.doris.nereids.trees.plans.AggMode;
+import org.apache.doris.nereids.trees.plans.AggPhase;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.Utils;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.Statistics;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**SplitAggMultiPhase
+ * only process agg with distinct function, split Agg into multi phase
+ * */
+public class SplitAggMultiPhase extends SplitAggRule implements 
ExplorationRuleFactory {
+    public static final SplitAggMultiPhase INSTANCE = new SplitAggMultiPhase();
+
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+                logicalAggregate()
+                        .when(agg -> agg.getAggregateParam().needSplit)
+                        .when(Aggregate::hasDistinctFunc)
+                        .when(agg -> !agg.getGroupByExpressions().isEmpty())
+                        .thenApplyMulti(ctx -> rewrite(ctx.root))
+                        .toRule(RuleType.SPLIT_AGG_MULTI_PHASE)
+        );
+    }
+
+    private List<Plan> rewrite(LogicalAggregate<? extends Plan> aggregate) {
+        if (shouldUseThreePhase(aggregate)) {
+            return ImmutableList.<Plan>builder()
+                    .add(splitToTwoPlusOnePhase(aggregate))
+                    .add(splitToOnePlusOnePhase(aggregate))
+                    .addAll(splitToOnePlusTwoPhase(aggregate))
+                    .build();
+        } else {
+            return ImmutableList.<Plan>builder()
+                    .add(splitToOnePlusOnePhase(aggregate))
+                    .add(splitToTwoPlusTwoPhase(aggregate))
+                    .addAll(splitToOnePlusTwoPhase(aggregate))
+                    .build();
+        }
+    }
+
+    private Plan splitToTwoPlusOnePhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        Map<AggregateFunction, Alias> middleAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan middleAgg = splitDeduplicateTwoPhase(aggregate, 
middleAggFunctionToAlias,
+                aggregate.getGroupByExpressions(), localAggGroupBySet);
+
+        // third phase
+        AggregateParam inputToResultParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.INPUT_TO_RESULT,
+                false);
+        return splitDistinctOnePhase(aggregate, inputToResultParam, 
middleAggFunctionToAlias, middleAgg);
+    }
+
+    private Plan splitToOnePlusOnePhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        // first phase
+        AggregateParam inputToResultParamFirst = new 
AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT, false);
+        AggregateParam paramForAggFunc = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_BUFFER);
+        Map<AggregateFunction, Alias> localAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan localAgg = splitDeduplicateOnePhase(aggregate, 
localAggGroupBySet, inputToResultParamFirst,
+                paramForAggFunc, localAggFunctionToAlias, aggregate.child(),
+                Utils.fastToImmutableList(aggregate.getGroupByExpressions()));
+
+        // second phase
+        AggregateParam inputToResultParamSecond = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL,
+                AggMode.INPUT_TO_RESULT, false);
+        return splitDistinctOnePhase(aggregate, inputToResultParamSecond, 
localAggFunctionToAlias, localAgg);
+    }
+
+    private Plan splitToTwoPlusTwoPhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        Map<AggregateFunction, Alias> middleAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan middleAgg = splitDeduplicateTwoPhase(aggregate, 
middleAggFunctionToAlias,
+                Utils.fastToImmutableList(localAggGroupBySet), 
localAggGroupBySet);
+
+        return splitDistinctTwoPhase(aggregate, middleAggFunctionToAlias, 
middleAgg);
+    }
+
+    private List<Plan> splitToOnePlusTwoPhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Set<NamedExpression> localAggGroupBySet = getAllKeySet(aggregate);
+        // first phase
+        AggregateParam paramForAgg = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, false);
+        AggregateParam paramForAggFunc = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_BUFFER, false);
+
+        Map<AggregateFunction, Alias> localAggFunctionToAlias = new 
LinkedHashMap<>();
+        Plan localAgg = splitDeduplicateOnePhase(aggregate, 
localAggGroupBySet, paramForAgg, paramForAggFunc,
+                localAggFunctionToAlias, aggregate.child(),
+                Utils.fastToImmutableList(aggregate.getDistinctArguments()));
+        AggregateParam param = new AggregateParam(AggPhase.DISTINCT_GLOBAL, 
AggMode.INPUT_TO_RESULT, false);
+        return 
ImmutableList.<Plan>builder().add(splitDistinctTwoPhase(aggregate, 
localAggFunctionToAlias, localAgg))
+                .add(splitDistinctOnePhase(aggregate, param, 
localAggFunctionToAlias, localAgg))
+                .build();
+    }
+
+    private LogicalAggregate<? extends Plan> 
splitDistinctOnePhase(LogicalAggregate<? extends Plan> aggregate,
+            AggregateParam inputToResultParamSecond, Map<AggregateFunction, 
Alias> childAggFuncMap, Plan child) {
+        List<NamedExpression> globalOutput = 
ExpressionUtils.rewriteDownShortCircuit(
+                aggregate.getOutputExpressions(), expr -> {
+                    if (expr instanceof AggregateFunction) {
+                        AggregateFunction aggFunc = (AggregateFunction) expr;
+                        if (aggFunc.isDistinct()) {
+                            // 测试一下为什么需要checkArgument here
+                            return new AggregateExpression(
+                                    aggFunc.withDistinctAndChildren(false, 
aggFunc.children()),
+                                    inputToResultParamSecond);
+                        } else {
+                            return new AggregateExpression(aggFunc,
+                                    new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT),
+                                    childAggFuncMap.get(aggFunc).toSlot());
+                        }
+                    }
+                    return expr;
+                });
+        return aggregate.withAggParam(globalOutput, 
aggregate.getGroupByExpressions(),
+                inputToResultParamSecond, aggregate.getLogicalProperties(),
+                aggregate.getGroupByExpressions(), child);
+    }
+
+    private boolean shouldUseThreePhase(LogicalAggregate<? extends Plan> 
aggregate) {
+        Statistics aggStats = 
aggregate.getGroupExpression().get().getOwnerGroup().getStatistics();
+        Statistics aggChildStats = 
aggregate.getGroupExpression().get().childStatistics(0);
+        for (Expression groupByExpr : aggregate.getGroupByExpressions()) {
+            ColumnStatistic columnStat = 
aggChildStats.findColumnStatistics(groupByExpr);
+            if (columnStat == null) {
+                columnStat = ExpressionEstimation.estimate(groupByExpr, 
aggChildStats);
+            }
+            if (columnStat.isUnKnown) {
+                return true;

Review Comment:
   if when ndv is small we use four stages agg to avoid skew, why use three 
phase agg when stat is unknown?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to