This is an automated email from the ASF dual-hosted git repository. kxiao pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit b1e39ad7859d22c37a34d9704787bb481f0d0f58 Author: morrySnow <101034200+morrys...@users.noreply.github.com> AuthorDate: Fri Jul 7 17:45:55 2023 +0800 [opt](Nereids) forbid some bad case on agg plans (#21565) 1. forbid all candidates that need to gather process except must do it 2. forbid do local agg after reshuffle of two phase agg of distinct 3. forbid one phase agg after reshuffle 4. forbid three or four phase agg for distinct if any stage need reshuffle 5. forbid multi distinct for one distinct agg if do not need reshuffle --- .../glue/translator/PhysicalPlanTranslator.java | 2 +- .../properties/ChildrenPropertiesRegulator.java | 53 +++- .../rules/implementation/AggregateStrategies.java | 333 ++++++++++++--------- .../expressions/functions/agg/AggregateParam.java | 31 +- .../functions/agg/MultiDistinctCount.java | 2 +- .../functions/agg/MultiDistinctGroupConcat.java | 2 +- .../functions/agg/MultiDistinctSum.java | 4 +- .../functions/agg/MultiDistinction.java | 27 ++ .../nereids/trees/plans/algebra/Aggregate.java | 6 +- .../physical/PhysicalStorageLayerAggregate.java | 14 +- .../apache/doris/nereids/util/ExpressionUtils.java | 5 + .../rules/rewrite/AggregateStrategiesTest.java | 2 + .../nereids_tpcds_shape_sf100_p0/shape/query94.out | 4 +- .../nereids_tpcds_shape_sf100_p0/shape/query95.out | 4 +- .../nereids_tpch_shape_sf1000_p0/shape/q16.out | 4 +- .../data/nereids_tpch_shape_sf500_p0/shape/q16.out | 4 +- .../suites/nereids_p0/join/test_join.groovy | 9 - .../suites/nereids_syntax_p0/agg_4_phase.groovy | 12 +- .../nereids_syntax_p0/aggregate_strategies.groovy | 7 +- .../suites/nereids_syntax_p0/group_concat.groovy | 4 +- 20 files changed, 314 insertions(+), 215 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 650bcc43a9..9844179633 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -1742,7 +1742,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla .map(e -> { Expression function = e.child(0).child(0); if (function instanceof AggregateFunction) { - AggregateParam param = AggregateParam.localResult(); + AggregateParam param = AggregateParam.LOCAL_RESULT; function = new AggregateExpression((AggregateFunction) function, param); } return ExpressionTranslator.translate(function, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 811c26569a..54e4c4780a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -23,7 +23,12 @@ import org.apache.doris.nereids.cost.CostCalculator; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; +import org.apache.doris.nereids.trees.expressions.AggregateExpression; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; @@ -43,6 +48,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * ensure child add enough distribute. update children properties if we do regular @@ -88,12 +94,55 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> { @Override public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg, Void context) { - // forbid one phase agg on distribute - if (agg.getAggMode() == AggMode.INPUT_TO_RESULT + if (!agg.getAggregateParam().canBeBanned) { + return true; + } + // forbid one phase agg on distribute and three or four stage distinct agg inter by distribute + if ((agg.getAggMode() == AggMode.INPUT_TO_RESULT || agg.getAggMode() == AggMode.BUFFER_TO_BUFFER) && children.get(0).getPlan() instanceof PhysicalDistribute) { // this means one stage gather agg, usually bad pattern return false; } + // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle + // TODO: this is forbid good plan after cte reuse by mistake + if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER + && requiredProperties.get(0).getDistributionSpec() instanceof DistributionSpecHash + && children.get(0).getPlan() instanceof PhysicalDistribute) { + return false; + } + // forbid multi distinct opt that bad than multi-stage version when multi-stage can be executed in one fragment + if (agg.getAggMode() == AggMode.INPUT_TO_BUFFER || agg.getAggMode() == AggMode.INPUT_TO_RESULT) { + List<MultiDistinction> multiDistinctions = agg.getOutputExpressions().stream() + .filter(Alias.class::isInstance) + .map(a -> ((Alias) a).child()) + .filter(AggregateExpression.class::isInstance) + .map(a -> ((AggregateExpression) a).getFunction()) + .filter(MultiDistinction.class::isInstance) + .map(MultiDistinction.class::cast) + .collect(Collectors.toList()); + if (multiDistinctions.size() == 1) { + Expression distinctChild = multiDistinctions.get(0).child(0); + DistributionSpec childDistribution = childrenProperties.get(0).getDistributionSpec(); + if (distinctChild instanceof SlotReference && childDistribution instanceof DistributionSpecHash) { + SlotReference slotReference = (SlotReference) distinctChild; + DistributionSpecHash distributionSpecHash = (DistributionSpecHash) childDistribution; + List<ExprId> groupByColumns = agg.getGroupByExpressions().stream() + .map(SlotReference.class::cast) + .map(SlotReference::getExprId) + .collect(Collectors.toList()); + DistributionSpecHash groupByRequire = new DistributionSpecHash( + groupByColumns, ShuffleType.REQUIRE); + List<ExprId> distinctChildColumns = Lists.newArrayList(slotReference.getExprId()); + distinctChildColumns.add(slotReference.getExprId()); + DistributionSpecHash distinctChildRequire = new DistributionSpecHash( + distinctChildColumns, ShuffleType.REQUIRE); + if ((!groupByColumns.isEmpty() && distributionSpecHash.satisfy(groupByRequire)) + || (groupByColumns.isEmpty() && distributionSpecHash.satisfy(distinctChildRequire))) { + return false; + } + } + } + } // process must shuffle visit(agg, context); // process agg diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 551e73532c..6df6b2f817 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -70,11 +70,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -123,41 +125,43 @@ public class AggregateStrategies implements ImplementationRuleFactory { .when(agg -> agg.getDistinctArguments().size() == 0) .thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext)) ), - RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( - basePattern - .when(this::containsCountDistinctMultiExpr) - .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext)) - ), + // RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( + // basePattern + // .when(this::containsCountDistinctMultiExpr) + // .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext)) + // ), RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build( basePattern .when(this::containsCountDistinctMultiExpr) .thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext)) ), - RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build( - basePattern - .when(agg -> agg.getDistinctArguments().size() == 1) - .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) - ), RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build( basePattern - .when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt()) + .when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg)) .thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) ), RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build( basePattern - .when(agg -> agg.getDistinctArguments().size() == 1 && enableSingleDistinctColumnOpt()) + .when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg)) .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) ), - RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build( - basePattern - .when(agg -> agg.getDistinctArguments().size() == 1) - .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) - ), RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build( basePattern - .when(agg -> agg.getDistinctArguments().size() > 1 && !containsCountDistinctMultiExpr(agg)) + .when(agg -> agg.getDistinctArguments().size() > 1 + && !containsCountDistinctMultiExpr(agg) + && couldConvertToMulti(agg)) .thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext)) ), + // RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build( + // basePattern + // .when(agg -> agg.getDistinctArguments().size() == 1) + // .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + // ), + RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + ), RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build( basePattern .when(agg -> agg.getDistinctArguments().size() == 1) @@ -169,15 +173,15 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(*) from tbl - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[], output=[count(*)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * LogicalAggregate(groupBy=[], output=[count(*)]) * | * PhysicalStorageLayerAggregate(pushAggOp=COUNT, table=PhysicalOlapScan(table=tbl)) @@ -205,7 +209,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .map(AggregateFunction::getClass) .collect(Collectors.toSet()); - Map<Class, PushDownAggOp> supportedAgg = PushDownAggOp.supportedFunctions(); + Map<Class<? extends AggregateFunction>, PushDownAggOp> supportedAgg = PushDownAggOp.supportedFunctions(); if (!supportedAgg.keySet().containsAll(functionClasses)) { return canNotPush; } @@ -292,7 +296,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { || colType == PrimitiveType.STRING) { return canNotPush; } - if (colType.isCharFamily() && mergeOp != PushDownAggOp.COUNT && column.getType().getLength() > 512) { + if (colType.isCharFamily() && column.getType().getLength() > 512) { return canNotPush; } } @@ -324,25 +328,25 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(*) from tbl group by id - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[id], output=[count(*)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[id], output=[count(*)]) * | * PhysicalDistribute(distributionSpec=GATHER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[id], output=[count(*)]) * | * LogicalOlapScan(table=tbl, **already distribute by id**) @@ -351,7 +355,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { private List<PhysicalHashAggregate<Plan>> onePhaseAggregateWithoutDistinct( LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) { RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); - AggregateParam inputToResultParam = AggregateParam.localResult(); + AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT; List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), outputChild -> { if (outputChild instanceof AggregateFunction) { @@ -366,7 +370,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { requireGather, logicalAgg.child()); if (logicalAgg.getGroupByExpressions().isEmpty()) { - return ImmutableList.of(gatherLocalAgg); + // TODO: usually bad, disable it until we could do better cost computation. + // return ImmutableList.of(gatherLocalAgg); + return ImmutableList.of(); } else { RequireProperties requireHash = RequireProperties.of( PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); @@ -383,17 +389,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id, name) from tbl group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[count(distinct id, name)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id]) @@ -401,9 +407,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalDistribute(distributionSpec=GATHER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id]) @@ -415,7 +421,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { */ private List<PhysicalHashAggregate<Plan>> twoPhaseAggregateWithCountDistinctMulti( LogicalAggregate<? extends Plan> logicalAgg, CascadesContext cascadesContext) { - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER; Collection<Expression> countDistinctArguments = logicalAgg.getDistinctArguments(); List<Expression> localAggGroupBy = ImmutableList.copyOf(ImmutableSet.<Expression>builder() @@ -487,7 +493,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { .withRequireTree(requireHash.withChildren(requireHash)) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.<PhysicalHashAggregate<Plan>>builder() - .add(gatherLocalGatherDistinctAgg) + // TODO: usually bad, disable it until we could do better cost computation. + //.add(gatherLocalGatherDistinctAgg) .add(hashLocalHashGlobalAgg) .build(); } @@ -495,17 +502,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id, name) from tbl group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[count(distinct id, name)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -515,9 +522,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[count(if(id is null, null, name))]) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -566,11 +573,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { List<Expression> globalAggGroupBy = localAggGroupBy; - AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + boolean hasCountDistinctMulti = logicalAgg.getAggregateFunctions().stream() + .filter(AggregateFunction::isDistinct) + .filter(Count.class::isInstance) + .anyMatch(c -> c.arity() > 1); + AggregateParam bufferToBufferParam = new AggregateParam( + AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, !hasCountDistinctMulti); + Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 = nonDistinctAggFunctionToAliasPhase1.entrySet() .stream() - .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { AggregateFunction originFunction = kv.getKey(); Alias localOutputAlias = kv.getValue(); AggregateExpression globalAggExpr = new AggregateExpression( @@ -596,7 +609,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { logicalAgg, cascadesContext).first; AggregateParam distinctInputToResultParam - = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, !hasCountDistinctMulti); AggregateParam globalBufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit( @@ -621,19 +634,19 @@ public class AggregateStrategies implements ImplementationRuleFactory { logicalAgg.getLogicalProperties(), requireGather, anyLocalGatherGlobalAgg ); - RequireProperties requireDistinctHash = RequireProperties.of( - PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); - PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalGatherDistinctAgg - = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of( - anyLocalGatherGlobalAgg - .withRequire(requireDistinctHash) - .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) - )); + // RequireProperties requireDistinctHash = RequireProperties.of( + // PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); + // PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalGatherDistinctAgg + // = anyLocalGatherGlobalGatherAgg.withChildren(ImmutableList.of( + // anyLocalGatherGlobalAgg + // .withRequire(requireDistinctHash) + // .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + // )); if (logicalAgg.getGroupByExpressions().isEmpty()) { return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() .add(anyLocalGatherGlobalGatherAgg) - .add(anyLocalHashGlobalGatherDistinctAgg) + //.add(anyLocalHashGlobalGatherDistinctAgg) .build(); } else { RequireProperties requireGroupByHash = RequireProperties.of( @@ -646,8 +659,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { ) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() - .add(anyLocalGatherGlobalGatherAgg) - .add(anyLocalHashGlobalGatherDistinctAgg) + // .add(anyLocalGatherGlobalGatherAgg) + // .add(anyLocalHashGlobalGatherDistinctAgg) .add(anyLocalHashGlobalHashDistinctAgg) .build(); } @@ -655,17 +668,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select name, count(value) from tbl group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[name, count(value)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=GATHER) @@ -673,9 +686,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, count(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=HASH(name)) @@ -713,6 +726,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT); List<NamedExpression> globalAggOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), outputChild -> { + if (!(outputChild instanceof AggregateFunction)) { + return outputChild; + } Alias inputToBufferAlias = inputToBufferAliases.get(outputChild); if (inputToBufferAlias == null) { return outputChild; @@ -722,7 +738,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { }); RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER); - PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate( + PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>( localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions), bufferToResultParam, false, anyLocalAgg.getLogicalProperties(), requireGather, anyLocalAgg); @@ -746,17 +762,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from tbl group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) @@ -764,9 +780,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalDistribute(distributionSpec=GATHER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) @@ -781,7 +797,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions(); Set<Expression> distinctArguments = aggregateFunctions.stream() - .filter(aggregateExpression -> aggregateExpression.isDistinct()) + .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); @@ -790,7 +806,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .addAll(distinctArguments) .build(); - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, true); + AggregateParam inputToBufferParam = AggregateParam.LOCAL_BUFFER; Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() .filter(aggregateFunction -> !aggregateFunction.isDistinct()) @@ -822,10 +838,13 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (outputChild instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) outputChild; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); - return new AggregateExpression(nonDistinct, AggregateParam.localResult()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); + return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT); } else { Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild); return new AggregateExpression( @@ -850,7 +869,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) )); return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() - .add(gatherLocalGatherGlobalAgg) + //.add(gatherLocalGatherGlobalAgg) .add(hashLocalGatherGlobalAgg) .build(); } else { @@ -863,7 +882,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { ) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() - .add(gatherLocalGatherGlobalAgg) + // .add(gatherLocalGatherGlobalAgg) .add(hashLocalHashGlobalAgg) .build(); } @@ -871,16 +890,16 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from tbl group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -890,9 +909,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT) * | * PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER) @@ -907,10 +926,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { // TODO: support one phase aggregate(group by columns + distinct columns) + two phase distinct aggregate private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithDistinct( LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) { + boolean couldBanned = couldConvertToMulti(logicalAgg); + Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions(); Set<Expression> distinctArguments = aggregateFunctions.stream() - .filter(aggregateExpression -> aggregateExpression.isDistinct()) + .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); @@ -919,7 +940,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { .addAll(distinctArguments) .build(); - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() .filter(aggregateFunction -> !aggregateFunction.isDistinct()) @@ -942,11 +963,11 @@ public class AggregateStrategies implements ImplementationRuleFactory { maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()); - AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned); Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 = nonDistinctAggFunctionToAliasPhase1.entrySet() .stream() - .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { AggregateFunction originFunction = kv.getKey(); Alias localOutput = kv.getValue(); AggregateExpression globalAggExpr = new AggregateExpression( @@ -965,15 +986,19 @@ public class AggregateStrategies implements ImplementationRuleFactory { bufferToBufferParam, false, logicalAgg.getLogicalProperties(), requireGather, anyLocalAgg); - AggregateParam bufferToResultParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT); + AggregateParam bufferToResultParam = new AggregateParam( + AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_RESULT, couldBanned); List<NamedExpression> distinctOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), expr -> { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); return new AggregateExpression(nonDistinct, bufferToResultParam, aggregateFunction.child(0)); } else { @@ -1017,8 +1042,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { ) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() - .add(anyLocalGatherGlobalGatherDistinctAgg) - .add(anyLocalHashGlobalGatherDistinctAgg) + // TODO: this plan pattern is not good usually, we remove it temporary. + //.add(anyLocalGatherGlobalGatherDistinctAgg) + //.add(anyLocalHashGlobalGatherDistinctAgg) .add(anyLocalHashGlobalHashDistinctAgg) .build(); } @@ -1026,25 +1052,25 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from (...) group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[count(distinct id)]) * | * any plan - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)]) * | * PhysicalDistribute(distributionSpec=GATHER) * | * any plan - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[multi_distinct_count(id)]) * | * any plan(**already distribute by name**) @@ -1052,7 +1078,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { */ private List<PhysicalHashAggregate<? extends Plan>> onePhaseAggregateWithMultiDistinct( LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) { - AggregateParam inputToResultParam = AggregateParam.localResult(); + AggregateParam inputToResultParam = AggregateParam.LOCAL_RESULT; List<NamedExpression> newOutput = ExpressionUtils.rewriteDownShortCircuit( logicalAgg.getOutputExpressions(), outputChild -> { if (outputChild instanceof AggregateFunction) { @@ -1068,7 +1094,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { maybeUsingStreamAgg(connectContext, logicalAgg), logicalAgg.getLogicalProperties(), requireGather, logicalAgg.child()); if (logicalAgg.getGroupByExpressions().isEmpty()) { - return ImmutableList.of(gatherLocalAgg); + // TODO: usually bad, disable it until we could do better cost computation. + // return ImmutableList.of(gatherLocalAgg); + return ImmutableList.of(); } else { RequireProperties requireHash = RequireProperties.of( PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); @@ -1085,17 +1113,17 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * sql: select count(distinct id) from tbl group by name - * + * <p> * before: - * + * <p> * LogicalAggregate(groupBy=[name], output=[name, count(distinct id)]) * | * LogicalOlapScan(table=tbl) - * + * <p> * after: - * + * <p> * single node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=GATHER) @@ -1103,9 +1131,9 @@ public class AggregateStrategies implements ImplementationRuleFactory { * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=INPUT_TO_BUFFER) * | * LogicalOlapScan(table=tbl) - * + * <p> * distribute node aggregate: - * + * <p> * PhysicalHashAggregate(groupBy=[name], output=[name, multi_count_distinct(value)], mode=BUFFER_TO_RESULT) * | * PhysicalDistribute(distributionSpec=HASH(name)) @@ -1157,17 +1185,16 @@ public class AggregateStrategies implements ImplementationRuleFactory { RequireProperties.of(PhysicalProperties.GATHER), anyLocalAgg); if (logicalAgg.getGroupByExpressions().isEmpty()) { - Collection<Expression> distinctArguments = logicalAgg.getDistinctArguments(); - RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash( - distinctArguments, ShuffleType.REQUIRE)); - PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg - .withChildren(ImmutableList.of(anyLocalAgg - .withRequire(requireDistinctHash) - .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) - )); + // Collection<Expression> distinctArguments = logicalAgg.getDistinctArguments(); + // RequireProperties requireDistinctHash = RequireProperties.of(PhysicalProperties.createHash( + // distinctArguments, ShuffleType.REQUIRE)); + // PhysicalHashAggregate<? extends Plan> hashLocalGatherGlobalAgg = anyLocalGatherGlobalAgg + // .withChildren(ImmutableList.of(anyLocalAgg + // .withRequire(requireDistinctHash) + // .withPartitionExpressions(ImmutableList.copyOf(logicalAgg.getDistinctArguments())) + // )); return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() .add(anyLocalGatherGlobalAgg) - .add(hashLocalGatherGlobalAgg) .build(); } else { RequireProperties requireHash = RequireProperties.of( @@ -1176,7 +1203,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { .withRequire(requireHash) .withPartitionExpressions(logicalAgg.getGroupByExpressions()); return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder() - .add(anyLocalGatherGlobalAgg) + // TODO: usually bad, disable it until we could do better cost computation. + // .add(anyLocalGatherGlobalAgg) .add(anyLocalHashGlobalAgg) .build(); } @@ -1215,7 +1243,7 @@ public class AggregateStrategies implements ImplementationRuleFactory { /** * countDistinctMultiExprToCountIf. - * + * <p> * NOTE: this function will break the normalized output, e.g. from `count(distinct slot1, slot2)` to * `count(if(slot1 is null, null, slot2))`. So if you invoke this method, and separate the * phase of aggregate, please normalize to slot and create a bottom project like NormalizeAggregate. @@ -1268,15 +1296,10 @@ public class AggregateStrategies implements ImplementationRuleFactory { return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg(); } - private boolean enableSingleDistinctColumnOpt() { - ConnectContext connectContext = ConnectContext.get(); - return connectContext == null || connectContext.getSessionVariable().enableSingleDistinctColumnOpt(); - } - /** * sql: * select count(distinct name), sum(age) from student; - * + * <p> * 4 phase plan * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER] * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name @@ -1286,26 +1309,29 @@ public class AggregateStrategies implements ImplementationRuleFactory { */ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinct( LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) { + boolean couldBanned = couldConvertToMulti(logicalAgg); + Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions(); - Set<Expression> distinctArguments = aggregateFunctions.stream() - .filter(aggregateExpression -> aggregateExpression.isDistinct()) + Set<NamedExpression> distinctArguments = aggregateFunctions.stream() + .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .map(NamedExpression.class::cast) .collect(ImmutableSet.toImmutableSet()); Set<NamedExpression> localAggGroupBySet = ImmutableSet.<NamedExpression>builder() - .addAll((List) logicalAgg.getGroupByExpressions()) + .addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions()) .addAll(distinctArguments) .build(); - AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, true); + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() .filter(aggregateFunction -> !aggregateFunction.isDistinct()) .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); return new Alias(localAggExpr, localAggExpr.toSql()); - })); + }, (oldValue, newValue) -> newValue)); List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder() .addAll(localAggGroupBySet) @@ -1321,11 +1347,11 @@ public class AggregateStrategies implements ImplementationRuleFactory { maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), requireAny, logicalAgg.child()); - AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER); + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned); Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 = nonDistinctAggFunctionToAliasPhase1.entrySet() .stream() - .collect(ImmutableMap.toImmutableMap(kv -> kv.getKey(), kv -> { + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { AggregateFunction originFunction = kv.getKey(); Alias localOutput = kv.getValue(); AggregateExpression globalAggExpr = new AggregateExpression( @@ -1350,7 +1376,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { requireDistinctHash, anyLocalAgg); // phase 3 - AggregateParam distinctLocalParam = new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER); + AggregateParam distinctLocalParam = new AggregateParam( + AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = new HashMap<>(); List<NamedExpression> localDistinctOutput = Lists.newArrayList(); for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) { @@ -1361,9 +1388,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); AggregateExpression nonDistinctAggExpr = new AggregateExpression(nonDistinct, distinctLocalParam, aggregateFunction.child(0)); return nonDistinctAggExpr; @@ -1389,7 +1419,8 @@ public class AggregateStrategies implements ImplementationRuleFactory { requireDistinctHash, anyLocalHashGlobalAgg); //phase 4 - AggregateParam distinctGlobalParam = new AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT); + AggregateParam distinctGlobalParam = new AggregateParam( + AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, couldBanned); List<NamedExpression> globalDistinctOutput = Lists.newArrayList(); for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) { NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i); @@ -1397,9 +1428,12 @@ public class AggregateStrategies implements ImplementationRuleFactory { if (expr instanceof AggregateFunction) { AggregateFunction aggregateFunction = (AggregateFunction) expr; if (aggregateFunction.isDistinct()) { - Preconditions.checkArgument(aggregateFunction.arity() == 1); + Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); AggregateFunction nonDistinct = aggregateFunction - .withDistinctAndChildren(false, aggregateFunction.getArguments()); + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); int idx = logicalAgg.getOutputExpressions().indexOf(outputExpr); Alias localDistinctAlias = (Alias) (localDistinctOutput.get(idx)); return new AggregateExpression(nonDistinct, @@ -1424,4 +1458,11 @@ public class AggregateStrategies implements ImplementationRuleFactory { .add(distinctGlobal) .build(); } + + private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> aggregate) { + return ExpressionUtils.noneMatch(aggregate.getOutputExpressions(), expr -> + expr instanceof AggregateFunction && ((AggregateFunction) expr).isDistinct() + && (expr.arity() > 1 + || !(expr instanceof Count || expr instanceof Sum || expr instanceof GroupConcat))); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java index 2ff8eb262f..89e9c1ea7d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateParam.java @@ -25,43 +25,35 @@ import java.util.Objects; /** AggregateParam. */ public class AggregateParam { - public final AggPhase aggPhase; + public static AggregateParam LOCAL_RESULT = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT); + public static AggregateParam LOCAL_BUFFER = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER); + public final AggPhase aggPhase; public final AggMode aggMode; - - // TODO remove this flag, and generate it in enforce and cost job - public boolean needColocateScan; + // TODO: this is a short-term plan to process count(distinct a, b) correctly + public final boolean canBeBanned; /** AggregateParam */ public AggregateParam(AggPhase aggPhase, AggMode aggMode) { - this(aggPhase, aggMode, false); + this(aggPhase, aggMode, true); } - /** AggregateParam */ - public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean needColocateScan) { + public AggregateParam(AggPhase aggPhase, AggMode aggMode, boolean canBeBanned) { this.aggMode = Objects.requireNonNull(aggMode, "aggMode cannot be null"); this.aggPhase = Objects.requireNonNull(aggPhase, "aggPhase cannot be null"); - this.needColocateScan = needColocateScan; - } - - public static AggregateParam localResult() { - return new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT, true); + this.canBeBanned = canBeBanned; } public AggregateParam withAggPhase(AggPhase aggPhase) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); + return new AggregateParam(aggPhase, aggMode, canBeBanned); } public AggregateParam withAggPhase(AggMode aggMode) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); + return new AggregateParam(aggPhase, aggMode, canBeBanned); } public AggregateParam withAppPhaseAndAppMode(AggPhase aggPhase, AggMode aggMode) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); - } - - public AggregateParam withNeedColocateScan(boolean needColocateScan) { - return new AggregateParam(aggPhase, aggMode, needColocateScan); + return new AggregateParam(aggPhase, aggMode, canBeBanned); } @Override @@ -87,7 +79,6 @@ public class AggregateParam { return "AggregateParam{" + "aggPhase=" + aggPhase + ", aggMode=" + aggMode - + ", needColocateScan=" + needColocateScan + '}'; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java index 72a26d288e..b9e7c1fdb1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java @@ -35,7 +35,7 @@ import java.util.List; /** MultiDistinctCount */ public class MultiDistinctCount extends AggregateFunction - implements AlwaysNotNullable, ExplicitlyCastableSignature { + implements AlwaysNotNullable, ExplicitlyCastableSignature, MultiDistinction { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java index 737e895906..5f5bb6815a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java @@ -36,7 +36,7 @@ import java.util.List; /** MultiDistinctGroupConcat */ public class MultiDistinctGroupConcat extends NullableAggregateFunction - implements ExplicitlyCastableSignature { + implements ExplicitlyCastableSignature, MultiDistinction { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java index a378dc0960..8441e02828 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java @@ -35,8 +35,8 @@ import com.google.common.collect.ImmutableList; import java.util.List; /** MultiDistinctSum */ -public class MultiDistinctSum extends AggregateFunction - implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature, ComputePrecisionForSum { +public class MultiDistinctSum extends AggregateFunction implements UnaryExpression, AlwaysNotNullable, + ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java new file mode 100644 index 0000000000..ab8842f730 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.nereids.trees.TreeNode; +import org.apache.doris.nereids.trees.expressions.Expression; + +/** + * base class of multi-distinct agg function + */ +public interface MultiDistinction extends TreeNode<Expression> { +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 6731bde58a..8361e230be 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -25,7 +25,7 @@ import org.apache.doris.nereids.trees.plans.UnaryPlan; import org.apache.doris.nereids.trees.plans.logical.OutputPrunable; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.util.List; import java.util.Set; @@ -53,10 +53,10 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends UnaryPlan<CHILD_TYPE return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance); } - default List<Expression> getDistinctArguments() { + default Set<Expression> getDistinctArguments() { return getAggregateFunctions().stream() .filter(AggregateFunction::isDistinct) .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) - .collect(ImmutableList.toImmutableList()); + .collect(ImmutableSet.toImmutableSet()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java index ed57eebcb0..094f5d75cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalStorageLayerAggregate.java @@ -21,7 +21,7 @@ import org.apache.doris.catalog.Table; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; -import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.Utils; import org.apache.doris.statistics.Statistics; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -78,11 +77,6 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation { return visitor.visitPhysicalStorageLayerAggregate(this, context); } - @Override - public List<? extends Expression> getExpressions() { - return ImmutableList.of(); - } - @Override public boolean equals(Object o) { if (this == o) { @@ -113,7 +107,7 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation { } public PhysicalStorageLayerAggregate withPhysicalOlapScan(PhysicalOlapScan physicalOlapScan) { - return new PhysicalStorageLayerAggregate(relation, aggOp); + return new PhysicalStorageLayerAggregate(physicalOlapScan, aggOp); } @Override @@ -142,8 +136,8 @@ public class PhysicalStorageLayerAggregate extends PhysicalRelation { public enum PushDownAggOp { COUNT, MIN_MAX, MIX; - public static Map<Class, PushDownAggOp> supportedFunctions() { - return ImmutableMap.<Class, PushDownAggOp>builder() + public static Map<Class<? extends AggregateFunction>, PushDownAggOp> supportedFunctions() { + return ImmutableMap.<Class<? extends AggregateFunction>, PushDownAggOp>builder() .put(Count.class, PushDownAggOp.COUNT) .put(Min.class, PushDownAggOp.MIN_MAX) .put(Max.class, PushDownAggOp.MIN_MAX) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index e1fbefad61..a3a3ca1b80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -403,6 +403,11 @@ public class ExpressionUtils { .anyMatch(expr -> expr.anyMatch(predicate)); } + public static boolean noneMatch(List<? extends Expression> expressions, Predicate<TreeNode<Expression>> predicate) { + return expressions.stream() + .noneMatch(expr -> expr.anyMatch(predicate)); + } + public static boolean containsType(List<? extends Expression> expressions, Class type) { return anyMatch(expressions, type::isInstance); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java index 417605bfa5..e8419b2458 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java @@ -223,6 +223,8 @@ public class AggregateStrategiesTest implements MemoPatternMatchSupported { * </pre> */ @Test + @Disabled + @Developing("reopen it after we could choose agg phase by CBO") public void distinctAggregateWithoutGroupByApply2PhaseRule() { List<Expression> groupExpressionList = new ArrayList<>(); List<NamedExpression> outputExpressionList = Lists.newArrayList(new Alias( diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out index e12f8482f5..c26035693e 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query94.out @@ -4,8 +4,8 @@ PhysicalTopN --PhysicalTopN ----PhysicalProject ------hashAgg[GLOBAL] ---------hashAgg[LOCAL] -----------PhysicalDistribute +--------PhysicalDistribute +----------hashAgg[LOCAL] ------------PhysicalProject --------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk) ----------------PhysicalProject diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out index 6ddca5c0c2..014535c50d 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query95.out @@ -14,8 +14,8 @@ CteAnchor[cteId= ( CTEId#3=] ) ----PhysicalTopN ------PhysicalProject --------hashAgg[GLOBAL] -----------hashAgg[LOCAL] -------------PhysicalDistribute +----------PhysicalDistribute +------------hashAgg[LOCAL] --------------PhysicalProject ----------------hashJoin[INNER_JOIN](ws1.ws_ship_date_sk = date_dim.d_date_sk) ------------------PhysicalProject diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out index 515a72a29d..d72afcf57d 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q16.out @@ -4,8 +4,8 @@ PhysicalQuickSort --PhysicalDistribute ----PhysicalQuickSort ------hashAgg[GLOBAL] ---------hashAgg[LOCAL] -----------PhysicalDistribute +--------PhysicalDistribute +----------hashAgg[LOCAL] ------------PhysicalProject --------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = supplier.s_suppkey) ----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey) diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out index 515a72a29d..d72afcf57d 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q16.out @@ -4,8 +4,8 @@ PhysicalQuickSort --PhysicalDistribute ----PhysicalQuickSort ------hashAgg[GLOBAL] ---------hashAgg[LOCAL] -----------PhysicalDistribute +--------PhysicalDistribute +----------hashAgg[LOCAL] ------------PhysicalProject --------------hashJoin[LEFT_ANTI_JOIN](partsupp.ps_suppkey = supplier.s_suppkey) ----------------hashJoin[INNER_JOIN](part.p_partkey = partsupp.ps_partkey) diff --git a/regression-test/suites/nereids_p0/join/test_join.groovy b/regression-test/suites/nereids_p0/join/test_join.groovy index ce643f081a..584b80384a 100644 --- a/regression-test/suites/nereids_p0/join/test_join.groovy +++ b/regression-test/suites/nereids_p0/join/test_join.groovy @@ -24,15 +24,6 @@ suite("test_join", "nereids_p0") { def tbName1 = "test" def tbName2 = "baseall" def tbName3 = "bigtable" - def empty_name = "empty" - - qt_agg_sql1 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL) from test;""" - qt_agg_sql2 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL), avg(k2) from baseall;""" - qt_agg_sql3 """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by k1;""" - - qt_agg_sql4 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL) from test;""" - qt_agg_sql5 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ count(distinct k1, NULL), avg(k2) from baseall;""" - qt_agg_sql6 """select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI')*/ k1,count(distinct k2,k3),min(k4),count(*) from baseall group by k1 order by k1;""" order_sql """select j.*, d.* from ${tbName2} j full outer join ${tbName1} d on (j.k1=d.k1) order by j.k1, j.k2, j.k3, j.k4, d.k1, d.k2 limit 100""" diff --git a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy index d2d48e3e08..a672f8dee3 100644 --- a/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy +++ b/regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy @@ -43,16 +43,16 @@ suite("agg_4_phase") { (0, 0, "aa", 10), (1, 1, "bb",20), (2, 2, "cc", 30), (1, 1, "bb",20); """ def test_sql = """ - select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT,TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ - count(distinct name), sum(age) + select + count(distinct id) from agg_4_phase_tbl; """ explain{ sql(test_sql) - contains "6:VAGGREGATE (merge finalize)" - contains "5:VEXCHANGE" - contains "4:VAGGREGATE (update serialize)" - contains "3:VAGGREGATE (merge serialize)" + contains "5:VAGGREGATE (merge finalize)" + contains "4:VEXCHANGE" + contains "3:VAGGREGATE (update serialize)" + contains "2:VAGGREGATE (merge serialize)" contains "1:VAGGREGATE (update serialize)" } qt_4phase (test_sql) diff --git a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy index 39742e8d21..ea63b5b789 100644 --- a/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy +++ b/regression-test/suites/nereids_syntax_p0/aggregate_strategies.groovy @@ -82,7 +82,6 @@ suite("aggregate_strategies") { explain { sql """ select - /*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT, FOUR_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id) from $tableName """ @@ -90,17 +89,17 @@ suite("aggregate_strategies") { notContains "STREAMING" } + // test multi_distinct test { sql """select - /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ - count(distinct id) + count(distinct name) from $tableName""" result([[5L]]) } + // test four phase distinct test { sql """select - /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id) from $tableName""" result([[5L]]) diff --git a/regression-test/suites/nereids_syntax_p0/group_concat.groovy b/regression-test/suites/nereids_syntax_p0/group_concat.groovy index fe2062d66d..60f52c2ba0 100644 --- a/regression-test/suites/nereids_syntax_p0/group_concat.groovy +++ b/regression-test/suites/nereids_syntax_p0/group_concat.groovy @@ -21,14 +21,14 @@ suite("group_concat") { test { - sql """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/ + sql """select group_concat(cast(number as string), ',' order by number) from numbers('number'='10')""" result([["0,1,2,3,4,5,6,7,8,9"]]) } test { - sql """select /*+SET_VAR(disable_nereids_rules='ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT')*/ + sql """select group_concat(cast(number as string), ',' order by number) from numbers('number'='10')""" result([["0,1,2,3,4,5,6,7,8,9"]]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org