This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b6d48ed5b [feature](nereids) support distinct count (#12159)
5b6d48ed5b is described below

commit 5b6d48ed5b6db033607224523579da0a77d957f2
Author: yinzhijian <373141...@qq.com>
AuthorDate: Thu Sep 15 13:01:47 2022 +0800

    [feature](nereids) support distinct count (#12159)
    
    support distinct count with group by clause.
    for example:
    SELECT count(distinct c_custkey + 1) FROM customer group by c_nation;
    
    TODO: support distinct count without group by clause.
---
 .../glue/translator/ExpressionTranslator.java      |   2 +
 .../glue/translator/PhysicalPlanTranslator.java    |  17 +-
 .../properties/ChildOutputPropertyDeriver.java     |   2 +-
 .../nereids/properties/RequestPropertyDeriver.java |   7 +-
 .../doris/nereids/rules/analysis/BindFunction.java |  12 ++
 .../expression/rewrite/ExpressionRewrite.java      |   2 +-
 .../LogicalAggToPhysicalHashAgg.java               |   1 +
 .../rules/rewrite/AggregateDisassemble.java        | 235 +++++++++++++++------
 .../rules/rewrite/logical/NormalizeAggregate.java  |   2 +-
 .../expressions/functions/AggregateFunction.java   |  31 +++
 .../nereids/trees/expressions/functions/Count.java |  12 +-
 .../trees/plans/logical/LogicalAggregate.java      |  36 +++-
 .../trees/plans/physical/PhysicalAggregate.java    |  35 ++-
 .../doris/nereids/parser/HavingClauseTest.java     |   4 +-
 .../properties/ChildOutputPropertyDeriverTest.java |   2 +
 .../properties/RequestPropertyDeriverTest.java     |   3 +
 .../rewrite/logical/AggregateDisassembleTest.java  |  81 +++++++
 .../trees/expressions/ExpressionEqualsTest.java    |  20 ++
 .../doris/nereids/trees/plans/PlanEqualsTest.java  |  12 +-
 .../data/nereids_syntax_p0/function.out            |   5 +
 .../suites/nereids_syntax_p0/function.groovy       |   4 +
 21 files changed, 416 insertions(+), 109 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 017ec6b5b7..1c3f59361d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -256,6 +256,8 @@ public class ExpressionTranslator extends 
DefaultExpressionVisitor<Expr, PlanTra
             Count count = (Count) function;
             if (count.isStar()) {
                 return new FunctionCallExpr(function.getName(), 
FunctionParams.createStarParam());
+            } else if (count.isDistinct()) {
+                return new FunctionCallExpr(function.getName(), new 
FunctionParams(true, paramList));
             }
         }
         return new FunctionCallExpr(function.getName(), paramList);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index d47bf1c1aa..a783567a70 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -191,12 +191,17 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         // 3. generate output tuple
         List<Slot> slotList = Lists.newArrayList();
         TupleDescriptor outputTupleDesc;
-        if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
+        if (aggregate.getAggPhase() == AggPhase.LOCAL) {
+            outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, 
context);
+        } else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && 
aggregate.isFinalPhase())
+                || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
             slotList.addAll(groupSlotList);
             slotList.addAll(aggFunctionOutput);
             outputTupleDesc = generateTupleDesc(slotList, null, context);
         } else {
-            outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, 
context);
+            // In the distinct agg scenario, global shares local's desc
+            AggregationNode localAggNode = (AggregationNode) 
inputPlanFragment.getPlanRoot().getChild(0);
+            outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
         }
 
         if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
@@ -204,6 +209,13 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 execAggregateFunction.setMergeForNereids(true);
             }
         }
+        if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
+            for (FunctionCallExpr execAggregateFunction : 
execAggregateFunctions) {
+                if (!execAggregateFunction.isDistinct()) {
+                    execAggregateFunction.setMergeForNereids(true);
+                }
+            }
+        }
         AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, 
execAggregateFunctions, outputTupleDesc,
                 outputTupleDesc, aggregate.getAggPhase().toExec());
         AggregationNode aggregationNode = new 
