This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit b1e39ad7859d22c37a34d9704787bb481f0d0f58
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Fri Jul 7 17:45:55 2023 +0800

    [opt](Nereids) forbid some bad case on agg plans (#21565)
    
    1. forbid all candidates that need to gather process except must do it
    2. forbid do local agg after reshuffle of two phase agg of distinct
    3. forbid one phase agg after reshuffle
    4. forbid three or four phase agg for distinct if any stage need reshuffle
    5. forbid multi distinct for one distinct agg if do not need reshuffle
---
 .../glue/translator/PhysicalPlanTranslator.java    |   2 +-
 .../properties/ChildrenPropertiesRegulator.java    |  53 +++-
 .../rules/implementation/AggregateStrategies.java  | 333 ++++++++++++---------
 .../expressions/functions/agg/AggregateParam.java  |  31 +-
 .../functions/agg/MultiDistinctCount.java          |   2 +-
 .../functions/agg/MultiDistinctGroupConcat.java    |   2 +-
 .../functions/agg/MultiDistinctSum.java            |   4 +-
 .../functions/agg/MultiDistinction.java            |  27 ++
 .../nereids/trees/plans/algebra/Aggregate.java     |   6 +-
 .../physical/PhysicalStorageLayerAggregate.java    |  14 +-
 .../apache/doris/nereids/util/ExpressionUtils.java |   5 +
 .../rules/rewrite/AggregateStrategiesTest.java     |   2 +
 .../nereids_tpcds_shape_sf100_p0/shape/query94.out |   4 +-
 .../nereids_tpcds_shape_sf100_p0/shape/query95.out |   4 +-
 .../nereids_tpch_shape_sf1000_p0/shape/q16.out     |   4 +-
 .../data/nereids_tpch_shape_sf500_p0/shape/q16.out |   4 +-
 .../suites/nereids_p0/join/test_join.groovy        |   9 -
 .../suites/nereids_syntax_p0/agg_4_phase.groovy    |  12 +-
 .../nereids_syntax_p0/aggregate_strategies.groovy  |   7 +-
 .../suites/nereids_syntax_p0/group_concat.groovy   |   4 +-
 20 files changed, 314 insertions(+), 215 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 650bcc43a9..9844179633 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -1742,7 +1742,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 .map(e -> {
                     Expression function = e.child(0).child(0);
                     if (function instanceof AggregateFunction) {
-                        AggregateParam param = AggregateParam.localResult();
+                        AggregateParam param = AggregateParam.LOCAL_RESULT;
                         function = new AggregateExpression((AggregateFunction) 
function, param);
                     }
                     return ExpressionTranslator.translate(function, context);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
index 811c26569a..54e4c4780a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
@@ -23,7 +23,12 @@ import org.apache.doris.nereids.cost.CostCalculator;
 import org.apache.doris.nereids.jobs.JobContext;
 import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
 import org.apache.doris.nereids.trees.plans.AggMode;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
@@ -43,6 +48,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * ensure child add enough distribute. update children properties if we do 
regular
@@ -88,12 +94,55 @@ public class ChildrenPropertiesRegulator extends 
PlanVisitor<Boolean, Void> {
 
     @Override
     public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends 
Plan> agg, Void context) {
-        // forbid one phase agg on distribute
-        if (agg.getAggMode() == AggMode.INPUT_TO_RESULT
+        if (!agg.getAggregateParam().canBeBanned) {
+            return true;
+        }
+        // forbid one phase agg on distribute and three or four stage distinct 
agg inter by distribute
+        if ((agg.getAggMode() == AggMode.INPUT_TO_RESULT || agg.getAggMode() 
== AggMode.BUFFER_TO_BUFFER)
                 && children.get(0).getPlan() instanceof PhysicalDistribute) {
             // this means one stage gather agg, usually bad pattern
             return false;
         }
+        // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
+        // TODO: this is forbid good plan after cte reuse by mistake
+        if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER
+                && requiredProperties.get(0).getDistributionSpec() instanceof 
DistributionSpecHash
+                && children.get(0).getPlan() instanceof PhysicalDistribute) {
+            return false;
+        }
+        // forbid multi distinct opt that bad than multi-stage version when 
multi-stage can be executed in one fragment
+        if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER || agg.getAggMode() == 
AggMode.INPUT_TO_RESULT) {
+            List<MultiDistinction> multiDistinctions = 
agg.getOutputExpressions().stream()
+                    .filter(Alias.class::isInstance)
+                    .map(a -> ((Alias) a).child())
+                    .filter(AggregateExpression.class::isInstance)
+                    .map(a -> ((AggregateExpression) a).getFunction())
+                    .filter(MultiDistinction.class::isInstance)
+                    .map(MultiDistinction.class::cast)
+                    .collect(Collectors.toList());
+            if (multiDistinctions.size() == 1) {
+                Expression distinctChild = multiDistinctions.get(0).child(0);
+                DistributionSpec childDistribution = 
childrenProperties.get(0).getDistributionSpec();
+                if (distinctChild instanceof SlotReference && 
childDistribution instanceof DistributionSpecHash) {
+                    SlotReference slotReference = (SlotReference) 
distinctChild;
+                    DistributionSpecHash distributionSpecHash = 
(DistributionSpecHash) childDistribution;
+                    List<ExprId> groupByColumns = 
agg.getGroupByExpressions().stream()
+                            .map(SlotReference.class::cast)
+                            .map(SlotReference::getExprId)
+                            .collect(Collectors.toList());
+                    DistributionSpecHash groupByRequire = new 
DistributionSpecHash(
+                            groupByColumns, ShuffleType.REQUIRE);
+                    List<ExprId> distinctChildColumns = 
Lists.newArrayList(slotReference.getExprId());
+                    distinctChildColumns.add(slotReference.getExprId());
+                    DistributionSpecHash distinctChildRequire = new 
DistributionSpecHash(
+                            distinctChildColumns, ShuffleType.REQUIRE);
+                    if ((!groupByColumns.isEmpty() && 
distributionSpecHash.satisfy(groupByRequire))
+                            || (groupByColumns.isEmpty() && 
distributionSpecHash.satisfy(distinctChildRequire))) {
+                        return false;
+                    }
+                }
+            }
+        }
         // process must shuffle
         visit(agg, context);
         // process agg
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index 551e73532c..6df6b2f817 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -70,11 +70,13 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -123,41 +125,43 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     .when(agg -> agg.getDistinctArguments().size() == 0)
                     .thenApplyMulti(ctx -> 
twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
             ),
-            RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
-                basePattern
-                    .when(this::containsCountDistinctMultiExpr)
-                    .thenApplyMulti(ctx -> 
twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
-            ),
+            // RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
+            //     basePattern
+            //         .when(this::containsCountDistinctMultiExpr)
+            //         .thenApplyMulti(ctx -> 
twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
+            // ),
             RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
                 basePattern
                     .when(this::containsCountDistinctMultiExpr)
                     .thenApplyMulti(ctx -> 
threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
             ),
-            RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
-                basePattern
-                    .when(agg -> agg.getDistinctArguments().size() == 1)
-                    .thenApplyMulti(ctx -> 
twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
-            ),
             RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
                 basePattern
-                    .when(agg -> agg.getDistinctArguments().size() == 1 && 
enableSingleDistinctColumnOpt())
+                    .when(agg -> agg.getDistinctArguments().size() == 1 && 
couldConvertToMulti(agg))
                     .thenApplyMulti(ctx -> 
onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
             ),
             RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
                 basePattern
-                    .when(agg -> agg.getDistinctArguments().size() == 1 && 
enableSingleDistinctColumnOpt())
+                    .when(agg -> agg.getDistinctArguments().size() == 1 && 
couldConvertToMulti(agg))
                     .thenApplyMulti(ctx -> 
twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
             ),
-            RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
-                basePattern
-                    .when(agg -> agg.getDistinctArguments().size() == 1)
-                    .thenApplyMulti(ctx -> 
threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
-            ),
             RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build(
                 basePattern
-                    .when(agg -> agg.getDistinctArguments().size() > 1 && 
!containsCountDistinctMultiExpr(agg))
+                    .when(agg -> agg.getDistinctArguments().size() > 1
+                            && !containsCountDistinctMultiExpr(agg)
+                            && couldConvertToMulti(agg))
                     .thenApplyMulti(ctx -> 
twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
             ),
+            // RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
+            //     basePattern
+            //         .when(agg -> agg.getDistinctArguments().size() == 1)
+            //         .thenApplyMulti(ctx -> 
twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
+            // ),
+            RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
+                    basePattern
+                            .when(agg -> agg.getDistinctArguments().size() == 
1)
+                            .thenApplyMulti(ctx -> 
threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
+            ),
             RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
                     basePattern
                             .when(agg -> agg.getDistinctArguments().size() == 
1)
@@ -169,15 +173,15 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(*) from tbl
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *               LogicalAggregate(groupBy=[], output=[count(*)])
      *                                |
      *                       LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *               LogicalAggregate(groupBy=[], output=[count(*)])
      *                                |
      *        PhysicalStorageLayerAggregate(pushAggOp=COUNT, 
table=PhysicalOlapScan(table=tbl))
@@ -205,7 +209,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .map(AggregateFunction::getClass)
                 .collect(Collectors.toSet());
 
-        Map<Class, PushDownAggOp> supportedAgg = 
PushDownAggOp.supportedFunctions();
+        Map<Class<? extends AggregateFunction>, PushDownAggOp> supportedAgg = 
PushDownAggOp.supportedFunctions();
         if (!supportedAgg.keySet().containsAll(functionClasses)) {
             return canNotPush;
         }
@@ -292,7 +296,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                          || colType == PrimitiveType.STRING) {
                     return canNotPush;
                 }
-                if (colType.isCharFamily() && mergeOp != PushDownAggOp.COUNT 
&& column.getType().getLength() > 512) {
+                if (colType.isCharFamily() && column.getType().getLength() > 
512) {
                     return canNotPush;
                 }
             }
@@ -324,25 +328,25 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(*) from tbl group by id
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *          LogicalAggregate(groupBy=[id], output=[count(*)])
      *                       |
      *               LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *             PhysicalHashAggregate(groupBy=[id], output=[count(*)])
      *                              |
      *                 PhysicalDistribute(distributionSpec=GATHER)
      *                             |
      *                     LogicalOlapScan(table=tbl)
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *            PhysicalHashAggregate(groupBy=[id], output=[count(*)])
      *                                    |
      *           LogicalOlapScan(table=tbl, **already distribute by id**)
@@ -351,7 +355,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
     private List<PhysicalHashAggregate<Plan>> onePhaseAggregateWithoutDistinct(
             LogicalAggregate<? extends Plan> logicalAgg, ConnectContext 
connectContext) {
         RequireProperties requireGather = 
RequireProperties.of(PhysicalProperties.GATHER);
-        AggregateParam inputToResultParam = AggregateParam.localResult();
+        AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT;
         List<NamedExpression> newOutput = 
ExpressionUtils.rewriteDownShortCircuit(
                 logicalAgg.getOutputExpressions(), outputChild -> {
                     if (outputChild instanceof AggregateFunction) {
@@ -366,7 +370,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 requireGather, logicalAgg.child());
 
         if (logicalAgg.getGroupByExpressions().isEmpty()) {
-            return ImmutableList.of(gatherLocalAgg);
+            // TODO: usually bad, disable it until we could do better cost 
computation.
+            // return ImmutableList.of(gatherLocalAgg);
+            return ImmutableList.of();
         } else {
             RequireProperties requireHash = RequireProperties.of(
                     
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), 
ShuffleType.REQUIRE));
@@ -383,17 +389,17 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(distinct id, name) from tbl group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *          LogicalAggregate(groupBy=[name], output=[count(distinct id, 
name)])
      *                               |
      *                       LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, 
null, name))])
      *                                |
      *          PhysicalHashAggregate(groupBy=[name, id], output=[name, id])
@@ -401,9 +407,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      *           PhysicalDistribute(distributionSpec=GATHER)
      *                               |
      *                     LogicalOlapScan(table=tbl)
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, 
null, name))])
      *                                |
      *          PhysicalHashAggregate(groupBy=[name, id], output=[name, id])
