This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 49a3bab399e [fix](nereids) fix aggregate function roll up when 
expression arguments is not equals (#29256)
49a3bab399e is described below

commit 49a3bab399ef743ad38ef9bf2ec5f71e56edea14
Author: seawinde <149132972+seawi...@users.noreply.github.com>
AuthorDate: Wed Jan 3 18:58:18 2024 +0800

    [fix](nereids) fix aggregate function roll up when expression arguments is 
not equals (#29256)
    
    when aggregate function roll up, we should check the qury and mv function 
argument is equal
    such as mv def and query sql as following, it should not rewrite success, 
because the  bitmap_union_basic field augument is
    not equal to the `count(distinct case when o_shippriority > 10 and 
o_orderkey IN (1, 3) then o_custkey else null end)`  field in query
    
    mv def:
    >      select l_shipdate, o_orderdate, l_partkey, l_suppkey,
    >            sum(o_totalprice) as sum_total,
    >            max(o_totalprice) as max_total,
    >            min(o_totalprice) as min_total,
    >            count(*) as count_all,
    >            bitmap_union(to_bitmap(case when o_shippriority > 1 and 
o_orderkey IN (1, 3) then o_custkey else null end)) as bitmap_union_basic
    >           from lineitem
    >           left join orders on lineitem.l_orderkey = orders.o_orderkey and 
l_shipdate = o_orderdate
    >            group by
    >         l_shipdate,
    >         o_orderdate,
    >          l_partkey,
    >         l_suppkey;
    
    query sql:
    
    >             select t1.l_partkey, t1.l_suppkey, o_orderdate,
    >           sum(o_totalprice),
    >            max(o_totalprice),
    >           min(o_totalprice),
    >           count(*),
    >            count(distinct case when o_shippriority > 10 and o_orderkey IN 
(1, 3) then o_custkey else null end)
    >            from (select * from lineitem where l_shipdate = '2023-12-11') 
t1
    >            left join orders on t1.l_orderkey = orders.o_orderkey and 
t1.l_shipdate = o_orderdate
    >            group by
    >            o_orderdate,
    >            l_partkey,
    >            l_suppkey;
---
 .../mv/AbstractMaterializedViewAggregateRule.java  | 103 ++++++++++++++++-----
 .../org/apache/doris/nereids/trees/TreeNode.java   |  28 ++++++
 .../doris/nereids/trees/expressions/Any.java       |  10 ++
 .../mv/agg_with_roll_up/aggregate_with_roll_up.out |   6 ++
 .../agg_with_roll_up/aggregate_with_roll_up.groovy |  35 +++++++
 5 files changed, 158 insertions(+), 24 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
index 685f8a8c3a9..11faaa6a6d3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java
@@ -26,6 +26,7 @@ import 
org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext
 import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
 import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
 import org.apache.doris.nereids.trees.expressions.Any;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -35,11 +36,14 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
 import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
 import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
+import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.HashMultimap;
 import com.google.common.collect.Multimap;
 import com.google.common.collect.Sets;
@@ -47,10 +51,11 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 
 import java.util.ArrayList;
-import java.util.HashMap;
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -60,15 +65,18 @@ import java.util.stream.Collectors;
  */
 public abstract class AbstractMaterializedViewAggregateRule extends 
AbstractMaterializedViewRule {
 
-    protected static final Map<Expression, Expression>
-            AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>();
+    // we only support roll up function which has only one argument currently
+    protected static final Multimap<Expression, Expression>
+            AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = 
ArrayListMultimap.create();
     protected final String currentClassName = this.getClass().getSimpleName();
 
     private final Logger logger = LogManager.getLogger(this.getClass());
 
     static {
         AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, 
Any.INSTANCE),
-                new BitmapUnion(Any.INSTANCE));
+                new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, 
BigIntType.INSTANCE))));
+        AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, 
Any.INSTANCE),
+                new BitmapUnion(new ToBitmap(Any.INSTANCE)));
     }
 
     @Override
@@ -249,17 +257,30 @@ public abstract class 
AbstractMaterializedViewAggregateRule extends AbstractMate
         return rewrittenAggregate;
     }
 