AggregationNode(context.nextPlanNodeId(),
@@ -216,6 +228,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 aggregationNode.setIntermediateTuple();
                 break;
             case GLOBAL:
+            case DISTINCT_LOCAL:
                 if (currentFragment.getPlanRoot() instanceof ExchangeNode) {
                     ExchangeNode exchangeNode = (ExchangeNode) 
currentFragment.getPlanRoot();
                     currentFragment = new 
PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
index 1d7974e161..ba8976e71d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java
@@ -80,12 +80,12 @@ public class ChildOutputPropertyDeriver extends 
PlanVisitor<PhysicalProperties,
             case LOCAL:
                 return new 
PhysicalProperties(childOutputProperty.getDistributionSpec());
             case GLOBAL:
+            case DISTINCT_LOCAL:
                 List<ExprId> columns = agg.getPartitionExpressions().stream()
                         .map(SlotReference.class::cast)
                         .map(SlotReference::getExprId)
                         .collect(Collectors.toList());
                 return PhysicalProperties.createHash(new 
DistributionSpecHash(columns, ShuffleType.AGGREGATE));
-            case DISTINCT_LOCAL:
             case DISTINCT_GLOBAL:
             default:
                 throw new RuntimeException("Could not derive output properties 
for agg phase: " + agg.getAggPhase());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
index 0c2f2089ae..67a9032f85 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
@@ -25,6 +25,7 @@ import 
org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.plans.AggPhase;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
@@ -82,14 +83,16 @@ public class RequestPropertyDeriver extends 
PlanVisitor<Void, PlanContext> {
             addToRequestPropertyToChildren(PhysicalProperties.ANY);
             return null;
         }
-
+        if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) {
+            addToRequestPropertyToChildren(requestPropertyFromParent);
+            return null;
+        }
         // 2. second phase agg, need to return shuffle with partition key
         List<Expression> partitionExpressions = agg.getPartitionExpressions();
         if (partitionExpressions.isEmpty()) {
             addToRequestPropertyToChildren(PhysicalProperties.GATHER);
             return null;
         }
-
         // TODO: when parent is a join node,
         //    use requestPropertyFromParent to keep column order as join to 
avoid shuffle again.
         if 
(partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
index 40782b2e28..fcef341f5f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
@@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
+import org.apache.doris.nereids.trees.expressions.functions.Count;
 import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
@@ -115,6 +116,17 @@ public class BindFunction implements AnalysisRuleFactory {
 
         @Override
         public BoundFunction visitUnboundFunction(UnboundFunction 
unboundFunction, Env env) {
+            // FunctionRegistry can't support boolean arg now, tricky here.
+            if (unboundFunction.getName().equalsIgnoreCase("count")) {
+                List<Expression> arguments = unboundFunction.getArguments();
+                if ((arguments.size() == 0 && unboundFunction.isStar()) || 
arguments.stream()
+                        .allMatch(Expression::isConstant)) {
+                    return new Count();
+                }
+                if (arguments.size() == 1) {
+                    return new Count(unboundFunction.getArguments().get(0), 
unboundFunction.isDistinct());
+                }
+            }
             FunctionRegistry functionRegistry = env.getFunctionRegistry();
             String functionName = unboundFunction.getName();
             FunctionBuilder builder = functionRegistry.findFunctionBuilder(
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
index d808183c24..f660ee0ec0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java
@@ -126,7 +126,7 @@ public class ExpressionRewrite implements 
RewriteRuleFactory {
                     return agg;
                 }
                 return new LogicalAggregate<>(newGroupByExprs, 
newOutputExpressions,
-                        agg.isDisassembled(), agg.isNormalized(), 
agg.getAggPhase(), agg.child());
+                        agg.isDisassembled(), agg.isNormalized(), 
agg.isFinalPhase(), agg.getAggPhase(), agg.child());
             }).toRule(RuleType.REWRITE_AGG_EXPRESSION);
         }
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
index ecc59393d4..4e4d52b551 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java
@@ -36,6 +36,7 @@ public class LogicalAggToPhysicalHashAgg extends 
OneImplementationRuleFactory {
                 ImmutableList.of(),
                 agg.getAggPhase(),
                 false,
+                agg.isFinalPhase(),
                 agg.getLogicalProperties(),
                 agg.child())
         ).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 7d68752e07..4166d9db5d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -49,92 +49,187 @@ import java.util.stream.Collectors;
  *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as 
b], groupByExpr: [k + 1])
  *       +-- childPlan
  *
