This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new 8699bb7 [Query] Optimize where clause by extracting the common predicate in the OR compound predicate. (#3278) 8699bb7 is described below commit 8699bb7bd47033c6dd5db3e9021ef3cc5d324f63 Author: yangzhg <780531...@qq.com> AuthorDate: Thu Apr 9 21:57:45 2020 +0800 [Query] Optimize where clause by extracting the common predicate in the OR compound predicate. (#3278) Queries like below cannot finish in a acceptable time, `store_sales` has 2800w rows, `customer_address` has 5w rows, for now Doris will create only one cross join node to execute this sql, the time of eval the where clause is about 200-300 ns, the total count of eval will be 2800w * 5w, this is extremely large, and this will cost 2800w * 5w * 250 ns = 4 billion seconds; ``` select avg(ss_quantity) ,avg(ss_ext_sales_price) ,avg(ss_ext_wholesale_cost) ,sum(ss_ext_wholesale_cost) from store_sales, customer_address where ((ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('CO', 'IL', 'MN') and ss_net_profit between 100 and 200 ) or (ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('OH', 'MT', 'NM') and ss_net_profit between 150 and 300 ) or (ss_addr_sk = ca_address_sk and ca_country = 'United States' and ca_state in ('TX', 'MO', 'MI') and ss_net_profit between 50 and 250 )) ``` but this sql can be rewrite to ``` select avg(ss_quantity) ,avg(ss_ext_sales_price) ,avg(ss_ext_wholesale_cost) ,sum(ss_ext_wholesale_cost) from store_sales, customer_address where ss_addr_sk = ca_address_sk and ca_country = 'United States' and (((ca_state in ('CO', 'IL', 'MN') and ss_net_profit between 100 and 200 ) or (ca_state in ('OH', 'MT', 'NM') and ss_net_profit between 150 and 300 ) or (ca_state in ('TX', 'MO', 'MI') and ss_net_profit between 50 and 250 )) ) ``` there for we can do a hash join first and then use ``` (((ca_state in ('CO', 'IL', 'MN') and ss_net_profit between 100 and 200 ) or (ca_state in ('OH', 'MT', 'NM') and ss_net_profit between 150 and 300 ) or (ca_state in ('TX', 'MO', 'MI') and ss_net_profit between 50 and 250 )) ) ``` to filter the value, in TPCDS 10g dataset, the rewritten sql only cost about 1 seconds. --- .../java/org/apache/doris/analysis/SelectStmt.java | 131 ++++++++++++++ .../org/apache/doris/analysis/SelectStmtTest.java | 198 ++++++++++++++++++++- 2 files changed, 324 insertions(+), 5 deletions(-) diff --git a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java index ebb5c50..2e84409 100644 --- a/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java +++ b/fe/src/main/java/org/apache/doris/analysis/SelectStmt.java @@ -54,6 +54,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -495,6 +496,10 @@ public class SelectStmt extends QueryStmt { } private void whereClauseRewrite() { + Expr deDuplicatedWhere = deduplicateOrs(whereClause); + if (deDuplicatedWhere != null) { + whereClause = deDuplicatedWhere; + } if (whereClause instanceof IntLiteral) { if (((IntLiteral) whereClause).getLongValue() == 0) { whereClause = new BoolLiteral(false); @@ -505,6 +510,132 @@ public class SelectStmt extends QueryStmt { } /** + * this function only process (a and b and c) or (d and e and f) like clause, + * this function will extract this to [[a, b, c], [d, e, f]] + */ + private List<List<Expr>> extractDuplicateOrs(CompoundPredicate expr) { + List<List<Expr>> orExprs = new ArrayList<>(); + for (Expr child : expr.getChildren()) { + if (child instanceof CompoundPredicate) { + CompoundPredicate childCp = (CompoundPredicate) child; + if (childCp.getOp() == CompoundPredicate.Operator.OR) { + orExprs.addAll(extractDuplicateOrs(childCp)); + continue; + } else if (childCp.getOp() == CompoundPredicate.Operator.AND) { + orExprs.add(flatAndExpr(child)); + continue; + } + } + orExprs.add(Arrays.asList(child)); + } + return orExprs; + } + + /** + * This function attempts to apply the inverse OR distributive law: + * ((A AND B) OR (A AND C)) => (A AND (B OR C)) + * That is, locate OR clauses in which every subclause contains an + * identical term, and pull out the duplicated terms. + */ + private Expr deduplicateOrs(Expr expr) { + if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.OR) { + Expr rewritedExpr = processDuplicateOrs(extractDuplicateOrs((CompoundPredicate) expr)); + if (rewritedExpr != null) { + return rewritedExpr; + } + } else { + for (int i = 0; i < expr.getChildren().size(); i++) { + Expr rewritedExpr = deduplicateOrs(expr.getChild(i)); + if (rewritedExpr != null) { + expr.setChild(i, rewritedExpr); + } + } + } + return expr; + } + + /** + * try to flat and , a and b and c => [a, b, c] + */ + private List<Expr> flatAndExpr(Expr expr) { + List<Expr> andExprs = new ArrayList<>(); + if (expr instanceof CompoundPredicate && ((CompoundPredicate) expr).getOp() == CompoundPredicate.Operator.AND) { + andExprs.addAll(flatAndExpr(expr.getChild(0))); + andExprs.addAll(flatAndExpr(expr.getChild(1))); + } else { + andExprs.add(expr); + } + return andExprs; + } + + /** + * the input is a list of list, the inner list is and connected exprs, the outer list is or connected + * for example clause (a and b and c) or (a and e and f) after extractDuplicateOrs will be [[a, b, c], [a, e, f]] + * this is the input of this function, first step is deduplicate [[a, b, c], [a, e, f]] => [[a], [b, c], [e, f]] + * then rebuild the expr to a and ((b and c) or (e and f)) + */ + private Expr processDuplicateOrs(List<List<Expr>> exprs) { + if (exprs.size() < 2) { + return null; + } + // 1. remove duplicated elements [[a,a], [a, b], [a,b]] => [[a], [a,b]] + Set<Set<Expr>> set = new LinkedHashSet<>(); + for (List<Expr> ex : exprs) { + Set<Expr> es = new LinkedHashSet<>(); + es.addAll(ex); + set.add(es); + } + List<List<Expr>> clearExprs = new ArrayList<>(); + for (Set<Expr> es : set) { + List<Expr> el = new ArrayList<>(); + el.addAll(es); + clearExprs.add(el); + } + if (clearExprs.size() == 1) { + return makeCompound(clearExprs.get(0), CompoundPredicate.Operator.AND); + } + // 2. find duplcate cross the clause + List<Expr> cloneExprs = new ArrayList<>(clearExprs.get(0)); + for (int i = 1; i < clearExprs.size(); ++i) { + cloneExprs.retainAll(clearExprs.get(i)); + } + List<Expr> temp = new ArrayList<>(); + if (CollectionUtils.isNotEmpty(cloneExprs)) { + temp.add(makeCompound(cloneExprs, CompoundPredicate.Operator.AND)); + } + + for (List<Expr> exprList : clearExprs) { + exprList.removeAll(cloneExprs); + temp.add(makeCompound(exprList, CompoundPredicate.Operator.AND)); + } + + // rebuild CompoundPredicate if found duplicate predicate will build (predcate) and (.. or ..) predicate in + // step 1: will build (.. or ..) + Expr result = CollectionUtils.isNotEmpty(cloneExprs) ? new CompoundPredicate(CompoundPredicate.Operator.AND, + temp.get(0), makeCompound(temp.subList(1, temp.size()), CompoundPredicate.Operator.OR)) + : makeCompound(temp, CompoundPredicate.Operator.OR); + LOG.debug("rewrite ors: " + result.toSql()); + return result; + } + + /** + * Rebuild CompoundPredicate, [a, e, f] AND => a and e and f + */ + private Expr makeCompound(List<Expr> exprs, CompoundPredicate.Operator op) { + if (CollectionUtils.isEmpty(exprs)) { + return null; + } + if (exprs.size() == 1) { + return exprs.get(0); + } + CompoundPredicate result = new CompoundPredicate(op, exprs.get(0), exprs.get(1)); + for (int i = 2; i < exprs.size(); ++i) { + result = new CompoundPredicate(op, result.clone(), exprs.get(i)); + } + return result; + } + + /** * Generates and registers !empty() predicates to filter out empty collections directly * in the parent scan of collection table refs. This is a performance optimization to * avoid the expensive processing of empty collections inside a subplan that would diff --git a/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java index 171776b..92db5b1 100644 --- a/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java +++ b/fe/src/test/java/org/apache/doris/analysis/SelectStmtTest.java @@ -20,6 +20,7 @@ package org.apache.doris.analysis; import org.apache.doris.common.AnalysisException; import org.apache.doris.qe.ConnectContext; import org.apache.doris.rewrite.ExprRewriter; +import org.apache.doris.thrift.TPrimitiveType; import org.apache.doris.utframe.DorisAssert; import org.apache.doris.utframe.UtFrameUtils; import org.junit.AfterClass; @@ -29,6 +30,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import java.io.IOException; import java.util.UUID; public class SelectStmtTest { @@ -89,10 +91,196 @@ public class SelectStmtTest { "FROM db1.tbl1;"; SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql1, ctx); stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); - Assert.assertEquals("SELECT CASE WHEN `$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3` END" + - " AS `kk4` FROM `default_cluster:db1`.`tbl1` (SELECT count(*) / 2.0 AS `count(*) / 2.0` FROM " + - "`default_cluster:db1`.`tbl1`) $a$1 (SELECT avg(`k4`) AS `avg(``k4``)` FROM" + - " `default_cluster:db1`.`tbl1`) $a$2 (SELECT sum(`k4`) AS `sum(``k4``)` " + - "FROM `default_cluster:db1`.`tbl1`) $a$3", stmt.toSql()); + Assert.assertTrue(stmt.toSql().contains("`$a$1`.`$c$1` > `k4` THEN `$a$2`.`$c$2` ELSE `$a$3`.`$c$3`")); + } + + @Test + public void testDeduplicateOrs() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + String sql = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2,\n" + + " db1.tbl1 t3,\n" + + " db1.tbl1 t4,\n" + + " db1.tbl1 t5,\n" + + " db1.tbl1 t6\n" + + "where\n" + + " t2.k1 = t1.k1\n" + + " and t1.k2 = t6.k2\n" + + " and t6.k4 = 2001\n" + + " and(\n" + + " (\n" + + " t1.k2 = t4.k2\n" + + " and t3.k3 = t1.k3\n" + + " and t3.k1 = 'D'\n" + + " and t4.k3 = '2 yr Degree'\n" + + " and t1.k4 between 100.00\n" + + " and 150.00\n" + + " and t4.k4 = 3\n" + + " )\n" + + " or (\n" + + " t1.k2 = t4.k2\n" + + " and t3.k3 = t1.k3\n" + + " and t3.k1 = 'S'\n" + + " and t4.k3 = 'Secondary'\n" + + " and t1.k4 between 50.00\n" + + " and 100.00\n" + + " and t4.k4 = 1\n" + + " )\n" + + " or (\n" + + " t1.k2 = t4.k2\n" + + " and t3.k3 = t1.k3\n" + + " and t3.k1 = 'W'\n" + + " and t4.k3 = 'Advanced Degree'\n" + + " and t1.k4 between 150.00\n" + + " and 200.00\n" + + " and t4.k4 = 1\n" + + " )\n" + + " )\n" + + " and(\n" + + " (\n" + + " t1.k1 = t5.k1\n" + + " and t5.k2 = 'United States'\n" + + " and t5.k3 in ('CO', 'IL', 'MN')\n" + + " and t1.k4 between 100\n" + + " and 200\n" + + " )\n" + + " or (\n" + + " t1.k1 = t5.k1\n" + + " and t5.k2 = 'United States'\n" + + " and t5.k3 in ('OH', 'MT', 'NM')\n" + + " and t1.k4 between 150\n" + + " and 300\n" + + " )\n" + + " or (\n" + + " t1.k1 = t5.k1\n" + + " and t5.k2 = 'United States'\n" + + " and t5.k3 in ('TX', 'MO', 'MI')\n" + + " and t1.k4 between 50 and 250\n" + + " )\n" + + " );"; + SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, ctx); + stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + String rewritedFragment1 = "(((`t1`.`k2` = `t4`.`k2`) AND (`t3`.`k3` = `t1`.`k3`)) AND ((((((`t3`.`k1` = 'D')" + + " AND (`t4`.`k3` = '2 yr Degree')) AND ((`t1`.`k4` >= 100.00) AND (`t1`.`k4` <= 150.00))) AND" + + " (`t4`.`k4` = 3)) OR ((((`t3`.`k1` = 'S') AND (`t4`.`k3` = 'Secondary')) AND ((`t1`.`k4` >= 50.00)" + + " AND (`t1`.`k4` <= 100.00))) AND (`t4`.`k4` = 1))) OR ((((`t3`.`k1` = 'W') AND " + + "(`t4`.`k3` = 'Advanced Degree')) AND ((`t1`.`k4` >= 150.00) AND (`t1`.`k4` <= 200.00)))" + + " AND (`t4`.`k4` = 1))))"; + String rewritedFragment2 = "(((`t1`.`k1` = `t5`.`k1`) AND (`t5`.`k2` = 'United States')) AND" + + " ((((`t5`.`k3` IN ('CO', 'IL', 'MN')) AND ((`t1`.`k4` >= 100) AND (`t1`.`k4` <= 200)))" + + " OR ((`t5`.`k3` IN ('OH', 'MT', 'NM')) AND ((`t1`.`k4` >= 150) AND (`t1`.`k4` <= 300))))" + + " OR ((`t5`.`k3` IN ('TX', 'MO', 'MI')) AND ((`t1`.`k4` >= 50) AND (`t1`.`k4` <= 250)))))"; + Assert.assertTrue(stmt.toSql().contains(rewritedFragment1)); + Assert.assertTrue(stmt.toSql().contains(rewritedFragment2)); + + String sql2 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + "(\n" + + " t1.k1 = t2.k3\n" + + " and t2.k2 = 'United States'\n" + + " and t2.k3 in ('CO', 'IL', 'MN')\n" + + " and t1.k4 between 100\n" + + " and 200\n" + + ")\n" + + "or (\n" + + " t1.k1 = t2.k1\n" + + " and t2.k2 = 'United States1'\n" + + " and t2.k3 in ('OH', 'MT', 'NM')\n" + + " and t1.k4 between 150\n" + + " and 300\n" + + ")\n" + + "or (\n" + + " t1.k1 = t2.k1\n" + + " and t2.k2 = 'United States'\n" + + " and t2.k3 in ('TX', 'MO', 'MI')\n" + + " and t1.k4 between 50 and 250\n" + + ")"; + SelectStmt stmt2 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql2, ctx); + stmt2.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + String fragment3 = "(((((`t1`.`k1` = `t2`.`k3`) AND (`t2`.`k2` = 'United States')) AND " + + "(`t2`.`k3` IN ('CO', 'IL', 'MN'))) AND ((`t1`.`k4` >= 100) AND (`t1`.`k4` <= 200))) OR" + + " ((((`t1`.`k1` = `t2`.`k1`) AND (`t2`.`k2` = 'United States1')) AND (`t2`.`k3` IN ('OH', 'MT', 'NM')))" + + " AND ((`t1`.`k4` >= 150) AND (`t1`.`k4` <= 300)))) OR ((((`t1`.`k1` = `t2`.`k1`) AND " + + "(`t2`.`k2` = 'United States')) AND (`t2`.`k3` IN ('TX', 'MO', 'MI'))) AND ((`t1`.`k4` >= 50)" + + " AND (`t1`.`k4` <= 250)))"; + Assert.assertTrue(stmt2.toSql().contains(fragment3)); + + String sql3 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + " t1.k1 = t2.k3 or t1.k1 = t2.k3 or t1.k1 = t2.k3"; + SelectStmt stmt3 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql3, ctx); + stmt3.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertFalse(stmt3.toSql().contains("((`t1`.`k1` = `t2`.`k3`) OR (`t1`.`k1` = `t2`.`k3`)) OR" + + " (`t1`.`k1` = `t2`.`k3`)")); + + String sql4 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + " t1.k1 = t2.k2 or t1.k1 = t2.k3 or t1.k1 = t2.k3"; + SelectStmt stmt4 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql4, ctx); + stmt4.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertTrue(stmt4.toSql().contains("(`t1`.`k1` = `t2`.`k2`) OR (`t1`.`k1` = `t2`.`k3`)")); + + String sql5 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + " t2.k1 is not null or t1.k1 is not null or t1.k1 is not null"; + SelectStmt stmt5 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql5, ctx); + stmt5.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertTrue(stmt5.toSql().contains("(`t2`.`k1` IS NOT NULL) OR (`t1`.`k1` IS NOT NULL)")); + Assert.assertEquals(2, stmt5.toSql().split(" OR ").length); + + String sql6 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + " t2.k1 is not null or t1.k1 is not null and t1.k1 is not null"; + SelectStmt stmt6 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql6, ctx); + stmt6.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertTrue(stmt6.toSql().contains("(`t2`.`k1` IS NOT NULL) OR (`t1`.`k1` IS NOT NULL)")); + Assert.assertEquals(2, stmt6.toSql().split(" OR ").length); + + String sql7 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + " t2.k1 is not null or t1.k1 is not null and t1.k2 is not null"; + SelectStmt stmt7 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql7, ctx); + stmt7.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertTrue(stmt7.toSql().contains("(`t2`.`k1` IS NOT NULL) OR ((`t1`.`k1` IS NOT NULL) " + + "AND (`t1`.`k2` IS NOT NULL))")); + + String sql8 = "select\n" + + " avg(t1.k4)\n" + + "from\n" + + " db1.tbl1 t1,\n" + + " db1.tbl1 t2\n" + + "where\n" + + " t2.k1 is not null and t1.k1 is not null and t1.k1 is not null"; + SelectStmt stmt8 = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql8, ctx); + stmt8.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); + Assert.assertTrue(stmt8.toSql().contains("((`t2`.`k1` IS NOT NULL) AND (`t1`.`k1` IS NOT NULL))" + + " AND (`t1`.`k1` IS NOT NULL)")); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org