-    // only support sum roll up, support other agg functions later.
-    private Function rollup(AggregateFunction queryFunction,
-            Expression queryFunctionShuttled,
+    /**
+     * Roll up query aggregate function when query dimension num is less than 
mv dimension num,
+     *
+     * @param queryAggregateFunction query aggregate function to roll up.
+     * @param queryAggregateFunctionShuttled query aggregate function shuttled 
by lineage.
+     * @param mvExprToMvScanExprQueryBased mv def sql output expressions to mv 
result data output mapping.
+     *         <p>
+     *         Such as query is
+     *         select max(a) + 1 from table group by b.
+     *         mv is
+     *         select max(a) from table group by a, b.
+     *         the queryAggregateFunction is max(a), 
queryAggregateFunctionShuttled is max(a) + 1
+     *         mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) }
+     */
+    private Function rollup(AggregateFunction queryAggregateFunction,
+            Expression queryAggregateFunctionShuttled,
             Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
-        if (!(queryFunction instanceof CouldRollUp)) {
+        if (!(queryAggregateFunction instanceof CouldRollUp)) {
             return null;
         }
         Expression rollupParam = null;
-        if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) {
+        if 
(mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)) {
             // function can rewrite by view
-            rollupParam = 
mvExprToMvScanExprQueryBased.get(queryFunctionShuttled);
+            rollupParam = 
mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled);
         } else {
             // function can not rewrite by view, try to use complex roll up 
param
             // eg: query is count(distinct param), mv sql is 
bitmap_union(to_bitmap(param))
@@ -267,7 +288,8 @@ public abstract class AbstractMaterializedViewAggregateRule 
extends AbstractMate
                 if (!(mvExprShuttled instanceof Function)) {
                     continue;
                 }
-                if (isAggregateFunctionEquivalent(queryFunction, (Function) 
mvExprShuttled)) {
+                if (isAggregateFunctionEquivalent(queryAggregateFunction, 
queryAggregateFunctionShuttled,
+                        (Function) mvExprShuttled)) {
                     rollupParam = 
mvExprToMvScanExprQueryBased.get(mvExprShuttled);
                 }
             }
@@ -276,7 +298,7 @@ public abstract class AbstractMaterializedViewAggregateRule 
extends AbstractMate
             return null;
         }
         // do roll up
-        return ((CouldRollUp) queryFunction).constructRollUp(rollupParam);
+        return ((CouldRollUp) 
queryAggregateFunction).constructRollUp(rollupParam);
     }
 
     private Pair<Set<? extends Expression>, Set<? extends Expression>> 
topPlanSplitToGroupAndFunction(
@@ -347,22 +369,55 @@ public abstract class 
AbstractMaterializedViewAggregateRule extends AbstractMate
         return true;
     }
 