+ * Distinct Agg With Group By Processing:
+ * If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, 
Alias(COUNT(distinct v1 * v2) + 1) #2]
+ *                            , groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1, 
Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b])
+ *   +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
+ *       +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as 
a], groupByExpr: [k + 1, a])
+ *           +-- childPlan
+ *
  * TODO:
  *     1. use different class represent different phase aggregate
  *     2. if instance count is 1, shouldn't disassemble the agg plan
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
+    // used in secondDisassemble to transform local expressions into global
+    private final Map<Expression, Expression> globalOutputSubstitutionMap = 
Maps.newHashMap();
+    // used in secondDisassemble to transform local expressions into global
+    private final Map<Expression, Expression> globalGroupBySubstitutionMap = 
Maps.newHashMap();
+    // used to indicate the existence of a distinct function for the entire 
phase
+    private boolean hasDistinctAgg = false;
 
     @Override
     public Rule build() {
         return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
             LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+            LogicalAggregate firstAggregate = firstDisassemble(aggregate);
+            if (!hasDistinctAgg) {
+                return firstAggregate;
+            }
+            return secondDisassemble(firstAggregate);
+        }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+    }
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate 
secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) {
+        LogicalAggregate<GroupPlan> local = aggregate.child();
+        // replace expression in globalOutputExprs and globalGroupByExprs
+        List<NamedExpression> globalOutputExprs = 
local.getOutputExpressions().stream()
+                .map(e -> ExpressionUtils.replace(e, 
globalOutputSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+        List<Expression> globalGroupByExprs = 
local.getGroupByExpressions().stream()
+                .map(e -> ExpressionUtils.replace(e, 
globalGroupBySubstitutionMap))
+                .collect(Collectors.toList());
+
+        // generate new plan
+        LogicalAggregate globalAggregate = new LogicalAggregate<>(
+                globalGroupByExprs,
+                globalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                false,
+                AggPhase.GLOBAL,
+                local
+        );
+        return new LogicalAggregate<>(
+                aggregate.getGroupByExpressions(),
+                aggregate.getOutputExpressions(),
+                true,
+                aggregate.isNormalized(),
+                true,
+                AggPhase.DISTINCT_LOCAL,
+                globalAggregate
+        );
+    }
+
+    private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+        Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
 
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+        // 1. generate a map from local aggregate output to global aggregate 
expr substitution.
+        //    inputSubstitutionMap use for replacing expression in global 
aggregate
+        //    replace rule is:
+        //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
+        //        b. Expression is a group by key and is an expression. e.g. 
group by k1 + 1
+        //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | situation | origin expression   | local output expression | 
expression in global aggregate |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | a         | Ref(k1)#1           | Ref(k1)#1               | 
Ref(k1)#1                      |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | 
Ref(key)#2                     |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | 
AF(af#3)                       |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: 
ExprId x
+        // 2. collect local aggregate output expressions and local aggregate 
group by expression list
+        List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+        List<NamedExpression> localOutputExprs = Lists.newArrayList();
+        for (Expression originGroupByExpr : originGroupByExprs) {
+            if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+                globalOutputSubstitutionMap.put(originGroupByExpr, 
originGroupByExpr);
+                globalGroupBySubstitutionMap.put(originGroupByExpr, 
originGroupByExpr);
+                localOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                NamedExpression localOutputExpr = new Alias(originGroupByExpr, 
originGroupByExpr.toSql());
+                inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
+                globalOutputSubstitutionMap.put(localOutputExpr, 
localOutputExpr.toSlot());
+                globalGroupBySubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
+                localOutputExprs.add(localOutputExpr);
+            }
+        }
+        List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
+        List<NamedExpression> distinctExprsForLocalOutput = 
Lists.newArrayList();
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            Set<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (inputSubstitutionMap.containsKey(aggregateFunction)) {
                     continue;
                 }
-                if (originGroupByExpr instanceof SlotReference) {
-                    inputSubstitutionMap.put(originGroupByExpr, 
originGroupByExpr);
-                    localOutputExprs.add((SlotReference) originGroupByExpr);
-                } else {
-                    NamedExpression localOutputExpr = new 
Alias(originGroupByExpr, originGroupByExpr.toSql());
-                    inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
-                    localOutputExprs.add(localOutputExpr);
-                }
-            }
-            for (NamedExpression originOutputExpr : originOutputExprs) {
-                Set<AggregateFunction> aggregateFunctions
-                        = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
-                for (AggregateFunction aggregateFunction : aggregateFunctions) 
{
-                    if (inputSubstitutionMap.containsKey(aggregateFunction)) {
-                        continue;
+                if (aggregateFunction.isDistinct()) {
+                    hasDistinctAgg = true;
+                    for (Expression expr : aggregateFunction.children()) {
+                        if (expr instanceof SlotReference) {
+                            distinctExprsForLocalOutput.add((SlotReference) 
expr);
+                            if (!inputSubstitutionMap.containsKey(expr)) {
+                                inputSubstitutionMap.put(expr, expr);
+                                globalOutputSubstitutionMap.put(expr, expr);
+                                globalGroupBySubstitutionMap.put(expr, expr);
+                            }
+                        } else {
+                            NamedExpression globalOutputExpr = new Alias(expr, 
expr.toSql());
+                            distinctExprsForLocalOutput.add(globalOutputExpr);
+                            if (!inputSubstitutionMap.containsKey(expr)) {
+                                inputSubstitutionMap.put(expr, 
globalOutputExpr.toSlot());
+                                
globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot());
+                                globalGroupBySubstitutionMap.put(expr, 
globalOutputExpr.toSlot());
+                            }
+                        }
+                        distinctExprsForLocalGroupBy.add(expr);
                     }
-                    NamedExpression localOutputExpr = new 
Alias(aggregateFunction, aggregateFunction.toSql());
-                    Expression substitutionValue = 
aggregateFunction.withChildren(
-                            Lists.newArrayList(localOutputExpr.toSlot()));
-                    inputSubstitutionMap.put(aggregateFunction, 
substitutionValue);
-                    localOutputExprs.add(localOutputExpr);
+                    continue;
                 }
+                NamedExpression localOutputExpr = new Alias(aggregateFunction, 
aggregateFunction.toSql());
+                Expression substitutionValue = aggregateFunction.withChildren(
+                        Lists.newArrayList(localOutputExpr.toSlot()));
+                inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+                globalOutputSubstitutionMap.put(aggregateFunction, 
substitutionValue);
+                localOutputExprs.add(localOutputExpr);
             }
+        }
 
-            // 3. replace expression in globalOutputExprs and 
globalGroupByExprs
-            List<NamedExpression> globalOutputExprs = 
aggregate.getOutputExpressions().stream()
-                    .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
-                    .map(NamedExpression.class::cast)
-                    .collect(Collectors.toList());
-            List<Expression> globalGroupByExprs = localGroupByExprs.stream()
-                    .map(e -> ExpressionUtils.replace(e, 
inputSubstitutionMap)).collect(Collectors.toList());
-
-            // 4. generate new plan
-            LogicalAggregate localAggregate = new LogicalAggregate<>(
-                    localGroupByExprs,
-                    localOutputExprs,
-                    true,
-                    aggregate.isNormalized(),
-                    AggPhase.LOCAL,
-                    aggregate.child()
-            );
-            return new LogicalAggregate<>(
-                    globalGroupByExprs,
-                    globalOutputExprs,
-                    true,
-                    aggregate.isNormalized(),
-                    AggPhase.GLOBAL,
-                    localAggregate
-            );
-        }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+        // 3. replace expression in globalOutputExprs and globalGroupByExprs
+        List<NamedExpression> globalOutputExprs = 
aggregate.getOutputExpressions().stream()
+                .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+        List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+                .map(e -> ExpressionUtils.replace(e, 
inputSubstitutionMap)).collect(Collectors.toList());
+        // To avoid repeated substitution of distinct expressions,
+        // here the expressions are put into the local after the substitution 
is completed
+        localOutputExprs.addAll(distinctExprsForLocalOutput);
+        localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
+        // 4. generate new plan
+        LogicalAggregate localAggregate = new LogicalAggregate<>(
+                localGroupByExprs,
+                localOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                false,
+                AggPhase.LOCAL,
+                aggregate.child()
+        );
+        return new LogicalAggregate<>(
+                globalGroupByExprs,
+                globalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                true,
+                AggPhase.GLOBAL,
+                localAggregate
+        );
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index 0fe139b85b..45a4a3c027 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -124,7 +124,7 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory {
                 root = new LogicalProject<>(bottomProjections, root);
             }
             root = new LogicalAggregate<>(newKeys, newOutputs, 
aggregate.isDisassembled(),
-                    true, aggregate.getAggPhase(), root);
+                    true, aggregate.isFinalPhase(), aggregate.getAggPhase(), 
root);
             List<NamedExpression> projections = outputs.stream()
                     .map(e -> ExpressionUtils.replace(e, substitutionMap))
                     .map(NamedExpression.class::cast)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
