morrySnow commented on code in PR #12159:
URL: https://github.com/apache/doris/pull/12159#discussion_r957615245


##########
regression-test/data/nereids_syntax_p0/function.out:
##########
@@ -11,6 +11,11 @@
 -- !count --
 3      3
 
+-- !distinct_count_with_group_by --
+1      1
+1      1
+1      1

Review Comment:
   we need to load more data to table to verify it



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java:
##########
@@ -31,21 +31,28 @@
 public class Count extends AggregateFunction {
 
     private final boolean isStar;
+    private final boolean isDistinct;

Review Comment:
   move it to super class, and add a two parameters constructor. the original 
constructor set it to false by default. In super class, method isDistinct 
return this value directly



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        // origin sql: select count(distinct a) from t1 group by b;
+        // global agg: select a from t1 group by b, a;
+        // second local agg: select count(distinct a) from t1 group by b;
+        // In order to get the second local agg from global agg:
+        // 1. the distinct expression needs to be removed from the output and 
the group by of global agg
+        // 2. add distinct agg function back to output
+        List<NamedExpression> originFirstGlobalOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originFirstGlobalGroupByExprs = 
aggregate.getGroupByExpressions();
+        List<NamedExpression> secondLocalOutputExprs = new 
ArrayList<>(aggregate.getOutputExpressions());
+        // add origin distinct function back
+        secondLocalOutputExprs.addAll(distinctOriginOutputExprs);
+        List<Expression> secondLocalGroupByExprs = new 
ArrayList<>(aggregate.getGroupByExpressions());
+
+        List<Expression> distinctExprs = new ArrayList<>();
+        // remove distinct param exprs from secondLocalGroupByExprs and 
secondLocalOutputExprs
+        for (NamedExpression expression : distinctAggFunctionParams) {
+            
secondLocalGroupByExprs.remove(inputSubstitutionMap.get(expression));
+            
secondLocalOutputExprs.remove(inputSubstitutionMap.get(expression));
+            distinctExprs.add(inputSubstitutionMap.get(expression));
+        }
+
+        Map<Expression, Expression> secondSubstitutionMap = Maps.newHashMap();
+
+        // replace the original slot reference with the latest one
+        Expression distinctAgg = 
distinctAggFunctions.get(0).withChildren(distinctExprs);
+        secondSubstitutionMap.put(distinctAggFunctions.get(0), distinctAgg);
+        List<NamedExpression> secondLocalOutputNamedExprs = 
secondLocalOutputExprs.stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
secondSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+
+        // firstGlobalOutputExprs = originOutputExprs + originGroupByExprs
+        List<NamedExpression> firstGlobalOutputExprs = new 
ArrayList<>(originFirstGlobalOutputExprs);
+        for (Expression originGroupByExpr : originFirstGlobalGroupByExprs) {
+            if (firstGlobalOutputExprs.contains(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                firstGlobalOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                Preconditions.checkState(false);
+            }
+        }
+
+        // generate new plan
+        LogicalAggregate globalAggregate = new LogicalAggregate<>(
+                originFirstGlobalGroupByExprs,
+                firstGlobalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.GLOBAL,
+                aggregate.child()
+        );
+
+        return new LogicalAggregate<>(
+                secondLocalGroupByExprs,
+                secondLocalOutputNamedExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.DISTINCT_LOCAL,
+                globalAggregate
+        );
+    }
+
+    private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+        // 1. generate a map from local aggregate output to global aggregate 
expr substitution.
+        //    inputSubstitutionMap use for replacing expression in global 
aggregate
+        //    replace rule is:
+        //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
+        //        b. Expression is a group by key and is an expression. e.g. 
group by k1 + 1
+        //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | situation | origin expression   | local output expression | 
expression in global aggregate |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | a         | Ref(k1)#1           | Ref(k1)#1               | 
Ref(k1)#1                      |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | 
Ref(key)#2                     |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | 
AF(af#3)                       |
+        //    
+-----------+---------------------+-------------------------+--------------------------------+
+        //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: 
ExprId x
+        // 2. collect local aggregate output expressions and local aggregate 
group by expression list
+        List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+        List<NamedExpression> localOutputExprs = Lists.newArrayList();
+        for (Expression originGroupByExpr : originGroupByExprs) {
+            if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+                continue;
+            }
+            if (originGroupByExpr instanceof SlotReference) {
+                inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+                localOutputExprs.add((SlotReference) originGroupByExpr);
+            } else {
+                NamedExpression localOutputExpr = new Alias(originGroupByExpr, 
originGroupByExpr.toSql());
+                inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
+                localOutputExprs.add(localOutputExpr);
+            }
+        }
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            List<AggregateFunction> aggregateFunctions
+                    = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (inputSubstitutionMap.containsKey(aggregateFunction)) {
                     continue;
                 }
-                if (originGroupByExpr instanceof SlotReference) {
-                    inputSubstitutionMap.put(originGroupByExpr, 
originGroupByExpr);
-                    localOutputExprs.add((SlotReference) originGroupByExpr);
-                } else {
-                    NamedExpression localOutputExpr = new 
Alias(originGroupByExpr, originGroupByExpr.toSql());
-                    inputSubstitutionMap.put(originGroupByExpr, 
localOutputExpr.toSlot());
-                    localOutputExprs.add(localOutputExpr);
-                }
+                NamedExpression localOutputExpr = new Alias(aggregateFunction, 
aggregateFunction.toSql());
+                Expression substitutionValue = aggregateFunction.withChildren(
+                        Lists.newArrayList(localOutputExpr.toSlot()));
+                inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+                localOutputExprs.add(localOutputExpr);
             }