@@ -415,7 +421,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      */
     private List<PhysicalHashAggregate<Plan>> 
twoPhaseAggregateWithCountDistinctMulti(
             LogicalAggregate<? extends Plan> logicalAgg, CascadesContext 
cascadesContext) {
-        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER);
+        AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER;
         Collection<Expression> countDistinctArguments = 
logicalAgg.getDistinctArguments();
 
         List<Expression> localAggGroupBy = 
ImmutableList.copyOf(ImmutableSet.<Expression>builder()
@@ -487,7 +493,8 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     .withRequireTree(requireHash.withChildren(requireHash))
                     
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
             return ImmutableList.<PhysicalHashAggregate<Plan>>builder()
-                    .add(gatherLocalGatherDistinctAgg)
+                    // TODO: usually bad, disable it until we could do better 
cost computation.
+                    //.add(gatherLocalGatherDistinctAgg)
                     .add(hashLocalHashGlobalAgg)
                     .build();
         }
@@ -495,17 +502,17 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(distinct id, name) from tbl group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *          LogicalAggregate(groupBy=[name], output=[count(distinct id, 
name)])
      *                               |
      *                       LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, 
null, name))])
      *                                   |
      *          PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=BUFFER_TO_BUFFER)
@@ -515,9 +522,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      *       PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=INPUT_TO_BUFFER)
      *                                   |
      *                        LogicalOlapScan(table=tbl)
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, 
null, name))])
      *                                   |
      *          PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=BUFFER_TO_BUFFER)