index 73de61a058..69572b070a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
@@ -21,19 +21,50 @@ import 
org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.DataType;
 
+import java.util.Objects;
+
 /**
  * The function which consume arguments in lots of rows and product one value.
  */
 public abstract class AggregateFunction extends BoundFunction {
 
     private DataType intermediate;
+    private final boolean isDistinct;
 
     public AggregateFunction(String name, Expression... arguments) {
         super(name, arguments);
+        isDistinct = false;
+    }
+
+    public AggregateFunction(String name, boolean isDistinct, Expression... 
arguments) {
+        super(name, arguments);
+        this.isDistinct = isDistinct;
     }
 
     public abstract DataType getIntermediateType();
 
+    public boolean isDistinct() {
+        return isDistinct;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        AggregateFunction that = (AggregateFunction) o;
+        return Objects.equals(isDistinct, that.isDistinct) && 
Objects.equals(intermediate, that.intermediate)
+                && Objects.equals(getName(), that.getName()) && 
Objects.equals(children, that.children);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(isDistinct, intermediate, getName(), children);
+    }
+
     @Override
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitAggregateFunction(this, context);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
index a31122ab7a..e594671733 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java
@@ -37,8 +37,8 @@ public class Count extends AggregateFunction {
         this.isStar = true;
     }
 
-    public Count(Expression child) {
-        super("count", child);
+    public Count(Expression child, boolean isDistinct) {
+        super("count", isDistinct, child);
         this.isStar = false;
     }
 
@@ -62,7 +62,7 @@ public class Count extends AggregateFunction {
         if (children.size() == 0) {
             return new Count();
         }
-        return new Count(children.get(0));
+        return new Count(children.get(0), isDistinct());
     }
 
     @Override
@@ -79,6 +79,9 @@ public class Count extends AggregateFunction {
                 .stream()
                 .map(Expression::toSql)
                 .collect(Collectors.joining(", "));
+        if (isDistinct()) {
+            return "count(distinct " + args + ")";
+        }
         return "count(" + args + ")";
     }
 
@@ -91,6 +94,9 @@ public class Count extends AggregateFunction {
                 .stream()
                 .map(Expression::toString)
                 .collect(Collectors.joining(", "));
+        if (isDistinct()) {
+            return "count(distinct " + args + ")";
+        }
         return "count(" + args + ")";
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index cbe9e402ef..0cca04950d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -59,6 +59,13 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
     private final List<NamedExpression> outputExpressions;
     private final AggPhase aggPhase;
 
+    // use for scenes containing distinct agg
+    // 1. If there are LOCAL and GLOBAL phases, global is the final phase
+    // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL 
is the final phase
+    // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
+    // DISTINCT_GLOBAL is the final phase
+    private final boolean isFinalPhase;
+
     /**
      * Desc: Constructor for LogicalAggregate.
      */
@@ -66,7 +73,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             List<Expression> groupByExpressions,
             List<NamedExpression> outputExpressions,
             CHILD_TYPE child) {
-        this(groupByExpressions, outputExpressions, false, false, 
AggPhase.GLOBAL, child);
+        this(groupByExpressions, outputExpressions, false, false, true, 
AggPhase.GLOBAL, child);
     }
 
     public LogicalAggregate(
@@ -74,9 +81,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             List<NamedExpression> outputExpressions,
             boolean disassembled,
             boolean normalized,
+            boolean isFinalPhase,
             AggPhase aggPhase,
             CHILD_TYPE child) {
-        this(groupByExpressions, outputExpressions, disassembled, normalized,
+        this(groupByExpressions, outputExpressions, disassembled, normalized, 
isFinalPhase,
                 aggPhase, Optional.empty(), Optional.empty(), child);
     }
 
@@ -88,6 +96,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
             List<NamedExpression> outputExpressions,
             boolean disassembled,
             boolean normalized,
+            boolean isFinalPhase,
             AggPhase aggPhase,
             Optional<GroupExpression> groupExpression,
             Optional<LogicalProperties> logicalProperties,
@@ -97,6 +106,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         this.outputExpressions = outputExpressions;
         this.disassembled = disassembled;
         this.normalized = normalized;
+        this.isFinalPhase = isFinalPhase;
         this.aggPhase = aggPhase;
     }
 
@@ -149,6 +159,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
         return normalized;
     }
 
+    public boolean isFinalPhase() {
+        return isFinalPhase;
+    }
+
     /**
      * Determine the equality with another plan
      */
@@ -164,37 +178,37 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHIL
                 && Objects.equals(outputExpressions, that.outputExpressions)
                 && aggPhase == that.aggPhase
                 && disassembled == that.disassembled
-                && normalized == that.normalized;
+                && normalized == that.normalized
+                && isFinalPhase == that.isFinalPhase;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExpressions, outputExpressions, aggPhase, 
normalized, disassembled);
+        return Objects.hash(groupByExpressions, outputExpressions, aggPhase, 
normalized, disassembled, isFinalPhase);
     }
 
     @Override
     public LogicalAggregate<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
         return new LogicalAggregate<>(groupByExpressions, outputExpressions,
-                disassembled, normalized, aggPhase, children.get(0));
+                disassembled, normalized, isFinalPhase, aggPhase, 
children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
-        return new LogicalAggregate<>(groupByExpressions, outputExpressions,
-                disassembled, normalized, aggPhase, groupExpression, 
Optional.of(getLogicalProperties()),
-                children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
disassembled, normalized, isFinalPhase,
+                aggPhase, groupExpression, 
Optional.of(getLogicalProperties()), children.get(0));
     }
 
     @Override
     public LogicalAggregate<Plan> 
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
-        return new LogicalAggregate<>(groupByExpressions, outputExpressions,
-                disassembled, normalized, aggPhase, Optional.empty(), 
logicalProperties, children.get(0));
+        return new LogicalAggregate<>(groupByExpressions, outputExpressions, 
disassembled, normalized, isFinalPhase,
+                aggPhase, Optional.empty(), logicalProperties, 
children.get(0));
     }
 
     public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> 