-            for (NamedExpression originOutputExpr : originOutputExprs) {
-                List<AggregateFunction> aggregateFunctions
-                        = 
originOutputExpr.collect(AggregateFunction.class::isInstance);
-                for (AggregateFunction aggregateFunction : aggregateFunctions) 
{
-                    if (inputSubstitutionMap.containsKey(aggregateFunction)) {
-                        continue;
-                    }
-                    NamedExpression localOutputExpr = new 
Alias(aggregateFunction, aggregateFunction.toSql());
-                    Expression substitutionValue = 
aggregateFunction.withChildren(
-                            Lists.newArrayList(localOutputExpr.toSlot()));
-                    inputSubstitutionMap.put(aggregateFunction, 
substitutionValue);
-                    localOutputExprs.add(localOutputExpr);
+        }
+
+        // 3. replace expression in globalOutputExprs and globalGroupByExprs
+        List<NamedExpression> globalOutputExprs = 
aggregate.getOutputExpressions().stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap))
+                .map(NamedExpression.class::cast)
+                .collect(Collectors.toList());
+        List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+                .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap)).collect(Collectors.toList());
+
+        // 4. generate new plan
+        LogicalAggregate localAggregate = new LogicalAggregate<>(
+                localGroupByExprs,
+                localOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.LOCAL,
+                aggregate.child()
+        );
+        return new LogicalAggregate<>(
+                globalGroupByExprs,
+                globalOutputExprs,
+                true,
+                aggregate.isNormalized(),
+                AggPhase.GLOBAL,
+                localAggregate
+        );
+    }
+
+    private void 
moveDistinctExprFromOutputToGroupBy(LogicalAggregate<GroupPlan> aggregate) {
+        // for example:
+        // select count(distinct a) from t1 group by b;
+        // => select a from t1 group by b, a;
+        List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
+        List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
+
+        for (NamedExpression originOutputExpr : originOutputExprs) {
+            List<AggregateFunction> aggregateFunctions =
+                    
originOutputExpr.collect(AggregateFunction.class::isInstance);
+            for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                if (aggregateFunction.isDistinct()) {
+                    distinctAggFunctions.add(aggregateFunction);
+                    distinctOriginOutputExprs.add(originOutputExpr);
                 }
             }
+        }
+        if (!distinctAggFunctions.isEmpty()) {
+            for (Expression expr : distinctAggFunctions.get(0).children()) {
+                distinctAggFunctionParams.add((NamedExpression) expr);
+            }
+        }
+        originOutputExprs.removeAll(distinctOriginOutputExprs);
+        originOutputExprs.addAll(distinctAggFunctionParams);
+        originGroupByExprs.addAll(distinctAggFunctionParams);
+    }
 
