This is an automated email from the ASF dual-hosted git repository. englefly pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 16c692dfda7 [opt](nereids) extract single table expession from join run recursively (#49851) 16c692dfda7 is described below commit 16c692dfda73880534166b80139a10a4001b3723 Author: yujun <yu...@selectdb.com> AuthorDate: Tue Apr 29 10:15:44 2025 +0800 [opt](nereids) extract single table expession from join run recursively (#49851) ### What problem does this PR solve? ExtractSingleTableExpressionFromDisjunction will extract each table's expression from JOIN, but it forget to run recurively, this PR will fix this. For example: for sql `t1 join t2 where t1.a = 1 or (t1.b = 2 and (t1.c = 3 or t1.d = 4 and t2.x = 1))`, when try to extract table t1's expression, it cannot extract, because the rule don't run recurively, and when it met the deep child `t1.c = 3 or t1.d = 4 and t2.x = 1`, since it contains slot from t2, so it can not extract expression for t1. so this PR will make extract child recurively, after this PR, it will extract table t1's expression: ` t1.a = 1 or (t1.b = 2 and (t1.c = 3 or t1.d = 4)`, then later it will push down t1's expression through the join. --- ...xtractSingleTableExpressionFromDisjunction.java | 48 +++++++++++++++++++-- ...ctSingleTableExpressionFromDisjunctionTest.java | 50 ++++++++++++++++++++++ .../nereids/rules/rewrite/RewriteSqlTest.java | 49 +++++++++++++++++++++ 3 files changed, 143 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java index fe2d7072ef5..c3c64c9e076 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; @@ -28,10 +29,12 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * Paper: Quantifying TPC-H Choke Points and Their Optimizations @@ -106,18 +109,20 @@ public class ExtractSingleTableExpressionFromDisjunction implements RewriteRuleF // only check table in first disjunct. // In our example, qualifiers = { n1, n2 } // try to extract - for (Slot inputSlot : disjuncts.get(0).getInputSlots()) { - String qualifier = String.join(".", inputSlot.getQualifier()); + Set<String> qualifiers = disjuncts.get(0).getInputSlots().stream() + .map(slot -> String.join(".", slot.getQualifier())) + .collect(Collectors.toCollection(Sets::newLinkedHashSet)); + for (String qualifier : qualifiers) { List<Expression> extractForAll = Lists.newArrayList(); boolean success = true; - for (Expression expr : ExpressionUtils.extractDisjunction(conjunct)) { + for (Expression expr : disjuncts) { Optional<Expression> extracted = extractSingleTableExpression(expr, qualifier); if (!extracted.isPresent()) { // extract failed success = false; break; } else { - extractForAll.add(extracted.get()); + extractForAll.addAll(ExpressionUtils.extractDisjunction(extracted.get())); } } if (success) { @@ -132,11 +137,46 @@ public class ExtractSingleTableExpressionFromDisjunction implements RewriteRuleF // example: expr=(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY'), qualifier="n1." // output: n1.n_name = 'FRANCE' private Optional<Expression> extractSingleTableExpression(Expression expr, String qualifier) { + // suppose the qualifier is table T, then the process steps are as follow: + // 1. split the expression into conjunctions: c1 and c2 and c3 and ... + // 2. for each conjunction ci, suppose its extract is Ei: + // a) if ci's all slots come from T, then the whole ci is extracted, then Ei = ci; + // b) if ci is an OR expression, then split ci into disjunctions: ci => d1 or d2 or d3 or ..., + // for each disjunction, extract it recuirsely, suppose after extract dj, we get ej, + // if all the dj can extracted ej, then extract ci succ, which is Ei = e1 or e2 or e3 or ..., + // if any dj extract failed, then extract ci fail + // 3. collect all the succ extracted Ei, and the result for table T is `E1 and E2 and E3 and ...` + // + // for example: + // suppose expr = (t1.a = 1 or (t2.b = 2 and t1.c = 3)) and (t1.d = 4 or t2.e = 5), qualifier = t1, then + // c1 = (t1.a = 1 or (t2.b = 2 and t1.c = 3)), + // because the whole c1 contains slot t2.b not belong to t1, so cannot extract the whole c1, + // but c1 is an OR expression, so split c1 into disjunctions: + // d1 => t1.a = 1, d2 => (t2.b = 2 and t1.c = 3) + // then after extract on d1, we get e1 = t1.a = 1, extract on d2, we get t1.c = 3, + // so we can extract E1 for c1: t1.a = 1 or t1.c = 3 List<Expression> output = Lists.newArrayList(); List<Expression> conjuncts = ExpressionUtils.extractConjunction(expr); for (Expression conjunct : conjuncts) { if (isSingleTableExpression(conjunct, qualifier)) { output.add(conjunct); + } else if (conjunct instanceof Or) { + List<Expression> disjuncts = ExpressionUtils.extractDisjunction(conjunct); + List<Expression> extracted = Lists.newArrayListWithExpectedSize(disjuncts.size()); + boolean success = true; + for (Expression disjunct : disjuncts) { + Optional<Expression> extractedDisjunct = extractSingleTableExpression(disjunct, qualifier); + if (extractedDisjunct.isPresent()) { + extracted.addAll(ExpressionUtils.extractDisjunction(extractedDisjunct.get())); + } else { + // extract failed + success = false; + break; + } + } + if (success) { + output.add(ExpressionUtils.or(extracted)); + } } } if (output.isEmpty()) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java index 39706d39f2c..27901e2db9f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java @@ -20,6 +20,8 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; @@ -42,6 +44,7 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import java.util.Arrays; import java.util.List; import java.util.Set; @@ -215,4 +218,51 @@ public class ExtractSingleTableExpressionFromDisjunctionTest implements MemoPatt ); return conjuncts.size() == 2 && conjuncts.contains(or); } + + @Test + public void testExtractRecursive() { + Expression expr = new Or( + new And( + new GreaterThan(courseCid, new IntegerLiteral(1)), + new Or( + new And(new LessThan(courseCid, new IntegerLiteral(10)), + new EqualTo(studentAge, new IntegerLiteral(6))), + new And(new LessThan(courseCid, new IntegerLiteral(20)), + new EqualTo(studentAge, new IntegerLiteral(7))) + ) + ), + new And( + new EqualTo(studentGender, new IntegerLiteral(1)), + new EqualTo(courseName, new StringLiteral("abc")) + ) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, student, course, null); + LogicalFilter root = new LogicalFilter<>(ImmutableSet.of(expr), join); + PlanChecker.from(MemoTestUtils.createConnectContext(), root) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalFilter() + .when(filter -> verifyTestExtractRecursive(filter.getConjuncts())) + ); + Assertions.assertNotNull(studentGender); + } + + private boolean verifyTestExtractRecursive(Set<Expression> conjuncts) { + Expression or1 = new Or( + new And(new GreaterThan(courseCid, new IntegerLiteral(1)), + new Or( + new LessThan(courseCid, new IntegerLiteral(10)), + new LessThan(courseCid, new IntegerLiteral(20)) + ) + ), + new EqualTo(courseName, new StringLiteral("abc")) + ); + Expression or2 = new Or(Arrays.asList( + new EqualTo(studentAge, new IntegerLiteral(6)), + new EqualTo(studentAge, new IntegerLiteral(7)), + new EqualTo(studentGender, new IntegerLiteral(1)) + )); + + return conjuncts.size() == 3 && conjuncts.contains(or1) && conjuncts.contains(or2); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteSqlTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteSqlTest.java new file mode 100644 index 00000000000..e4a573a8622 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/RewriteSqlTest.java @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.sqltest.SqlTestBase; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Test; + +class RewriteSqlTest extends SqlTestBase { + + @Test + void testExtractSingleTableExpressionFromDisjunction() throws Exception { + createTables( + "CREATE TABLE IF NOT EXISTS test_extract_single_1 (\n" + + " a int,\n" + + " b int,\n" + + " c int,\n" + + " d int\n" + + ")\n" + + "DUPLICATE KEY(a)\n" + + "DISTRIBUTED BY HASH(a) BUCKETS 10\n" + + "PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + ")\n" + ); + String sql = "select * from test_extract_single_1 t1, T2 where t1.a = 1 or t1.b = 2 and (t1.c = 3 or t1.d = 4 and T2.id = 5)"; + PlanChecker.from(connectContext) + .analyze(sql) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matches(logicalFilter().when(f -> f.getConjuncts().stream().anyMatch(e -> e.toSql().equals("OR[(a = 1),AND[(b = 2),OR[(c = 3),(d = 4)]]]")))); + } + +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org