groupByExprList,
             List<NamedExpression> outputExpressionList) {
         return new LogicalAggregate<>(groupByExprList, outputExpressionList,
-                disassembled, normalized, aggPhase, child());
+                disassembled, normalized, isFinalPhase, aggPhase, child());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
index f2384920e5..8557a61ea5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java
@@ -53,11 +53,18 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
 
     private final boolean usingStream;
 
+    // use for scenes containing distinct agg
+    // 1. If there are LOCAL and GLOBAL phases, global is the final phase
+    // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL 
is the final phase
+    // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
+    // DISTINCT_GLOBAL is the final phase
+    private final boolean isFinalPhase;
+
     public PhysicalAggregate(List<Expression> groupByExpressions, 
List<NamedExpression> outputExpressions,
             List<Expression> partitionExpressions, AggPhase aggPhase, boolean 
usingStream,
-            LogicalProperties logicalProperties, CHILD_TYPE child) {
+            boolean isFinalPhase, LogicalProperties logicalProperties, 
CHILD_TYPE child) {
         this(groupByExpressions, outputExpressions, partitionExpressions, 
aggPhase, usingStream,
-                Optional.empty(), logicalProperties, child);
+                isFinalPhase, Optional.empty(), logicalProperties, child);
     }
 
     /**
@@ -69,7 +76,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
      * @param usingStream whether it's stream agg.
      */
     public PhysicalAggregate(List<Expression> groupByExpressions, 
List<NamedExpression> outputExpressions,
-            List<Expression> partitionExpressions, AggPhase aggPhase, boolean 
usingStream,
+            List<Expression> partitionExpressions, AggPhase aggPhase, boolean 
usingStream, boolean isFinalPhase,
             Optional<GroupExpression> groupExpression, LogicalProperties 
logicalProperties,
             CHILD_TYPE child) {
         super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, 
child);
@@ -78,6 +85,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
         this.aggPhase = aggPhase;
         this.partitionExpressions = partitionExpressions;
         this.usingStream = usingStream;
