This is an automated email from the ASF dual-hosted git repository. englefly pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new bb288e0d294 [Enhencement](Nereids) add rule of agg(case when) to agg(filter) (#33598) bb288e0d294 is described below commit bb288e0d294827226cc11a00f7c22f0ec0308087 Author: LiBinfeng <46676950+libinfeng...@users.noreply.github.com> AuthorDate: Tue Apr 16 09:33:28 2024 +0800 [Enhencement](Nereids) add rule of agg(case when) to agg(filter) (#33598) --- .../doris/nereids/jobs/executor/Rewriter.java | 2 + .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../rules/rewrite/EliminateAggCaseWhen.java | 83 +++++++ .../trees/plans/logical/LogicalAggregate.java | 7 + .../eliminate_aggregate_casewhen.out | 249 +++++++++++++++++++++ .../eliminate_aggregate_casewhen.groovy | 108 +++++++++ 6 files changed, 450 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index e8223524367..2361c276372 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -51,6 +51,7 @@ import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite; import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; +import org.apache.doris.nereids.rules.rewrite.EliminateAggCaseWhen; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows; import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; @@ -204,6 +205,7 @@ public class Rewriter extends AbstractBatchJobExecutor { new EliminateLimit(), new EliminateFilter(), new EliminateAggregate(), + new EliminateAggCaseWhen(), new ReduceAggregateChildOutputRows(), new EliminateJoinCondition(), new EliminateAssertNumRows(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 3a43e1ad672..696463523f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -202,6 +202,7 @@ public enum RuleType { // Eliminate plan MERGE_AGGREGATE(RuleTypeClass.REWRITE), ELIMINATE_AGGREGATE(RuleTypeClass.REWRITE), + ELIMINATE_AGG_CASE_WHEN(RuleTypeClass.REWRITE), ELIMINATE_LIMIT(RuleTypeClass.REWRITE), ELIMINATE_LIMIT_ON_ONE_ROW_RELATION(RuleTypeClass.REWRITE), ELIMINATE_LIMIT_ON_EMPTY_RELATION(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggCaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggCaseWhen.java new file mode 100644 index 00000000000..33a9bbd9242 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAggCaseWhen.java @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Filter; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Change argument 'case when' or 'if' inside aggregate function , to aggregate function(filter) + * example: + * select sum(case when t1.c1 = 101 then 1 END) from t1; + * ==> + * select sum(1) from t1 where t1.c1 = 101; + * note : + * only If expression is needed to process cause CaseWhenToIf have already changed case when to if + * but in sql we can still see case when so case when is reserved to explain this rule + * we can only have one output aggregate function cause of filter would influence other projection + * we can only have one aggregate function cause of filter would influence other aggregate function + * we can only have case when/if function without else cause of then can only have one branch of choice + * we can only have one case in case when cause of then can only have one branch of choice + */ +public final class EliminateAggCaseWhen extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate().then(agg -> { + Set<AggregateFunction> aggFunctions = agg.getAggregateFunctions(); + // check whether we only have one aggregate function, and only one projection of aggregate function + if (aggFunctions.size() != 1 || agg.getOutputExpressions().size() != 1 + || !agg.getGroupByExpressions().isEmpty()) { + return null; + } + for (AggregateFunction aggFun : aggFunctions) { + // check whether we only have on case when/if in aggregate function + if (aggFun.getArguments().size() != 1) { + return null; + } + // only If expression is needed to process cause CaseWhenToIf have already changed case when to if + if (aggFun.getArgument(0) instanceof If) { + If anIf = (If) aggFun.getArgument(0); + if (!(anIf.getArgument(2) instanceof NullLiteral)) { + return null; + } + Expression operand = anIf.getArgument(0); + Filter filter = new LogicalFilter<>(ExpressionUtils.extractConjunctionToSet(operand), agg.child()); + Expression result = anIf.getArgument(1); + Map<Expression, Expression> constantExprsReplaceMap = new HashMap<>(aggFunctions.size()); + constantExprsReplaceMap.put(aggFun, ((AggregateFunction) aggFun).withChildren(result)); + return agg.withChildAndOutput((Plan) filter, + ExpressionUtils.replaceNamedExpressions( + agg.getOutputExpressions(), constantExprsReplaceMap)); + } + } + return null; + }).toRule(RuleType.ELIMINATE_AGG_CASE_WHEN); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 1e5e5a45abe..73ddde6cef5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -271,6 +271,13 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), child()); } + public LogicalAggregate<Plan> withChildAndOutput(CHILD_TYPE child, + List<NamedExpression> outputExpressionList) { + return new LogicalAggregate<>(groupByExpressions, outputExpressionList, normalized, ordinalIsResolved, + generated, hasPushed, sourceRepeat, Optional.empty(), + Optional.empty(), child); + } + @Override public List<NamedExpression> getOutputs() { return outputExpressions; diff --git a/regression-test/data/nereids_rules_p0/eliminate_aggregate_casewhen/eliminate_aggregate_casewhen.out b/regression-test/data/nereids_rules_p0/eliminate_aggregate_casewhen/eliminate_aggregate_casewhen.out new file mode 100644 index 00000000000..46bfb749845 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/eliminate_aggregate_casewhen/eliminate_aggregate_casewhen.out @@ -0,0 +1,249 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !basic_1 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t1.c1 > 100)) +----------PhysicalOlapScan[t1] + +-- !basic_2 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t1.c1 > 100)) +----------PhysicalOlapScan[t1] + +-- !basic_3 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t1] + +-- !basic_4 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t1.c1 > 100)) +----------PhysicalOlapScan[t1] + +-- !basic_5 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t1.c1 > 100)) +----------PhysicalOlapScan[t1] + +-- !basic_6 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t1] + +-- !basic_2_1 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t2.c2 > 100)) +----------PhysicalOlapScan[t2] + +-- !basic_2_2 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t2.c2 > 100)) +----------PhysicalOlapScan[t2] + +-- !basic_2_3 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t2] + +-- !basic_2_4 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t2.c2 > 100)) +----------PhysicalOlapScan[t2] + +-- !basic_2_5 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t2.c2 > 100)) +----------PhysicalOlapScan[t2] + +-- !basic_2_6 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t2] + +-- !basic_3_1 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t3.c3 > 100)) +----------PhysicalOlapScan[t3] + +-- !basic_3_2 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t3.c3 > 100)) +----------PhysicalOlapScan[t3] + +-- !basic_3_3 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t3] + +-- !basic_3_4 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t3.c3 > 100)) +----------PhysicalOlapScan[t3] + +-- !basic_3_5 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t3.c3 > 100)) +----------PhysicalOlapScan[t3] + +-- !basic_3_6 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t3] + +-- !basic_4_1 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t4.c4 > 100)) +----------PhysicalOlapScan[t4] + +-- !basic_4_2 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t4.c4 > 100)) +----------PhysicalOlapScan[t4] + +-- !basic_4_4 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t4] + +-- !basic_4_4 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t4.c4 > 100)) +----------PhysicalOlapScan[t4] + +-- !basic_4_5 -- +PhysicalResultSink +--hashAgg[GLOBAL] +----hashAgg[LOCAL] +------PhysicalProject +--------filter((t4.c4 > 100)) +----------PhysicalOlapScan[t4] + +-- !basic_4_6 -- +PhysicalResultSink +--hashAgg[LOCAL] +----PhysicalProject +------PhysicalOlapScan[t4] + +-- !basic_1 -- +10 + +-- !basic_2 -- +1 + +-- !basic_3 -- +101 1 + +-- !basic_4 -- +10 + +-- !basic_5 -- +1 + +-- !basic_6 -- +101 1 + +-- !basic_2_1 -- +\N + +-- !basic_2_2 -- +0 + +-- !basic_2_3 -- +\N 0 + +-- !basic_2_4 -- +\N + +-- !basic_2_5 -- +0 + +-- !basic_2_6 -- +\N 0 + +-- !basic_3_1 -- +\N + +-- !basic_3_2 -- +0 + +-- !basic_3_3 -- + +-- !basic_3_4 -- +\N + +-- !basic_3_5 -- +0 + +-- !basic_3_6 -- + +-- !basic_4_1 -- +10 + +-- !basic_4_2 -- +2 + +-- !basic_4_4 -- +102 1 +103 1 + +-- !basic_4_4 -- +10 + +-- !basic_4_5 -- +2 + +-- !basic_4_6 -- +102 1 +103 1 + diff --git a/regression-test/suites/nereids_rules_p0/eliminate_aggregate_casewhen/eliminate_aggregate_casewhen.groovy b/regression-test/suites/nereids_rules_p0/eliminate_aggregate_casewhen/eliminate_aggregate_casewhen.groovy new file mode 100644 index 00000000000..6623df4716e --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/eliminate_aggregate_casewhen/eliminate_aggregate_casewhen.groovy @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("eliminate_aggregate_casewhen") { + sql "SET enable_nereids_planner=true" + sql "set runtime_filter_mode=OFF" + sql "SET enable_fallback_to_original_planner=false" + sql "SET ignore_shape_nodes='PhysicalDistribute'" + sql 'DROP DATABASE IF EXISTS test_eliminate_aggregate_casewhen' + sql 'CREATE DATABASE IF NOT EXISTS test_eliminate_aggregate_casewhen' + sql 'use test_eliminate_aggregate_casewhen' + + // create tables + sql """drop table if exists t1;""" + sql """drop table if exists t2;""" + sql """drop table if exists t3;""" + sql """drop table if exists t4;""" + + sql """create table t1 (c1 int, c11 int) distributed by hash(c1) buckets 3 properties('replication_num' = '1');""" + sql """create table t2 (c2 int, c22 int) distributed by hash(c2) buckets 3 properties('replication_num' = '1');""" + sql """create table t3 (c3 int, c33 int) distributed by hash(c3) buckets 3 properties('replication_num' = '1');""" + sql """create table t4 (c4 int, c44 int) distributed by hash(c4) buckets 3 properties('replication_num' = '1');""" + + sql "insert into t1 values (101, 101)" + sql "insert into t2 values (null, null)" + sql "insert into t4 values (102, 102)" + sql "insert into t4 values (103, 103)" + + /* ******** with one row ******** */ + qt_basic_1 """explain shape plan select max(case when t1.c1 > 100 then 10 end) from t1;""" + qt_basic_2 """explain shape plan select count(case when t1.c1 > 100 then 10 end) from t1;""" + qt_basic_3 """explain shape plan select t1.c1, count(case when t1.c1 > 100 then 10 end) from t1 group by t1.c1;""" + qt_basic_4 """explain shape plan select max(case when t1.c1 > 100 then 10 end) from t1;""" + qt_basic_5 """explain shape plan select count(case when t1.c1 > 100 then 10 end) from t1;""" + qt_basic_6 """explain shape plan select t1.c1, count(case when t1.c1 > 100 then 10 end) from t1 group by t1.c1;""" + + /* ******** with one row "null" ******** */ + qt_basic_2_1 """explain shape plan select max(case when t2.c2 > 100 then 10 end) from t2;""" + qt_basic_2_2 """explain shape plan select count(case when t2.c2 > 100 then 10 end) from t2;""" + qt_basic_2_3 """explain shape plan select t2.c2, count(case when t2.c2 > 100 then 10 end) from t2 group by t2.c2;""" + qt_basic_2_4 """explain shape plan select max(case when t2.c2 > 100 then 10 end) from t2;""" + qt_basic_2_5 """explain shape plan select count(case when t2.c2 > 100 then 10 end) from t2;""" + qt_basic_2_6 """explain shape plan select t2.c2, count(case when t2.c2 > 100 then 10 end) from t2 group by t2.c2;""" + + /* ******** with empty table ******** */ + qt_basic_3_1 """explain shape plan select max(case when t3.c3 > 100 then 10 end) from t3;""" + qt_basic_3_2 """explain shape plan select count(case when t3.c3 > 100 then 10 end) from t3;""" + qt_basic_3_3 """explain shape plan select t3.c3, count(case when t3.c3 > 100 then 10 end) from t3 group by t3.c3;""" + qt_basic_3_4 """explain shape plan select max(case when t3.c3 > 100 then 10 end) from t3;""" + qt_basic_3_5 """explain shape plan select count(case when t3.c3 > 100 then 10 end) from t3;""" + qt_basic_3_6 """explain shape plan select t3.c3, count(case when t3.c3 > 100 then 10 end) from t3 group by t3.c3;""" + + /* ******** with different group table ******** */ + qt_basic_4_1 """explain shape plan select max(case when t4.c4 > 100 then 10 end) from t4;""" + qt_basic_4_2 """explain shape plan select count(case when t4.c4 > 100 then 10 end) from t4;""" + qt_basic_4_4 """explain shape plan select t4.c4, count(case when t4.c4 > 100 then 10 end) from t4 group by t4.c4;""" + qt_basic_4_4 """explain shape plan select max(case when t4.c4 > 100 then 10 end) from t4;""" + qt_basic_4_5 """explain shape plan select count(case when t4.c4 > 100 then 10 end) from t4;""" + qt_basic_4_6 """explain shape plan select t4.c4, count(case when t4.c4 > 100 then 10 end) from t4 group by t4.c4;""" + + /* ******** Output ******** */ + + /* ******** with one row ******** */ + order_qt_basic_1 """select max(case when t1.c1 > 100 then 10 end) from t1;""" + order_qt_basic_2 """select count(case when t1.c1 > 100 then 10 end) from t1;""" + order_qt_basic_3 """select t1.c1, count(case when t1.c1 > 100 then 10 end) from t1 group by t1.c1;""" + order_qt_basic_4 """select max(case when t1.c1 > 100 then 10 end) from t1;""" + order_qt_basic_5 """select count(case when t1.c1 > 100 then 10 end) from t1;""" + order_qt_basic_6 """select t1.c1, count(case when t1.c1 > 100 then 10 end) from t1 group by t1.c1;""" + + /* ******** with one row "null" ******** */ + order_qt_basic_2_1 """select max(case when t2.c2 > 100 then 10 end) from t2;""" + order_qt_basic_2_2 """select count(case when t2.c2 > 100 then 10 end) from t2;""" + order_qt_basic_2_3 """select t2.c2, count(case when t2.c2 > 100 then 10 end) from t2 group by t2.c2;""" + order_qt_basic_2_4 """select max(case when t2.c2 > 100 then 10 end) from t2;""" + order_qt_basic_2_5 """select count(case when t2.c2 > 100 then 10 end) from t2;""" + order_qt_basic_2_6 """select t2.c2, count(case when t2.c2 > 100 then 10 end) from t2 group by t2.c2;""" + + /* ******** with empty table ******** */ + order_qt_basic_3_1 """select max(case when t3.c3 > 100 then 10 end) from t3;""" + order_qt_basic_3_2 """select count(case when t3.c3 > 100 then 10 end) from t3;""" + order_qt_basic_3_3 """select t3.c3, count(case when t3.c3 > 100 then 10 end) from t3 group by t3.c3;""" + order_qt_basic_3_4 """select max(case when t3.c3 > 100 then 10 end) from t3;""" + order_qt_basic_3_5 """select count(case when t3.c3 > 100 then 10 end) from t3;""" + order_qt_basic_3_6 """select t3.c3, count(case when t3.c3 > 100 then 10 end) from t3 group by t3.c3;""" + + /* ******** with different group table ******** */ + order_qt_basic_4_1 """select max(case when t4.c4 > 100 then 10 end) from t4;""" + order_qt_basic_4_2 """select count(case when t4.c4 > 100 then 10 end) from t4;""" + order_qt_basic_4_4 """select t4.c4, count(case when t4.c4 > 100 then 10 end) from t4 group by t4.c4;""" + order_qt_basic_4_4 """select max(case when t4.c4 > 100 then 10 end) from t4;""" + order_qt_basic_4_5 """select count(case when t4.c4 > 100 then 10 end) from t4;""" + order_qt_basic_4_6 """select t4.c4, count(case when t4.c4 > 100 then 10 end) from t4 group by t4.c4;""" +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org