-            // 3. replace expression in globalOutputExprs and 
globalGroupByExprs
-            List<NamedExpression> globalOutputExprs = 
aggregate.getOutputExpressions().stream()
-                    .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap))
-                    .map(NamedExpression.class::cast)
-                    .collect(Collectors.toList());
-            List<Expression> globalGroupByExprs = localGroupByExprs.stream()
-                    .map(e -> ExpressionReplacer.INSTANCE.visit(e, 
inputSubstitutionMap)).collect(Collectors.toList());
-
-            // 4. generate new plan
-            LogicalAggregate localAggregate = new LogicalAggregate<>(
-                    localGroupByExprs,
-                    localOutputExprs,
-                    true,
-                    aggregate.isNormalized(),
-                    AggPhase.LOCAL,
-                    aggregate.child()
-            );
-            return new LogicalAggregate<>(
-                    globalGroupByExprs,
-                    globalOutputExprs,
-                    true,
-                    aggregate.isNormalized(),
-                    AggPhase.GLOBAL,
-                    localAggregate
-            );
+    @Override
+    public Rule build() {

Review Comment:
   move build as the first function of this class for easy reading



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -54,86 +56,192 @@
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
-    @Override
-    public Rule build() {
-        return logicalAggregate().when(agg -> 
!agg.isDisassembled()).thenApply(ctx -> {
-            LogicalAggregate<GroupPlan> aggregate = ctx.root;
-            List<NamedExpression> originOutputExprs = 
aggregate.getOutputExpressions();
-            List<Expression> originGroupByExprs = 
aggregate.getGroupByExpressions();
-
-            // 1. generate a map from local aggregate output to global 
aggregate expr substitution.
-            //    inputSubstitutionMap use for replacing expression in global 
aggregate
-            //    replace rule is:
-            //        a: Expression is a group by key and is a slot reference. 
e.g. group by k1
-            //        b. Expression is a group by key and is an expression. 
e.g. group by k1 + 1
-            //        c. Expression is an aggregate function. e.g. sum(v1) in 
select list
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | situation | origin expression   | local output expression 
| expression in global aggregate |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | a         | Ref(k1)#1           | Ref(k1)#1               
| Ref(k1)#1                      |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 
| Ref(key)#2                     |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     
| AF(af#3)                       |
-            //    
+-----------+---------------------+-------------------------+--------------------------------+
-            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, 
#x: ExprId x
-            // 2. collect local aggregate output expressions and local 
aggregate group by expression list
-            Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
-            List<Expression> localGroupByExprs = 
aggregate.getGroupByExpressions();
-            List<NamedExpression> localOutputExprs = Lists.newArrayList();
-            for (Expression originGroupByExpr : originGroupByExprs) {
-                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+    private Map<Expression, Expression> inputSubstitutionMap = 
Maps.newHashMap();
+    private List<NamedExpression> distinctAggFunctionParams = new 
ArrayList<>();
+    private List<AggregateFunction> distinctAggFunctions = new ArrayList<>();
+    private List<NamedExpression> distinctOriginOutputExprs = new 
ArrayList<>();
+
+    // only support distinct function with group by
+    // TODO: support distinct function without group by. (add second global 
phase)
+    private LogicalAggregate secondDisassemble(LogicalAggregate<GroupPlan> 
aggregate) {

Review Comment:
   please add some UT for this class



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to