+        this.isFinalPhase = isFinalPhase;
     }
 
     /**
@@ -89,7 +97,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
      * @param usingStream whether it's stream agg.
      */
     public PhysicalAggregate(List<Expression> groupByExpressions, 
List<NamedExpression> outputExpressions,
-            List<Expression> partitionExpressions, AggPhase aggPhase, boolean 
usingStream,
+            List<Expression> partitionExpressions, AggPhase aggPhase, boolean 
usingStream, boolean isFinalPhase,
             Optional<GroupExpression> groupExpression, LogicalProperties 
logicalProperties,
             PhysicalProperties physicalProperties, CHILD_TYPE child) {
         super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, 
physicalProperties, child);
@@ -98,6 +106,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
         this.aggPhase = aggPhase;
         this.partitionExpressions = partitionExpressions;
         this.usingStream = usingStream;
+        this.isFinalPhase = isFinalPhase;
     }
 
     public AggPhase getAggPhase() {
@@ -112,6 +121,10 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
         return outputExpressions;
     }
 
+    public boolean isFinalPhase() {
+        return isFinalPhase;
+    }
+
     public boolean isUsingStream() {
         return usingStream;
     }
@@ -156,36 +169,38 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CH
                 && Objects.equals(outputExpressions, that.outputExpressions)
                 && Objects.equals(partitionExpressions, 
that.partitionExpressions)
                 && usingStream == that.usingStream
-                && aggPhase == that.aggPhase;
+                && aggPhase == that.aggPhase
+                && isFinalPhase == that.isFinalPhase;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase, usingStream);
+        return Objects.hash(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase, usingStream,
+                isFinalPhase);
     }
 
     @Override
     public PhysicalAggregate<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1);
         return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, getLogicalProperties(), children.get(0));
+                usingStream, isFinalPhase, getLogicalProperties(), 
children.get(0));
     }
 
     @Override
     public PhysicalAggregate<CHILD_TYPE> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
         return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, groupExpression, getLogicalProperties(), child());
+                usingStream, isFinalPhase, groupExpression, 
getLogicalProperties(), child());
     }
 
     @Override
     public PhysicalAggregate<CHILD_TYPE> 
withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
         return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, Optional.empty(), logicalProperties.get(), 
child());
+                usingStream, isFinalPhase, Optional.empty(), 
logicalProperties.get(), child());
     }
 
     @Override
     public PhysicalAggregate<CHILD_TYPE> 
withPhysicalProperties(PhysicalProperties physicalProperties) {
         return new PhysicalAggregate<>(groupByExpressions, outputExpressions, 
partitionExpressions, aggPhase,
-                usingStream, Optional.empty(), getLogicalProperties(), 
physicalProperties, child());
+                usingStream, isFinalPhase, Optional.empty(), 
getLogicalProperties(), physicalProperties, child());
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
index dd09c58a50..29ef8bb2ff 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java
@@ -360,9 +360,9 @@ public class HavingClauseTest extends AnalyzeCheckTestBase 
implements PatternMat
         Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, 
Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
         Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 
2)), "(pk + 2)");
         Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
-        Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1), 
Literal.of((byte) 1)), "(COUNT(a1) + 1)");
+        Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, 
false), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
         Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), 