@@ -566,11 +573,17 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
         List<Expression> globalAggGroupBy = localAggGroupBy;
 
-        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
+        boolean hasCountDistinctMulti = 
logicalAgg.getAggregateFunctions().stream()
+                .filter(AggregateFunction::isDistinct)
+                .filter(Count.class::isInstance)
+                .anyMatch(c -> c.arity() > 1);
+        AggregateParam bufferToBufferParam = new AggregateParam(
+                AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, 
!hasCountDistinctMulti);
+
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
                 nonDistinctAggFunctionToAliasPhase1.entrySet()
                         .stream()
-                        .collect(ImmutableMap.toImmutableMap(kv -> 
kv.getKey(), kv -> {
+                        .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv 
-> {
                             AggregateFunction originFunction = kv.getKey();
                             Alias localOutputAlias = kv.getValue();
                             AggregateExpression globalAggExpr = new 
AggregateExpression(
@@ -596,7 +609,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 logicalAgg, cascadesContext).first;
 
         AggregateParam distinctInputToResultParam
-                = new AggregateParam(AggPhase.DISTINCT_LOCAL, 
AggMode.INPUT_TO_RESULT);
+                = new AggregateParam(AggPhase.DISTINCT_LOCAL, 
AggMode.INPUT_TO_RESULT, !hasCountDistinctMulti);
         AggregateParam globalBufferToResultParam
                 = new AggregateParam(AggPhase.GLOBAL, 
AggMode.BUFFER_TO_RESULT);
         List<NamedExpression> distinctOutput = 
ExpressionUtils.rewriteDownShortCircuit(
@@ -621,19 +634,19 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 logicalAgg.getLogicalProperties(), requireGather, 
anyLocalGatherGlobalAgg
         );
 
-        RequireProperties requireDistinctHash = RequireProperties.of(
-                
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), 
ShuffleType.REQUIRE));
-        PhysicalHashAggregate<? extends Plan> 
anyLocalHashGlobalGatherDistinctAgg
-                = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of(
-                        anyLocalGatherGlobalAgg
-                                .withRequire(requireDistinctHash)
-                                
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
-                ));
+        // RequireProperties requireDistinctHash = RequireProperties.of(
+        //         
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), 
ShuffleType.REQUIRE));
+        // PhysicalHashAggregate<? extends Plan> 
anyLocalHashGlobalGatherDistinctAgg
+        //         = 
anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of(
+        //                 anyLocalGatherGlobalAgg
+        //                         .withRequire(requireDistinctHash)
+        //                         
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
+        //         ));
 
         if (logicalAgg.getGroupByExpressions().isEmpty()) {
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
                     .add(anyLocalGatherGlobalGatherAgg)
-                    .add(anyLocalHashGlobalGatherDistinctAgg)
+                    //.add(anyLocalHashGlobalGatherDistinctAgg)
                     .build();
         } else {
             RequireProperties requireGroupByHash = RequireProperties.of(
@@ -646,8 +659,8 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     )
                     
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
-                    .add(anyLocalGatherGlobalGatherAgg)
-                    .add(anyLocalHashGlobalGatherDistinctAgg)
+                    // .add(anyLocalGatherGlobalGatherAgg)
+                    // .add(anyLocalHashGlobalGatherDistinctAgg)
                     .add(anyLocalHashGlobalHashDistinctAgg)
                     .build();
         }
@@ -655,17 +668,17 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select name, count(value) from tbl group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *          LogicalAggregate(groupBy=[name], output=[name, count(value)])
      *                               |
      *                       LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], 
mode=BUFFER_TO_RESULT)
      *                                |
      *               PhysicalDistribute(distributionSpec=GATHER)
@@ -673,9 +686,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      *          PhysicalHashAggregate(groupBy=[name], output=[name, 
count(value)], mode=INPUT_TO_BUFFER)
      *                                |
      *                     LogicalOlapScan(table=tbl)
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], 
mode=BUFFER_TO_RESULT)
      *                                |
      *               PhysicalDistribute(distributionSpec=HASH(name))
@@ -713,6 +726,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
         AggregateParam bufferToResultParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
         List<NamedExpression> globalAggOutput = 
ExpressionUtils.rewriteDownShortCircuit(
                 logicalAgg.getOutputExpressions(), outputChild -> {
+                    if (!(outputChild instanceof AggregateFunction)) {
+                        return outputChild;
+                    }
                     Alias inputToBufferAlias = 
inputToBufferAliases.get(outputChild);
                     if (inputToBufferAlias == null) {
                         return outputChild;
@@ -722,7 +738,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 });
 
         RequireProperties requireGather = 
RequireProperties.of(PhysicalProperties.GATHER);
-        PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new 
PhysicalHashAggregate(
+        PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new 
PhysicalHashAggregate<>(
                 localAggGroupBy, globalAggOutput, 
Optional.of(partitionExpressions),
                 bufferToResultParam, false, anyLocalAgg.getLogicalProperties(),
                 requireGather, anyLocalAgg);
@@ -746,17 +762,17 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(distinct id) from tbl group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *               LogicalAggregate(groupBy=[name], output=[name, 
count(distinct id)])
      *                                         |
      *                              LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=INPUT_TO_BUFFER)
@@ -764,9 +780,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      *                     PhysicalDistribute(distributionSpec=GATHER)
      *                                          |
      *                               LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * distribute node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=INPUT_TO_BUFFER)
