This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c847daba0 [Feature](Nereids) Support grouping set for materialized 
index. (#15383)
6c847daba0 is described below

commit 6c847daba038f372b09c97bc9f38b6f01024e25f
Author: Shuo Wang <wangshuo...@gmail.com>
AuthorDate: Thu Dec 29 23:17:02 2022 +0800

    [Feature](Nereids) Support grouping set for materialized index. (#15383)
    
    This PR adds support for materialized index selecting when the query has 
grouping sets.
---
 .../jobs/batch/NereidsRewriteJobExecutor.java      |   6 +-
 .../org/apache/doris/nereids/rules/RuleType.java   |   5 +
 .../mv/SelectMaterializedIndexWithAggregate.java   | 221 ++++++++++++++++++++-
 .../doris/nereids/rules/mv/SelectMvIndexTest.java  |   6 +-
 4 files changed, 224 insertions(+), 14 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
index b6cd0dc5b6..0715bfa51b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriteJobExecutor.java
@@ -98,14 +98,14 @@ public class NereidsRewriteJobExecutor extends 
BatchRulesJob {
                 .add(topDownBatch(ImmutableList.of(new EliminateFilter())))
                 .add(topDownBatch(ImmutableList.of(new 
PruneOlapScanPartition())))
                 .add(topDownBatch(ImmutableList.of(new 
CountDistinctRewrite())))
-                .add(topDownBatch(ImmutableList.of(new 
SelectMaterializedIndexWithAggregate())))
-                .add(topDownBatch(ImmutableList.of(new 
SelectMaterializedIndexWithoutAggregate())))
-                .add(topDownBatch(ImmutableList.of(new PruneOlapScanTablet())))
                 // we need to execute this rule at the end of rewrite
                 // to avoid two consecutive same project appear when we do 
optimization.
                 .add(topDownBatch(ImmutableList.of(new 
EliminateGroupByConstant())))
                 .add(topDownBatch(ImmutableList.of(new 
EliminateOrderByConstant())))
                 .add(topDownBatch(ImmutableList.of(new 
EliminateUnnecessaryProject())))
+                .add(topDownBatch(ImmutableList.of(new 
SelectMaterializedIndexWithAggregate())))
+                .add(topDownBatch(ImmutableList.of(new 
SelectMaterializedIndexWithoutAggregate())))
+                .add(topDownBatch(ImmutableList.of(new PruneOlapScanTablet())))
                 .add(topDownBatch(ImmutableList.of(new EliminateAggregate())))
                 .add(bottomUpBatch(ImmutableList.of(new MergeSetOperations())))
                 .add(topDownBatch(ImmutableList.of(new LimitPushDown())))
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index fb04269bfa..71149d8d83 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -154,6 +154,11 @@ public enum RuleType {
     MATERIALIZED_INDEX_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE),
     MATERIALIZED_INDEX_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
     MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
+    MATERIALIZED_INDEX_AGG_REPEAT_SCAN(RuleTypeClass.REWRITE),
+    MATERIALIZED_INDEX_AGG_REPEAT_FILTER_SCAN(RuleTypeClass.REWRITE),
+    MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_SCAN(RuleTypeClass.REWRITE),
+    MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
+    MATERIALIZED_INDEX_AGG_REPEAT_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
     MATERIALIZED_INDEX_SCAN(RuleTypeClass.REWRITE),
     MATERIALIZED_INDEX_FILTER_SCAN(RuleTypeClass.REWRITE),
     MATERIALIZED_INDEX_PROJECT_SCAN(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java
index c189292510..a193b63901 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/mv/SelectMaterializedIndexWithAggregate.java
@@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@@ -52,6 +53,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.base.Preconditions;
@@ -160,7 +162,8 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                     ImmutableSet.of(),
                                     extractAggFunctionAndReplaceSlot(agg,
                                             Optional.of(project)),
-                                    agg.getGroupByExpressions()
+                                    
ExpressionUtils.replace(agg.getGroupByExpressions(),
+                                            project.getAliasToProducer())
                             );
 
                             if (result.exprRewriteMap.isEmpty()) {
@@ -262,7 +265,200 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                         filter.withChildren(newProject)
                                 );
                             }
-                        
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN)
+                        
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_FILTER_PROJECT_SCAN),
+
+                // only agg above scan
+                // Aggregate(Repeat(Scan))
+                
logicalAggregate(logicalRepeat(logicalOlapScan().when(this::shouldSelectIndex))).then(agg
 -> {
+                    LogicalRepeat<LogicalOlapScan> repeat = agg.child();
+                    LogicalOlapScan scan = repeat.child();
+                    SelectResult result = select(
+                            scan,
+                            agg.getInputSlots(),
+                            ImmutableSet.of(),
+                            extractAggFunctionAndReplaceSlot(agg, 
Optional.empty()),
+                            nonVirtualGroupByExprs(agg));
+                    if (result.exprRewriteMap.isEmpty()) {
+                        return agg.withChildren(
+                                repeat.withChildren(
+                                        
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
+                        );
+                    } else {
+                        return new LogicalAggregate<>(
+                                agg.getGroupByExpressions(),
+                                replaceAggOutput(agg, Optional.empty(), 
Optional.empty(), result.exprRewriteMap),
+                                agg.isNormalized(),
+                                agg.getSourceRepeat(),
+                                repeat.withChildren(
+                                        
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
+                        );
+                    }
+                }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_SCAN),
+
+                // filter could push down scan.
+                // Aggregate(Repeat(Filter(Scan)))
+                
logicalAggregate(logicalRepeat(logicalFilter(logicalOlapScan().when(this::shouldSelectIndex))))
+                        .then(agg -> {
+                            LogicalRepeat<LogicalFilter<LogicalOlapScan>> 
repeat = agg.child();
+                            LogicalFilter<LogicalOlapScan> filter = 
repeat.child();
+                            LogicalOlapScan scan = filter.child();
+                            ImmutableSet<Slot> requiredSlots = 
ImmutableSet.<Slot>builder()
+                                    .addAll(agg.getInputSlots())
+                                    .addAll(filter.getInputSlots())
+                                    .build();
+
+                            SelectResult result = select(
+                                    scan,
+                                    requiredSlots,
+                                    filter.getConjuncts(),
+                                    extractAggFunctionAndReplaceSlot(agg, 
Optional.empty()),
+                                    nonVirtualGroupByExprs(agg)
+                            );
+
+                            if (result.exprRewriteMap.isEmpty()) {
+                                return agg.withChildren(
+                                        repeat.withChildren(
+                                                filter.withChildren(
+                                                        
scan.withMaterializedIndexSelected(result.preAggStatus,
+                                                                
result.indexId))
+                                        ));
+                            } else {
+                                return new LogicalAggregate<>(
+                                        agg.getGroupByExpressions(),
+                                        replaceAggOutput(agg, 
Optional.empty(), Optional.empty(),
+                                                result.exprRewriteMap),
+                                        agg.isNormalized(),
+                                        agg.getSourceRepeat(),
+                                        // Not that no need to replace slots 
in the filter, because the slots to replace
+                                        // are value columns, which shouldn't 
appear in filters.
+                                        
repeat.withChildren(filter.withChildren(
+                                                
scan.withMaterializedIndexSelected(result.preAggStatus,
+                                                        result.indexId)))
+                                );
+                            }
+                        
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_SCAN),
+
+                // column pruning or other projections such as alias, etc.
+                // Aggregate(Repeat(Project(Scan)))
+                
logicalAggregate(logicalRepeat(logicalProject(logicalOlapScan().when(this::shouldSelectIndex))))
+                        .then(agg -> {
+                            LogicalRepeat<LogicalProject<LogicalOlapScan>> 
repeat = agg.child();
+                            LogicalProject<LogicalOlapScan> project = 
repeat.child();
+                            LogicalOlapScan scan = project.child();
+                            SelectResult result = select(
+                                    scan,
+                                    project.getInputSlots(),
+                                    ImmutableSet.of(),
+                                    extractAggFunctionAndReplaceSlot(agg,
+                                            Optional.of(project)),
+                                    
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
+                                            project.getAliasToProducer())
+                            );
+
+                            if (result.exprRewriteMap.isEmpty()) {
+                                return agg.withChildren(
+                                        repeat.withChildren(
+                                                project.withChildren(
+                                                        
scan.withMaterializedIndexSelected(result.preAggStatus,
+                                                                result.indexId)
+                                                ))
+                                );
+                            } else {
+                                List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                                        result.exprRewriteMap.projectExprMap);
+                                LogicalProject<LogicalOlapScan> newProject = 
new LogicalProject<>(
+                                        newProjectList,
+                                        
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId));
+                                return new LogicalAggregate<>(
+                                        agg.getGroupByExpressions(),
+                                        replaceAggOutput(agg, 
Optional.of(project), Optional.of(newProject),
+                                                result.exprRewriteMap),
+                                        agg.isNormalized(),
+                                        agg.getSourceRepeat(),
+                                        repeat.withChildren(newProject)
+                                );
+                            }
+                        
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_SCAN),
+
+                // filter could push down and project.
+                // Aggregate(Repeat(Project(Filter(Scan))))
+                
logicalAggregate(logicalRepeat(logicalProject(logicalFilter(logicalOlapScan()
+                        .when(this::shouldSelectIndex))))).then(agg -> {
+                            
LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>> repeat = 
agg.child();
+                            LogicalProject<LogicalFilter<LogicalOlapScan>> 
project = repeat.child();
+                            LogicalFilter<LogicalOlapScan> filter = 
project.child();
+                            LogicalOlapScan scan = filter.child();
+                            Set<Slot> requiredSlots = Stream.concat(
+                                    project.getInputSlots().stream(), 
filter.getInputSlots().stream())
+                                    .collect(Collectors.toSet());
+                            SelectResult result = select(
+                                    scan,
+                                    requiredSlots,
+                                    filter.getConjuncts(),
+                                    extractAggFunctionAndReplaceSlot(agg, 
Optional.of(project)),
+                                    
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
+                                            project.getAliasToProducer())
+                            );
+
+                            if (result.exprRewriteMap.isEmpty()) {
+                                return 
agg.withChildren(repeat.withChildren(project.withChildren(filter.withChildren(
+                                        
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
+                                )));
+                            } else {
+                                List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                                        result.exprRewriteMap.projectExprMap);
+                                LogicalProject<Plan> newProject = new 
LogicalProject<>(newProjectList,
+                                        
filter.withChildren(scan.withMaterializedIndexSelected(result.preAggStatus,
+                                                result.indexId)));
+
+                                return new LogicalAggregate<>(
+                                        agg.getGroupByExpressions(),
+                                        replaceAggOutput(agg, 
Optional.of(project), Optional.of(newProject),
+                                                result.exprRewriteMap),
+                                        agg.isNormalized(),
+                                        agg.getSourceRepeat(),
+                                        repeat.withChildren(newProject)
+                                );
+                            }
+                        
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_FILTER_SCAN),
+
+                // filter can't push down
+                // Aggregate(Repeat(Filter(Project(Scan))))
+                
logicalAggregate(logicalRepeat(logicalFilter(logicalProject(logicalOlapScan()
+                        .when(this::shouldSelectIndex))))).then(agg -> {
+                            
LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>> repeat = 
agg.child();
+                            LogicalFilter<LogicalProject<LogicalOlapScan>> 
filter = repeat.child();
+                            LogicalProject<LogicalOlapScan> project = 
filter.child();
+                            LogicalOlapScan scan = project.child();
+                            SelectResult result = select(
+                                    scan,
+                                    project.getInputSlots(),
+                                    ImmutableSet.of(),
+                                    extractAggFunctionAndReplaceSlot(agg, 
Optional.of(project)),
+                                    
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
+                                            project.getAliasToProducer())
+                            );
+
+                            if (result.exprRewriteMap.isEmpty()) {
+                                return 
agg.withChildren(repeat.withChildren(filter.withChildren(project.withChildren(
+                                        
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId))
+                                )));
+                            } else {
+                                List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                                        result.exprRewriteMap.projectExprMap);
+                                LogicalProject<Plan> newProject = new 
LogicalProject<>(newProjectList,
+                                        
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId));
+
+                                return new LogicalAggregate<>(
+                                        agg.getGroupByExpressions(),
+                                        replaceAggOutput(agg, 
Optional.of(project), Optional.of(newProject),
+                                                result.exprRewriteMap),
+                                        agg.isNormalized(),
+                                        agg.getSourceRepeat(),
+                                        
repeat.withChildren(filter.withChildren(newProject))
+                                );
+                            }
+                        
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_PROJECT_SCAN)
         );
     }
 