"SUM((a1 + a2))");
-        Alias v1 = new Alias(new ExprId(0), new Count(a2), "v1");
+        Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1");
         PlanChecker.from(connectContext).analyze(sql)
                 .matchesFromRoot(
                     logicalProject(
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
index 08d91b777f..fe0b577cc4 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java
@@ -263,6 +263,7 @@ public class ChildOutputPropertyDeriverTest {
                 Lists.newArrayList(key),
                 AggPhase.LOCAL,
                 true,
+                true,
                 logicalProperties,
                 groupPlan
         );
@@ -286,6 +287,7 @@ public class ChildOutputPropertyDeriverTest {
                 Lists.newArrayList(partition),
                 AggPhase.GLOBAL,
                 true,
+                true,
                 logicalProperties,
                 groupPlan
         );
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
index dda5c0b006..9802a7d66b 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
@@ -146,6 +146,7 @@ public class RequestPropertyDeriverTest {
                 Lists.newArrayList(key),
                 AggPhase.LOCAL,
                 true,
+                true,
                 logicalProperties,
                 groupPlan
         );
@@ -168,6 +169,7 @@ public class RequestPropertyDeriverTest {
                 Lists.newArrayList(partition),
                 AggPhase.GLOBAL,
                 true,
+                true,
                 logicalProperties,
                 groupPlan
         );
@@ -192,6 +194,7 @@ public class RequestPropertyDeriverTest {
                 Lists.newArrayList(),
                 AggPhase.GLOBAL,
                 true,
+                true,
                 logicalProperties,
                 groupPlan
         );
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index 72f4a8829a..ef32f31def 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Count;
 import org.apache.doris.nereids.trees.expressions.functions.Sum;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.plans.AggPhase;
@@ -269,6 +270,86 @@ public class AggregateDisassembleTest {
                 global.getOutputExpressions().get(0).getExprId());
     }
 
+    /**
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2) 
as c], groupByExpr: [id + 3])
+     *   +-- childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) + 
2) as c], groupByExpr: [a])
+     *   +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a, 
b])
+     *       +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age + 
1) as b], groupByExpr: [id + 3, age + 1])
+     *           +-- childPlan(id, name, age)
+     */
+    @Test
+    public void distinctAggregateWithGroupBy() {
+        List<Expression> groupExpressionList = Lists.newArrayList(
+                new Add(rStudent.getOutput().get(0).toSlot(), new 
IntegerLiteral(3)));
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(new 
Alias(
+                new Add(new Count(new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true),
+                        new IntegerLiteral(2)), "c"));
+        Plan root = new LogicalAggregate<>(groupExpressionList, 
outputExpressionList, rStudent);
+
+        Plan after = rewrite(root);
+
+        Assertions.assertTrue(after instanceof LogicalUnary);
+        Assertions.assertTrue(after instanceof LogicalAggregate);
+        Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
+        LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
+        LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
+        LogicalAggregate<Plan> local = (LogicalAggregate) 
after.child(0).child(0);
+        Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, 
distinctLocal.getAggPhase());
+        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
+        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+        // check local:
+        // id + 3
+        Expression localOutput0 = new 
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
+        // age + 1
+        Expression localOutput1 = new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
+        // id + 3
+        Expression localGroupBy0 = new 
Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
+        // age + 1
+        Expression localGroupBy1 = new 
Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
+
+        Assertions.assertEquals(2, local.getOutputExpressions().size());
+        Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof 
Alias);
+        Assertions.assertEquals(localOutput0, 
local.getOutputExpressions().get(0).child(0));
+        Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof 
Alias);
+        Assertions.assertEquals(localOutput1, 
local.getOutputExpressions().get(1).child(0));
+        Assertions.assertEquals(2, local.getGroupByExpressions().size());
+        Assertions.assertEquals(localGroupBy0, 
local.getGroupByExpressions().get(0));
+        Assertions.assertEquals(localGroupBy1, 
local.getGroupByExpressions().get(1));
+
+        // check global:
+        Expression globalOutput0 = 
local.getOutputExpressions().get(0).toSlot();
+        Expression globalOutput1 = 
local.getOutputExpressions().get(1).toSlot();
+        Expression globalGroupBy0 = 
local.getOutputExpressions().get(0).toSlot();
+        Expression globalGroupBy1 = 
local.getOutputExpressions().get(1).toSlot();
+
+        Assertions.assertEquals(2, global.getOutputExpressions().size());
+        Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof 
SlotReference);
+        Assertions.assertEquals(globalOutput0, 
global.getOutputExpressions().get(0));
+        Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof 
SlotReference);
+        Assertions.assertEquals(globalOutput1, 
global.getOutputExpressions().get(1));
+        Assertions.assertEquals(2, global.getGroupByExpressions().size());
+        Assertions.assertEquals(globalGroupBy0, 
global.getGroupByExpressions().get(0));
+        Assertions.assertEquals(globalGroupBy1, 
global.getGroupByExpressions().get(1));
+
+        // check distinct local:
+        Expression distinctLocalOutput = new Add(new 
Count(local.getOutputExpressions().get(1).toSlot(), true),
+                new IntegerLiteral(2));
+        Expression distinctLocalGroupBy = 
local.getOutputExpressions().get(0).toSlot();
+
+        Assertions.assertEquals(1, 
distinctLocal.getOutputExpressions().size());
+        Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) 
instanceof Alias);
+        Assertions.assertEquals(distinctLocalOutput, 
distinctLocal.getOutputExpressions().get(0).child(0));
+        Assertions.assertEquals(1, 
distinctLocal.getGroupByExpressions().size());
+        Assertions.assertEquals(distinctLocalGroupBy, 
distinctLocal.getGroupByExpressions().get(0));
+
+        // check id:
+        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
+                distinctLocal.getOutputExpressions().get(0).getExprId());
+    }
+
     private Plan rewrite(Plan input) {
         return PlanRewriter.topDownRewrite(input, new ConnectContext(), new 
AggregateDisassemble());
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
index 5860d95c6b..71d8248655 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
 import org.apache.doris.nereids.analyzer.UnboundAlias;
 import org.apache.doris.nereids.analyzer.UnboundFunction;
 import org.apache.doris.nereids.analyzer.UnboundStar;
+import org.apache.doris.nereids.trees.expressions.functions.Count;
 import org.apache.doris.nereids.trees.expressions.functions.Sum;
 import org.apache.doris.nereids.types.IntegerType;
 
@@ -168,6 +169,25 @@ public class ExpressionEqualsTest {
         Assertions.assertEquals(sum1.hashCode(), sum2.hashCode());
     }
 
+    @Test
+    public void testAggregateFunction() {
+        Count count1 = new Count();
+        Count count2 = new Count();
+        Assertions.assertEquals(count1, count2);
+        Assertions.assertEquals(count1.hashCode(), count2.hashCode());
+
+        Count count3 = new Count(child1, true);
+        Count count4 = new Count(child2, true);
+        Assertions.assertEquals(count3, count4);
+        Assertions.assertEquals(count3.hashCode(), count4.hashCode());
+
+        // bad case
+        Count count5 = new Count(child1, true);
+        Count count6 = new Count(child2, false);
+        Assertions.assertNotEquals(count5, count6);
+        Assertions.assertNotEquals(count5.hashCode(), count6.hashCode());
+    }
+
     @Test
     public void testNamedExpression() {
         ExprId aliasId = new ExprId(2);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
index 1d7878a2db..cdd5454e78 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
@@ -71,17 +71,17 @@ public class PlanEqualsTest {
 
         unexpected = new LogicalAggregate<>(Lists.newArrayList(), 
ImmutableList.of(
                 new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, 
true, Lists.newArrayList())),
-                true, false, AggPhase.GLOBAL, child);
+                true, false, true, AggPhase.GLOBAL, child);
         Assertions.assertNotEquals(unexpected, actual);
 
         unexpected = new LogicalAggregate<>(Lists.newArrayList(), 
ImmutableList.of(
                 new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, 
true, Lists.newArrayList())),
-                false, true, AggPhase.GLOBAL, child);
+                false, true, true, AggPhase.GLOBAL, child);
         Assertions.assertNotEquals(unexpected, actual);
 
         unexpected = new LogicalAggregate<>(Lists.newArrayList(), 
ImmutableList.of(
                 new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, 
true, Lists.newArrayList())),
-                false, false, AggPhase.LOCAL, child);
+                false, false, true, AggPhase.LOCAL, child);
         Assertions.assertNotEquals(unexpected, actual);
     }
 
