This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new bc381f0c97f branch-3.1: [opt](nereids)use SetPreAggStatus rule instead
of AdjustPreAggStatus with join limited #48502 (#51947)
bc381f0c97f is described below
commit bc381f0c97fecd09959f2f6dad41b0709b42157b
Author: starocean999 <[email protected]>
AuthorDate: Thu Jun 19 19:36:01 2025 +0800
branch-3.1: [opt](nereids)use SetPreAggStatus rule instead of
AdjustPreAggStatus with join limited #48502 (#51947)
pick from master #48502
---
.../doris/nereids/jobs/executor/Rewriter.java | 4 +-
.../org/apache/doris/nereids/rules/RuleType.java | 16 +-
.../nereids/rules/rewrite/AdjustPreAggStatus.java | 751 ---------------------
.../nereids/rules/rewrite/SetPreAggStatus.java | 592 ++++++++++++++++
.../rules/rewrite/mv/SelectRollupIndexTest.java | 38 +-
.../nereids_rules_p0/set_preagg/set_preagg.groovy | 312 +++++++++
6 files changed, 926 insertions(+), 787 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 60522e3da39..57df7f9999b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -36,7 +36,6 @@ import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit;
import org.apache.doris.nereids.rules.rewrite.AddProjectForJoin;
import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
-import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus;
import
org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion;
import org.apache.doris.nereids.rules.rewrite.CTEInline;
@@ -132,6 +131,7 @@ import
org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion;
import org.apache.doris.nereids.rules.rewrite.ReduceAggregateChildOutputRows;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
+import org.apache.doris.nereids.rules.rewrite.SetPreAggStatus;
import org.apache.doris.nereids.rules.rewrite.SimplifyWindowExpression;
import org.apache.doris.nereids.rules.rewrite.SplitLimit;
import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite;
@@ -416,7 +416,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT,
EliminateUnnecessaryProject::new)
),
topic("adjust preagg status",
- topDown(new AdjustPreAggStatus())
+ custom(RuleType.SET_PREAGG_STATUS,
SetPreAggStatus::new)
),
topic("Point query short circuit",
topDown(new
LogicalResultSinkToShortCircuitPointQuery())),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b7cbe5343d1..565e8843743 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -273,21 +273,7 @@ public enum RuleType {
MATERIALIZED_INDEX_PROJECT_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
MATERIALIZED_INDEX_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_FILTER_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_PROJECT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_REPEAT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_FILTER_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_PROJECT_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_PROJECT_FILTER_SCAN(RuleTypeClass.REWRITE),
- PREAGG_STATUS_FILTER_PROJECT_SCAN(RuleTypeClass.REWRITE),
+ SET_PREAGG_STATUS(RuleTypeClass.REWRITE),
REDUCE_AGGREGATE_CHILD_OUTPUT_ROWS(RuleTypeClass.REWRITE),
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
deleted file mode 100644
index 495a06870f5..00000000000
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
+++ /dev/null
@@ -1,751 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite;
-
-import org.apache.doris.catalog.AggregateType;
-import org.apache.doris.catalog.KeysType;
-import org.apache.doris.catalog.MaterializedIndexMeta;
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.annotation.Developing;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.CaseWhen;
-import org.apache.doris.nereids.trees.expressions.Cast;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
-import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
-import org.apache.doris.nereids.trees.expressions.WhenClause;
-import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
-import
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
-import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
-import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.PreAggStatus;
-import org.apache.doris.nereids.trees.plans.algebra.Project;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
-import org.apache.doris.nereids.util.ExpressionUtils;
-
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Sets;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * AdjustPreAggStatus
- */
-@Developing
-public class AdjustPreAggStatus implements RewriteRuleFactory {
- ///////////////////////////////////////////////////////////////////////////
- // All the patterns
- ///////////////////////////////////////////////////////////////////////////
- @Override
- public List<Rule> buildRules() {
- return ImmutableList.of(
- // Aggregate(Scan)
-
logicalAggregate(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))
- .thenApply(ctx -> {
- LogicalAggregate<LogicalOlapScan> agg = ctx.root;
- LogicalOlapScan scan = agg.child();
- PreAggStatus preAggStatus = checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction> aggregateFunctions =
- extractAggFunctionAndReplaceSlot(agg,
Optional.empty());
- List<Expression> groupByExpressions =
agg.getGroupByExpressions();
- Set<Expression> predicates = ImmutableSet.of();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return
agg.withChildren(scan.withPreAggStatus(preAggStatus));
- }).toRule(RuleType.PREAGG_STATUS_AGG_SCAN),
-
- // Aggregate(Filter(Scan))
- logicalAggregate(
-
logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalFilter<LogicalOlapScan>> agg = ctx.root;
- LogicalFilter<LogicalOlapScan> filter =
agg.child();
- LogicalOlapScan scan = filter.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.empty());
- List<Expression> groupByExpressions =
- agg.getGroupByExpressions();
- Set<Expression> predicates =
filter.getConjuncts();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return agg.withChildren(filter
-
.withChildren(scan.withPreAggStatus(preAggStatus)));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_SCAN),
-
- // Aggregate(Project(Scan))
- logicalAggregate(logicalProject(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalProject<LogicalOlapScan>> agg =
- ctx.root;
- LogicalProject<LogicalOlapScan> project =
agg.child();
- LogicalOlapScan scan = project.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg,
- Optional.of(project));
- List<Expression> groupByExpressions =
-
ExpressionUtils.replace(agg.getGroupByExpressions(),
-
project.getAliasToProducer());
- Set<Expression> predicates =
ImmutableSet.of();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return agg.withChildren(project
-
.withChildren(scan.withPreAggStatus(preAggStatus)));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_SCAN),
-
- // Aggregate(Project(Filter(Scan)))
- logicalAggregate(logicalProject(logicalFilter(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalProject<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
-
LogicalProject<LogicalFilter<LogicalOlapScan>> project = agg.child();
- LogicalFilter<LogicalOlapScan> filter =
project.child();
- LogicalOlapScan scan = filter.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.of(project));
- List<Expression> groupByExpressions =
-
ExpressionUtils.replace(agg.getGroupByExpressions(),
-
project.getAliasToProducer());
- Set<Expression> predicates =
filter.getConjuncts();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return
agg.withChildren(project.withChildren(filter
-
.withChildren(scan.withPreAggStatus(preAggStatus))));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_PROJECT_FILTER_SCAN),
-
- // Aggregate(Filter(Project(Scan)))
- logicalAggregate(logicalFilter(logicalProject(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalFilter<LogicalProject<LogicalOlapScan>>> agg = ctx.root;
-
LogicalFilter<LogicalProject<LogicalOlapScan>> filter =
- agg.child();
- LogicalProject<LogicalOlapScan> project =
filter.child();
- LogicalOlapScan scan = project.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.of(project));
- List<Expression> groupByExpressions =
-
ExpressionUtils.replace(agg.getGroupByExpressions(),
-
project.getAliasToProducer());
- Set<Expression> predicates =
ExpressionUtils.replace(
- filter.getConjuncts(),
project.getAliasToProducer());
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return
agg.withChildren(filter.withChildren(project
-
.withChildren(scan.withPreAggStatus(preAggStatus))));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_FILTER_PROJECT_SCAN),
-
- // Aggregate(Repeat(Scan))
- logicalAggregate(
-
logicalRepeat(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalRepeat<LogicalOlapScan>> agg = ctx.root;
- LogicalRepeat<LogicalOlapScan> repeat =
agg.child();
- LogicalOlapScan scan = repeat.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.empty());
- List<Expression> groupByExpressions =
nonVirtualGroupByExprs(agg);
- Set<Expression> predicates =
ImmutableSet.of();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return agg.withChildren(repeat
-
.withChildren(scan.withPreAggStatus(preAggStatus)));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_SCAN),
-
- // Aggregate(Repeat(Filter(Scan)))
- logicalAggregate(logicalRepeat(logicalFilter(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalRepeat<LogicalFilter<LogicalOlapScan>>> agg = ctx.root;
-
LogicalRepeat<LogicalFilter<LogicalOlapScan>> repeat = agg.child();
- LogicalFilter<LogicalOlapScan> filter =
repeat.child();
- LogicalOlapScan scan = filter.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.empty());
- List<Expression> groupByExpressions =
- nonVirtualGroupByExprs(agg);
- Set<Expression> predicates =
filter.getConjuncts();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return
agg.withChildren(repeat.withChildren(filter
-
.withChildren(scan.withPreAggStatus(preAggStatus))));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_SCAN),
-
- // Aggregate(Repeat(Project(Scan)))
- logicalAggregate(logicalRepeat(logicalProject(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalRepeat<LogicalProject<LogicalOlapScan>>> agg = ctx.root;
-
LogicalRepeat<LogicalProject<LogicalOlapScan>> repeat = agg.child();
- LogicalProject<LogicalOlapScan> project =
repeat.child();
- LogicalOlapScan scan = project.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.empty());
- List<Expression> groupByExpressions =
-
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
-
project.getAliasToProducer());
- Set<Expression> predicates =
ImmutableSet.of();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return
agg.withChildren(repeat.withChildren(project
-
.withChildren(scan.withPreAggStatus(preAggStatus))));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_SCAN),
-
- // Aggregate(Repeat(Project(Filter(Scan))))
- logicalAggregate(logicalRepeat(logicalProject(logicalFilter(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>>>
agg
- = ctx.root;
-
LogicalRepeat<LogicalProject<LogicalFilter<LogicalOlapScan>>> repeat =
agg.child();
-
LogicalProject<LogicalFilter<LogicalOlapScan>> project = repeat.child();
- LogicalFilter<LogicalOlapScan> filter =
project.child();
- LogicalOlapScan scan = filter.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.empty());
- List<Expression> groupByExpressions =
-
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
-
project.getAliasToProducer());
- Set<Expression> predicates =
filter.getConjuncts();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return agg.withChildren(repeat
-
.withChildren(project.withChildren(filter.withChildren(
-
scan.withPreAggStatus(preAggStatus)))));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_PROJECT_FILTER_SCAN),
-
- // Aggregate(Repeat(Filter(Project(Scan))))
- logicalAggregate(logicalRepeat(logicalFilter(logicalProject(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))))
- .thenApply(ctx -> {
-
LogicalAggregate<LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>>>
agg
- = ctx.root;
-
LogicalRepeat<LogicalFilter<LogicalProject<LogicalOlapScan>>> repeat =
agg.child();
-
LogicalFilter<LogicalProject<LogicalOlapScan>> filter = repeat.child();
- LogicalProject<LogicalOlapScan> project =
filter.child();
- LogicalOlapScan scan = project.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions =
-
extractAggFunctionAndReplaceSlot(agg, Optional.of(project));
- List<Expression> groupByExpressions =
-
ExpressionUtils.replace(nonVirtualGroupByExprs(agg),
-
project.getAliasToProducer());
- Set<Expression> predicates =
ExpressionUtils.replace(
- filter.getConjuncts(),
project.getAliasToProducer());
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return agg.withChildren(repeat
-
.withChildren(filter.withChildren(project.withChildren(
-
scan.withPreAggStatus(preAggStatus)))));
-
}).toRule(RuleType.PREAGG_STATUS_AGG_REPEAT_FILTER_PROJECT_SCAN),
-
- // Filter(Project(Scan))
- logicalFilter(logicalProject(
-
logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)))
- .thenApply(ctx -> {
-
LogicalFilter<LogicalProject<LogicalOlapScan>> filter = ctx.root;
- LogicalProject<LogicalOlapScan> project =
filter.child();
- LogicalOlapScan scan = project.child();
- PreAggStatus preAggStatus =
checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction>
aggregateFunctions = ImmutableList.of();
- List<Expression> groupByExpressions =
ImmutableList.of();
- Set<Expression> predicates =
ExpressionUtils.replace(
- filter.getConjuncts(),
project.getAliasToProducer());
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return filter.withChildren(project
-
.withChildren(scan.withPreAggStatus(preAggStatus)));
-
}).toRule(RuleType.PREAGG_STATUS_FILTER_PROJECT_SCAN),
-
- // Filter(Scan)
-
logicalFilter(logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet))
- .thenApply(ctx -> {
- LogicalFilter<LogicalOlapScan> filter = ctx.root;
- LogicalOlapScan scan = filter.child();
- PreAggStatus preAggStatus = checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction> aggregateFunctions =
ImmutableList.of();
- List<Expression> groupByExpressions =
ImmutableList.of();
- Set<Expression> predicates =
filter.getConjuncts();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return
filter.withChildren(scan.withPreAggStatus(preAggStatus));
- }).toRule(RuleType.PREAGG_STATUS_FILTER_SCAN),
-
- // only scan.
- logicalOlapScan().when(LogicalOlapScan::isPreAggStatusUnSet)
- .thenApply(ctx -> {
- LogicalOlapScan scan = ctx.root;
- PreAggStatus preAggStatus = checkKeysType(scan);
- if (preAggStatus == PreAggStatus.unset()) {
- List<AggregateFunction> aggregateFunctions =
ImmutableList.of();
- List<Expression> groupByExpressions =
ImmutableList.of();
- Set<Expression> predicates = ImmutableSet.of();
- preAggStatus = checkPreAggStatus(scan,
predicates,
- aggregateFunctions,
groupByExpressions);
- }
- return scan.withPreAggStatus(preAggStatus);
- }).toRule(RuleType.PREAGG_STATUS_SCAN));
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Set pre-aggregation status.
- ///////////////////////////////////////////////////////////////////////////
-
- /**
- * Do aggregate function extraction and replace aggregate function's input
slots by underlying project.
- * <p>
- * 1. extract aggregate functions in aggregate plan.
- * <p>
- * 2. replace aggregate function's input slot by underlying project
expression if project is present.
- * <p>
- * For example:
- * <pre>
- * input arguments:
- * agg: Aggregate(sum(v) as sum_value)
- * underlying project: Project(a + b as v)
- *
- * output:
- * sum(a + b)
- * </pre>
- */
- private List<AggregateFunction>
extractAggFunctionAndReplaceSlot(LogicalAggregate<?> agg,
- Optional<LogicalProject<?>> project) {
- Optional<Map<Slot, Expression>> slotToProducerOpt =
- project.map(Project::getAliasToProducer);
- return agg.getOutputExpressions().stream()
- // extract aggregate functions.
- .flatMap(e ->
e.<AggregateFunction>collect(AggregateFunction.class::isInstance)
- .stream())
- // replace aggregate function's input slot by its producing
expression.
- .map(expr -> slotToProducerOpt
- .map(slotToExpressions -> (AggregateFunction)
ExpressionUtils.replace(expr,
- slotToExpressions))
- .orElse(expr))
- .collect(Collectors.toList());
- }
-
- private PreAggStatus checkKeysType(LogicalOlapScan olapScan) {
- long selectIndexId = olapScan.getSelectedIndexId();
- MaterializedIndexMeta meta =
olapScan.getTable().getIndexMetaByIndexId(selectIndexId);
- if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType() ==
KeysType.UNIQUE_KEYS
- && olapScan.getTable().getEnableUniqueKeyMergeOnWrite())) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus.unset();
- }
- }
-
- private PreAggStatus checkPreAggStatus(LogicalOlapScan olapScan,
Set<Expression> predicates,
- List<AggregateFunction> aggregateFuncs, List<Expression>
groupingExprs) {
- Set<Slot> outputSlots = olapScan.getOutputSet();
- Pair<Set<SlotReference>, Set<SlotReference>> splittedSlots =
splitSlots(outputSlots);
- Set<SlotReference> keySlots = splittedSlots.first;
- Set<SlotReference> valueSlots = splittedSlots.second;
- Preconditions.checkState(outputSlots.size() == keySlots.size() +
valueSlots.size(),
- "output slots contains no key or value slots");
-
- Set<Slot> groupingExprsInputSlots =
ExpressionUtils.getInputSlotSet(groupingExprs);
- if (groupingExprsInputSlots.retainAll(keySlots)) {
- return PreAggStatus
- .off(String.format("Grouping expression %s contains
non-key column %s",
- groupingExprs, groupingExprsInputSlots));
- }
-
- Set<Slot> predicateInputSlots =
ExpressionUtils.getInputSlotSet(predicates);
- if (predicateInputSlots.retainAll(keySlots)) {
- return PreAggStatus.off(String.format("Predicate %s contains
non-key column %s",
- predicates, predicateInputSlots));
- }
-
- return checkAggregateFunctions(aggregateFuncs,
groupingExprsInputSlots);
- }
-
- private Pair<Set<SlotReference>, Set<SlotReference>> splitSlots(Set<Slot>
slots) {
- Set<SlotReference> keySlots =
Sets.newHashSetWithExpectedSize(slots.size());
- Set<SlotReference> valueSlots =
Sets.newHashSetWithExpectedSize(slots.size());
- for (Slot slot : slots) {
- if (slot instanceof SlotReference && ((SlotReference)
slot).getColumn().isPresent()) {
- if (((SlotReference) slot).getColumn().get().isKey()) {
- keySlots.add((SlotReference) slot);
- } else {
- valueSlots.add((SlotReference) slot);
- }
- }
- }
- return Pair.of(keySlots, valueSlots);
- }
-
- private static Expression removeCast(Expression expression) {
- while (expression instanceof Cast) {
- expression = ((Cast) expression).child();
- }
- return expression;
- }
-
- private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction
aggFunc,
- Set<SlotReference> keySlots, Set<SlotReference> valueSlots) {
- Expression child = aggFunc.child(0);
- List<Expression> conditionExps = new ArrayList<>();
- List<Expression> returnExps = new ArrayList<>();
-
- // ignore cast
- while (child instanceof Cast) {
- if (!((Cast) child).getDataType().isNumericType()) {
- return PreAggStatus.off(String.format("%s is not numeric
CAST.", child.toSql()));
- }
- child = child.child(0);
- }
- // step 1: extract all condition exprs and return exprs
- if (child instanceof If) {
- conditionExps.add(child.child(0));
- returnExps.add(removeCast(child.child(1)));
- returnExps.add(removeCast(child.child(2)));
- } else if (child instanceof CaseWhen) {
- CaseWhen caseWhen = (CaseWhen) child;
- // WHEN THEN
- for (WhenClause whenClause : caseWhen.getWhenClauses()) {
- conditionExps.add(whenClause.getOperand());
- returnExps.add(removeCast(whenClause.getResult()));
- }
- // ELSE
- returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new
NullLiteral())));
- } else {
- // currently, only IF and CASE WHEN are supported
- returnExps.add(removeCast(child));
- }
-
- // step 2: check condition expressions
- Set<Slot> inputSlots = ExpressionUtils.getInputSlotSet(conditionExps);
- inputSlots.retainAll(valueSlots);
- if (!inputSlots.isEmpty()) {
- return PreAggStatus
- .off(String.format("some columns in condition %s is not
key.", conditionExps));
- }
-
- return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc, returnExps);
- }
-
- private PreAggStatus checkAggregateFunctions(List<AggregateFunction>
aggregateFuncs,
- Set<Slot> groupingExprsInputSlots) {
- PreAggStatus preAggStatus = aggregateFuncs.isEmpty() &&
groupingExprsInputSlots.isEmpty()
- ? PreAggStatus.off("No aggregate on scan.")
- : PreAggStatus.on();
- for (AggregateFunction aggFunc : aggregateFuncs) {
- if (aggFunc.children().isEmpty()) {
- preAggStatus = PreAggStatus.off(
- String.format("can't turn preAgg on for aggregate
function %s", aggFunc));
- } else if (aggFunc.children().size() == 1 && aggFunc.child(0)
instanceof Slot) {
- Slot aggSlot = (Slot) aggFunc.child(0);
- if (aggSlot instanceof SlotReference
- && ((SlotReference) aggSlot).getColumn().isPresent()) {
- if (((SlotReference) aggSlot).getColumn().get().isKey()) {
- preAggStatus =
OneKeySlotAggChecker.INSTANCE.check(aggFunc);
- } else {
- preAggStatus =
OneValueSlotAggChecker.INSTANCE.check(aggFunc,
- ((SlotReference)
aggSlot).getColumn().get().getAggregationType());
- }
- } else {
- preAggStatus = PreAggStatus.off(
- String.format("aggregate function %s use unknown
slot %s from scan",
- aggFunc, aggSlot));
- }
- } else {
- Set<Slot> aggSlots = aggFunc.getInputSlots();
- Pair<Set<SlotReference>, Set<SlotReference>> splitSlots =
splitSlots(aggSlots);
- preAggStatus =
- checkAggWithKeyAndValueSlots(aggFunc,
splitSlots.first, splitSlots.second);
- }
- if (preAggStatus.isOff()) {
- return preAggStatus;
- }
- }
- return preAggStatus;
- }
-
- private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends
Plan> agg) {
- return agg.getGroupByExpressions().stream()
- .filter(expr -> !(expr instanceof VirtualSlotReference))
- .collect(ImmutableList.toImmutableList());
- }
-
- private static class OneValueSlotAggChecker
- extends ExpressionVisitor<PreAggStatus, AggregateType> {
- public static final OneValueSlotAggChecker INSTANCE = new
OneValueSlotAggChecker();
-
- public PreAggStatus check(AggregateFunction aggFun, AggregateType
aggregateType) {
- return aggFun.accept(INSTANCE, aggregateType);
- }
-
- @Override
- public PreAggStatus visit(Expression expr, AggregateType
aggregateType) {
- return PreAggStatus.off(String.format("%s is not aggregate
function.", expr.toSql()));
- }
-
- @Override
- public PreAggStatus visitAggregateFunction(AggregateFunction
aggregateFunction,
- AggregateType aggregateType) {
- return PreAggStatus
- .off(String.format("%s is not supported.",
aggregateFunction.toSql()));
- }
-
- @Override
- public PreAggStatus visitMax(Max max, AggregateType aggregateType) {
- if (aggregateType == AggregateType.MAX && !max.isDistinct()) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus
- .off(String.format("%s is not match agg mode %s or has
distinct param",
- max.toSql(), aggregateType));
- }
- }
-
- @Override
- public PreAggStatus visitMin(Min min, AggregateType aggregateType) {
- if (aggregateType == AggregateType.MIN && !min.isDistinct()) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus
- .off(String.format("%s is not match agg mode %s or has
distinct param",
- min.toSql(), aggregateType));
- }
- }
-
- @Override
- public PreAggStatus visitSum(Sum sum, AggregateType aggregateType) {
- if (aggregateType == AggregateType.SUM && !sum.isDistinct()) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus
- .off(String.format("%s is not match agg mode %s or has
distinct param",
- sum.toSql(), aggregateType));
- }
- }
-
- @Override
- public PreAggStatus visitBitmapUnionCount(BitmapUnionCount
bitmapUnionCount,
- AggregateType aggregateType) {
- if (aggregateType == AggregateType.BITMAP_UNION) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus.off("invalid bitmap_union_count: " +
bitmapUnionCount.toSql());
- }
- }
-
- @Override
- public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion,
AggregateType aggregateType) {
- if (aggregateType == AggregateType.BITMAP_UNION) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus.off("invalid bitmapUnion: " +
bitmapUnion.toSql());
- }
- }
-
- @Override
- public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg,
AggregateType aggregateType) {
- if (aggregateType == AggregateType.HLL_UNION) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus.off("invalid hllUnionAgg: " +
hllUnionAgg.toSql());
- }
- }
-
- @Override
- public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType
aggregateType) {
- if (aggregateType == AggregateType.HLL_UNION) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus.off("invalid hllUnion: " +
hllUnion.toSql());
- }
- }
- }
-
- private static class OneKeySlotAggChecker extends
ExpressionVisitor<PreAggStatus, Void> {
- public static final OneKeySlotAggChecker INSTANCE = new
OneKeySlotAggChecker();
-
- public PreAggStatus check(AggregateFunction aggFun) {
- return aggFun.accept(INSTANCE, null);
- }
-
- @Override
- public PreAggStatus visit(Expression expr, Void context) {
- return PreAggStatus.off(String.format("%s is not aggregate
function.", expr.toSql()));
- }
-
- @Override
- public PreAggStatus visitAggregateFunction(AggregateFunction
aggregateFunction,
- Void context) {
- return PreAggStatus.off(String.format("Aggregate function %s
contains key column %s",
- aggregateFunction.toSql(),
aggregateFunction.child(0).toSql()));
- }
-
- @Override
- public PreAggStatus visitMax(Max max, Void context) {
- return PreAggStatus.on();
- }
-
- @Override
- public PreAggStatus visitMin(Min min, Void context) {
- return PreAggStatus.on();
- }
-
- @Override
- public PreAggStatus visitCount(Count count, Void context) {
- if (count.isDistinct()) {
- return PreAggStatus.on();
- } else {
- return PreAggStatus.off(String.format("%s is not distinct.",
count.toSql()));
- }
- }
- }
-
- private static class KeyAndValueSlotsAggChecker
- extends ExpressionVisitor<PreAggStatus, List<Expression>> {
- public static final KeyAndValueSlotsAggChecker INSTANCE = new
KeyAndValueSlotsAggChecker();
-
- public PreAggStatus check(AggregateFunction aggFun, List<Expression>
returnValues) {
- return aggFun.accept(INSTANCE, returnValues);
- }
-
- @Override
- public PreAggStatus visit(Expression expr, List<Expression>
returnValues) {
- return PreAggStatus.off(String.format("%s is not aggregate
function.", expr.toSql()));
- }
-
- @Override
- public PreAggStatus visitAggregateFunction(AggregateFunction
aggregateFunction,
- List<Expression> returnValues) {
- return PreAggStatus
- .off(String.format("%s is not supported.",
aggregateFunction.toSql()));
- }
-
- @Override
- public PreAggStatus visitSum(Sum sum, List<Expression> returnValues) {
- for (Expression value : returnValues) {
- if (!(isAggTypeMatched(value, AggregateType.SUM) ||
value.isZeroLiteral()
- || value.isNullLiteral())) {
- return PreAggStatus.off(String.format("%s is not
supported.", sum.toSql()));
- }
- }
- return PreAggStatus.on();
- }
-
- @Override
- public PreAggStatus visitMax(Max max, List<Expression> returnValues) {
- for (Expression value : returnValues) {
- if (!(isAggTypeMatched(value, AggregateType.MAX) ||
isKeySlot(value)
- || value.isNullLiteral())) {
- return PreAggStatus.off(String.format("%s is not
supported.", max.toSql()));
- }
- }
- return PreAggStatus.on();
- }
-
- @Override
- public PreAggStatus visitMin(Min min, List<Expression> returnValues) {
- for (Expression value : returnValues) {
- if (!(isAggTypeMatched(value, AggregateType.MIN) ||
isKeySlot(value)
- || value.isNullLiteral())) {
- return PreAggStatus.off(String.format("%s is not
supported.", min.toSql()));
- }
- }
- return PreAggStatus.on();
- }
-
- @Override
- public PreAggStatus visitCount(Count count, List<Expression>
returnValues) {
- if (count.isDistinct()) {
- for (Expression value : returnValues) {
- if (!(isKeySlot(value) || value.isZeroLiteral() ||
value.isNullLiteral())) {
- return PreAggStatus
- .off(String.format("%s is not supported.",
count.toSql()));
- }
- }
- return PreAggStatus.on();
- } else {
- return PreAggStatus.off(String.format("%s is not supported.",
count.toSql()));
- }
- }
-
- private boolean isKeySlot(Expression expression) {
- return expression instanceof SlotReference
- && ((SlotReference) expression).getColumn().isPresent()
- && ((SlotReference) expression).getColumn().get().isKey();
- }
-
- private boolean isAggTypeMatched(Expression expression, AggregateType
aggregateType) {
- return expression instanceof SlotReference
- && ((SlotReference) expression).getColumn().isPresent()
- && ((SlotReference) expression).getColumn().get()
- .getAggregationType() == aggregateType;
- }
- }
-}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java
new file mode 100644
index 00000000000..77b1a357370
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java
@@ -0,0 +1,592 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.catalog.AggregateType;
+import org.apache.doris.catalog.KeysType;
+import org.apache.doris.catalog.MaterializedIndexMeta;
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.PreAggStatus;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+
+/**
+ * SetPreAggStatus
+ * bottom-up tranverse the plan tree and collect required info into
PreAggInfoContext
+ * when get to the bottom LogicalOlapScan node, we set the preagg status using
info in PreAggInfoContext
+ */
+public class SetPreAggStatus extends
DefaultPlanRewriter<Stack<SetPreAggStatus.PreAggInfoContext>>
+ implements CustomRewriter {
+ private Map<RelationId, PreAggInfoContext> olapScanPreAggContexts = new
HashMap<>();
+
+ /**
+ * PreAggInfoContext
+ */
+ public static class PreAggInfoContext {
+ private List<Expression> filterConjuncts = new ArrayList<>();
+ private List<Expression> joinConjuncts = new ArrayList<>();
+ private List<Expression> groupByExpresssions = new ArrayList<>();
+ private Set<AggregateFunction> aggregateFunctions = new HashSet<>();
+ private Set<RelationId> olapScanIds = new HashSet<>();
+
+ private Map<Slot, Expression> replaceMap = new HashMap<>();
+
+ private void setReplaceMap(Map<Slot, Expression> replaceMap) {
+ this.replaceMap = replaceMap;
+ }
+
+ private void addRelationId(RelationId id) {
+ olapScanIds.add(id);
+ }
+
+ private void addJoinInfo(LogicalJoin logicalJoin) {
+ joinConjuncts.addAll(logicalJoin.getExpressions());
+ joinConjuncts =
Lists.newArrayList(ExpressionUtils.replace(joinConjuncts, replaceMap));
+ }
+
+ private void addFilterConjuncts(List<Expression> conjuncts) {
+ filterConjuncts.addAll(conjuncts);
+ filterConjuncts =
Lists.newArrayList(ExpressionUtils.replace(filterConjuncts, replaceMap));
+ }
+
+ private void addGroupByExpresssions(List<Expression> expressions) {
+ groupByExpresssions.addAll(expressions);
+ groupByExpresssions =
Lists.newArrayList(ExpressionUtils.replace(groupByExpresssions, replaceMap));
+ }
+
+ private void addAggregateFunctions(Set<AggregateFunction> functions) {
+ aggregateFunctions.addAll(functions);
+ Set<AggregateFunction> newAggregateFunctions = Sets.newHashSet();
+ for (AggregateFunction aggregateFunction : aggregateFunctions) {
+ newAggregateFunctions
+ .add((AggregateFunction)
ExpressionUtils.replace(aggregateFunction, replaceMap));
+ }
+ aggregateFunctions = newAggregateFunctions;
+ }
+ }
+
+ @Override
+ public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+ Plan newPlan = plan.accept(this, new Stack<>());
+ return newPlan.accept(SetOlapScanPreAgg.INSTANCE,
olapScanPreAggContexts);
+ }
+
+ @Override
+ public Plan visit(Plan plan, Stack<PreAggInfoContext> context) {
+ Plan newPlan = super.visit(plan, context);
+ context.clear();
+ return newPlan;
+ }
+
+ @Override
+ public Plan visitLogicalOlapScan(LogicalOlapScan logicalOlapScan,
Stack<PreAggInfoContext> context) {
+ if (logicalOlapScan.isPreAggStatusUnSet()) {
+ long selectIndexId = logicalOlapScan.getSelectedIndexId();
+ MaterializedIndexMeta meta =
logicalOlapScan.getTable().getIndexMetaByIndexId(selectIndexId);
+ if (meta.getKeysType() == KeysType.DUP_KEYS || (meta.getKeysType()
== KeysType.UNIQUE_KEYS
+ &&
logicalOlapScan.getTable().getEnableUniqueKeyMergeOnWrite())) {
+ return logicalOlapScan.withPreAggStatus(PreAggStatus.on());
+ } else {
+ if (context.empty()) {
+ context.push(new PreAggInfoContext());
+ }
+ context.peek().addRelationId(logicalOlapScan.getRelationId());
+ return logicalOlapScan;
+ }
+ } else {
+ return logicalOlapScan;
+ }
+ }
+
+ @Override
+ public Plan visitLogicalFilter(LogicalFilter<? extends Plan>
logicalFilter, Stack<PreAggInfoContext> context) {
+ LogicalFilter plan = (LogicalFilter) super.visit(logicalFilter,
context);
+ if (!context.empty()) {
+ context.peek().addFilterConjuncts(plan.getExpressions());
+ }
+ return plan;
+ }
+
+ @Override
+ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
logicalJoin,
+ Stack<PreAggInfoContext> context) {
+ LogicalJoin plan = (LogicalJoin) super.visit(logicalJoin, context);
+ if (!context.empty()) {
+ context.peek().addJoinInfo(plan);
+ }
+ return plan;
+ }
+
+ @Override
+ public Plan visitLogicalProject(LogicalProject<? extends Plan>
logicalProject,
+ Stack<PreAggInfoContext> context) {
+ LogicalProject plan = (LogicalProject) super.visit(logicalProject,
context);
+ if (!context.empty()) {
+ context.peek().setReplaceMap(plan.getAliasToProducer());
+ }
+ return plan;
+ }
+
+ @Override
+ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan>
logicalAggregate,
+ Stack<PreAggInfoContext> context) {
+ Plan plan = super.visit(logicalAggregate, context);
+ if (!context.isEmpty()) {
+ PreAggInfoContext preAggInfoContext = context.pop();
+
preAggInfoContext.addAggregateFunctions(logicalAggregate.getAggregateFunctions());
+
preAggInfoContext.addGroupByExpresssions(nonVirtualGroupByExprs(logicalAggregate));
+ for (RelationId id : preAggInfoContext.olapScanIds) {
+ olapScanPreAggContexts.put(id, preAggInfoContext);
+ }
+ }
+ return plan;
+ }
+
+ @Override
+ public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan>
logicalRepeat, Stack<PreAggInfoContext> context) {
+ return super.visit(logicalRepeat, context);
+ }
+
+ private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends
Plan> agg) {
+ return agg.getGroupByExpressions().stream()
+ .filter(expr -> !(expr instanceof VirtualSlotReference))
+ .collect(ImmutableList.toImmutableList());
+ }
+
+ private static class SetOlapScanPreAgg extends
DefaultPlanRewriter<Map<RelationId, PreAggInfoContext>> {
+ private static SetOlapScanPreAgg INSTANCE = new SetOlapScanPreAgg();
+
+ @Override
+ public Plan visitLogicalOlapScan(LogicalOlapScan olapScan,
Map<RelationId, PreAggInfoContext> context) {
+ if (olapScan.isPreAggStatusUnSet()) {
+ PreAggStatus preAggStatus = PreAggStatus.off("No valid
aggregate on scan.");
+ PreAggInfoContext preAggInfoContext =
context.get(olapScan.getRelationId());
+ if (preAggInfoContext != null) {
+ preAggStatus = createPreAggStatus(olapScan,
preAggInfoContext);
+ }
+ return olapScan.withPreAggStatus(preAggStatus);
+ } else {
+ return olapScan;
+ }
+ }
+
+ private PreAggStatus createPreAggStatus(LogicalOlapScan
logicalOlapScan, PreAggInfoContext context) {
+ List<Expression> filterConjuncts = context.filterConjuncts;
+ List<Expression> joinConjuncts = context.joinConjuncts;
+ Set<AggregateFunction> aggregateFuncs = context.aggregateFunctions;
+ List<Expression> groupingExprs = context.groupByExpresssions;
+ Set<Slot> outputSlots = logicalOlapScan.getOutputSet();
+ Pair<Set<SlotReference>, Set<SlotReference>> splittedSlots =
splitKeyValueSlots(outputSlots);
+ Set<SlotReference> keySlots = splittedSlots.first;
+ Set<SlotReference> valueSlots = splittedSlots.second;
+ Preconditions.checkState(outputSlots.size() == keySlots.size() +
valueSlots.size(),
+ "output slots contains no key or value slots");
+
+ Set<Slot> groupingExprsInputSlots =
ExpressionUtils.getInputSlotSet(groupingExprs);
+ if (!Sets.intersection(groupingExprsInputSlots,
valueSlots).isEmpty()) {
+ return PreAggStatus
+ .off(String.format("Grouping expression %s contains
non-key column %s",
+ groupingExprs, groupingExprsInputSlots));
+ }
+
+ Set<Slot> filterInputSlots =
ExpressionUtils.getInputSlotSet(filterConjuncts);
+ if (!Sets.intersection(filterInputSlots, valueSlots).isEmpty()) {
+ return PreAggStatus.off(String.format("Filter conjuncts %s
contains non-key column %s",
+ filterConjuncts, filterInputSlots));
+ }
+
+ Set<Slot> joinInputSlots =
ExpressionUtils.getInputSlotSet(joinConjuncts);
+ if (!Sets.intersection(joinInputSlots, valueSlots).isEmpty()) {
+ return PreAggStatus.off(String.format("Join conjuncts %s
contains non-key column %s",
+ joinConjuncts, joinInputSlots));
+ }
+ Set<AggregateFunction> candidateAggFuncs = Sets.newHashSet();
+ for (AggregateFunction aggregateFunction : aggregateFuncs) {
+ if (!Sets.intersection(aggregateFunction.getInputSlots(),
outputSlots).isEmpty()) {
+ candidateAggFuncs.add(aggregateFunction);
+ } else {
+ if (!(aggregateFunction instanceof Max ||
aggregateFunction instanceof Min
+ || (aggregateFunction instanceof Count &&
aggregateFunction.isDistinct()))) {
+ return PreAggStatus.off(
+ String.format("can't turn preAgg on because
aggregate function %s in other table",
+ aggregateFunction));
+ }
+ }
+ }
+
+ Set<Slot> candidateGroupByInputSlots = Sets.newHashSet();
+ candidateGroupByInputSlots.addAll(groupingExprsInputSlots);
+ candidateGroupByInputSlots.retainAll(outputSlots);
+ if (candidateAggFuncs.isEmpty() &&
candidateGroupByInputSlots.isEmpty()) {
+ return !aggregateFuncs.isEmpty() || !groupingExprs.isEmpty() ?
PreAggStatus.on()
+ : PreAggStatus.off("No aggregate on scan.");
+ } else {
+ return checkAggregateFunctions(candidateAggFuncs,
candidateGroupByInputSlots);
+ }
+ }
+
+ private PreAggStatus checkAggregateFunctions(Set<AggregateFunction>
aggregateFuncs,
+ Set<Slot> groupingExprsInputSlots) {
+ if (aggregateFuncs.isEmpty() && groupingExprsInputSlots.isEmpty())
{
+ return PreAggStatus.off("No aggregate on scan.");
+ }
+ PreAggStatus preAggStatus = PreAggStatus.on();
+ for (AggregateFunction aggFunc : aggregateFuncs) {
+ if (aggFunc.children().isEmpty()) {
+ preAggStatus = PreAggStatus.off(
+ String.format("can't turn preAgg on for aggregate
function %s", aggFunc));
+ } else if (aggFunc.children().size() == 1 && aggFunc.child(0)
instanceof Slot) {
+ Slot aggSlot = (Slot) aggFunc.child(0);
+ if (aggSlot instanceof SlotReference
+ && ((SlotReference)
aggSlot).getColumn().isPresent()) {
+ if (((SlotReference)
aggSlot).getColumn().get().isKey()) {
+ preAggStatus =
OneKeySlotAggChecker.INSTANCE.check(aggFunc);
+ } else {
+ preAggStatus =
OneValueSlotAggChecker.INSTANCE.check(aggFunc,
+ ((SlotReference)
aggSlot).getColumn().get().getAggregationType());
+ }
+ } else {
+ preAggStatus = PreAggStatus.off(
+ String.format("aggregate function %s use
unknown slot %s from scan",
+ aggFunc, aggSlot));
+ }
+ } else {
+ Set<Slot> aggSlots = aggFunc.getInputSlots();
+ Pair<Set<SlotReference>, Set<SlotReference>> splitSlots =
splitKeyValueSlots(aggSlots);
+ preAggStatus = checkAggWithKeyAndValueSlots(aggFunc,
splitSlots.first, splitSlots.second);
+ }
+ if (preAggStatus.isOff()) {
+ return preAggStatus;
+ }
+ }
+ return preAggStatus;
+ }
+
+ private Pair<Set<SlotReference>, Set<SlotReference>>
splitKeyValueSlots(Set<Slot> slots) {
+ Set<SlotReference> keySlots =
com.google.common.collect.Sets.newHashSetWithExpectedSize(slots.size());
+ Set<SlotReference> valueSlots =
com.google.common.collect.Sets.newHashSetWithExpectedSize(slots.size());
+ for (Slot slot : slots) {
+ if (slot instanceof SlotReference && ((SlotReference)
slot).getColumn().isPresent()) {
+ if (((SlotReference) slot).getColumn().get().isKey()) {
+ keySlots.add((SlotReference) slot);
+ } else {
+ valueSlots.add((SlotReference) slot);
+ }
+ }
+ }
+ return Pair.of(keySlots, valueSlots);
+ }
+
+ private PreAggStatus checkAggWithKeyAndValueSlots(AggregateFunction
aggFunc,
+ Set<SlotReference> keySlots, Set<SlotReference> valueSlots) {
+ Expression child = aggFunc.child(0);
+ List<Expression> conditionExps = new ArrayList<>();
+ List<Expression> returnExps = new ArrayList<>();
+
+ // ignore cast
+ while (child instanceof Cast) {
+ if (!((Cast) child).getDataType().isNumericType()) {
+ return PreAggStatus.off(String.format("%s is not numeric
CAST.", child.toSql()));
+ }
+ child = child.child(0);
+ }
+ // step 1: extract all condition exprs and return exprs
+ if (child instanceof If) {
+ conditionExps.add(child.child(0));
+ returnExps.add(removeCast(child.child(1)));
+ returnExps.add(removeCast(child.child(2)));
+ } else if (child instanceof CaseWhen) {
+ CaseWhen caseWhen = (CaseWhen) child;
+ // WHEN THEN
+ for (WhenClause whenClause : caseWhen.getWhenClauses()) {
+ conditionExps.add(whenClause.getOperand());
+ returnExps.add(removeCast(whenClause.getResult()));
+ }
+ // ELSE
+
returnExps.add(removeCast(caseWhen.getDefaultValue().orElse(new
NullLiteral())));
+ } else {
+ // currently, only IF and CASE WHEN are supported
+ returnExps.add(removeCast(child));
+ }
+
+ // step 2: check condition expressions
+ Set<Slot> inputSlots =
ExpressionUtils.getInputSlotSet(conditionExps);
+ if (!keySlots.containsAll(inputSlots)) {
+ return PreAggStatus
+ .off(String.format("some columns in condition %s is
not key.", conditionExps));
+ }
+
+ return KeyAndValueSlotsAggChecker.INSTANCE.check(aggFunc,
returnExps);
+ }
+
+ private static Expression removeCast(Expression expression) {
+ while (expression instanceof Cast) {
+ expression = ((Cast) expression).child();
+ }
+ return expression;
+ }
+
+ private static class OneValueSlotAggChecker
+ extends ExpressionVisitor<PreAggStatus, AggregateType> {
+ public static final OneValueSlotAggChecker INSTANCE = new
OneValueSlotAggChecker();
+
+ public PreAggStatus check(AggregateFunction aggFun, AggregateType
aggregateType) {
+ return aggFun.accept(INSTANCE, aggregateType);
+ }
+
+ @Override
+ public PreAggStatus visit(Expression expr, AggregateType
aggregateType) {
+ return PreAggStatus.off(String.format("%s is not aggregate
function.", expr.toSql()));
+ }
+
+ @Override
+ public PreAggStatus visitAggregateFunction(AggregateFunction
aggregateFunction,
+ AggregateType aggregateType) {
+ return PreAggStatus
+ .off(String.format("%s is not supported.",
aggregateFunction.toSql()));
+ }
+
+ @Override
+ public PreAggStatus visitMax(Max max, AggregateType aggregateType)
{
+ if (aggregateType == AggregateType.MAX && !max.isDistinct()) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus
+ .off(String.format("%s is not match agg mode %s or
has distinct param",
+ max.toSql(), aggregateType));
+ }
+ }
+
+ @Override
+ public PreAggStatus visitMin(Min min, AggregateType aggregateType)
{
+ if (aggregateType == AggregateType.MIN && !min.isDistinct()) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus
+ .off(String.format("%s is not match agg mode %s or
has distinct param",
+ min.toSql(), aggregateType));
+ }
+ }
+
+ @Override
+ public PreAggStatus visitSum(Sum sum, AggregateType aggregateType)
{
+ if (aggregateType == AggregateType.SUM && !sum.isDistinct()) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus
+ .off(String.format("%s is not match agg mode %s or
has distinct param",
+ sum.toSql(), aggregateType));
+ }
+ }
+
+ @Override
+ public PreAggStatus visitBitmapUnionCount(BitmapUnionCount
bitmapUnionCount,
+ AggregateType aggregateType) {
+ if (aggregateType == AggregateType.BITMAP_UNION) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus.off("invalid bitmap_union_count: " +
bitmapUnionCount.toSql());
+ }
+ }
+
+ @Override
+ public PreAggStatus visitBitmapUnion(BitmapUnion bitmapUnion,
AggregateType aggregateType) {
+ if (aggregateType == AggregateType.BITMAP_UNION) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus.off("invalid bitmapUnion: " +
bitmapUnion.toSql());
+ }
+ }
+
+ @Override
+ public PreAggStatus visitHllUnionAgg(HllUnionAgg hllUnionAgg,
AggregateType aggregateType) {
+ if (aggregateType == AggregateType.HLL_UNION) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus.off("invalid hllUnionAgg: " +
hllUnionAgg.toSql());
+ }
+ }
+
+ @Override
+ public PreAggStatus visitHllUnion(HllUnion hllUnion, AggregateType
aggregateType) {
+ if (aggregateType == AggregateType.HLL_UNION) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus.off("invalid hllUnion: " +
hllUnion.toSql());
+ }
+ }
+ }
+
+ private static class OneKeySlotAggChecker extends
ExpressionVisitor<PreAggStatus, Void> {
+ public static final OneKeySlotAggChecker INSTANCE = new
OneKeySlotAggChecker();
+
+ public PreAggStatus check(AggregateFunction aggFun) {
+ return aggFun.accept(INSTANCE, null);
+ }
+
+ @Override
+ public PreAggStatus visit(Expression expr, Void context) {
+ return PreAggStatus.off(String.format("%s is not aggregate
function.", expr.toSql()));
+ }
+
+ @Override
+ public PreAggStatus visitAggregateFunction(AggregateFunction
aggregateFunction,
+ Void context) {
+ if (aggregateFunction.isDistinct()) {
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus.off(String.format("%s is not
distinct.", aggregateFunction.toSql()));
+ }
+ }
+
+ @Override
+ public PreAggStatus visitMax(Max max, Void context) {
+ return PreAggStatus.on();
+ }
+
+ @Override
+ public PreAggStatus visitMin(Min min, Void context) {
+ return PreAggStatus.on();
+ }
+ }
+
+ private static class KeyAndValueSlotsAggChecker
+ extends ExpressionVisitor<PreAggStatus, List<Expression>> {
+ public static final KeyAndValueSlotsAggChecker INSTANCE = new
KeyAndValueSlotsAggChecker();
+
+ public PreAggStatus check(AggregateFunction aggFun,
List<Expression> returnValues) {
+ return aggFun.accept(INSTANCE, returnValues);
+ }
+
+ @Override
+ public PreAggStatus visit(Expression expr, List<Expression>
returnValues) {
+ return PreAggStatus.off(String.format("%s is not aggregate
function.", expr.toSql()));
+ }
+
+ @Override
+ public PreAggStatus visitAggregateFunction(AggregateFunction
aggregateFunction,
+ List<Expression> returnValues) {
+ return PreAggStatus
+ .off(String.format("%s is not supported.",
aggregateFunction.toSql()));
+ }
+
+ @Override
+ public PreAggStatus visitSum(Sum sum, List<Expression>
returnValues) {
+ for (Expression value : returnValues) {
+ if (!(isAggTypeMatched(value, AggregateType.SUM) ||
value.isZeroLiteral()
+ || value.isNullLiteral())) {
+ return PreAggStatus.off(String.format("%s is not
supported.", sum.toSql()));
+ }
+ }
+ return PreAggStatus.on();
+ }
+
+ @Override
+ public PreAggStatus visitMax(Max max, List<Expression>
returnValues) {
+ for (Expression value : returnValues) {
+ if (!(isAggTypeMatched(value, AggregateType.MAX) ||
isKeySlot(value)
+ || value.isNullLiteral())) {
+ return PreAggStatus.off(String.format("%s is not
supported.", max.toSql()));
+ }
+ }
+ return PreAggStatus.on();
+ }
+
+ @Override
+ public PreAggStatus visitMin(Min min, List<Expression>
returnValues) {
+ for (Expression value : returnValues) {
+ if (!(isAggTypeMatched(value, AggregateType.MIN) ||
isKeySlot(value)
+ || value.isNullLiteral())) {
+ return PreAggStatus.off(String.format("%s is not
supported.", min.toSql()));
+ }
+ }
+ return PreAggStatus.on();
+ }
+
+ @Override
+ public PreAggStatus visitCount(Count count, List<Expression>
returnValues) {
+ if (count.isDistinct()) {
+ for (Expression value : returnValues) {
+ if (!(isKeySlot(value) || value.isZeroLiteral() ||
value.isNullLiteral())) {
+ return PreAggStatus
+ .off(String.format("%s is not supported.",
count.toSql()));
+ }
+ }
+ return PreAggStatus.on();
+ } else {
+ return PreAggStatus.off(String.format("%s is not
supported.", count.toSql()));
+ }
+ }
+
+ private boolean isKeySlot(Expression expression) {
+ return expression instanceof SlotReference
+ && ((SlotReference) expression).getColumn().isPresent()
+ && ((SlotReference)
expression).getColumn().get().isKey();
+ }
+
+ private boolean isAggTypeMatched(Expression expression,
AggregateType aggregateType) {
+ return expression instanceof SlotReference
+ && ((SlotReference) expression).getColumn().isPresent()
+ && ((SlotReference) expression).getColumn().get()
+ .getAggregationType() == aggregateType;
+ }
+ }
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java
index ecced02bf35..a67a5f4dd95 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectRollupIndexTest.java
@@ -19,9 +19,9 @@ package org.apache.doris.nereids.rules.rewrite.mv;
import org.apache.doris.common.FeConstants;
import
org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
-import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject;
+import org.apache.doris.nereids.rules.rewrite.SetPreAggStatus;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
@@ -112,7 +112,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
PlanChecker.from(connectContext)
.analyze(" select k1, sum(v1) from t group by k1")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
Assertions.assertTrue(scan.getPreAggStatus().isOn());
Assertions.assertEquals("t",
scan.getSelectedMaterializedIndexName().get());
@@ -125,7 +125,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
PlanChecker.from(connectContext)
.analyze("select k2, sum(v1) from t where k3=0 group by k2")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
Assertions.assertTrue(scan.getPreAggStatus().isOn());
Assertions.assertEquals("r2",
scan.getSelectedMaterializedIndexName().get());
@@ -153,7 +153,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
PlanChecker.from(connectContext)
.analyze("select k2, sum(v1) from t where k3=0 group by k2")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
Assertions.assertTrue(scan.getPreAggStatus().isOn());
Assertions.assertEquals("r2",
scan.getSelectedMaterializedIndexName().get());
@@ -166,7 +166,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
PlanChecker.from(connectContext)
.analyze("select k2, sum(v1) from t where k3>0 group by k2")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
Assertions.assertTrue(scan.getPreAggStatus().isOn());
Assertions.assertEquals("r2",
scan.getSelectedMaterializedIndexName().get());
@@ -179,7 +179,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
PlanChecker.from(connectContext)
.analyze("select k2, sum(v1) from t where k2>3 group by k2")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
Assertions.assertTrue(scan.getPreAggStatus().isOn());
Assertions.assertEquals("r1",
scan.getSelectedMaterializedIndexName().get());
@@ -199,7 +199,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.applyBottomUp(new MergeProjects())
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
Assertions.assertTrue(scan.getPreAggStatus().isOn());
Assertions.assertEquals("r2",
scan.getSelectedMaterializedIndexName().get());
@@ -217,11 +217,11 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, v1 from t")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOff());
- Assertions.assertEquals("No aggregate on scan.",
preAgg.getOffReason());
+ Assertions.assertEquals("No valid aggregate on scan.",
preAgg.getOffReason());
return true;
}));
}
@@ -232,7 +232,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, min(v1) from t group by k1")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOff());
@@ -247,7 +247,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, sum(v1 + 1) from t group by k1")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOff());
@@ -263,11 +263,11 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, sum(k2) from t group by k1")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOff());
- Assertions.assertEquals("Aggregate function sum(k2)
contains key column k2",
+ Assertions.assertEquals("sum(k2) is not distinct.",
preAgg.getOffReason());
return true;
}));
@@ -279,7 +279,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k2, max(k3) from t group by k2")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOn());
@@ -294,7 +294,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k2, min(k3) from t group by k2")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOn());
@@ -309,7 +309,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, min(k2), max(k2) from t group by k1")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOn());
@@ -325,7 +325,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
.applyTopDown(new MergeProjects())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOn());
@@ -340,7 +340,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, sum(k1) from duplicate_tbl group by k1")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOn());
@@ -354,7 +354,7 @@ class SelectRollupIndexTest extends
BaseMaterializedIndexSelectTest implements M
.analyze("select k1, v1 from duplicate_tbl")
.applyTopDown(new SelectMaterializedIndexWithAggregate())
.applyTopDown(new SelectMaterializedIndexWithoutAggregate())
- .applyTopDown(new AdjustPreAggStatus())
+ .customRewrite(new SetPreAggStatus())
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOn());
diff --git
a/regression-test/suites/nereids_rules_p0/set_preagg/set_preagg.groovy
b/regression-test/suites/nereids_rules_p0/set_preagg/set_preagg.groovy
new file mode 100644
index 00000000000..106f05b8f13
--- /dev/null
+++ b/regression-test/suites/nereids_rules_p0/set_preagg/set_preagg.groovy
@@ -0,0 +1,312 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("set_preagg") {
+ multi_sql """
+ set disable_nereids_rules='PRUNE_EMPTY_PARTITION';
+ set forbid_unknown_col_stats=false;
+ set enable_stats=false;
+ drop table if exists preagg_t1;
+ drop table if exists preagg_t2;
+ drop table if exists preagg_t3;
+
+ create table preagg_t1(
+ k1 int null,
+ k2 int null,
+ k3 int null,
+ k4 int null,
+ k5 int null,
+ k6 int null,
+ v7 bigint SUM,
+ v8 bigint SUM,
+ v9 bigint MAX
+ )
+ aggregate key (k1,k2,k3,k4,k5,k6)
+ distributed BY hash(k1) buckets 3
+ properties("replication_num" = "1");
+
+ create table preagg_t2(
+ k1 int null,
+ k2 int null,
+ k3 int null,
+ k4 int null,
+ k5 int null,
+ k6 int null,
+ v7 bigint SUM,
+ v8 bigint SUM,
+ v9 bigint MAX
+ )
+ aggregate key (k1,k2,k3,k4,k5,k6)
+ distributed BY hash(k1) buckets 3
+ properties("replication_num" = "1");
+ create table preagg_t3(
+ k1 int null,
+ k2 int null,
+ k3 int null,
+ k4 int null,
+ k5 int null,
+ k6 int null,
+ v7 bigint SUM,
+ v8 bigint SUM,
+ v9 bigint MAX
+ )
+ aggregate key (k1,k2,k3,k4,k5,k6)
+ distributed BY hash(k1) buckets 3
+ properties("replication_num" = "1");
+ """
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, sum(t12.v1), max(preagg_t3.v9)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
sum(ta1.t1_sum_v7) v1, sum(ta2.t2_sum_v7) v2
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, sum(v7) t2_sum_v7 from
preagg_t2 group by k1, k2, k3, k4, k5) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: ON"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(preagg_t3.v9)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
max(ta1.t1_sum_v7) v1, sum(ta2.t2_sum_v7) v2
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, sum(v7) t2_sum_v7 from
preagg_t2 group by k1, k2, k3, k4, k5) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ notContains "PREAGGREGATION: OFF"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(t12.v2), max(preagg_t3.v9),
sum(t12.v3)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
max(ta1.t1_sum_v7) v1, max(ta2.k4) v2, count(distinct ta2.k5) v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v7 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: ON"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, sum(t12.v2), max(preagg_t3.v9)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
max(ta1.t1_sum_v7) v1, max(ta2.v7) v2
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v7 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: OFF. Reason: max(v7) is not
match agg mode SUM"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(t12.v2), max(preagg_t3.v9),
sum(t12.v3)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4, max(case
when ta2.k1 > 0 then ta2.v9 when ta2.k1 = 0 then null when ta2.k1 < 0 then
ta2.v9 else null end) v2, count(distinct ta2.k5) v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v7, v8, v9 from preagg_t2) as
ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: ON"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(t12.v2), max(preagg_t3.v9),
sum(t12.v3)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4, sum(case
when ta2.k1 > 0 then ta2.v7 when ta2.k1 = 0 then 0 when ta2.k1 < 0 then ta2.v8
else 0 end) v2, count(distinct ta2.k5) v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v7, v8 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: ON"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, sum(t12.v2), max(preagg_t3.v9)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
max(ta1.t1_sum_v7) v1, sum(ta2.v7) v2
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v7 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: ON"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(t12.v2), max(preagg_t3.v9),
min(t12.v3)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
max(ta1.t1_sum_v7) v1, count(distinct ta2.k4) v2, count(distinct ta2.k5) v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ left join
+ (select k1, k2, k3, k4, k5, v7, v8 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ notContains "PREAGGREGATION: OFF"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(t12.v2), max(preagg_t3.v9),
sum(t12.v3)
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
max(ta1.t1_sum_v7) v1, count(case when ta2.k1 > 0 then ta2.v7 when ta2.k1 = 0
then 0 when ta1.k1 < 0 then ta2.v8 else 0 end) v2, sum(ta2.v7) v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ left join
+ (select k1, k2, k3, k4, k5, v7, v8 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ group by k1, k2, k3, k4
+ ) t12 inner join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: OFF. Reason: count("
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(t12.v2), max(preagg_t3.v9),
count(distinct t12.v3), count(distinct t12.k4) v3
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
ta1.t1_sum_v7 v1, ta2.v9 v2, ta2.k5 v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v9 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ ) t12 right join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ notContains "PREAGGREGATION: OFF"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, max(preagg_t3.v9), count(distinct
t12.v3), count(distinct t12.k4) v3
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
ta1.t1_sum_v7 v1, ta1.k5 v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v9 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ ) t12 right join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ notContains "PREAGGREGATION: OFF"
+ }
+
+ explain {
+ sql("""
+ select preagg_t3.k2, t12.k2, sum(t12.v1), max(preagg_t3.v9),
count(distinct t12.v3), count(distinct t12.k4) v3
+ from
+ (
+ select ta1.k1 k1, ta1.k2 k2, ta2.k1 k3, ta2.k2 k4,
ta1.t1_sum_v7 v1, ta1.k5 v3
+ from
+ (select k1, k2, k3, k4, k5, sum(v7) t1_sum_v7 from
preagg_t1 group by k1, k2, k3, k4, k5) as ta1
+ inner join
+ (select k1, k2, k3, k4, k5, v9 from preagg_t2) as ta2
+ on ta1.k3 = ta2.k3
+ ) t12 right join preagg_t3 on t12.k1 = preagg_t3.k1
+ group by preagg_t3.k2, t12.k2
+ order by 1, 2;
+ """)
+ contains "(preagg_t1), PREAGGREGATION: ON"
+ contains "(preagg_t2), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ contains "(preagg_t3), PREAGGREGATION: OFF. Reason: can't turn preAgg
on because aggregate function sum"
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]