This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 5b6d48ed5b [feature](nereids) support distinct count (#12159) 5b6d48ed5b is described below commit 5b6d48ed5b6db033607224523579da0a77d957f2 Author: yinzhijian <373141...@qq.com> AuthorDate: Thu Sep 15 13:01:47 2022 +0800 [feature](nereids) support distinct count (#12159) support distinct count with group by clause. for example: SELECT count(distinct c_custkey + 1) FROM customer group by c_nation; TODO: support distinct count without group by clause. --- .../glue/translator/ExpressionTranslator.java | 2 + .../glue/translator/PhysicalPlanTranslator.java | 17 +- .../properties/ChildOutputPropertyDeriver.java | 2 +- .../nereids/properties/RequestPropertyDeriver.java | 7 +- .../doris/nereids/rules/analysis/BindFunction.java | 12 ++ .../expression/rewrite/ExpressionRewrite.java | 2 +- .../LogicalAggToPhysicalHashAgg.java | 1 + .../rules/rewrite/AggregateDisassemble.java | 235 +++++++++++++++------ .../rules/rewrite/logical/NormalizeAggregate.java | 2 +- .../expressions/functions/AggregateFunction.java | 31 +++ .../nereids/trees/expressions/functions/Count.java | 12 +- .../trees/plans/logical/LogicalAggregate.java | 36 +++- .../trees/plans/physical/PhysicalAggregate.java | 35 ++- .../doris/nereids/parser/HavingClauseTest.java | 4 +- .../properties/ChildOutputPropertyDeriverTest.java | 2 + .../properties/RequestPropertyDeriverTest.java | 3 + .../rewrite/logical/AggregateDisassembleTest.java | 81 +++++++ .../trees/expressions/ExpressionEqualsTest.java | 20 ++ .../doris/nereids/trees/plans/PlanEqualsTest.java | 12 +- .../data/nereids_syntax_p0/function.out | 5 + .../suites/nereids_syntax_p0/function.groovy | 4 + 21 files changed, 416 insertions(+), 109 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index 017ec6b5b7..1c3f59361d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -256,6 +256,8 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra Count count = (Count) function; if (count.isStar()) { return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam()); + } else if (count.isDistinct()) { + return new FunctionCallExpr(function.getName(), new FunctionParams(true, paramList)); } } return new FunctionCallExpr(function.getName(), paramList); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index d47bf1c1aa..a783567a70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -191,12 +191,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla // 3. generate output tuple List<Slot> slotList = Lists.newArrayList(); TupleDescriptor outputTupleDesc; - if (aggregate.getAggPhase() == AggPhase.GLOBAL) { + if (aggregate.getAggPhase() == AggPhase.LOCAL) { + outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context); + } else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase()) + || aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) { slotList.addAll(groupSlotList); slotList.addAll(aggFunctionOutput); outputTupleDesc = generateTupleDesc(slotList, null, context); } else { - outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context); + // In the distinct agg scenario, global shares local's desc + AggregationNode localAggNode = (AggregationNode) inputPlanFragment.getPlanRoot().getChild(0); + outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc(); } if (aggregate.getAggPhase() == AggPhase.GLOBAL) { @@ -204,6 +209,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla execAggregateFunction.setMergeForNereids(true); } } + if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) { + for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) { + if (!execAggregateFunction.isDistinct()) { + execAggregateFunction.setMergeForNereids(true); + } + } + } AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, execAggregateFunctions, outputTupleDesc, outputTupleDesc, aggregate.getAggPhase().toExec()); AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(), @@ -216,6 +228,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla aggregationNode.setIntermediateTuple(); break; case GLOBAL: + case DISTINCT_LOCAL: if (currentFragment.getPlanRoot() instanceof ExchangeNode) { ExchangeNode exchangeNode = (ExchangeNode) currentFragment.getPlanRoot(); currentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java index 1d7974e161..ba8976e71d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java @@ -80,12 +80,12 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties, case LOCAL: return new PhysicalProperties(childOutputProperty.getDistributionSpec()); case GLOBAL: + case DISTINCT_LOCAL: List<ExprId> columns = agg.getPartitionExpressions().stream() .map(SlotReference.class::cast) .map(SlotReference::getExprId) .collect(Collectors.toList()); return PhysicalProperties.createHash(new DistributionSpecHash(columns, ShuffleType.AGGREGATE)); - case DISTINCT_LOCAL: case DISTINCT_GLOBAL: default: throw new RuntimeException("Could not derive output properties for agg phase: " + agg.getAggPhase()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index 0c2f2089ae..67a9032f85 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.AggPhase; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; @@ -82,14 +83,16 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> { addToRequestPropertyToChildren(PhysicalProperties.ANY); return null; } - + if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) { + addToRequestPropertyToChildren(requestPropertyFromParent); + return null; + } // 2. second phase agg, need to return shuffle with partition key List<Expression> partitionExpressions = agg.getPartitionExpressions(); if (partitionExpressions.isEmpty()) { addToRequestPropertyToChildren(PhysicalProperties.GATHER); return null; } - // TODO: when parent is a join node, // use requestPropertyFromParent to keep column order as join to avoid shuffle again. if (partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java index 40782b2e28..fcef341f5f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.Count; import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.GroupPlan; @@ -115,6 +116,17 @@ public class BindFunction implements AnalysisRuleFactory { @Override public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) { + // FunctionRegistry can't support boolean arg now, tricky here. + if (unboundFunction.getName().equalsIgnoreCase("count")) { + List<Expression> arguments = unboundFunction.getArguments(); + if ((arguments.size() == 0 && unboundFunction.isStar()) || arguments.stream() + .allMatch(Expression::isConstant)) { + return new Count(); + } + if (arguments.size() == 1) { + return new Count(unboundFunction.getArguments().get(0), unboundFunction.isDistinct()); + } + } FunctionRegistry functionRegistry = env.getFunctionRegistry(); String functionName = unboundFunction.getName(); FunctionBuilder builder = functionRegistry.findFunctionBuilder( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java index d808183c24..f660ee0ec0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java @@ -126,7 +126,7 @@ public class ExpressionRewrite implements RewriteRuleFactory { return agg; } return new LogicalAggregate<>(newGroupByExprs, newOutputExpressions, - agg.isDisassembled(), agg.isNormalized(), agg.getAggPhase(), agg.child()); + agg.isDisassembled(), agg.isNormalized(), agg.isFinalPhase(), agg.getAggPhase(), agg.child()); }).toRule(RuleType.REWRITE_AGG_EXPRESSION); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java index ecc59393d4..4e4d52b551 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalAggToPhysicalHashAgg.java @@ -36,6 +36,7 @@ public class LogicalAggToPhysicalHashAgg extends OneImplementationRuleFactory { ImmutableList.of(), agg.getAggPhase(), false, + agg.isFinalPhase(), agg.getLogicalProperties(), agg.child()) ).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java index 7d68752e07..4166d9db5d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java @@ -49,92 +49,187 @@ import java.util.stream.Collectors; * +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1]) * +-- childPlan * + * Distinct Agg With Group By Processing: + * If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1 + * the initial plan is: + * Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(COUNT(distinct v1 * v2) + 1) #2] + * , groupByExpr: [k + 1]) + * +-- childPlan + * we should rewrite to: + * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1, Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b]) + * +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a]) + * +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as a], groupByExpr: [k + 1, a]) + * +-- childPlan + * * TODO: * 1. use different class represent different phase aggregate * 2. if instance count is 1, shouldn't disassemble the agg plan */ public class AggregateDisassemble extends OneRewriteRuleFactory { + // used in secondDisassemble to transform local expressions into global + private final Map<Expression, Expression> globalOutputSubstitutionMap = Maps.newHashMap(); + // used in secondDisassemble to transform local expressions into global + private final Map<Expression, Expression> globalGroupBySubstitutionMap = Maps.newHashMap(); + // used to indicate the existence of a distinct function for the entire phase + private boolean hasDistinctAgg = false; @Override public Rule build() { return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> { LogicalAggregate<GroupPlan> aggregate = ctx.root; - List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions(); - List<Expression> originGroupByExprs = aggregate.getGroupByExpressions(); + LogicalAggregate firstAggregate = firstDisassemble(aggregate); + if (!hasDistinctAgg) { + return firstAggregate; + } + return secondDisassemble(firstAggregate); + }).toRule(RuleType.AGGREGATE_DISASSEMBLE); + } + + // only support distinct function with group by + // TODO: support distinct function without group by. (add second global phase) + private LogicalAggregate secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) { + LogicalAggregate<GroupPlan> local = aggregate.child(); + // replace expression in globalOutputExprs and globalGroupByExprs + List<NamedExpression> globalOutputExprs = local.getOutputExpressions().stream() + .map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap)) + .map(NamedExpression.class::cast) + .collect(Collectors.toList()); + List<Expression> globalGroupByExprs = local.getGroupByExpressions().stream() + .map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap)) + .collect(Collectors.toList()); + + // generate new plan + LogicalAggregate globalAggregate = new LogicalAggregate<>( + globalGroupByExprs, + globalOutputExprs, + true, + aggregate.isNormalized(), + false, + AggPhase.GLOBAL, + local + ); + return new LogicalAggregate<>( + aggregate.getGroupByExpressions(), + aggregate.getOutputExpressions(), + true, + aggregate.isNormalized(), + true, + AggPhase.DISTINCT_LOCAL, + globalAggregate + ); + } + + private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> aggregate) { + List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions(); + List<Expression> originGroupByExprs = aggregate.getGroupByExpressions(); + Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap(); - // 1. generate a map from local aggregate output to global aggregate expr substitution. - // inputSubstitutionMap use for replacing expression in global aggregate - // replace rule is: - // a: Expression is a group by key and is a slot reference. e.g. group by k1 - // b. Expression is a group by key and is an expression. e.g. group by k1 + 1 - // c. Expression is an aggregate function. e.g. sum(v1) in select list - // +-----------+---------------------+-------------------------+--------------------------------+ - // | situation | origin expression | local output expression | expression in global aggregate | - // +-----------+---------------------+-------------------------+--------------------------------+ - // | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 | - // +-----------+---------------------+-------------------------+--------------------------------+ - // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 | - // +-----------+---------------------+-------------------------+--------------------------------+ - // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) | - // +-----------+---------------------+-------------------------+--------------------------------+ - // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x - // 2. collect local aggregate output expressions and local aggregate group by expression list - Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap(); - List<Expression> localGroupByExprs = aggregate.getGroupByExpressions(); - List<NamedExpression> localOutputExprs = Lists.newArrayList(); - for (Expression originGroupByExpr : originGroupByExprs) { - if (inputSubstitutionMap.containsKey(originGroupByExpr)) { + // 1. generate a map from local aggregate output to global aggregate expr substitution. + // inputSubstitutionMap use for replacing expression in global aggregate + // replace rule is: + // a: Expression is a group by key and is a slot reference. e.g. group by k1 + // b. Expression is a group by key and is an expression. e.g. group by k1 + 1 + // c. Expression is an aggregate function. e.g. sum(v1) in select list + // +-----------+---------------------+-------------------------+--------------------------------+ + // | situation | origin expression | local output expression | expression in global aggregate | + // +-----------+---------------------+-------------------------+--------------------------------+ + // | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 | + // +-----------+---------------------+-------------------------+--------------------------------+ + // | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 | + // +-----------+---------------------+-------------------------+--------------------------------+ + // | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) | + // +-----------+---------------------+-------------------------+--------------------------------+ + // NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x + // 2. collect local aggregate output expressions and local aggregate group by expression list + List<Expression> localGroupByExprs = aggregate.getGroupByExpressions(); + List<NamedExpression> localOutputExprs = Lists.newArrayList(); + for (Expression originGroupByExpr : originGroupByExprs) { + if (inputSubstitutionMap.containsKey(originGroupByExpr)) { + continue; + } + if (originGroupByExpr instanceof SlotReference) { + inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); + globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); + globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr); + localOutputExprs.add((SlotReference) originGroupByExpr); + } else { + NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql()); + inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); + globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot()); + globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); + localOutputExprs.add(localOutputExpr); + } + } + List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList(); + List<NamedExpression> distinctExprsForLocalOutput = Lists.newArrayList(); + for (NamedExpression originOutputExpr : originOutputExprs) { + Set<AggregateFunction> aggregateFunctions + = originOutputExpr.collect(AggregateFunction.class::isInstance); + for (AggregateFunction aggregateFunction : aggregateFunctions) { + if (inputSubstitutionMap.containsKey(aggregateFunction)) { continue; } - if (originGroupByExpr instanceof SlotReference) { - inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr); - localOutputExprs.add((SlotReference) originGroupByExpr); - } else { - NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql()); - inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot()); - localOutputExprs.add(localOutputExpr); - } - } - for (NamedExpression originOutputExpr : originOutputExprs) { - Set<AggregateFunction> aggregateFunctions - = originOutputExpr.collect(AggregateFunction.class::isInstance); - for (AggregateFunction aggregateFunction : aggregateFunctions) { - if (inputSubstitutionMap.containsKey(aggregateFunction)) { - continue; + if (aggregateFunction.isDistinct()) { + hasDistinctAgg = true; + for (Expression expr : aggregateFunction.children()) { + if (expr instanceof SlotReference) { + distinctExprsForLocalOutput.add((SlotReference) expr); + if (!inputSubstitutionMap.containsKey(expr)) { + inputSubstitutionMap.put(expr, expr); + globalOutputSubstitutionMap.put(expr, expr); + globalGroupBySubstitutionMap.put(expr, expr); + } + } else { + NamedExpression globalOutputExpr = new Alias(expr, expr.toSql()); + distinctExprsForLocalOutput.add(globalOutputExpr); + if (!inputSubstitutionMap.containsKey(expr)) { + inputSubstitutionMap.put(expr, globalOutputExpr.toSlot()); + globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot()); + globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot()); + } + } + distinctExprsForLocalGroupBy.add(expr); } - NamedExpression localOutputExpr = new Alias(aggregateFunction, aggregateFunction.toSql()); - Expression substitutionValue = aggregateFunction.withChildren( - Lists.newArrayList(localOutputExpr.toSlot())); - inputSubstitutionMap.put(aggregateFunction, substitutionValue); - localOutputExprs.add(localOutputExpr); + continue; } + NamedExpression localOutputExpr = new Alias(aggregateFunction, aggregateFunction.toSql()); + Expression substitutionValue = aggregateFunction.withChildren( + Lists.newArrayList(localOutputExpr.toSlot())); + inputSubstitutionMap.put(aggregateFunction, substitutionValue); + globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue); + localOutputExprs.add(localOutputExpr); } + } - // 3. replace expression in globalOutputExprs and globalGroupByExprs - List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressions().stream() - .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)) - .map(NamedExpression.class::cast) - .collect(Collectors.toList()); - List<Expression> globalGroupByExprs = localGroupByExprs.stream() - .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList()); - - // 4. generate new plan - LogicalAggregate localAggregate = new LogicalAggregate<>( - localGroupByExprs, - localOutputExprs, - true, - aggregate.isNormalized(), - AggPhase.LOCAL, - aggregate.child() - ); - return new LogicalAggregate<>( - globalGroupByExprs, - globalOutputExprs, - true, - aggregate.isNormalized(), - AggPhase.GLOBAL, - localAggregate - ); - }).toRule(RuleType.AGGREGATE_DISASSEMBLE); + // 3. replace expression in globalOutputExprs and globalGroupByExprs + List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressions().stream() + .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)) + .map(NamedExpression.class::cast) + .collect(Collectors.toList()); + List<Expression> globalGroupByExprs = localGroupByExprs.stream() + .map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList()); + // To avoid repeated substitution of distinct expressions, + // here the expressions are put into the local after the substitution is completed + localOutputExprs.addAll(distinctExprsForLocalOutput); + localGroupByExprs.addAll(distinctExprsForLocalGroupBy); + // 4. generate new plan + LogicalAggregate localAggregate = new LogicalAggregate<>( + localGroupByExprs, + localOutputExprs, + true, + aggregate.isNormalized(), + false, + AggPhase.LOCAL, + aggregate.child() + ); + return new LogicalAggregate<>( + globalGroupByExprs, + globalOutputExprs, + true, + aggregate.isNormalized(), + true, + AggPhase.GLOBAL, + localAggregate + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index 0fe139b85b..45a4a3c027 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -124,7 +124,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory { root = new LogicalProject<>(bottomProjections, root); } root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(), - true, aggregate.getAggPhase(), root); + true, aggregate.isFinalPhase(), aggregate.getAggPhase(), root); List<NamedExpression> projections = outputs.stream() .map(e -> ExpressionUtils.replace(e, substitutionMap)) .map(NamedExpression.class::cast) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java index 73de61a058..69572b070a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java @@ -21,19 +21,50 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; +import java.util.Objects; + /** * The function which consume arguments in lots of rows and product one value. */ public abstract class AggregateFunction extends BoundFunction { private DataType intermediate; + private final boolean isDistinct; public AggregateFunction(String name, Expression... arguments) { super(name, arguments); + isDistinct = false; + } + + public AggregateFunction(String name, boolean isDistinct, Expression... arguments) { + super(name, arguments); + this.isDistinct = isDistinct; } public abstract DataType getIntermediateType(); + public boolean isDistinct() { + return isDistinct; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AggregateFunction that = (AggregateFunction) o; + return Objects.equals(isDistinct, that.isDistinct) && Objects.equals(intermediate, that.intermediate) + && Objects.equals(getName(), that.getName()) && Objects.equals(children, that.children); + } + + @Override + public int hashCode() { + return Objects.hash(isDistinct, intermediate, getName(), children); + } + @Override public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { return visitor.visitAggregateFunction(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java index a31122ab7a..e594671733 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Count.java @@ -37,8 +37,8 @@ public class Count extends AggregateFunction { this.isStar = true; } - public Count(Expression child) { - super("count", child); + public Count(Expression child, boolean isDistinct) { + super("count", isDistinct, child); this.isStar = false; } @@ -62,7 +62,7 @@ public class Count extends AggregateFunction { if (children.size() == 0) { return new Count(); } - return new Count(children.get(0)); + return new Count(children.get(0), isDistinct()); } @Override @@ -79,6 +79,9 @@ public class Count extends AggregateFunction { .stream() .map(Expression::toSql) .collect(Collectors.joining(", ")); + if (isDistinct()) { + return "count(distinct " + args + ")"; + } return "count(" + args + ")"; } @@ -91,6 +94,9 @@ public class Count extends AggregateFunction { .stream() .map(Expression::toString) .collect(Collectors.joining(", ")); + if (isDistinct()) { + return "count(distinct " + args + ")"; + } return "count(" + args + ")"; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index cbe9e402ef..0cca04950d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -59,6 +59,13 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL private final List<NamedExpression> outputExpressions; private final AggPhase aggPhase; + // use for scenes containing distinct agg + // 1. If there are LOCAL and GLOBAL phases, global is the final phase + // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase + // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases, + // DISTINCT_GLOBAL is the final phase + private final boolean isFinalPhase; + /** * Desc: Constructor for LogicalAggregate. */ @@ -66,7 +73,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL List<Expression> groupByExpressions, List<NamedExpression> outputExpressions, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, false, false, AggPhase.GLOBAL, child); + this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child); } public LogicalAggregate( @@ -74,9 +81,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL List<NamedExpression> outputExpressions, boolean disassembled, boolean normalized, + boolean isFinalPhase, AggPhase aggPhase, CHILD_TYPE child) { - this(groupByExpressions, outputExpressions, disassembled, normalized, + this(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase, aggPhase, Optional.empty(), Optional.empty(), child); } @@ -88,6 +96,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL List<NamedExpression> outputExpressions, boolean disassembled, boolean normalized, + boolean isFinalPhase, AggPhase aggPhase, Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, @@ -97,6 +106,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL this.outputExpressions = outputExpressions; this.disassembled = disassembled; this.normalized = normalized; + this.isFinalPhase = isFinalPhase; this.aggPhase = aggPhase; } @@ -149,6 +159,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL return normalized; } + public boolean isFinalPhase() { + return isFinalPhase; + } + /** * Determine the equality with another plan */ @@ -164,37 +178,37 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL && Objects.equals(outputExpressions, that.outputExpressions) && aggPhase == that.aggPhase && disassembled == that.disassembled - && normalized == that.normalized; + && normalized == that.normalized + && isFinalPhase == that.isFinalPhase; } @Override public int hashCode() { - return Objects.hash(groupByExpressions, outputExpressions, aggPhase, normalized, disassembled); + return Objects.hash(groupByExpressions, outputExpressions, aggPhase, normalized, disassembled, isFinalPhase); } @Override public LogicalAggregate<Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 1); return new LogicalAggregate<>(groupByExpressions, outputExpressions, - disassembled, normalized, aggPhase, children.get(0)); + disassembled, normalized, isFinalPhase, aggPhase, children.get(0)); } @Override public LogicalAggregate<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) { - return new LogicalAggregate<>(groupByExpressions, outputExpressions, - disassembled, normalized, aggPhase, groupExpression, Optional.of(getLogicalProperties()), - children.get(0)); + return new LogicalAggregate<>(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase, + aggPhase, groupExpression, Optional.of(getLogicalProperties()), children.get(0)); } @Override public LogicalAggregate<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) { - return new LogicalAggregate<>(groupByExpressions, outputExpressions, - disassembled, normalized, aggPhase, Optional.empty(), logicalProperties, children.get(0)); + return new LogicalAggregate<>(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase, + aggPhase, Optional.empty(), logicalProperties, children.get(0)); } public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList, List<NamedExpression> outputExpressionList) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, - disassembled, normalized, aggPhase, child()); + disassembled, normalized, isFinalPhase, aggPhase, child()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java index f2384920e5..8557a61ea5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalAggregate.java @@ -53,11 +53,18 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH private final boolean usingStream; + // use for scenes containing distinct agg + // 1. If there are LOCAL and GLOBAL phases, global is the final phase + // 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase + // 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases, + // DISTINCT_GLOBAL is the final phase + private final boolean isFinalPhase; + public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions, List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, - LogicalProperties logicalProperties, CHILD_TYPE child) { + boolean isFinalPhase, LogicalProperties logicalProperties, CHILD_TYPE child) { this(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream, - Optional.empty(), logicalProperties, child); + isFinalPhase, Optional.empty(), logicalProperties, child); } /** @@ -69,7 +76,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH * @param usingStream whether it's stream agg. */ public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions, - List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, + List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, CHILD_TYPE child) { super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, child); @@ -78,6 +85,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH this.aggPhase = aggPhase; this.partitionExpressions = partitionExpressions; this.usingStream = usingStream; + this.isFinalPhase = isFinalPhase; } /** @@ -89,7 +97,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH * @param usingStream whether it's stream agg. */ public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions, - List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, + List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, PhysicalProperties physicalProperties, CHILD_TYPE child) { super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, physicalProperties, child); @@ -98,6 +106,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH this.aggPhase = aggPhase; this.partitionExpressions = partitionExpressions; this.usingStream = usingStream; + this.isFinalPhase = isFinalPhase; } public AggPhase getAggPhase() { @@ -112,6 +121,10 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH return outputExpressions; } + public boolean isFinalPhase() { + return isFinalPhase; + } + public boolean isUsingStream() { return usingStream; } @@ -156,36 +169,38 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH && Objects.equals(outputExpressions, that.outputExpressions) && Objects.equals(partitionExpressions, that.partitionExpressions) && usingStream == that.usingStream - && aggPhase == that.aggPhase; + && aggPhase == that.aggPhase + && isFinalPhase == that.isFinalPhase; } @Override public int hashCode() { - return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream); + return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream, + isFinalPhase); } @Override public PhysicalAggregate<Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 1); return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, - usingStream, getLogicalProperties(), children.get(0)); + usingStream, isFinalPhase, getLogicalProperties(), children.get(0)); } @Override public PhysicalAggregate<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) { return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, - usingStream, groupExpression, getLogicalProperties(), child()); + usingStream, isFinalPhase, groupExpression, getLogicalProperties(), child()); } @Override public PhysicalAggregate<CHILD_TYPE> withLogicalProperties(Optional<LogicalProperties> logicalProperties) { return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, - usingStream, Optional.empty(), logicalProperties.get(), child()); + usingStream, isFinalPhase, Optional.empty(), logicalProperties.get(), child()); } @Override public PhysicalAggregate<CHILD_TYPE> withPhysicalProperties(PhysicalProperties physicalProperties) { return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, - usingStream, Optional.empty(), getLogicalProperties(), physicalProperties, child()); + usingStream, isFinalPhase, Optional.empty(), getLogicalProperties(), physicalProperties, child()); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java index dd09c58a50..29ef8bb2ff 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/HavingClauseTest.java @@ -360,9 +360,9 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)"); Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)"); Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)"); - Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); + Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of((byte) 1)), "(COUNT(a1) + 1)"); Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))"); - Alias v1 = new Alias(new ExprId(0), new Count(a2), "v1"); + Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1"); PlanChecker.from(connectContext).analyze(sql) .matchesFromRoot( logicalProject( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java index 08d91b777f..fe0b577cc4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java @@ -263,6 +263,7 @@ public class ChildOutputPropertyDeriverTest { Lists.newArrayList(key), AggPhase.LOCAL, true, + true, logicalProperties, groupPlan ); @@ -286,6 +287,7 @@ public class ChildOutputPropertyDeriverTest { Lists.newArrayList(partition), AggPhase.GLOBAL, true, + true, logicalProperties, groupPlan ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index dda5c0b006..9802a7d66b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -146,6 +146,7 @@ public class RequestPropertyDeriverTest { Lists.newArrayList(key), AggPhase.LOCAL, true, + true, logicalProperties, groupPlan ); @@ -168,6 +169,7 @@ public class RequestPropertyDeriverTest { Lists.newArrayList(partition), AggPhase.GLOBAL, true, + true, logicalProperties, groupPlan ); @@ -192,6 +194,7 @@ public class RequestPropertyDeriverTest { Lists.newArrayList(), AggPhase.GLOBAL, true, + true, logicalProperties, groupPlan ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java index 72f4a8829a..ef32f31def 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Count; import org.apache.doris.nereids.trees.expressions.functions.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.AggPhase; @@ -269,6 +270,86 @@ public class AggregateDisassembleTest { global.getOutputExpressions().get(0).getExprId()); } + /** + * the initial plan is: + * Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2) as c], groupByExpr: [id + 3]) + * +-- childPlan(id, name, age) + * we should rewrite to: + * Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) + 2) as c], groupByExpr: [a]) + * +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a, b]) + * +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age + 1) as b], groupByExpr: [id + 3, age + 1]) + * +-- childPlan(id, name, age) + */ + @Test + public void distinctAggregateWithGroupBy() { + List<Expression> groupExpressionList = Lists.newArrayList( + new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3))); + List<NamedExpression> outputExpressionList = Lists.newArrayList(new Alias( + new Add(new Count(new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true), + new IntegerLiteral(2)), "c")); + Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent); + + Plan after = rewrite(root); + + Assertions.assertTrue(after instanceof LogicalUnary); + Assertions.assertTrue(after instanceof LogicalAggregate); + Assertions.assertTrue(after.child(0) instanceof LogicalUnary); + LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after; + LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0); + LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0).child(0); + Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, distinctLocal.getAggPhase()); + Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase()); + Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase()); + // check local: + // id + 3 + Expression localOutput0 = new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3)); + // age + 1 + Expression localOutput1 = new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)); + // id + 3 + Expression localGroupBy0 = new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3)); + // age + 1 + Expression localGroupBy1 = new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)); + + Assertions.assertEquals(2, local.getOutputExpressions().size()); + Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof Alias); + Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0).child(0)); + Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias); + Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0)); + Assertions.assertEquals(2, local.getGroupByExpressions().size()); + Assertions.assertEquals(localGroupBy0, local.getGroupByExpressions().get(0)); + Assertions.assertEquals(localGroupBy1, local.getGroupByExpressions().get(1)); + + // check global: + Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot(); + Expression globalOutput1 = local.getOutputExpressions().get(1).toSlot(); + Expression globalGroupBy0 = local.getOutputExpressions().get(0).toSlot(); + Expression globalGroupBy1 = local.getOutputExpressions().get(1).toSlot(); + + Assertions.assertEquals(2, global.getOutputExpressions().size()); + Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference); + Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0)); + Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof SlotReference); + Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1)); + Assertions.assertEquals(2, global.getGroupByExpressions().size()); + Assertions.assertEquals(globalGroupBy0, global.getGroupByExpressions().get(0)); + Assertions.assertEquals(globalGroupBy1, global.getGroupByExpressions().get(1)); + + // check distinct local: + Expression distinctLocalOutput = new Add(new Count(local.getOutputExpressions().get(1).toSlot(), true), + new IntegerLiteral(2)); + Expression distinctLocalGroupBy = local.getOutputExpressions().get(0).toSlot(); + + Assertions.assertEquals(1, distinctLocal.getOutputExpressions().size()); + Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) instanceof Alias); + Assertions.assertEquals(distinctLocalOutput, distinctLocal.getOutputExpressions().get(0).child(0)); + Assertions.assertEquals(1, distinctLocal.getGroupByExpressions().size()); + Assertions.assertEquals(distinctLocalGroupBy, distinctLocal.getGroupByExpressions().get(0)); + + // check id: + Assertions.assertEquals(outputExpressionList.get(0).getExprId(), + distinctLocal.getOutputExpressions().get(0).getExprId()); + } + private Plan rewrite(Plan input) { return PlanRewriter.topDownRewrite(input, new ConnectContext(), new AggregateDisassemble()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java index 5860d95c6b..71d8248655 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionEqualsTest.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.analyzer.UnboundAlias; import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.analyzer.UnboundStar; +import org.apache.doris.nereids.trees.expressions.functions.Count; import org.apache.doris.nereids.trees.expressions.functions.Sum; import org.apache.doris.nereids.types.IntegerType; @@ -168,6 +169,25 @@ public class ExpressionEqualsTest { Assertions.assertEquals(sum1.hashCode(), sum2.hashCode()); } + @Test + public void testAggregateFunction() { + Count count1 = new Count(); + Count count2 = new Count(); + Assertions.assertEquals(count1, count2); + Assertions.assertEquals(count1.hashCode(), count2.hashCode()); + + Count count3 = new Count(child1, true); + Count count4 = new Count(child2, true); + Assertions.assertEquals(count3, count4); + Assertions.assertEquals(count3.hashCode(), count4.hashCode()); + + // bad case + Count count5 = new Count(child1, true); + Count count6 = new Count(child2, false); + Assertions.assertNotEquals(count5, count6); + Assertions.assertNotEquals(count5.hashCode(), count6.hashCode()); + } + @Test public void testNamedExpression() { ExprId aliasId = new ExprId(2); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java index 1d7878a2db..cdd5454e78 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java @@ -71,17 +71,17 @@ public class PlanEqualsTest { unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), - true, false, AggPhase.GLOBAL, child); + true, false, true, AggPhase.GLOBAL, child); Assertions.assertNotEquals(unexpected, actual); unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), - false, true, AggPhase.GLOBAL, child); + false, true, true, AggPhase.GLOBAL, child); Assertions.assertNotEquals(unexpected, actual); unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of( new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), - false, false, AggPhase.LOCAL, child); + false, false, true, AggPhase.LOCAL, child); Assertions.assertNotEquals(unexpected, actual); } @@ -178,20 +178,20 @@ public class PlanEqualsTest { List<NamedExpression> outputExpressionList = ImmutableList.of( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())); PhysicalAggregate<Plan> actual = new PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList, - Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, child); + Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child); List<NamedExpression> outputExpressionList1 = ImmutableList.of( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())); PhysicalAggregate<Plan> expected = new PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList1, - Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, child); + Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child); Assertions.assertEquals(expected, actual); List<NamedExpression> outputExpressionList2 = ImmutableList.of( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())); PhysicalAggregate<Plan> unexpected = new PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList2, - Lists.newArrayList(), AggPhase.LOCAL, false, logicalProperties, child); + Lists.newArrayList(), AggPhase.LOCAL, false, true, logicalProperties, child); Assertions.assertNotEquals(unexpected, actual); } diff --git a/regression-test/data/nereids_syntax_p0/function.out b/regression-test/data/nereids_syntax_p0/function.out index cac9a7c5b1..b1d705b814 100644 --- a/regression-test/data/nereids_syntax_p0/function.out +++ b/regression-test/data/nereids_syntax_p0/function.out @@ -11,6 +11,11 @@ -- !count -- 3 3 +-- !distinct_count -- +1 +1 +1 + -- !avg -- 2.5E-323 1.1644193E-317 diff --git a/regression-test/suites/nereids_syntax_p0/function.groovy b/regression-test/suites/nereids_syntax_p0/function.groovy index c4099a0798..a041fc36ab 100644 --- a/regression-test/suites/nereids_syntax_p0/function.groovy +++ b/regression-test/suites/nereids_syntax_p0/function.groovy @@ -41,6 +41,10 @@ suite("function") { SELECT count(c_city), count(*) AS custdist FROM customer; """ + order_qt_distinct_count """ + SELECT count(distinct c_custkey + 1) AS custdist FROM customer group by c_city; + """ + order_qt_avg """ SELECT avg(lo_tax), avg(lo_extendedprice) AS avg_extendedprice FROM lineorder; """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org