@@ -178,20 +178,20 @@ public class PlanEqualsTest {
         List<NamedExpression> outputExpressionList = ImmutableList.of(
                 new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, 
true, Lists.newArrayList()));
         PhysicalAggregate<Plan> actual = new 
PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList,
-                Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, 
child);
+                Lists.newArrayList(), AggPhase.LOCAL, true, true, 
logicalProperties, child);
 
         List<NamedExpression> outputExpressionList1 = ImmutableList.of(
                 new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, 
true, Lists.newArrayList()));
         PhysicalAggregate<Plan> expected = new 
PhysicalAggregate<>(Lists.newArrayList(),
                 outputExpressionList1,
-                Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, 
child);
+                Lists.newArrayList(), AggPhase.LOCAL, true, true, 
logicalProperties, child);
         Assertions.assertEquals(expected, actual);
 
         List<NamedExpression> outputExpressionList2 = ImmutableList.of(
                 new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, 
true, Lists.newArrayList()));
         PhysicalAggregate<Plan> unexpected = new 
PhysicalAggregate<>(Lists.newArrayList(),
                 outputExpressionList2,
-                Lists.newArrayList(), AggPhase.LOCAL, false, 
logicalProperties, child);
+                Lists.newArrayList(), AggPhase.LOCAL, false, true, 
logicalProperties, child);
         Assertions.assertNotEquals(unexpected, actual);
     }
 
diff --git a/regression-test/data/nereids_syntax_p0/function.out 
b/regression-test/data/nereids_syntax_p0/function.out
index cac9a7c5b1..b1d705b814 100644
--- a/regression-test/data/nereids_syntax_p0/function.out
+++ b/regression-test/data/nereids_syntax_p0/function.out
@@ -11,6 +11,11 @@
 -- !count --
 3      3
 
+-- !distinct_count --
+1
+1
+1
+
 -- !avg --
 2.5E-323       1.1644193E-317
 
diff --git a/regression-test/suites/nereids_syntax_p0/function.groovy 
b/regression-test/suites/nereids_syntax_p0/function.groovy
index c4099a0798..a041fc36ab 100644
--- a/regression-test/suites/nereids_syntax_p0/function.groovy
+++ b/regression-test/suites/nereids_syntax_p0/function.groovy
@@ -41,6 +41,10 @@ suite("function") {
         SELECT count(c_city), count(*) AS custdist FROM customer;
     """
 
+    order_qt_distinct_count """
+        SELECT count(distinct c_custkey + 1) AS custdist FROM customer group 
by c_city;
+    """
+
     order_qt_avg """
         SELECT avg(lo_tax), avg(lo_extendedprice) AS avg_extendedprice FROM 
lineorder;
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to