This is an automated email from the ASF dual-hosted git repository. englefly pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new fa3bdbce966 [opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE (#43856) fa3bdbce966 is described below commit fa3bdbce966512da6986046df8558c1d04e93f61 Author: minghong <zhoumingh...@selectdb.com> AuthorDate: Thu Nov 28 11:19:32 2024 +0800 [opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE (#43856) ### What problem does this PR solve? PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE has some restrictions do not support count(*) do not support join with other join conditions do not support the project between agg and join that contains non-slot expressions this pr removes above restrictions for pattern: agg-project-join --- .../rewrite/PushDownAggThroughJoinOneSide.java | 123 +++++++++++++++------ .../rewrite/PushDownMinMaxSumThroughJoinTest.java | 16 ++- .../push_down_count_through_join_one_side.out | 22 ++++ .../push_down_count_through_join_one_side.groovy | 95 ++++++++++++++++ 4 files changed, 212 insertions(+), 44 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index f32bf8ea91b..c5d3d0fb49a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; @@ -74,8 +75,8 @@ public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { Set<AggregateFunction> funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && !((Count) f).isCountStar())) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count && !f.isDistinct() + && (f.children().isEmpty() || f.child(0) instanceof Slot))); }) .thenApply(ctx -> { Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -88,15 +89,16 @@ public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { }) .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE), logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + // .when(agg -> agg.child().isAllSlots()) + // .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) + .whenNot(agg -> agg.child() + .child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set<AggregateFunction> funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count) && !f.isDistinct() + && (f.children().isEmpty() || f.child(0) instanceof Slot)); }) .thenApply(ctx -> { Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -118,23 +120,6 @@ public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) { List<Slot> leftOutput = join.left().getOutput(); List<Slot> rightOutput = join.right().getOutput(); - - List<AggregateFunction> leftFuncs = new ArrayList<>(); - List<AggregateFunction> rightFuncs = new ArrayList<>(); - for (AggregateFunction func : agg.getAggregateFunctions()) { - Slot slot = (Slot) func.child(0); - if (leftOutput.contains(slot)) { - leftFuncs.add(func); - } else if (rightOutput.contains(slot)) { - rightFuncs.add(func); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) { - return null; - } - Set<Slot> leftGroupBy = new HashSet<>(); Set<Slot> rightGroupBy = new HashSet<>(); for (Expression e : agg.getGroupByExpressions()) { @@ -144,18 +129,71 @@ public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { } else if (rightOutput.contains(slot)) { rightGroupBy.add(slot); } else { - return null; + if (projects.isEmpty()) { + // TODO: select ... from ... group by A , B, 1.2; 1.2 is constant + return null; + } else { + for (NamedExpression proj : projects) { + if (proj instanceof Alias && proj.toSlot().equals(slot)) { + Set<Slot> inputForAliasSet = proj.getInputSlots(); + for (Slot aliasInputSlot : inputForAliasSet) { + if (leftOutput.contains(aliasInputSlot)) { + leftGroupBy.add(aliasInputSlot); + } else if (rightOutput.contains(aliasInputSlot)) { + rightGroupBy.add(aliasInputSlot); + } else { + return null; + } + } + break; + } + } + } } } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); + + List<AggregateFunction> leftFuncs = new ArrayList<>(); + List<AggregateFunction> rightFuncs = new ArrayList<>(); + Count countStar = null; + Count rewrittenCountStar = null; + for (AggregateFunction func : agg.getAggregateFunctions()) { + if (func instanceof Count && ((Count) func).isCountStar()) { + countStar = (Count) func; + } else { + Slot slot = (Slot) func.child(0); + if (leftOutput.contains(slot)) { + leftFuncs.add(func); + } else if (rightOutput.contains(slot)) { + rightFuncs.add(func); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + } + // rewrite count(*) to count(A), where A is slot from left/right group by key + if (countStar != null) { + if (!leftGroupBy.isEmpty()) { + rewrittenCountStar = (Count) countStar.withChildren(leftGroupBy.iterator().next()); + leftFuncs.add(rewrittenCountStar); + } else if (!rightGroupBy.isEmpty()) { + rewrittenCountStar = (Count) countStar.withChildren(rightGroupBy.iterator().next()); + rightFuncs.add(rewrittenCountStar); } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); + return null; + } + } + for (Expression condition : join.getHashJoinConjuncts()) { + for (Slot joinConditionSlot : condition.getInputSlots()) { + if (leftOutput.contains(joinConditionSlot)) { + leftGroupBy.add(joinConditionSlot); + } else if (rightOutput.contains(joinConditionSlot)) { + rightGroupBy.add(joinConditionSlot); + } else { + // apply failed + return null; + } } - })); + } Plan left = join.left(); Plan right = join.right(); @@ -196,6 +234,10 @@ public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { for (NamedExpression ne : agg.getOutputExpressions()) { if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) { AggregateFunction func = (AggregateFunction) ((Alias) ne).child(); + if (func instanceof Count && ((Count) func).isCountStar()) { + // countStar is already rewritten as count(left_slot) or count(right_slot) + func = rewrittenCountStar; + } Slot slot = (Slot) func.child(0); if (leftSlotToOutput.containsKey(slot)) { Expression newFunc = replaceAggFunc(func, leftSlotToOutput.get(slot).toSlot()); @@ -210,9 +252,20 @@ public class PushDownAggThroughJoinOneSide implements RewriteRuleFactory { newOutputExprs.add(ne); } } - - // TODO: column prune project - return agg.withAggOutputChild(newOutputExprs, newJoin); + Plan newAggChild = newJoin; + if (agg.child() instanceof LogicalProject) { + LogicalProject project = (LogicalProject) agg.child(); + List<NamedExpression> newProjections = Lists.newArrayList(); + newProjections.addAll(project.getProjects()); + Set<NamedExpression> leftDifference = new HashSet<NamedExpression>(left.getOutput()); + leftDifference.removeAll(project.getProjects()); + newProjections.addAll(leftDifference); + Set<NamedExpression> rightDifference = new HashSet<NamedExpression>(right.getOutput()); + rightDifference.removeAll(project.getProjects()); + newProjections.addAll(rightDifference); + newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin); + } + return agg.withAggOutputChild(newOutputExprs, newAggChild); } private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java index 58ab7fbe9e9..cffe91045d0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java @@ -323,11 +323,11 @@ class PushDownMinMaxSumThroughJoinTest implements MemoPatternMatchSupported { .applyTopDown(new PushDownAggThroughJoinOneSide()) .printlnTree() .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), + logicalJoin( + logicalAggregate( logicalOlapScan() - ) + ), + logicalOlapScan() ) ); } @@ -346,11 +346,9 @@ class PushDownMinMaxSumThroughJoinTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushDownAggThroughJoinOneSide()) .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) + logicalJoin( + logicalAggregate(logicalOlapScan()), + logicalAggregate(logicalOlapScan()) ) ); } diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out index da69919becd..8267eb3e38f 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out @@ -1034,3 +1034,25 @@ Used: UnUsed: use_push_down_agg_through_join_one_side SyntaxError: +-- !shape -- +PhysicalResultSink +--PhysicalTopN[MERGE_SORT] +----PhysicalTopN[LOCAL_SORT] +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10))) +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click')) +------------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd] +------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19')) +--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd] + +Hint log: +Used: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE +UnUsed: +SyntaxError: + +-- !agg_pushed -- +2 是 2024-08-19 + diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy index 02e06710296..e551fa04c91 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy @@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") { qt_with_hint_groupby_pushdown_nested_queries """ explain shape plan select /*+ USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select * from count_t_one_side where score > 20) t1 join (select * from count_t_one_side where id < 100) t2 on t1.id = t2.id group by t1.name; """ + + sql """ + drop table if exists dw_user_b2c_tracking_info_tmp_ymd; + create table dw_user_b2c_tracking_info_tmp_ymd ( + guid int, + dt varchar, + first_visit_time varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', '2024-08-19'); + + drop table if exists dwd_tracking_sensor_init_tmp_ymd; + create table dwd_tracking_sensor_init_tmp_ymd ( + guid int, + dt varchar, + tracking_type varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 'click'), (1, '2024-08-19', 'click'); + """ + sql """ + set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE"; + set disable_join_reorder=true; + """ + + qt_shape """ + explain shape plan + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ + + qt_agg_pushed """ + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org