This is an automated email from the ASF dual-hosted git repository. huajianlan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new daf2e27202 [feature] (Nereids) add rule to push down predicate through aggregate (#11162) daf2e27202 is described below commit daf2e2720250416c2908380794c51d8803b7a629 Author: minghong <engle...@gmail.com> AuthorDate: Wed Jul 27 15:16:15 2022 +0800 [feature] (Nereids) add rule to push down predicate through aggregate (#11162) add rule to push predicates down to aggregation node add PushDownPredicatesThroughAggregation.java add ut for PushDownPredicatesThroughAggregation For example: ``` Logical plan tree: any_node | filter (a>0 and b>0) | group by(a, c) | scan ``` transformed to: ``` project | upper filter (b>0) | group by(a, c) | bottom filter (a>0) | scan ``` Note: 'a>0' could be push down, because 'a' is in group by keys; but 'b>0' could not push down, because 'b' is not in group by keys. --- .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../logical/PushPredicateThroughAggregation.java | 109 +++++++++++ .../PushDownPredicateThroughAggregationTest.java | 206 +++++++++++++++++++++ 3 files changed, 316 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2485fad8a4..fa6df6d5f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -43,6 +43,7 @@ public enum RuleType { COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE), // predicate push down rules PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE), // column prune rules, COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE), COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java new file mode 100644 index 0000000000..bc4155bdec --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + + +/** + * Push the predicate in the LogicalFilter to the aggregate child. + * For example: + * Logical plan tree: + * any_node + * | + * filter (a>0 and b>0) + * | + * group by(a, c) + * | + * scan + * transformed to: + * project + * | + * upper filter (b>0) + * | + * group by(a, c) + * | + * bottom filter (a>0) + * | + * scan + * Note: + * 'a>0' could be push down, because 'a' is in group by keys; + * but 'b>0' could not push down, because 'b' is not in group by keys. + * + */ + +public class PushPredicateThroughAggregation extends OneRewriteRuleFactory { + + @Override + public Rule build() { + return logicalFilter(logicalAggregate()).then(filter -> { + LogicalAggregate<GroupPlan> aggregate = filter.child(); + Set<Slot> groupBySlots = new HashSet<>(); + for (Expression groupByExpression : aggregate.getGroupByExpressionList()) { + if (groupByExpression instanceof Slot) { + groupBySlots.add((Slot) groupByExpression); + } + } + List<Expression> pushDownPredicates = Lists.newArrayList(); + List<Expression> filterPredicates = Lists.newArrayList(); + ExpressionUtils.extractConjunct(filter.getPredicates()).forEach(conjunct -> { + Set<Slot> conjunctSlots = SlotExtractor.extractSlot(conjunct); + if (groupBySlots.containsAll(conjunctSlots)) { + pushDownPredicates.add(conjunct); + } else { + filterPredicates.add(conjunct); + } + }); + + return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates); + }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION); + } + + private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate, + List<Expression> pushDownPredicates, List<Expression> filterPredicates) { + if (pushDownPredicates.size() == 0) { + //nothing pushed down, just return origin plan + return filter; + } + LogicalFilter bottomFilter = new LogicalFilter(ExpressionUtils.and(pushDownPredicates), + (Plan) aggregate.child(0)); + if (filterPredicates.isEmpty()) { + //all predicates are pushed down, just exchange filter and aggregate + return aggregate.withChildren(Lists.newArrayList(bottomFilter)); + } else { + aggregate = aggregate.withChildren(Lists.newArrayList(bottomFilter)); + return new LogicalFilter<>(ExpressionUtils.and(filterPredicates), aggregate); + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java new file mode 100644 index 0000000000..9980f246ad --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + + + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.Table; +import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.memo.Memo; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.Literal; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PlanRewriter; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.junit.Assert; +import org.junit.Test; + +import java.util.List; + +public class PushDownPredicateThroughAggregationTest { + + /** + * origin plan: + * project + * | + * filter gender=1 + * | + * aggregation group by gender + * | + * scan(student) + * + * transformed plan: + * project + * | + * aggregation group by gender + * | + * filter gender=1 + * | + * scan(student) + */ + @Test + public void pushDownPredicateOneFilterTest() { + Table student = new Table(0L, "student", Table.TableType.OLAP, + ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "", ""), + new Column("age", Type.INT, true, AggregateType.NONE, "", ""))); + Plan scan = new LogicalOlapScan(student, ImmutableList.of("student")); + Slot gender = scan.getOutput().get(1); + Slot age = scan.getOutput().get(3); + + List<Expression> groupByKeys = Lists.newArrayList(age, gender); + List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age); + Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan); + Expression filterPredicate = new GreaterThan(gender, Literal.of(1)); + LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation); + Plan root = new LogicalProject<>( + Lists.newArrayList(gender), + filter + ); + + Memo memo = rewrite(root); + System.out.println(memo.copyOut().treeString()); + Group rootGroup = memo.getRoot(); + + GroupExpression groupExpression = rootGroup + .getLogicalExpression().child(0) + .getLogicalExpression(); + aggregation = groupExpression.getPlan(); + Assert.assertTrue(aggregation instanceof LogicalAggregate); + + groupExpression = groupExpression.child(0).getLogicalExpression(); + Plan bottomFilter = groupExpression.getPlan(); + Assert.assertTrue(bottomFilter instanceof LogicalFilter); + Expression greater = ((LogicalFilter<?>) bottomFilter).getPredicates(); + Assert.assertTrue(greater instanceof GreaterThan); + Assert.assertTrue(greater.child(0) instanceof Slot); + Assert.assertEquals("gender", ((Slot) greater.child(0)).getName()); + + groupExpression = groupExpression.child(0).getLogicalExpression(); + Plan scan2 = groupExpression.getPlan(); + Assert.assertTrue(scan2 instanceof LogicalOlapScan); + } + + /** + * origin plan: + * project + * | + * filter gender=1 and name="abc" and (gender+10)<100 + * | + * aggregation group by gender + * | + * scan(student) + * + * transformed plan: + * project + * | + * filter name="abc" + * | + * aggregation group by gender + * | + * filter gender=1 and and (gender+10)<100 + * | + * scan(student) + */ + @Test + public void pushDownPredicateTwoFilterTest() { + Table student = new Table(0L, "student", Table.TableType.OLAP, + ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "", ""), + new Column("age", Type.INT, true, AggregateType.NONE, "", ""))); + Plan scan = new LogicalOlapScan(student, ImmutableList.of("student")); + Slot gender = scan.getOutput().get(1); + Slot name = scan.getOutput().get(2); + Slot age = scan.getOutput().get(3); + + List<Expression> groupByKeys = Lists.newArrayList(age, gender); + List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age); + Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan); + Expression filterPredicate = ExpressionUtils.and( + new GreaterThan(gender, Literal.of(1)), + new LessThanEqual( + new Add( + gender, + Literal.of(10) + ), + Literal.of(100) + ), + new EqualTo(name, Literal.of("abc")) + ); + LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation); + Plan root = new LogicalProject<>( + Lists.newArrayList(gender), + filter + ); + + Memo memo = rewrite(root); + System.out.println(memo.copyOut().treeString()); + Group rootGroup = memo.getRoot(); + GroupExpression groupExpression = rootGroup.getLogicalExpression().child(0).getLogicalExpression(); + Plan upperFilter = groupExpression.getPlan(); + Assert.assertTrue(upperFilter instanceof LogicalFilter); + Expression upperPredicates = ((LogicalFilter<?>) upperFilter).getPredicates(); + Assert.assertTrue(upperPredicates instanceof EqualTo); + Assert.assertTrue(upperPredicates.child(0) instanceof Slot); + groupExpression = groupExpression.child(0).getLogicalExpression(); + aggregation = groupExpression.getPlan(); + Assert.assertTrue(aggregation instanceof LogicalAggregate); + groupExpression = groupExpression.child(0).getLogicalExpression(); + Plan bottomFilter = groupExpression.getPlan(); + Assert.assertTrue(bottomFilter instanceof LogicalFilter); + Expression bottomPredicates = ((LogicalFilter<?>) bottomFilter).getPredicates(); + Assert.assertTrue(bottomPredicates instanceof And); + Assert.assertEquals(2, bottomPredicates.children().size()); + Expression greater = bottomPredicates.child(0); + Assert.assertTrue(greater instanceof GreaterThan); + Assert.assertTrue(greater.child(0) instanceof Slot); + Assert.assertEquals("gender", ((Slot) greater.child(0)).getName()); + Expression less = bottomPredicates.child(1); + Assert.assertTrue(less instanceof LessThanEqual); + Assert.assertTrue(less.child(0) instanceof Add); + + groupExpression = groupExpression.child(0).getLogicalExpression(); + Plan scan2 = groupExpression.getPlan(); + Assert.assertTrue(scan2 instanceof LogicalOlapScan); + } + + private Memo rewrite(Plan plan) { + return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicateThroughAggregation()); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org