@@ -781,7 +797,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
         Set<AggregateFunction> aggregateFunctions = 
logicalAgg.getAggregateFunctions();
 
         Set<Expression> distinctArguments = aggregateFunctions.stream()
-                .filter(aggregateExpression -> 
aggregateExpression.isDistinct())
+                .filter(AggregateFunction::isDistinct)
                 .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
                 .collect(ImmutableSet.toImmutableSet());
 
@@ -790,7 +806,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .addAll(distinctArguments)
                 .build();
 
-        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, true);
+        AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER;
 
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = 
aggregateFunctions.stream()
                 .filter(aggregateFunction -> !aggregateFunction.isDistinct())
@@ -822,10 +838,13 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     if (outputChild instanceof AggregateFunction) {
                         AggregateFunction aggregateFunction = 
(AggregateFunction) outputChild;
                         if (aggregateFunction.isDistinct()) {
-                            
Preconditions.checkArgument(aggregateFunction.arity() == 1);
+                            Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
+                            Preconditions.checkArgument(aggChild.size() == 1,
+                                    "cannot process more than one child in 
aggregate distinct function: "
+                                            + aggregateFunction);
                             AggregateFunction nonDistinct = aggregateFunction
-                                    .withDistinctAndChildren(false, 
aggregateFunction.getArguments());
-                            return new AggregateExpression(nonDistinct, 
AggregateParam.localResult());
+                                    .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
+                            return new AggregateExpression(nonDistinct, 
AggregateParam.LOCAL_RESULT);
                         } else {
                             Alias alias = 
nonDistinctAggFunctionToAliasPhase1.get(outputChild);
                             return new AggregateExpression(
@@ -850,7 +869,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                             
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
                     ));
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
-                    .add(gatherLocalGatherGlobalAgg)
+                    //.add(gatherLocalGatherGlobalAgg)
                     .add(hashLocalGatherGlobalAgg)
                     .build();
         } else {
@@ -863,7 +882,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     )
                     
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
-                    .add(gatherLocalGatherGlobalAgg)
+                    // .add(gatherLocalGatherGlobalAgg)
                     .add(hashLocalHashGlobalAgg)
                     .build();
         }
@@ -871,16 +890,16 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(distinct id) from tbl group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *               LogicalAggregate(groupBy=[name], output=[name, 
count(distinct id)])
      *                                         |
      *                              LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
      *  single node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=BUFFER_TO_BUFFER)
@@ -890,9 +909,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=INPUT_TO_BUFFER)
      *                                          |
      *                               LogicalOlapScan(table=tbl)
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
count(distinct(id)], mode=BUFFER_TO_RESULT)
      *                                          |
      *     PhysicalHashAggregate(groupBy=[name, id], output=[name, id], 
mode=BUFFER_TO_BUFFER)
@@ -907,10 +926,12 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
     // TODO: support one phase aggregate(group by columns + distinct columns) 
+ two phase distinct aggregate
     private List<PhysicalHashAggregate<? extends Plan>> 
threePhaseAggregateWithDistinct(
             LogicalAggregate<? extends Plan> logicalAgg, ConnectContext 
connectContext) {
+        boolean couldBanned = couldConvertToMulti(logicalAgg);
+
         Set<AggregateFunction> aggregateFunctions = 
logicalAgg.getAggregateFunctions();
 
         Set<Expression> distinctArguments = aggregateFunctions.stream()
-                .filter(aggregateExpression -> 
aggregateExpression.isDistinct())
+                .filter(AggregateFunction::isDistinct)
                 .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
                 .collect(ImmutableSet.toImmutableSet());
 
@@ -919,7 +940,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .addAll(distinctArguments)
                 .build();
 
-        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER);
+        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
 
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = 
aggregateFunctions.stream()
                 .filter(aggregateFunction -> !aggregateFunction.isDistinct())
@@ -942,11 +963,11 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 maybeUsingStreamAgg, Optional.empty(), 
logicalAgg.getLogicalProperties(),
                 requireAny, logicalAgg.child());
 
-        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
+        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
                 nonDistinctAggFunctionToAliasPhase1.entrySet()
                     .stream()
-                    .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv 
-> {
+                    .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
                         AggregateFunction originFunction = kv.getKey();
                         Alias localOutput = kv.getValue();
                         AggregateExpression globalAggExpr = new 
AggregateExpression(
@@ -965,15 +986,19 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
                 requireGather, anyLocalAgg);
 
-        AggregateParam bufferToResultParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT);
+        AggregateParam bufferToResultParam = new AggregateParam(
+                AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, couldBanned);
         List<NamedExpression> distinctOutput = 
