This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 49a3bab399e [fix](nereids) fix aggregate function roll up when expression arguments is not equals (#29256) 49a3bab399e is described below commit 49a3bab399ef743ad38ef9bf2ec5f71e56edea14 Author: seawinde <149132972+seawi...@users.noreply.github.com> AuthorDate: Wed Jan 3 18:58:18 2024 +0800 [fix](nereids) fix aggregate function roll up when expression arguments is not equals (#29256) when aggregate function roll up, we should check the qury and mv function argument is equal such as mv def and query sql as following, it should not rewrite success, because the bitmap_union_basic field augument is not equal to the `count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 3) then o_custkey else null end)` field in query mv def: > select l_shipdate, o_orderdate, l_partkey, l_suppkey, > sum(o_totalprice) as sum_total, > max(o_totalprice) as max_total, > min(o_totalprice) as min_total, > count(*) as count_all, > bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as bitmap_union_basic > from lineitem > left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate > group by > l_shipdate, > o_orderdate, > l_partkey, > l_suppkey; query sql: > select t1.l_partkey, t1.l_suppkey, o_orderdate, > sum(o_totalprice), > max(o_totalprice), > min(o_totalprice), > count(*), > count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 3) then o_custkey else null end) > from (select * from lineitem where l_shipdate = '2023-12-11') t1 > left join orders on t1.l_orderkey = orders.o_orderkey and t1.l_shipdate = o_orderdate > group by > o_orderdate, > l_partkey, > l_suppkey; --- .../mv/AbstractMaterializedViewAggregateRule.java | 103 ++++++++++++++++----- .../org/apache/doris/nereids/trees/TreeNode.java | 28 ++++++ .../doris/nereids/trees/expressions/Any.java | 10 ++ .../mv/agg_with_roll_up/aggregate_with_roll_up.out | 6 ++ .../agg_with_roll_up/aggregate_with_roll_up.groovy | 35 +++++++ 5 files changed, 158 insertions(+), 24 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java index 685f8a8c3a9..11faaa6a6d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; import org.apache.doris.nereids.trees.expressions.Any; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -35,11 +36,14 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; @@ -47,10 +51,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; @@ -60,15 +65,18 @@ import java.util.stream.Collectors; */ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule { - protected static final Map<Expression, Expression> - AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>(); + // we only support roll up function which has only one argument currently + protected static final Multimap<Expression, Expression> + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create(); protected final String currentClassName = this.getClass().getSimpleName(); private final Logger logger = LogManager.getLogger(this.getClass()); static { AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE), - new BitmapUnion(Any.INSTANCE)); + new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE)))); + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE), + new BitmapUnion(new ToBitmap(Any.INSTANCE))); } @Override @@ -249,17 +257,30 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return rewrittenAggregate; } - // only support sum roll up, support other agg functions later. - private Function rollup(AggregateFunction queryFunction, - Expression queryFunctionShuttled, + /** + * Roll up query aggregate function when query dimension num is less than mv dimension num, + * + * @param queryAggregateFunction query aggregate function to roll up. + * @param queryAggregateFunctionShuttled query aggregate function shuttled by lineage. + * @param mvExprToMvScanExprQueryBased mv def sql output expressions to mv result data output mapping. + * <p> + * Such as query is + * select max(a) + 1 from table group by b. + * mv is + * select max(a) from table group by a, b. + * the queryAggregateFunction is max(a), queryAggregateFunctionShuttled is max(a) + 1 + * mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) } + */ + private Function rollup(AggregateFunction queryAggregateFunction, + Expression queryAggregateFunctionShuttled, Map<Expression, Expression> mvExprToMvScanExprQueryBased) { - if (!(queryFunction instanceof CouldRollUp)) { + if (!(queryAggregateFunction instanceof CouldRollUp)) { return null; } Expression rollupParam = null; - if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) { + if (mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)) { // function can rewrite by view - rollupParam = mvExprToMvScanExprQueryBased.get(queryFunctionShuttled); + rollupParam = mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled); } else { // function can not rewrite by view, try to use complex roll up param // eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param)) @@ -267,7 +288,8 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate if (!(mvExprShuttled instanceof Function)) { continue; } - if (isAggregateFunctionEquivalent(queryFunction, (Function) mvExprShuttled)) { + if (isAggregateFunctionEquivalent(queryAggregateFunction, queryAggregateFunctionShuttled, + (Function) mvExprShuttled)) { rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled); } } @@ -276,7 +298,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return null; } // do roll up - return ((CouldRollUp) queryFunction).constructRollUp(rollupParam); + return ((CouldRollUp) queryAggregateFunction).constructRollUp(rollupParam); } private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction( @@ -347,22 +369,55 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return true; } - private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) { + /** + * Check the queryFunction is equivalent to view function when function roll up. + * Not only check the function name but also check the argument between query and view aggregate function. + * Such as query is + * select count(distinct a) + 1 from table group by b. + * mv is + * select bitmap_union(to_bitmap(a)) from table group by a, b. + * the queryAggregateFunction is count(distinct a), queryAggregateFunctionShuttled is count(distinct a) + 1 + * mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : MTMVScan(output#0) } + * This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv, + * and then check their arguments is equivalent. + */ + private boolean isAggregateFunctionEquivalent(Function queryFunction, Expression queryFunctionShuttled, + Function viewFunction) { if (queryFunction.equals(viewFunction)) { return true; } - // get query equivalent function - Expression equivalentFunction = null; - for (Map.Entry<Expression, Expression> entry : AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) { - if (entry.getKey().equals(queryFunction)) { - equivalentFunction = entry.getValue(); + // check the argument of rollup function is equivalent to view function or not + for (Map.Entry<Expression, Collection<Expression>> equivalentFunctionEntry : + AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) { + if (equivalentFunctionEntry.getKey().equals(queryFunction)) { + // check is have equivalent function or not + for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) { + if (!Any.equals(equivalentFunction, viewFunction)) { + continue; + } + // check param in query function is same as the view function + List<Expression> viewFunctionArguments = extractViewArguments(equivalentFunction, viewFunction); + if (queryFunctionShuttled.getArguments().size() != 1 || viewFunctionArguments.size() != 1) { + continue; + } + if (Objects.equals(queryFunctionShuttled.getArguments().get(0), viewFunctionArguments.get(0))) { + return true; + } + } } } - // check is have equivalent function or not - if (equivalentFunction == null) { - return false; - } - // current compare - return equivalentFunction.equals(viewFunction); + return false; + } + + /** + * Extract the view function arguments by equivalentFunction pattern + * Such as equivalentFunction def is bitmap_union(to_bitmap(Any.INSTANCE)), + * viewFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end)) + * after extracting, the return argument is: case when a = 5 then 1 else 2 end + */ + private List<Expression> extractViewArguments(Expression equivalentFunction, Function viewFunction) { + Set<Object> exprSetToRemove = equivalentFunction.collectToSet(expr -> !(expr instanceof Any)); + return viewFunction.collectFirst(expr -> + exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr))); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index 557ff43b51d..00ac71eaf24 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Deque; import java.util.List; import java.util.Set; @@ -150,6 +151,19 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> { return rewriteFunction.apply(rewrittenChildren); } + /** + * Foreach treeNode. Top-down traverse implicitly, stop traverse if satisfy test. + * @param func foreach function + */ + default void foreach(Predicate<TreeNode<NODE_TYPE>> func) { + boolean valid = func.test(this); + if (!valid) { + for (NODE_TYPE child : children()) { + child.foreach(func); + } + } + } + /** * Foreach treeNode. Top-down traverse implicitly. * @param func foreach function @@ -241,6 +255,20 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> { return (Set<T>) result.build(); } + /** + * Collect the nodes that satisfied the predicate firstly. + */ + default <T> List<T> collectFirst(Predicate<TreeNode<NODE_TYPE>> predicate) { + List<TreeNode<NODE_TYPE>> result = new ArrayList<>(); + foreach(node -> { + if (result.isEmpty() && predicate.test(node)) { + result.add(node); + } + return !result.isEmpty(); + }); + return (List<T>) ImmutableList.copyOf(result); + } + /** * iterate top down and test predicate if contains any instance of the classes * @param types classes array diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java index 43d284bf678..2e4bc745b2a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Objects; /** * This represents any expression, it means it equals any expression @@ -55,6 +56,15 @@ public class Any extends Expression implements LeafExpression { return true; } + /** + * Equals with direction + * Since the equals method in Any is always true, that means Any is equals to others, but not equal in reverse. + * The expression with Any should always be the first argument. + */ + public static boolean equals(Expression expressionWithAny, Expression target) { + return Objects.equals(expressionWithAny, target); + } + @Override public int hashCode() { return 0; diff --git a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out index fb223bc661b..334980ed00c 100644 --- a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out +++ b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out @@ -5,6 +5,12 @@ -- !query13_0_after -- 3 3 2023-12-11 43.20 43.20 43.20 1 0 +-- !query13_1_before -- +3 3 2023-12-11 43.20 43.20 43.20 1 0 + +-- !query13_1_after -- +3 3 2023-12-11 43.20 43.20 43.20 1 0 + -- !query14_0_before -- 2 3 2023-12-08 20.00 10.50 9.50 2 0 2 3 2023-12-12 \N \N \N 1 0 diff --git a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy index fd3c02408d9..e9d1ee76b37 100644 --- a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy @@ -247,6 +247,41 @@ suite("aggregate_with_roll_up") { sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_0""" + def mv13_1 = """ + select l_shipdate, o_orderdate, l_partkey, l_suppkey, + sum(o_totalprice) as sum_total, + max(o_totalprice) as max_total, + min(o_totalprice) as min_total, + count(*) as count_all, + bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as bitmap_union_basic + from lineitem + left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate + group by + l_shipdate, + o_orderdate, + l_partkey, + l_suppkey; + """ + def query13_1 = """ + select t1.l_partkey, t1.l_suppkey, o_orderdate, + sum(o_totalprice), + max(o_totalprice), + min(o_totalprice), + count(*), + count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 3) then o_custkey else null end) + from (select * from lineitem where l_shipdate = '2023-12-11') t1 + left join orders on t1.l_orderkey = orders.o_orderkey and t1.l_shipdate = o_orderdate + group by + o_orderdate, + l_partkey, + l_suppkey; + """ + order_qt_query13_1_before "${query13_1}" + check_not_match(mv13_1, query13_1, "mv13_1") + order_qt_query13_1_after "${query13_1}" + sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_1""" + + // filter inside + right + use roll up dimension def mv14_0 = "select l_shipdate, o_orderdate, l_partkey, l_suppkey, " + "sum(o_totalprice) as sum_total, " + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org