@@ -284,9 +480,13 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
             Set<Expression> predicates,
             List<AggregateFunction> aggregateFunctions,
             List<Expression> groupingExprs) {
-        
Preconditions.checkArgument(scan.getOutputSet().containsAll(requiredScanOutput),
+        // remove virtual slot for grouping sets.
+        Set<Slot> nonVirtualRequiredScanOutput = requiredScanOutput.stream()
+                .filter(slot -> !(slot instanceof VirtualSlotReference))
+                .collect(ImmutableSet.toImmutableSet());
+        
Preconditions.checkArgument(scan.getOutputSet().containsAll(nonVirtualRequiredScanOutput),
                 String.format("Scan's output (%s) should contains all the 
input required scan output (%s).",
-                        scan.getOutput(), requiredScanOutput));
+                        scan.getOutput(), nonVirtualRequiredScanOutput));
 
         OlapTable table = scan.getTable();
 
@@ -303,7 +503,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                     return new SelectResult(preAggStatus, 
scan.getTable().getBaseIndexId(), new ExprRewriteMap());
                 } else {
                     List<MaterializedIndex> rollupsWithAllRequiredCols = 
table.getVisibleIndex().stream()
-                            .filter(index -> containAllRequiredColumns(index, 
scan, requiredScanOutput))
+                            .filter(index -> containAllRequiredColumns(index, 
scan, nonVirtualRequiredScanOutput))
                             .collect(Collectors.toList());
                     return new SelectResult(preAggStatus, 
selectBestIndex(rollupsWithAllRequiredCols, scan, predicates),
                             new ExprRewriteMap());
@@ -328,7 +528,8 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                 ImmutableList.of())
                         .stream()
                         .filter(index -> 
!candidatesWithoutRewriting.contains(index))
-                        .map(index -> rewriteAgg(index, scan, 
requiredScanOutput, predicates, aggregateFunctions,
+                        .map(index -> rewriteAgg(index, scan, 
nonVirtualRequiredScanOutput, predicates,
+                                aggregateFunctions,
                                 groupingExprs))
                         .filter(aggRewriteResult -> checkPreAggStatus(scan, 
aggRewriteResult.index.getId(),
                                 predicates,
@@ -340,7 +541,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
 
                 List<MaterializedIndex> haveAllRequiredColumns = 
Streams.concat(
                         candidatesWithoutRewriting.stream()
-                                .filter(index -> 
containAllRequiredColumns(index, scan, requiredScanOutput)),
+                                .filter(index -> 
containAllRequiredColumns(index, scan, nonVirtualRequiredScanOutput)),
                         candidatesWithRewriting
                                 .stream()
                                 .filter(aggRewriteResult -> 
containAllRequiredColumns(aggRewriteResult.index, scan,
@@ -995,4 +1196,10 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                 .map(expr -> (NamedExpression) ExpressionUtils.replace(expr, 
projectMap))
                 .collect(ImmutableList.toImmutableList());
     }
+
+    private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends 
Plan> agg) {
+        return agg.getGroupByExpressions().stream()
+                .filter(expr -> !(expr instanceof VirtualSlotReference))
+                .collect(ImmutableList.toImmutableList());
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java
index 422041e93c..cb1eb6dd65 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/mv/SelectMvIndexTest.java
@@ -223,9 +223,8 @@ public class SelectMvIndexTest extends 
BaseMaterializedIndexSelectTest implement
     /**
      * Aggregation query with groupSets at coarser level of aggregation than
      * aggregation materialized view.
-     * TODO: enable this when group by rollup is supported.
      */
-    @Disabled
+    @Test
     public void testGroupingSetQueryOnAggMV() throws Exception {
         String createMVSql = "create materialized view " + EMPS_MV_NAME + " as 
select empid, deptno, sum(salary) "
                 + "from " + EMPS_TABLE_NAME + " group by empid, deptno;";
@@ -271,9 +270,8 @@ public class SelectMvIndexTest extends 
BaseMaterializedIndexSelectTest implement
 
     /**
      * Query with rollup and arithmetic expr
-     * TODO: enable this when group by rollup is supported.
      */
-    @Disabled
+    @Test
     public void testAggQueryOnAggMV10() throws Exception {
         String createMVSql = "create materialized view " + EMPS_MV_NAME + " as 
select deptno, commission, sum(salary) "
                 + "from " + EMPS_TABLE_NAME + " group by deptno, commission;";


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to