ExpressionUtils.rewriteDownShortCircuit(
                 logicalAgg.getOutputExpressions(), expr -> {
                     if (expr instanceof AggregateFunction) {
                         AggregateFunction aggregateFunction = 
(AggregateFunction) expr;
                         if (aggregateFunction.isDistinct()) {
-                            
Preconditions.checkArgument(aggregateFunction.arity() == 1);
+                            Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
+                            Preconditions.checkArgument(aggChild.size() == 1,
+                                    "cannot process more than one child in 
aggregate distinct function: "
+                                            + aggregateFunction);
                             AggregateFunction nonDistinct = aggregateFunction
-                                    .withDistinctAndChildren(false, 
aggregateFunction.getArguments());
+                                    .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
                             return new AggregateExpression(nonDistinct,
                                     bufferToResultParam, 
aggregateFunction.child(0));
                         } else {
@@ -1017,8 +1042,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     )
                     
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
-                    .add(anyLocalGatherGlobalGatherDistinctAgg)
-                    .add(anyLocalHashGlobalGatherDistinctAgg)
+                    // TODO: this plan pattern is not good usually, we remove 
it temporary.
+                    //.add(anyLocalGatherGlobalGatherDistinctAgg)
+                    //.add(anyLocalHashGlobalGatherDistinctAgg)
                     .add(anyLocalHashGlobalHashDistinctAgg)
                     .build();
         }
@@ -1026,25 +1052,25 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(distinct id) from (...) group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *          LogicalAggregate(groupBy=[name], output=[count(distinct id)])
      *                       |
      *                    any plan
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *             PhysicalHashAggregate(groupBy=[name], 
output=[multi_distinct_count(id)])
      *                                    |
      *                 PhysicalDistribute(distributionSpec=GATHER)
      *                                    |
      *                                any plan
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *            PhysicalHashAggregate(groupBy=[name], 
output=[multi_distinct_count(id)])
      *                                    |
      *                     any plan(**already distribute by name**)
@@ -1052,7 +1078,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      */
     private List<PhysicalHashAggregate<? extends Plan>> 
onePhaseAggregateWithMultiDistinct(
             LogicalAggregate<? extends Plan> logicalAgg, ConnectContext 
connectContext) {
-        AggregateParam inputToResultParam = AggregateParam.localResult();
+        AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT;
         List<NamedExpression> newOutput = 
ExpressionUtils.rewriteDownShortCircuit(
                 logicalAgg.getOutputExpressions(), outputChild -> {
                     if (outputChild instanceof AggregateFunction) {
@@ -1068,7 +1094,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 maybeUsingStreamAgg(connectContext, logicalAgg),
                 logicalAgg.getLogicalProperties(), requireGather, 
logicalAgg.child());
         if (logicalAgg.getGroupByExpressions().isEmpty()) {
-            return ImmutableList.of(gatherLocalAgg);
+            // TODO: usually bad, disable it until we could do better cost 
computation.
+            // return ImmutableList.of(gatherLocalAgg);
+            return ImmutableList.of();
         } else {
             RequireProperties requireHash = RequireProperties.of(
                     
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), 
ShuffleType.REQUIRE));
@@ -1085,17 +1113,17 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * sql: select count(distinct id) from tbl group by name
-     *
+     * <p>
      * before:
-     *
+     * <p>
      *          LogicalAggregate(groupBy=[name], output=[name, count(distinct 
id)])
      *                               |
      *                       LogicalOlapScan(table=tbl)
-     *
+     * <p>
      * after:
-     *
+     * <p>
      *  single node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
multi_count_distinct(value)], mode=BUFFER_TO_RESULT)
      *                                                 |
      *                                
PhysicalDistribute(distributionSpec=GATHER)
@@ -1103,9 +1131,9 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
multi_count_distinct(value)], mode=INPUT_TO_BUFFER)
      *                                                 |
      *                                       LogicalOlapScan(table=tbl)
-     *
+     * <p>
      *  distribute node aggregate:
-     *
+     * <p>
      *     PhysicalHashAggregate(groupBy=[name], output=[name, 
multi_count_distinct(value)], mode=BUFFER_TO_RESULT)
      *                                                |
      *                               
PhysicalDistribute(distributionSpec=HASH(name))
@@ -1157,17 +1185,16 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 RequireProperties.of(PhysicalProperties.GATHER), anyLocalAgg);
 
         if (logicalAgg.getGroupByExpressions().isEmpty()) {
-            Collection<Expression> distinctArguments = 
logicalAgg.getDistinctArguments();
-            RequireProperties requireDistinctHash = 
RequireProperties.of(PhysicalProperties.createHash(
-                    distinctArguments, ShuffleType.REQUIRE));
-            PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = 
anyLocalGatherGlobalAgg
-                    .withChildren(ImmutableList.of(anyLocalAgg
-                            .withRequire(requireDistinctHash)
-                            
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
-                    ));
+            // Collection<Expression> distinctArguments = 
logicalAgg.getDistinctArguments();
+            // RequireProperties requireDistinctHash = 
RequireProperties.of(PhysicalProperties.createHash(
+            //         distinctArguments, ShuffleType.REQUIRE));
+            // PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg 
= anyLocalGatherGlobalAgg
+            //         .withChildren(ImmutableList.of(anyLocalAgg
+            //                 .withRequire(requireDistinctHash)
+            //                 
.withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments()))
+            //         ));
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
                     .add(anyLocalGatherGlobalAgg)
-                    .add(hashLocalGatherGlobalAgg)
                     .build();
         } else {
             RequireProperties requireHash = RequireProperties.of(
@@ -1176,7 +1203,8 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     .withRequire(requireHash)
                     
.withPartitionExpressions(logicalAgg.getGroupByExpressions());
             return ImmutableList.<PhysicalHashAggregate<? extends 
Plan>>builder()
-                    .add(anyLocalGatherGlobalAgg)
+                    // TODO: usually bad, disable it until we could do better 
cost computation.
+                    // .add(anyLocalGatherGlobalAgg)
                     .add(anyLocalHashGlobalAgg)
                     .build();
         }
@@ -1215,7 +1243,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
     /**
      * countDistinctMultiExprToCountIf.
-     *
+     * <p>
      * NOTE: this function will break the normalized output, e.g. from 
`count(distinct slot1, slot2)` to
      *       `count(if(slot1 is null, null, slot2))`. So if you invoke this 
method, and separate the
      *       phase of aggregate, please normalize to slot and create a bottom 
project like NormalizeAggregate.
@@ -1268,15 +1296,10 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
         return connectContext == null || 
connectContext.getSessionVariable().enablePushDownNoGroupAgg();
     }
 
-    private boolean enableSingleDistinctColumnOpt() {
-        ConnectContext connectContext = ConnectContext.get();
-        return connectContext == null || 
connectContext.getSessionVariable().enableSingleDistinctColumnOpt();
-    }
-
     /**
      * sql:
      * select count(distinct name), sum(age) from student;
-     *
+     * <p>
      * 4 phase plan
      * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), 
sum(age#5)], [GATHER]
      * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), 
partial_sum(age)), hash distribute by name
@@ -1286,26 +1309,29 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
      */
     private List<PhysicalHashAggregate<? extends Plan>> 
fourPhaseAggregateWithDistinct(
             LogicalAggregate<? extends Plan> logicalAgg, ConnectContext 
connectContext) {
+        boolean couldBanned = couldConvertToMulti(logicalAgg);
+
         Set<AggregateFunction> aggregateFunctions = 
logicalAgg.getAggregateFunctions();
 
-        Set<Expression> distinctArguments = aggregateFunctions.stream()
-                .filter(aggregateExpression -> 
aggregateExpression.isDistinct())
+        Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
+                .filter(AggregateFunction::isDistinct)
                 .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
+                .map(NamedExpression.class::cast)
                 .collect(ImmutableSet.toImmutableSet());
 
         Set<NamedExpression> localAggGroupBySet = 
ImmutableSet.<NamedExpression>builder()
-                .addAll((List) logicalAgg.getGroupByExpressions())
+                .addAll((List<NamedExpression>) (List) 
logicalAgg.getGroupByExpressions())
                 .addAll(distinctArguments)
                 .build();
 
-        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, true);
+        AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
 
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = 
aggregateFunctions.stream()
                 .filter(aggregateFunction -> !aggregateFunction.isDistinct())
                 .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
                     AggregateExpression localAggExpr = new 
AggregateExpression(expr, inputToBufferParam);
                     return new Alias(localAggExpr, localAggExpr.toSql());
-                }));
+                }, (oldValue, newValue) -> newValue));
 
         List<NamedExpression> localAggOutput = 
