This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new f29beba4367 [fix] (inverted index) fix the error result in the query 
when using count on index (#41375) (#41687)
f29beba4367 is described below

commit f29beba43673c8ad3b21aa5d516996e419f03b4e
Author: Sun Chenyang <csun5...@gmail.com>
AuthorDate: Tue Oct 15 18:09:08 2024 +0800

    [fix] (inverted index) fix the error result in the query when using count 
on index (#41375) (#41687)
---
 .../rules/implementation/AggregateStrategies.java  | 112 ++++++++++++++-------
 .../inverted_index_p0/test_count_on_index_2.out    |   9 ++
 .../inverted_index_p0/test_count_on_index_2.groovy |  29 ++++++
 3 files changed, 115 insertions(+), 35 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index 7bbbc7841e8..1482ba4d013 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
@@ -108,47 +109,72 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 logicalAggregate(
                     logicalFilter(
                         
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
-                    ).when(filter -> !filter.getConjuncts().isEmpty()))
-                    .when(agg -> enablePushDownCountOnIndex())
-                    .when(agg -> agg.getGroupByExpressions().isEmpty())
-                    .when(agg -> {
-                        Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
-                        return !funcs.isEmpty() && funcs.stream()
-                                .allMatch(f -> f instanceof Count && 
!f.isDistinct() && (((Count) f).isStar()
-                                || f.children().isEmpty()
-                                || (f.children().size() == 1 && f.child(0) 
instanceof Literal)
-                                || f.child(0) instanceof Slot));
-                    })
-                    .thenApply(ctx -> {
-                        LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = 
ctx.root;
-                        LogicalFilter<LogicalOlapScan> filter = agg.child();
-                        LogicalOlapScan olapScan = filter.child();
-                        return pushdownCountOnIndex(agg, null, filter, 
olapScan, ctx.cascadesContext);
-                    })
+                    )
+                )
+                .when(agg -> enablePushDownCountOnIndex())
+                .when(agg -> agg.getGroupByExpressions().isEmpty())
+                .when(agg -> {
+                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
+                    if (funcs.isEmpty() || !funcs.stream()
+                            .allMatch(f -> f instanceof Count && 
!f.isDistinct() && (((Count) f).isStar()
+                            || f.children().isEmpty()
+                            || (f.children().size() == 1 && f.child(0) 
instanceof Literal)
+                            || f.child(0) instanceof Slot))) {
+                        return false;
+                    }
+                    Set<Expression> conjuncts = agg.child().getConjuncts();
+                    if (conjuncts.isEmpty()) {
+                        return false;
+                    }
+
+                    Set<Slot> aggSlots = funcs.stream()
+                            .flatMap(f -> f.getInputSlots().stream())
+                            .collect(Collectors.toSet());
+                    return conjuncts.stream().allMatch(expr -> 
checkSlotInOrExpression(expr, aggSlots));
+                })
+                .thenApply(ctx -> {
+                    LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = 
ctx.root;
+                    LogicalFilter<LogicalOlapScan> filter = agg.child();
+                    LogicalOlapScan olapScan = filter.child();
+                    return pushdownCountOnIndex(agg, null, filter, olapScan, 
ctx.cascadesContext);
+                })
             ),
             RuleType.COUNT_ON_INDEX.build(
                 logicalAggregate(
                     logicalProject(
                         logicalFilter(
                             
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
-                        ).when(filter -> !filter.getConjuncts().isEmpty())))
-                    .when(agg -> enablePushDownCountOnIndex())
-                    .when(agg -> agg.getGroupByExpressions().isEmpty())
-                    .when(agg -> {
-                        Set<AggregateFunction> funcs = 
agg.getAggregateFunctions();
-                        return !funcs.isEmpty() && funcs.stream()
-                               .allMatch(f -> f instanceof Count && 
!f.isDistinct() && (((Count) f).isStar()
-                               || f.children().isEmpty()
-                               || (f.children().size() == 1 && f.child(0) 
instanceof Literal)
-                               || f.child(0) instanceof Slot));
-                    })
-                    .thenApply(ctx -> {
-                        
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
-                        LogicalProject<LogicalFilter<LogicalOlapScan>> project 
= agg.child();
-                        LogicalFilter<LogicalOlapScan> filter = 
project.child();
-                        LogicalOlapScan olapScan = filter.child();
-                        return pushdownCountOnIndex(agg, project, filter, 
olapScan, ctx.cascadesContext);
-                    })
+                        )
+                    )
+                )
+                .when(agg -> enablePushDownCountOnIndex())
+                .when(agg -> agg.getGroupByExpressions().isEmpty())
+                .when(agg -> {
+                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
+                    if (funcs.isEmpty() || !funcs.stream()
+                            .allMatch(f -> f instanceof Count && 
!f.isDistinct() && (((Count) f).isStar()
+                            || f.children().isEmpty()
+                            || (f.children().size() == 1 && f.child(0) 
instanceof Literal)
+                            || f.child(0) instanceof Slot))) {
+                        return false;
+                    }
+                    Set<Expression> conjuncts = 
agg.child().child().getConjuncts();
+                    if (conjuncts.isEmpty()) {
+                        return false;
+                    }
+
+                    Set<Slot> aggSlots = funcs.stream()
+                            .flatMap(f -> f.getInputSlots().stream())
+                            .collect(Collectors.toSet());
+                    return conjuncts.stream().allMatch(expr -> 
checkSlotInOrExpression(expr, aggSlots));
+                })
+                .thenApply(ctx -> {
+                    
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
+                    LogicalProject<LogicalFilter<LogicalOlapScan>> project = 
agg.child();
+                    LogicalFilter<LogicalOlapScan> filter = project.child();
+                    LogicalOlapScan olapScan = filter.child();
+                    return pushdownCountOnIndex(agg, project, filter, 
olapScan, ctx.cascadesContext);
+                })
             ),
             
RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build(
                 logicalAggregate(
@@ -331,6 +357,22 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
         return connectContext != null && 
connectContext.getSessionVariable().isEnablePushDownCountOnIndex();
     }
 
+    private boolean checkSlotInOrExpression(Expression expr, Set<Slot> 
aggSlots) {
+        if (expr instanceof Or) {
+            Set<Slot> slots = expr.getInputSlots();
+            if (!slots.stream().allMatch(aggSlots::contains)) {
+                return false;
+            }
+        } else {
+            for (Expression child : expr.children()) {
+                if (!checkSlotInOrExpression(child, aggSlots)) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
     private boolean isDupOrMowKeyTable(LogicalOlapScan logicalScan) {
         if (logicalScan != null) {
             KeysType keysType = logicalScan.getTable().getKeysType();
diff --git a/regression-test/data/inverted_index_p0/test_count_on_index_2.out 
b/regression-test/data/inverted_index_p0/test_count_on_index_2.out
index 94d2a83388b..de74ba29ffe 100644
--- a/regression-test/data/inverted_index_p0/test_count_on_index_2.out
+++ b/regression-test/data/inverted_index_p0/test_count_on_index_2.out
@@ -101,3 +101,12 @@
 -- !sql --
 3
 
+-- !sql --
+1
+
+-- !sql --
+1
+
+-- !sql --
+1
+
diff --git 
a/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy 
b/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy
index 851c9120aa2..8f95e3cb13d 100644
--- a/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy
+++ b/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy
@@ -201,6 +201,35 @@ suite("test_count_on_index_2", "p0"){
         qt_sql """ select count() from ${indexTbName3} where (a >= 10 and a < 
20) and (b >= 5 and b < 14) and (c >= 16 and c < 25); """
         qt_sql """ select count() from ${indexTbName3} where (a >= 10 and a < 
20) and (b >= 5 and b < 16) and (c >= 13 and c < 25); """
 
+        sql """ DROP TABLE IF EXISTS tt """
+        sql """
+            CREATE TABLE `tt` (
+                `a` int NULL,
+                `b` int NULL,
+                `c` int NULL,
+                INDEX col_c (`b`) USING INVERTED,
+                INDEX col_b (`c`) USING INVERTED
+            ) ENGINE=OLAP
+            DUPLICATE KEY(`a`)
+            COMMENT 'OLAP'
+            DISTRIBUTED BY RANDOM BUCKETS 1
+            PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1"
+            );
+        """
+
+        sql """ insert into tt values (20, 23, 30); """
+        sql """ insert into tt values (20, null, 30); """
+        qt_sql """ select count(b) from tt where b = 23 or c = 30; """
+        qt_sql """ select count(b) from tt where b = 23  and (c = 20 or c = 
30); """
+        explain {
+            sql("select count(b) from tt where b = 23  and (c = 20 or c = 
30);")
+            contains "COUNT_ON_INDEX"
+        }
+        explain {
+            sql("select count(b) from tt where b = 23 or b = 30;")
+            contains "COUNT_ON_INDEX"
+        }
     } finally {
         //try_sql("DROP TABLE IF EXISTS ${testTable}")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to