-    private boolean isAggregateFunctionEquivalent(Function queryFunction, 
Function viewFunction) {
+    /**
+     * Check the queryFunction is equivalent to view function when function 
roll up.
+     * Not only check the function name but also check the argument between 
query and view aggregate function.
+     * Such as query is
+     * select count(distinct a) + 1 from table group by b.
+     * mv is
+     * select bitmap_union(to_bitmap(a)) from table group by a, b.
+     * the queryAggregateFunction is count(distinct a), 
queryAggregateFunctionShuttled is count(distinct a) + 1
+     * mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : 
MTMVScan(output#0) }
+     * This will check the count(distinct a) in query is equivalent to  
bitmap_union(to_bitmap(a)) in mv,
+     * and then check their arguments is equivalent.
+     */
+    private boolean isAggregateFunctionEquivalent(Function queryFunction, 
Expression queryFunctionShuttled,
+            Function viewFunction) {
         if (queryFunction.equals(viewFunction)) {
             return true;
         }
-        // get query equivalent function
-        Expression equivalentFunction = null;
-        for (Map.Entry<Expression, Expression> entry : 
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) {
-            if (entry.getKey().equals(queryFunction)) {
-                equivalentFunction = entry.getValue();
+        // check the argument of rollup function is equivalent to view 
function or not
+        for (Map.Entry<Expression, Collection<Expression>> 
equivalentFunctionEntry :
+                AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
+            if (equivalentFunctionEntry.getKey().equals(queryFunction)) {
+                // check is have equivalent function or not
+                for (Expression equivalentFunction : 
equivalentFunctionEntry.getValue()) {
+                    if (!Any.equals(equivalentFunction, viewFunction)) {
+                        continue;
+                    }
+                    // check param in query function is same as the view 
function
+                    List<Expression> viewFunctionArguments = 
extractViewArguments(equivalentFunction, viewFunction);
+                    if (queryFunctionShuttled.getArguments().size() != 1 || 
viewFunctionArguments.size() != 1) {
+                        continue;
+                    }
+                    if 
(Objects.equals(queryFunctionShuttled.getArguments().get(0), 
viewFunctionArguments.get(0))) {
+                        return true;
+                    }
+                }
             }
         }
-        // check is have equivalent function or not
-        if (equivalentFunction == null) {
-            return false;
-        }
-        // current compare
-        return equivalentFunction.equals(viewFunction);
+        return false;
+    }
+
+    /**
+     * Extract the view function arguments by equivalentFunction pattern
+     * Such as equivalentFunction def is bitmap_union(to_bitmap(Any.INSTANCE)),
+     * viewFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 
end))
+     * after extracting, the return argument is: case when a = 5 then 1 else 2 
end
+     */
+    private List<Expression> extractViewArguments(Expression 
equivalentFunction, Function viewFunction) {
+        Set<Object> exprSetToRemove = equivalentFunction.collectToSet(expr -> 
!(expr instanceof Any));
+        return viewFunction.collectFirst(expr ->
+                exprSetToRemove.stream().noneMatch(exprToRemove -> 
exprToRemove.equals(expr)));
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index 557ff43b51d..00ac71eaf24 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.ImmutableSet;
 
 import java.util.ArrayDeque;
+import java.util.ArrayList;
 import java.util.Deque;
 import java.util.List;
 import java.util.Set;
@@ -150,6 +151,19 @@ public interface TreeNode<NODE_TYPE extends 
TreeNode<NODE_TYPE>> {
         return rewriteFunction.apply(rewrittenChildren);
     }
 
+    /**
+     * Foreach treeNode. Top-down traverse implicitly, stop traverse if 
satisfy test.
+     * @param func foreach function
+     */
+    default void foreach(Predicate<TreeNode<NODE_TYPE>> func) {
+        boolean valid = func.test(this);
+        if (!valid) {
+            for (NODE_TYPE child : children()) {
+                child.foreach(func);
+            }
+        }
+    }
+
     /**
      * Foreach treeNode. Top-down traverse implicitly.
      * @param func foreach function
@@ -241,6 +255,20 @@ public interface TreeNode<NODE_TYPE extends 
TreeNode<NODE_TYPE>> {
         return (Set<T>) result.build();
     }
 
+    /**
+     * Collect the nodes that satisfied the predicate firstly.
+     */
+    default <T> List<T> collectFirst(Predicate<TreeNode<NODE_TYPE>> predicate) 
{
+        List<TreeNode<NODE_TYPE>> result = new ArrayList<>();
+        foreach(node -> {
+            if (result.isEmpty() && predicate.test(node)) {
+                result.add(node);
+            }
+            return !result.isEmpty();
+        });
+        return (List<T>) ImmutableList.copyOf(result);
+    }
+
     /**
      * iterate top down and test predicate if contains any instance of the 
classes
      * @param types classes array
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
index 43d284bf678..2e4bc745b2a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java
@@ -24,6 +24,7 @@ import 
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import com.google.common.collect.ImmutableList;
 
 import java.util.List;
+import java.util.Objects;
 
 /**
  * This represents any expression, it means it equals any expression
@@ -55,6 +56,15 @@ public class Any extends Expression implements 
LeafExpression {
         return true;
     }
 
+    /**
+     * Equals with direction
+     * Since the equals method in Any is always true, that means Any is equals 
to others, but not equal in reverse.
+     * The expression with Any should always be the first argument.
+     */
+    public static boolean equals(Expression expressionWithAny, Expression 
target) {
+        return Objects.equals(expressionWithAny, target);
+    }
+
     @Override
     public int hashCode() {
         return 0;
diff --git 
a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
 
b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
index fb223bc661b..334980ed00c 100644
--- 
a/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
+++ 
b/regression-test/data/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.out
@@ -5,6 +5,12 @@
 -- !query13_0_after --
 3      3       2023-12-11      43.20   43.20   43.20   1       0
 
+-- !query13_1_before --
+3      3       2023-12-11      43.20   43.20   43.20   1       0
+
+-- !query13_1_after --
+3      3       2023-12-11      43.20   43.20   43.20   1       0
+
 -- !query14_0_before --
 2      3       2023-12-08      20.00   10.50   9.50    2       0
 2      3       2023-12-12      \N      \N      \N      1       0
diff --git 
a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
 
b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
index fd3c02408d9..e9d1ee76b37 100644
--- 
a/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/mv/agg_with_roll_up/aggregate_with_roll_up.groovy
@@ -247,6 +247,41 @@ suite("aggregate_with_roll_up") {
     sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_0"""
 
 
+    def mv13_1 = """
+            select l_shipdate, o_orderdate, l_partkey, l_suppkey, 
+            sum(o_totalprice) as sum_total, 
+            max(o_totalprice) as max_total, 
+            min(o_totalprice) as min_total, 
+            count(*) as count_all, 
+            bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey 
IN (1, 3) then o_custkey else null end)) as bitmap_union_basic 
+            from lineitem 
+            left join orders on lineitem.l_orderkey = orders.o_orderkey and 
l_shipdate = o_orderdate 
+            group by 
+            l_shipdate, 
+            o_orderdate, 
+            l_partkey, 
+            l_suppkey;
+    """
+    def query13_1 = """
+            select t1.l_partkey, t1.l_suppkey, o_orderdate,
+            sum(o_totalprice),
+            max(o_totalprice),
+            min(o_totalprice),
+            count(*),
+            count(distinct case when o_shippriority > 10 and o_orderkey IN (1, 
3) then o_custkey else null end)
+            from (select * from lineitem where l_shipdate = '2023-12-11') t1
+            left join orders on t1.l_orderkey = orders.o_orderkey and 
t1.l_shipdate = o_orderdate
+            group by
+            o_orderdate, 
+            l_partkey,
+            l_suppkey;
+    """
+    order_qt_query13_1_before "${query13_1}"
+    check_not_match(mv13_1, query13_1, "mv13_1")
+    order_qt_query13_1_after "${query13_1}"
+    sql """ DROP MATERIALIZED VIEW IF EXISTS mv13_1"""
+
+
     // filter inside + right + use roll up dimension
     def mv14_0 = "select l_shipdate, o_orderdate, l_partkey, l_suppkey, " +
             "sum(o_totalprice) as sum_total, " +


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to