ImmutableList.<NamedExpression>builder()
                 .addAll(localAggGroupBySet)
@@ -1321,11 +1347,11 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 maybeUsingStreamAgg, Optional.empty(), 
logicalAgg.getLogicalProperties(),
                 requireAny, logicalAgg.child());
 
-        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER);
+        AggregateParam bufferToBufferParam = new 
AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
                 nonDistinctAggFunctionToAliasPhase1.entrySet()
                         .stream()
-                        .collect(ImmutableMap.toImmutableMap(kv -> 
kv.getKey(), kv -> {
+                        .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv 
-> {
                             AggregateFunction originFunction = kv.getKey();
                             Alias localOutput = kv.getValue();
                             AggregateExpression globalAggExpr = new 
AggregateExpression(
@@ -1350,7 +1376,8 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 requireDistinctHash, anyLocalAgg);
 
         // phase 3
-        AggregateParam distinctLocalParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER);
+        AggregateParam distinctLocalParam = new AggregateParam(
+                AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
         Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = 
new HashMap<>();
         List<NamedExpression> localDistinctOutput = Lists.newArrayList();
         for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
@@ -1361,9 +1388,12 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                         if (expr instanceof AggregateFunction) {
                             AggregateFunction aggregateFunction = 
(AggregateFunction) expr;
                             if (aggregateFunction.isDistinct()) {
-                                
Preconditions.checkArgument(aggregateFunction.arity() == 1);
+                                Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
+                                Preconditions.checkArgument(aggChild.size() == 
1,
+                                        "cannot process more than one child in 
aggregate distinct function: "
+                                                + aggregateFunction);
                                 AggregateFunction nonDistinct = 
aggregateFunction
-                                        .withDistinctAndChildren(false, 
aggregateFunction.getArguments());
+                                        .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
                                 AggregateExpression nonDistinctAggExpr = new 
AggregateExpression(nonDistinct,
                                         distinctLocalParam, 
aggregateFunction.child(0));
                                 return nonDistinctAggExpr;
@@ -1389,7 +1419,8 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 requireDistinctHash, anyLocalHashGlobalAgg);
 
         //phase 4
-        AggregateParam distinctGlobalParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT);
+        AggregateParam distinctGlobalParam = new AggregateParam(
+                AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, 
couldBanned);
         List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
         for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
             NamedExpression outputExpr = 
logicalAgg.getOutputExpressions().get(i);
@@ -1397,9 +1428,12 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 if (expr instanceof AggregateFunction) {
                     AggregateFunction aggregateFunction = (AggregateFunction) 
expr;
                     if (aggregateFunction.isDistinct()) {
-                        Preconditions.checkArgument(aggregateFunction.arity() 
== 1);
+                        Set<Expression> aggChild = 
Sets.newHashSet(aggregateFunction.children());
+                        Preconditions.checkArgument(aggChild.size() == 1,
+                                "cannot process more than one child in 
aggregate distinct function: "
+                                        + aggregateFunction);
                         AggregateFunction nonDistinct = aggregateFunction
-                                .withDistinctAndChildren(false, 
aggregateFunction.getArguments());
+                                .withDistinctAndChildren(false, 
ImmutableList.copyOf(aggChild));
                         int idx = 
logicalAgg.getOutputExpressions().indexOf(outputExpr);
                         Alias localDistinctAlias = (Alias) 
(localDistinctOutput.get(idx));
                         return new AggregateExpression(nonDistinct,
@@ -1424,4 +1458,11 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .add(distinctGlobal)
                 .build();
     }
+
+    private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> 
aggregate) {
+        return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), 
expr ->
+                expr instanceof AggregateFunction && ((AggregateFunction) 
expr).isDistinct()
+                        && (expr.arity() > 1
+                        || !(expr instanceof Count || expr instanceof Sum || 
expr instanceof GroupConcat)));
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
index 2ff8eb262f..89e9c1ea7d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java
@@ -25,43 +25,35 @@ import java.util.Objects;
 /** AggregateParam. */
 public class AggregateParam {
 
-    public final AggPhase aggPhase;
+    public static AggregateParam LOCAL_RESULT = new 
AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT);
+    public static AggregateParam LOCAL_BUFFER = new 
AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
 
+    public final AggPhase aggPhase;
     public final AggMode aggMode;
-
-    // TODO remove this flag, and generate it in enforce and cost job
-    public boolean needColocateScan;
+    // TODO: this is a short-term plan to process count(distinct a, b) 
correctly
+    public final boolean canBeBanned;
 
     /** AggregateParam */
     public AggregateParam(AggPhase aggPhase, AggMode aggMode) {
-        this(aggPhase, aggMode, false);
+        this(aggPhase, aggMode, true);
     }
 
-    /** AggregateParam */
-    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
needColocateScan) {
+    public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean 
canBeBanned) {
         this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be 
null");
         this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be 
null");
-        this.needColocateScan = needColocateScan;
-    }
-
-    public static AggregateParam localResult() {
-        return new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT, 
true);
+        this.canBeBanned = canBeBanned;
     }
 
     public AggregateParam withAggPhase(AggPhase aggPhase) {
-        return new AggregateParam(aggPhase, aggMode, needColocateScan);
+        return new AggregateParam(aggPhase, aggMode, canBeBanned);
     }
 
     public AggregateParam withAggPhase(AggMode aggMode) {
-        return new AggregateParam(aggPhase, aggMode, needColocateScan);
+        return new AggregateParam(aggPhase, aggMode, canBeBanned);
     }
 
     public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode 
aggMode) {
-        return new AggregateParam(aggPhase, aggMode, needColocateScan);
-    }
-
-    public AggregateParam withNeedColocateScan(boolean needColocateScan) {
-        return new AggregateParam(aggPhase, aggMode, needColocateScan);
+        return new AggregateParam(aggPhase, aggMode, canBeBanned);
     }
 
     @Override
@@ -87,7 +79,6 @@ public class AggregateParam {
         return "AggregateParam{"
                 + "aggPhase=" + aggPhase
                 + ", aggMode=" + aggMode
-                + ", needColocateScan=" + needColocateScan
                 + '}';
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
index 72a26d288e..b9e7c1fdb1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
@@ -35,7 +35,7 @@ import java.util.List;
 
 /** MultiDistinctCount */
 public class MultiDistinctCount extends AggregateFunction
-        implements AlwaysNotNullable, ExplicitlyCastableSignature {
+        implements AlwaysNotNullable, ExplicitlyCastableSignature, 
MultiDistinction {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java
index 737e895906..5f5bb6815a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java
@@ -36,7 +36,7 @@ import java.util.List;
 
 /** MultiDistinctGroupConcat */
 public class MultiDistinctGroupConcat extends NullableAggregateFunction
-        implements ExplicitlyCastableSignature {
+        implements ExplicitlyCastableSignature, MultiDistinction {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java
index a378dc0960..8441e02828 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java
@@ -35,8 +35,8 @@ import com.google.common.collect.ImmutableList;
 import java.util.List;
 
 /** MultiDistinctSum */
-public class MultiDistinctSum extends AggregateFunction
-        implements UnaryExpression, AlwaysNotNullable, 
ExplicitlyCastableSignature, ComputePrecisionForSum {
+public class MultiDistinctSum extends AggregateFunction implements 
UnaryExpression, AlwaysNotNullable,
+        ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java
new file mode 100644
index 0000000000..ab8842f730
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.agg;
+
+import org.apache.doris.nereids.trees.TreeNode;
+import org.apache.doris.nereids.trees.expressions.Expression;
+
+/**
+ * base class of multi-distinct agg function
+ */
+public interface MultiDistinction extends TreeNode<Expression> {
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
index 6731bde58a..8361e230be 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
@@ -25,7 +25,7 @@ import org.apache.doris.nereids.trees.plans.UnaryPlan;
 import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
-import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 
 import java.util.List;
 import java.util.Set;
@@ -53,10 +53,10 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends 
UnaryPlan<CHILD_TYPE
         return ExpressionUtils.collect(getOutputExpressions(), 
AggregateFunction.class::isInstance);
     }
 
-    default List<Expression> getDistinctArguments() {
+    default Set<Expression> getDistinctArguments() {
         return getAggregateFunctions().stream()
                 .filter(AggregateFunction::isDistinct)
                 .flatMap(aggregateExpression -> 
aggregateExpression.getArguments().stream())
-                .collect(ImmutableList.toImmutableList());
+                .collect(ImmutableSet.toImmutableSet());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java
index ed57eebcb0..094f5d75cd 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java
@@ -21,7 +21,7 @@ import org.apache.doris.catalog.Table;
 import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
-import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
@@ -30,7 +30,6 @@ import 
org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
 import org.apache.doris.nereids.util.Utils;
 import org.apache.doris.statistics.Statistics;
 
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 
 import java.util.List;
@@ -78,11 +77,6 @@ public class PhysicalStorageLayerAggregate extends 
PhysicalRelation {
         return visitor.visitPhysicalStorageLayerAggregate(this, context);
     }
 
-    @Override
-    public List<? extends Expression> getExpressions() {
-        return ImmutableList.of();
-    }
-
     @Override
     public boolean equals(Object o) {
         if (this == o) {
@@ -113,7 +107,7 @@ public class PhysicalStorageLayerAggregate extends 
PhysicalRelation {
     }
 
     public PhysicalStorageLayerAggregate withPhysicalOlapScan(PhysicalOlapScan 
physicalOlapScan) {
-        return new PhysicalStorageLayerAggregate(relation, aggOp);
+        return new PhysicalStorageLayerAggregate(physicalOlapScan, aggOp);
     }
 
     @Override
@@ -142,8 +136,8 @@ public class PhysicalStorageLayerAggregate extends 
PhysicalRelation {
     public enum PushDownAggOp {
         COUNT, MIN_MAX, MIX;
 
-        public static Map<Class, PushDownAggOp> supportedFunctions() {
-            return ImmutableMap.<Class, PushDownAggOp>builder()
+        public static Map<Class<? extends AggregateFunction>, PushDownAggOp> 
supportedFunctions() {
+            return ImmutableMap.<Class<? extends AggregateFunction>, 
PushDownAggOp>builder()
                     .put(Count.class, PushDownAggOp.COUNT)
                     .put(Min.class, PushDownAggOp.MIN_MAX)
                     .put(Max.class, PushDownAggOp.MIN_MAX)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index e1fbefad61..a3a3ca1b80 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -403,6 +403,11 @@ public class ExpressionUtils {
                 .anyMatch(expr -> expr.anyMatch(predicate));
     }
 
+    public static boolean noneMatch(List<? extends Expression> expressions, 
Predicate<TreeNode<Expression>> predicate) {
+        return expressions.stream()
+                .noneMatch(expr -> expr.anyMatch(predicate));
+    }
+
     public static boolean containsType(List<? extends Expression> expressions, 
Class type) {
         return anyMatch(expressions, type::isInstance);
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
index 417605bfa5..e8419b2458 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
@@ -223,6 +223,8 @@ public class AggregateStrategiesTest implements 
MemoPatternMatchSupported {
      * </pre>
      */
     @Test
+    @Disabled
+    @Developing("reopen it after we could choose agg phase by CBO")
     public void distinctAggregateWithoutGroupByApply2PhaseRule() {
         List<Expression> groupExpressionList = new ArrayList<>();
         List<NamedExpression> outputExpressionList = Lists.newArrayList(new 
Alias(
diff --git 
a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out 
b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out
index e12f8482f5..c26035693e 100644
--- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out
+++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out
@@ -4,8 +4,8 @@ PhysicalTopN
 --PhysicalTopN
 ----PhysicalProject
 ------hashAgg[GLOBAL]
---------hashAgg[LOCAL]
-----------PhysicalDistribute
+--------PhysicalDistribute
+----------hashAgg[LOCAL]
 ------------PhysicalProject
 --------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk)
 ----------------PhysicalProject
diff --git 
a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out 
b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out
index 6ddca5c0c2..014535c50d 100644
--- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out
+++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out
@@ -14,8 +14,8 @@ CteAnchor[cteId= ( CTEId#3=] )
 ----PhysicalTopN
 ------PhysicalProject
 --------hashAgg[GLOBAL]
-----------hashAgg[LOCAL]
-------------PhysicalDistribute
+----------PhysicalDistribute
+------------hashAgg[LOCAL]
 --------------PhysicalProject
 ----------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk)
 ------------------PhysicalProject
diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out 
b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out
index 515a72a29d..d72afcf57d 100644
--- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out
+++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out
@@ -4,8 +4,8 @@ PhysicalQuickSort
 --PhysicalDistribute
 ----PhysicalQuickSort
 ------hashAgg[GLOBAL]
---------hashAgg[LOCAL]
-----------PhysicalDistribute
+--------PhysicalDistribute
+----------hashAgg[LOCAL]
 ------------PhysicalProject
 --------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = 
supplier.s_suppkey)
 ----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey)
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out 
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out
index 515a72a29d..d72afcf57d 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out
@@ -4,8 +4,8 @@ PhysicalQuickSort
 --PhysicalDistribute
 ----PhysicalQuickSort
 ------hashAgg[GLOBAL]
---------hashAgg[LOCAL]
-----------PhysicalDistribute
+--------PhysicalDistribute
+----------hashAgg[LOCAL]
 ------------PhysicalProject
 --------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = 
supplier.s_suppkey)
 ----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey)
diff --git a/regression-test/suites/nereids_p0/join/test_join.groovy 
b/regression-test/suites/nereids_p0/join/test_join.groovy
index ce643f081a..584b80384a 100644
--- a/regression-test/suites/nereids_p0/join/test_join.groovy
+++ b/regression-test/suites/nereids_p0/join/test_join.groovy
@@ -24,15 +24,6 @@ suite("test_join", "nereids_p0") {
     def tbName1 = "test"
     def tbName2 = "baseall"
     def tbName3 = "bigtable"
-    def empty_name = "empty"
-
-    qt_agg_sql1 """select 
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/
 count(distinct k1, NULL) from test;"""
-    qt_agg_sql2 """select 
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/
 count(distinct k1, NULL), avg(k2) from baseall;"""
-    qt_agg_sql3 """select 
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/
 k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by 
k1;"""
-
-    qt_agg_sql4 """select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/
 count(distinct k1, NULL) from test;"""
-    qt_agg_sql5 """select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/
 count(distinct k1, NULL), avg(k2) from baseall;"""
-    qt_agg_sql6 """select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/
 k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by 
k1;"""
 
     order_sql """select j.*, d.* from ${tbName2} j full outer join ${tbName1} 
d on (j.k1=d.k1) order by j.k1, j.k2, j.k3, j.k4, d.k1, d.k2
             limit 100"""
diff --git a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy 
b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
index d2d48e3e08..a672f8dee3 100644
--- a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
+++ b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
@@ -43,16 +43,16 @@ suite("agg_4_phase") {
         (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20);
     """
     def test_sql = """
-        select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/
 
-            count(distinct name), sum(age) 
+        select
+            count(distinct id)
         from agg_4_phase_tbl;
         """
     explain{
         sql(test_sql)
-        contains "6:VAGGREGATE (merge finalize)"
-        contains "5:VEXCHANGE"
-        contains "4:VAGGREGATE (update serialize)"
-        contains "3:VAGGREGATE (merge serialize)"
+        contains "5:VAGGREGATE (merge finalize)"
+        contains "4:VEXCHANGE"
+        contains "3:VAGGREGATE (update serialize)"
+        contains "2:VAGGREGATE (merge serialize)"
         contains "1:VAGGREGATE (update serialize)"
     }
     qt_4phase (test_sql)
diff --git 
a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy 
b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy
index 39742e8d21..ea63b5b789 100644
--- a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy
+++ b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy
@@ -82,7 +82,6 @@ suite("aggregate_strategies") {
         explain {
             sql """
             select
-                
/*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT,
 FOUR_PHASE_AGGREGATE_WITH_DISTINCT')*/
                 count(distinct id)
                 from $tableName
             """
@@ -90,17 +89,17 @@ suite("aggregate_strategies") {
             notContains "STREAMING"
         }
 
+        // test multi_distinct
         test {
             sql """select
-                
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/
-                count(distinct id)
+                count(distinct name)
                 from $tableName"""
             result([[5L]])
         }
 
+        // test four phase distinct
         test {
             sql """select
-                
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/
                 count(distinct id)
                 from $tableName"""
             result([[5L]])
diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy 
b/regression-test/suites/nereids_syntax_p0/group_concat.groovy
index fe2062d66d..60f52c2ba0 100644
--- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy
+++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy
@@ -21,14 +21,14 @@ suite("group_concat") {
 
 
     test {
-        sql """select 
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/
+        sql """select
                  group_concat(cast(number as string), ',' order by number)
                from numbers('number'='10')"""
         result([["0,1,2,3,4,5,6,7,8,9"]])
     }
 
     test {
-        sql """select 
/*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/
+        sql """select
                  group_concat(cast(number as string), ',' order by number)
                from numbers('number'='10')"""
         result([["0,1,2,3,4,5,6,7,8,9"]])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to