This is an automated email from the ASF dual-hosted git repository.
morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 5db57341993 [Enhancement](mv): Improve MV predicate compensation and
keep original min-max predicates non-inferred (#61345)
5db57341993 is described below
commit 5db57341993609fab70e090115818ce4bee51799
Author: foxtail463 <[email protected]>
AuthorDate: Mon Jun 1 16:04:36 2026 +0800
[Enhancement](mv): Improve MV predicate compensation and keep original
min-max predicates non-inferred (#61345)
Problem Summary:
The old residual compensation simply checked whether the query residual
set contained all view residual predicates. This breaks on real-world MV
rewrites where query and view residuals are structurally different but
logically implicative, for example:
query residual: A OR (B AND C)
view residual : A OR B
The old code sees two different expression trees and bails out, even
though the query side is strictly stronger.
This patch introduces DNF-based implication checking (impliesByDnf) to
replace the set-containment approach, so compensation succeeds whenever
the query candidates logically imply the view residual regardless of
structural differences. A hard cap (MAX_DNF_BRANCHES=1024) guards
against exponential expansion; when the proof is too expensive,
compensation fails conservatively rather than hanging the optimizer.
This patch also fixes predicate provenance in AddMinMax. AddMinMax may
derive min/max predicates and then move equivalent boundary predicates
from the original expression into the generated min/max list. If the
boundary predicate already existed in the original SQL, it must remain
non-inferred even after being moved; otherwise MV compensation may later
filter it out as an inferred predicate and lose a real query boundary.
Purely generated min/max predicates are still marked as inferred.
The three separate compensate calls in AbstractMaterializedViewRule are
collapsed into a single Predicates.compensatePredicates entry point that
encapsulates candidate collection and residual finalization.
---------
Co-authored-by: yangtao555 <[email protected]>
---
.../mv/AbstractMaterializedViewRule.java | 20 +-
.../nereids/rules/exploration/mv/Predicates.java | 525 ++++++++++++++++++---
.../nereids/rules/expression/rules/AddMinMax.java | 25 +-
.../nereids/trees/expressions/Expression.java | 3 +
.../apache/doris/nereids/mv/PredicatesTest.java | 389 ++++++++++++---
.../rules/expression/ExpressionRewriteTest.java | 61 ++-
.../mv/negative/negative_test.groovy | 7 +-
7 files changed, 895 insertions(+), 135 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java
index ef94ad458bf..38c4fade7ff 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java
@@ -35,6 +35,7 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory;
import org.apache.doris.nereids.rules.exploration.mv.Predicates.ExpressionInfo;
+import
org.apache.doris.nereids.rules.exploration.mv.Predicates.PredicateCompensation;
import org.apache.doris.nereids.rules.exploration.mv.Predicates.SplitPredicate;
import
org.apache.doris.nereids.rules.exploration.mv.StructInfo.PartitionRemover;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
@@ -854,22 +855,13 @@ public abstract class AbstractMaterializedViewRule
implements ExplorationRuleFac
if (couldNotPulledUpCompensateConjunctions == null) {
return SplitPredicate.INVALID_INSTANCE;
}
- // viewEquivalenceClass to query based
- // equal predicate compensate
- final Map<Expression, ExpressionInfo> equalCompensateConjunctions =
Predicates.compensateEquivalence(
- queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
comparisonResult);
- // range compensate
- final Map<Expression, ExpressionInfo> rangeCompensatePredicates =
- Predicates.compensateRangePredicate(queryStructInfo,
viewStructInfo, viewToQuerySlotMapping,
- comparisonResult, cascadesContext);
- // residual compensate
- final Map<Expression, ExpressionInfo> residualCompensatePredicates =
Predicates.compensateResidualPredicate(
- queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
comparisonResult);
- if (equalCompensateConjunctions == null || rangeCompensatePredicates
== null
- || residualCompensatePredicates == null) {
+ PredicateCompensation finalPredicateCompensation =
+ Predicates.compensatePredicates(queryStructInfo,
viewStructInfo,
+ viewToQuerySlotMapping, comparisonResult,
cascadesContext);
+ if (finalPredicateCompensation == null) {
return SplitPredicate.INVALID_INSTANCE;
}
- return SplitPredicate.of(equalCompensateConjunctions,
rangeCompensatePredicates, residualCompensatePredicates);
+ return finalPredicateCompensation.toSplitPredicate();
}
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/Predicates.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/Predicates.java
index bbd9f6181da..f22429dc08a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/Predicates.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/Predicates.java
@@ -23,14 +23,20 @@ import
org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
+import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
@@ -45,10 +51,13 @@ import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* This record the predicates which can be pulled up or some other type
predicates.
@@ -56,6 +65,12 @@ import java.util.Set;
*/
public class Predicates {
+ // Guard DNF expansion from exponential branch blow-up.
+ private static final int MAX_DNF_BRANCHES = 1024;
+ private static final List<PredicateImplicationRule>
PREDICATE_IMPLICATION_RULES = ImmutableList.of(
+ new SameExpressionImplicationRule(),
+ new ComparablePredicateImplicationRule());
+
// Predicates that can be pulled up
private final Set<Expression> pulledUpPredicates;
// Predicates that can not be pulled up, should be equals between query
and view
@@ -117,28 +132,28 @@ public class Predicates {
List<? extends Expression> viewPredicatesShuttled =
ExpressionUtils.shuttleExpressionWithLineage(
Lists.newArrayList(viewStructInfoPredicates.getCouldNotPulledUpPredicates()),
viewStructInfo.getTopPlan());
- List<Expression> viewPredicatesQueryBased =
ExpressionUtils.replace((List<Expression>) viewPredicatesShuttled,
+ List<Expression> queryBasedViewPredicates =
ExpressionUtils.replace((List<Expression>) viewPredicatesShuttled,
viewToQuerySlotMapping.toSlotReferenceMap());
// could not be pulled up predicates in query and view should be same
if (queryStructInfoPredicates.getCouldNotPulledUpPredicates().equals(
- Sets.newHashSet(viewPredicatesQueryBased))) {
+ Sets.newHashSet(queryBasedViewPredicates))) {
return ImmutableMap.of();
}
return null;
}
/**
- * compensate equivalence predicates
+ * collect equivalence candidates from query/view difference.
*/
- public static Map<Expression, ExpressionInfo>
compensateEquivalence(StructInfo queryStructInfo,
+ public static Map<Expression, ExpressionInfo>
collectEquivalenceCandidates(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult) {
EquivalenceClass queryEquivalenceClass =
queryStructInfo.getEquivalenceClass();
EquivalenceClass viewEquivalenceClass =
viewStructInfo.getEquivalenceClass();
Map<SlotReference, SlotReference> viewToQuerySlotMap =
viewToQuerySlotMapping.toSlotReferenceMap();
- EquivalenceClass viewEquivalenceClassQueryBased =
viewEquivalenceClass.permute(viewToQuerySlotMap);
- if (viewEquivalenceClassQueryBased == null) {
+ EquivalenceClass queryBasedViewEquivalenceClass =
viewEquivalenceClass.permute(viewToQuerySlotMap);
+ if (queryBasedViewEquivalenceClass == null) {
return null;
}
final Map<Expression, ExpressionInfo> equalCompensateConjunctions =
new HashMap<>();
@@ -149,7 +164,7 @@ public class Predicates {
return null;
}
EquivalenceClassMapping queryToViewEquivalenceMapping =
- EquivalenceClassMapping.generate(queryEquivalenceClass,
viewEquivalenceClassQueryBased);
+ EquivalenceClassMapping.generate(queryEquivalenceClass,
queryBasedViewEquivalenceClass);
// can not map all target equivalence class, can not compensate
if (queryToViewEquivalenceMapping.getEquivalenceClassSetMap().size()
< viewEquivalenceClass.getEquivalenceSetList().size()) {
@@ -191,9 +206,9 @@ public class Predicates {
}
/**
- * compensate range predicates
+ * collect range candidates from query/view difference.
*/
- public static Map<Expression, ExpressionInfo>
compensateRangePredicate(StructInfo queryStructInfo,
+ public static Map<Expression, ExpressionInfo>
collectRangeCandidates(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult,
@@ -201,19 +216,14 @@ public class Predicates {
SplitPredicate querySplitPredicate =
queryStructInfo.getSplitPredicate();
SplitPredicate viewSplitPredicate = viewStructInfo.getSplitPredicate();
- Set<Expression> viewRangeQueryBasedSet = new HashSet<>();
- for (Expression viewExpression :
viewSplitPredicate.getRangePredicateMap().keySet()) {
- viewRangeQueryBasedSet.add(
- ExpressionUtils.replace(viewExpression,
viewToQuerySlotMapping.toSlotReferenceMap()));
- }
- viewRangeQueryBasedSet.remove(BooleanLiteral.TRUE);
-
- Set<Expression> queryRangeSet =
querySplitPredicate.getRangePredicateMap().keySet();
- queryRangeSet.remove(BooleanLiteral.TRUE);
+ Set<Expression> queryBasedViewRangeSet =
collectNonInferredQueryBasedExpressions(
+ viewSplitPredicate.getRangePredicateMap().keySet(),
viewToQuerySlotMapping);
+ Set<Expression> queryRangeSet = collectNonInferredExpressions(
+ querySplitPredicate.getRangePredicateMap().keySet());
Set<Expression> differentExpressions = new HashSet<>();
- Sets.difference(queryRangeSet,
viewRangeQueryBasedSet).copyInto(differentExpressions);
- Sets.difference(viewRangeQueryBasedSet,
queryRangeSet).copyInto(differentExpressions);
+ Sets.difference(queryRangeSet,
queryBasedViewRangeSet).copyInto(differentExpressions);
+ Sets.difference(queryBasedViewRangeSet,
queryRangeSet).copyInto(differentExpressions);
// the range predicate in query and view is same, don't need to
compensate
if (differentExpressions.isEmpty()) {
return ImmutableMap.of();
@@ -225,24 +235,25 @@ public class Predicates {
// normalized expressions is not in query, can not compensate
return null;
}
- Map<Expression, ExpressionInfo> normalizedExpressionsWithLiteral = new
HashMap<>();
+ Map<Expression, ExpressionInfo> normalizedExpressionsWithLiteral = new
LinkedHashMap<>();
for (Expression expression : normalizedExpressions) {
- Set<Literal> literalSet = expression.collect(expressionTreeNode ->
expressionTreeNode instanceof Literal);
- if (!(expression instanceof ComparisonPredicate)
- || (expression instanceof GreaterThan || expression
instanceof LessThanEqual)
- || literalSet.size() != 1) {
- if (expression.anyMatch(AggregateFunction.class::isInstance)) {
- return null;
- }
- normalizedExpressionsWithLiteral.put(expression,
ExpressionInfo.EMPTY);
- continue;
- }
if (expression.anyMatch(AggregateFunction.class::isInstance)) {
return null;
}
- normalizedExpressionsWithLiteral.put(expression, new
ExpressionInfo(literalSet.iterator().next()));
+ normalizedExpressionsWithLiteral.put(expression,
buildRangeExpressionInfo(expression));
+ }
+ return ImmutableMap.copyOf(normalizedExpressionsWithLiteral);
+ }
+
+ private static ExpressionInfo buildRangeExpressionInfo(Expression
expression) {
+ Set<Literal> literalSet = expression.collect(expressionTreeNode ->
expressionTreeNode instanceof Literal);
+ if (!(expression instanceof ComparisonPredicate)
+ || expression instanceof GreaterThan
+ || expression instanceof LessThanEqual
+ || literalSet.size() != 1) {
+ return ExpressionInfo.EMPTY;
}
- return normalizedExpressionsWithLiteral;
+ return new ExpressionInfo(literalSet.iterator().next());
}
private static Set<Expression> normalizeExpression(Expression expression,
CascadesContext cascadesContext) {
@@ -254,41 +265,402 @@ public class Predicates {
return ExpressionUtils.extractConjunctionToSet(expression);
}
- /**
- * compensate residual predicates
- */
- public static Map<Expression, ExpressionInfo>
compensateResidualPredicate(StructInfo queryStructInfo,
+ /** Collect all predicate compensation candidates before residual
finalization. */
+ public static PredicateCompensation collectCompensationCandidates(
+ StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
- ComparisonResult comparisonResult) {
- // TODO Residual predicates compensate, simplify implementation
currently.
- SplitPredicate querySplitPredicate =
queryStructInfo.getSplitPredicate();
- SplitPredicate viewSplitPredicate = viewStructInfo.getSplitPredicate();
+ ComparisonResult comparisonResult,
+ CascadesContext cascadesContext) {
+ Map<Expression, ExpressionInfo> equalCandidates =
collectEquivalenceCandidates(
+ queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
comparisonResult);
+ Map<Expression, ExpressionInfo> rangeCandidates =
collectRangeCandidates(
+ queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
comparisonResult, cascadesContext);
+ Map<Expression, ExpressionInfo> residualCandidates =
collectResidualCandidates(queryStructInfo);
+ if (equalCandidates == null || rangeCandidates == null ||
residualCandidates == null) {
+ return null;
+ }
+ return new PredicateCompensation(equalCandidates, rangeCandidates,
residualCandidates);
+ }
- Set<Expression> viewResidualQueryBasedSet = new HashSet<>();
- for (Expression viewExpression :
viewSplitPredicate.getResidualPredicateMap().keySet()) {
- viewResidualQueryBasedSet.add(
- ExpressionUtils.replace(viewExpression,
viewToQuerySlotMapping.toSlotReferenceMap()));
+ /** Compensate predicates in one step. */
+ public static PredicateCompensation compensatePredicates(
+ StructInfo queryStructInfo,
+ StructInfo viewStructInfo,
+ SlotMapping viewToQuerySlotMapping,
+ ComparisonResult comparisonResult,
+ CascadesContext cascadesContext) {
+ PredicateCompensation compensationCandidates =
collectCompensationCandidates(
+ queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
comparisonResult, cascadesContext);
+ if (compensationCandidates == null) {
+ return null;
+ }
+ Set<Expression> queryBasedViewResidualPredicates =
collectNonInferredQueryBasedExpressions(
+
viewStructInfo.getSplitPredicate().getResidualPredicateMap().keySet(),
viewToQuerySlotMapping);
+ Set<Expression> exactCoveredPredicates = new LinkedHashSet<>();
+ exactCoveredPredicates.addAll(Sets.intersection(
+ compensationCandidates.getEquals().keySet(),
queryBasedViewResidualPredicates));
+ exactCoveredPredicates.addAll(Sets.intersection(
+ compensationCandidates.getRanges().keySet(),
queryBasedViewResidualPredicates));
+ exactCoveredPredicates.addAll(Sets.intersection(
+ compensationCandidates.getResiduals().keySet(),
queryBasedViewResidualPredicates));
+ Set<Expression> remainingQueryBasedViewResidualPredicates = new
LinkedHashSet<>(
+ Sets.difference(queryBasedViewResidualPredicates,
exactCoveredPredicates));
+
+ // Exact-covered predicates are enforced by both query and view.
Preserve that fast path before
+ // proving implication for the remaining residuals, because DNF
expansion is only needed for
+ // non-exact implication.
+ PredicateCompensation exactPrunedCompensationCandidates = new
PredicateCompensation(
+
removeExactCoveredPredicates(compensationCandidates.getEquals(),
exactCoveredPredicates),
+
removeExactCoveredPredicates(compensationCandidates.getRanges(),
exactCoveredPredicates),
+
removeExactCoveredPredicates(compensationCandidates.getResiduals(),
exactCoveredPredicates));
+
+ Expression combinedCompensationCandidates = buildCombinedPredicate(
+ exactPrunedCompensationCandidates.getEquals().keySet(),
+ exactPrunedCompensationCandidates.getRanges().keySet(),
+ exactPrunedCompensationCandidates.getResiduals().keySet());
+ Expression combinedQueryBasedViewResidual =
buildCombinedPredicate(remainingQueryBasedViewResidualPredicates);
+ if (BooleanLiteral.TRUE.equals(combinedQueryBasedViewResidual)) {
+ // The target residual is TRUE, so implication always holds and
DNF expansion is unnecessary.
+ return
rejectUnsafeResidualCompensation(exactPrunedCompensationCandidates);
}
- viewResidualQueryBasedSet.remove(BooleanLiteral.TRUE);
- Set<Expression> queryResidualSet =
querySplitPredicate.getResidualPredicateMap().keySet();
- // remove unnecessary literal BooleanLiteral.TRUE
- queryResidualSet.remove(BooleanLiteral.TRUE);
- // query residual predicate can not contain all view residual
predicate when view have residual predicate,
- // bail out
- if (!queryResidualSet.containsAll(viewResidualQueryBasedSet)) {
+ try {
+ // The compensation must not widen the view result:
+ // compensationCandidates => combinedQueryBasedViewResidual.
+ if (!impliesByDnf(combinedCompensationCandidates,
combinedQueryBasedViewResidual)) {
+ return null;
+ }
+
+ PredicateCompensation finalCompensation = new
PredicateCompensation(
+ removePredicatesImpliedByViewResidual(
+ exactPrunedCompensationCandidates.getEquals(),
combinedQueryBasedViewResidual),
+ removePredicatesImpliedByViewResidual(
+ exactPrunedCompensationCandidates.getRanges(),
combinedQueryBasedViewResidual),
+ removePredicatesImpliedByViewResidual(
+ exactPrunedCompensationCandidates.getResiduals(),
combinedQueryBasedViewResidual));
+ return rejectUnsafeResidualCompensation(finalCompensation);
+ } catch (DnfBranchOverflowException e) {
+ // DNF branch expansion may explode exponentially; fail
compensation conservatively.
return null;
}
- queryResidualSet.removeAll(viewResidualQueryBasedSet);
- Map<Expression, ExpressionInfo> expressionExpressionInfoMap = new
HashMap<>();
- for (Expression needCompensate : queryResidualSet) {
- if (needCompensate.anyMatch(AggregateFunction.class::isInstance)) {
+ }
+
+ private static Map<Expression, ExpressionInfo>
collectResidualCandidates(StructInfo queryStructInfo) {
+ Set<Expression> expressions = collectNonInferredExpressions(
+
queryStructInfo.getSplitPredicate().getResidualPredicateMap().keySet());
+ Map<Expression, ExpressionInfo> residualCandidates = new
LinkedHashMap<>();
+ for (Expression expression : expressions) {
+ residualCandidates.put(expression, ExpressionInfo.EMPTY);
+ }
+ return ImmutableMap.copyOf(residualCandidates);
+ }
+
+ private static PredicateCompensation
rejectUnsafeResidualCompensation(PredicateCompensation compensation) {
+ for (Expression expression : compensation.getResiduals().keySet()) {
+ if (expression.anyMatch(WindowExpression.class::isInstance)
+ ||
expression.anyMatch(AggregateFunction.class::isInstance)) {
+ // Aggregate and window residuals are not safe as detail-MV
compensation predicates.
return null;
}
- expressionExpressionInfoMap.put(needCompensate,
ExpressionInfo.EMPTY);
}
- return expressionExpressionInfoMap;
+ return compensation;
+ }
+
+ private static Set<Expression>
collectNonInferredExpressions(Collection<Expression> expressions) {
+ return expressions.stream()
+ .filter(expression -> !ExpressionUtils.isInferred(expression))
+ .filter(expression -> !BooleanLiteral.TRUE.equals(expression))
+ .collect(Collectors.toCollection(LinkedHashSet::new));
+ }
+
+ private static Set<Expression>
collectNonInferredQueryBasedExpressions(Collection<Expression> expressions,
+ SlotMapping viewToQuerySlotMapping) {
+ Map<SlotReference, SlotReference> slotReferenceMap =
viewToQuerySlotMapping.toSlotReferenceMap();
+ return expressions.stream()
+ .filter(expression -> !ExpressionUtils.isInferred(expression))
+ .map(expression -> ExpressionUtils.replace(expression,
slotReferenceMap))
+ .filter(expression -> !BooleanLiteral.TRUE.equals(expression))
+ .collect(Collectors.toCollection(LinkedHashSet::new));
+ }
+
+ @SafeVarargs
+ private static Expression buildCombinedPredicate(Collection<Expression>...
predicateCollections) {
+ List<Expression> combinedPredicates = new ArrayList<>();
+ for (Collection<Expression> predicateCollection :
predicateCollections) {
+ for (Expression predicate : predicateCollection) {
+ if (!BooleanLiteral.TRUE.equals(predicate)) {
+ combinedPredicates.add(predicate);
+ }
+ }
+ }
+ return ExpressionUtils.and(combinedPredicates);
+ }
+
+ private static Map<Expression, ExpressionInfo>
removePredicatesImpliedByViewResidual(
+ Map<Expression, ExpressionInfo> predicates,
+ Expression viewResidual) {
+ // Remove candidates already implied by the view residual predicate.
+ Map<Expression, ExpressionInfo> remainingPredicates = new
LinkedHashMap<>();
+ for (Map.Entry<Expression, ExpressionInfo> entry :
predicates.entrySet()) {
+ // If viewResidual => candidate, candidate is redundant and can be
dropped.
+ if (!impliesByDnf(viewResidual, entry.getKey())) {
+ remainingPredicates.put(entry.getKey(), entry.getValue());
+ }
+ }
+ return ImmutableMap.copyOf(remainingPredicates);
+ }
+
+ private static Map<Expression, ExpressionInfo>
removeExactCoveredPredicates(
+ Map<Expression, ExpressionInfo> predicates,
+ Set<Expression> exactCoveredPredicates) {
+ Map<Expression, ExpressionInfo> remainingPredicates = new
LinkedHashMap<>();
+ for (Map.Entry<Expression, ExpressionInfo> entry :
predicates.entrySet()) {
+ if (exactCoveredPredicates.contains(entry.getKey())) {
+ continue;
+ }
+ remainingPredicates.put(entry.getKey(), entry.getValue());
+ }
+ return ImmutableMap.copyOf(remainingPredicates);
+ }
+
+ private static boolean impliesByDnf(Expression source, Expression target) {
+ // Check whether source => target.
+ List<Set<Expression>> sourceBranches = extractDnfBranches(source);
+ List<Set<Expression>> targetBranches = extractDnfBranches(target);
+ if (sourceBranches.isEmpty()) {
+ return true;
+ }
+ if (targetBranches.isEmpty()) {
+ return false;
+ }
+ for (Set<Expression> sourceBranch : sourceBranches) {
+ boolean branchMatched = false;
+ for (Set<Expression> targetBranch : targetBranches) {
+ if (branchImplies(sourceBranch, targetBranch)) {
+ branchMatched = true;
+ break;
+ }
+ }
+ if (!branchMatched) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static boolean branchImplies(Set<Expression> sourceBranch,
Set<Expression> targetBranch) {
+ for (Expression targetPredicate : targetBranch) {
+ boolean predicateMatched = false;
+ for (Expression sourcePredicate : sourceBranch) {
+ if (predicateImplies(sourcePredicate, targetPredicate)) {
+ predicateMatched = true;
+ break;
+ }
+ }
+ if (!predicateMatched) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static boolean predicateImplies(Expression source, Expression
target) {
+ for (PredicateImplicationRule rule : PREDICATE_IMPLICATION_RULES) {
+ if (rule.proves(source, target)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private static ComparisonPredicate
normalizeComparisonPredicate(ComparisonPredicate predicate) {
+ if (predicate.left() instanceof Literal && !(predicate.right()
instanceof Literal)) {
+ return (ComparisonPredicate) predicate.commute();
+ }
+ return predicate;
+ }
+
+ private interface PredicateImplicationRule {
+ boolean proves(Expression source, Expression target);
+ }
+
+ private static final class SameExpressionImplicationRule implements
PredicateImplicationRule {
+ @Override
+ public boolean proves(Expression source, Expression target) {
+ return source.equals(target);
+ }
+ }
+
+ /*
+ * Prove implication between comparable predicates on the same input by
converting them to ranges.
+ * For example, "a > 20" implies "a > 10", and "a = 10" implies "a >= 10".
+ */
+ private static final class ComparablePredicateImplicationRule implements
PredicateImplicationRule {
+ @Override
+ public boolean proves(Expression source, Expression target) {
+ ComparablePredicateRange sourceRange =
ComparablePredicateRange.from(source);
+ ComparablePredicateRange targetRange =
ComparablePredicateRange.from(target);
+ return sourceRange != null && targetRange != null &&
sourceRange.implies(targetRange);
+ }
+
+ private static final class ComparablePredicateRange {
+ private final Expression input;
+ private final Bound lowerBound;
+ private final Bound upperBound;
+
+ private ComparablePredicateRange(Expression input, Bound
lowerBound, Bound upperBound) {
+ this.input = input;
+ this.lowerBound = lowerBound;
+ this.upperBound = upperBound;
+ }
+
+ private static ComparablePredicateRange from(Expression
expression) {
+ if (!(expression instanceof ComparisonPredicate)) {
+ return null;
+ }
+ ComparisonPredicate normalized =
normalizeComparisonPredicate((ComparisonPredicate) expression);
+ if (!(normalized.right() instanceof ComparableLiteral)) {
+ return null;
+ }
+ ComparableLiteral literal = (ComparableLiteral)
normalized.right();
+ if (normalized instanceof EqualTo) {
+ Bound bound = new Bound(literal, true);
+ return new ComparablePredicateRange(normalized.left(),
bound, bound);
+ }
+ if (normalized instanceof GreaterThan) {
+ return new ComparablePredicateRange(normalized.left(), new
Bound(literal, false), null);
+ }
+ if (normalized instanceof GreaterThanEqual) {
+ return new ComparablePredicateRange(normalized.left(), new
Bound(literal, true), null);
+ }
+ if (normalized instanceof LessThan) {
+ return new ComparablePredicateRange(normalized.left(),
null, new Bound(literal, false));
+ }
+ if (normalized instanceof LessThanEqual) {
+ return new ComparablePredicateRange(normalized.left(),
null, new Bound(literal, true));
+ }
+ return null;
+ }
+
+ private boolean implies(ComparablePredicateRange target) {
+ return input.equals(target.input)
+ && impliesLowerBound(target.lowerBound)
+ && impliesUpperBound(target.upperBound);
+ }
+
+ private boolean impliesLowerBound(Bound targetLowerBound) {
+ if (targetLowerBound == null) {
+ return true;
+ }
+ if (lowerBound == null) {
+ return false;
+ }
+ int compareResult =
lowerBound.literal.compareTo(targetLowerBound.literal);
+ if (compareResult != 0) {
+ return compareResult > 0;
+ }
+ return !lowerBound.inclusive || targetLowerBound.inclusive;
+ }
+
+ private boolean impliesUpperBound(Bound targetUpperBound) {
+ if (targetUpperBound == null) {
+ return true;
+ }
+ if (upperBound == null) {
+ return false;
+ }
+ int compareResult =
upperBound.literal.compareTo(targetUpperBound.literal);
+ if (compareResult != 0) {
+ return compareResult < 0;
+ }
+ return !upperBound.inclusive || targetUpperBound.inclusive;
+ }
+ }
+
+ private static final class Bound {
+ private final ComparableLiteral literal;
+ private final boolean inclusive;
+
+ private Bound(ComparableLiteral literal, boolean inclusive) {
+ this.literal = literal;
+ this.inclusive = inclusive;
+ }
+ }
+ }
+
+ /*
+ * Example:
+ * (A OR B) AND C
+ * becomes:
+ * [{A, C}, {B, C}]
+ */
+ private static List<Set<Expression>> extractDnfBranches(Expression
expression) {
+ if (BooleanLiteral.TRUE.equals(expression)) {
+ List<Set<Expression>> trueBranches = new ArrayList<>();
+ trueBranches.add(new LinkedHashSet<>());
+ return trueBranches;
+ }
+ if (BooleanLiteral.FALSE.equals(expression)) {
+ return ImmutableList.of();
+ }
+ if (expression instanceof Or) {
+ List<Set<Expression>> branches = new ArrayList<>();
+ for (Expression child :
ExpressionUtils.extractDisjunction(expression)) {
+ List<Set<Expression>> childBranches =
extractDnfBranches(child);
+ long expectedSize = (long) branches.size() +
childBranches.size();
+ if (expectedSize > MAX_DNF_BRANCHES) {
+ throw DnfBranchOverflowException.INSTANCE;
+ }
+ branches.addAll(childBranches);
+ }
+ return branches;
+ }
+ if (expression instanceof And) {
+ List<Set<Expression>> branches = new ArrayList<>();
+ branches.add(new LinkedHashSet<>());
+ for (Expression child :
ExpressionUtils.extractConjunction(expression)) {
+ List<Set<Expression>> childBranches =
extractDnfBranches(child);
+ branches = crossProductBranches(branches, childBranches);
+ if (branches.isEmpty()) {
+ return branches;
+ }
+ }
+ return branches;
+ }
+ List<Set<Expression>> branches = new ArrayList<>();
+ Set<Expression> branch = new LinkedHashSet<>();
+ branch.add(expression);
+ branches.add(branch);
+ return branches;
+ }
+
+ private static List<Set<Expression>>
crossProductBranches(List<Set<Expression>> leftBranches,
+ List<Set<Expression>> rightBranches) {
+ if (leftBranches.isEmpty() || rightBranches.isEmpty()) {
+ return ImmutableList.of();
+ }
+ long expectedSize = (long) leftBranches.size() * rightBranches.size();
+ if (expectedSize > MAX_DNF_BRANCHES) {
+ throw DnfBranchOverflowException.INSTANCE;
+ }
+ List<Set<Expression>> mergedBranches = new ArrayList<>((int)
expectedSize);
+ for (Set<Expression> leftBranch : leftBranches) {
+ for (Set<Expression> rightBranch : rightBranches) {
+ Set<Expression> mergedBranch = new LinkedHashSet<>(leftBranch);
+ mergedBranch.addAll(rightBranch);
+ mergedBranches.add(mergedBranch);
+ }
+ }
+ return mergedBranches;
+ }
+
+ private static final class DnfBranchOverflowException extends
RuntimeException {
+ private static final DnfBranchOverflowException INSTANCE = new
DnfBranchOverflowException();
+
+ private DnfBranchOverflowException() {
+ super(null, null, true, false);
+ }
}
@Override
@@ -311,6 +683,43 @@ public class Predicates {
}
}
+ /** Predicate compensation result holding equals, ranges and residuals. */
+ public static final class PredicateCompensation {
+ private final Map<Expression, ExpressionInfo> equals;
+ private final Map<Expression, ExpressionInfo> ranges;
+ private final Map<Expression, ExpressionInfo> residuals;
+
+ public PredicateCompensation(Map<Expression, ExpressionInfo> equals,
+ Map<Expression, ExpressionInfo> ranges,
+ Map<Expression, ExpressionInfo> residuals) {
+ this.equals = ImmutableMap.copyOf(equals);
+ this.ranges = ImmutableMap.copyOf(ranges);
+ this.residuals = ImmutableMap.copyOf(residuals);
+ }
+
+ public Map<Expression, ExpressionInfo> getEquals() {
+ return equals;
+ }
+
+ public Map<Expression, ExpressionInfo> getRanges() {
+ return ranges;
+ }
+
+ public Map<Expression, ExpressionInfo> getResiduals() {
+ return residuals;
+ }
+
+ public SplitPredicate toSplitPredicate() {
+ return SplitPredicate.of(equals, ranges, residuals);
+ }
+
+ @Override
+ public String toString() {
+ return Utils.toSqlString("PredicateCompensation",
+ "equals", equals, "ranges", ranges, "residuals",
residuals);
+ }
+ }
+
/**
* The split different representation for predicate expression, such as
equal, range and residual predicate.
*/
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
index bf42a796846..3317a2deec4 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
@@ -54,6 +54,8 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
+import java.util.ArrayDeque;
+import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -177,6 +179,7 @@ public class AddMinMax implements
ExpressionPatternRuleFactory, ValueDescVisitor
List<Map.Entry<Expression, MinMaxValue>> minMaxExprs =
exprMinMaxValues.entrySet().stream()
.sorted((a, b) -> Integer.compare(a.getValue().exprOrderIndex,
b.getValue().exprOrderIndex))
.collect(Collectors.toList());
+ Set<Expression> nonInferredOriginPredicates =
collectNonInferredPredicates(expr);
List<Expression> addExprs =
Lists.newArrayListWithExpectedSize(minMaxExprs.size() * 2);
for (Map.Entry<Expression, MinMaxValue> entry : minMaxExprs) {
Expression targetExpr = entry.getKey();
@@ -186,6 +189,7 @@ public class AddMinMax implements
ExpressionPatternRuleFactory, ValueDescVisitor
&& range.lowerBoundType() == BoundType.CLOSED
&& range.upperBoundType() == BoundType.CLOSED) {
Expression cmp = new EqualTo(targetExpr, (Literal)
range.lowerEndpoint());
+ cmp =
cmp.withInferred(!nonInferredOriginPredicates.contains(cmp));
addExprs.add(cmp);
continue;
}
@@ -194,6 +198,7 @@ public class AddMinMax implements
ExpressionPatternRuleFactory, ValueDescVisitor
Expression cmp = range.lowerBoundType() == BoundType.CLOSED
? new GreaterThanEqual(targetExpr, (Literal) literal)
: new GreaterThan(targetExpr, (Literal) literal);
+ cmp =
cmp.withInferred(!nonInferredOriginPredicates.contains(cmp));
addExprs.add(cmp);
}
if (range.hasUpperBound()) {
@@ -201,6 +206,7 @@ public class AddMinMax implements
ExpressionPatternRuleFactory, ValueDescVisitor
Expression cmp = range.upperBoundType() == BoundType.CLOSED
? new LessThanEqual(targetExpr, (Literal) literal)
: new LessThan(targetExpr, (Literal) literal);
+ cmp =
cmp.withInferred(!nonInferredOriginPredicates.contains(cmp));
addExprs.add(cmp);
}
}
@@ -217,8 +223,25 @@ public class AddMinMax implements
ExpressionPatternRuleFactory, ValueDescVisitor
return result;
}
+ private Set<Expression> collectNonInferredPredicates(Expression expr) {
+ Set<Expression> predicates = Sets.newHashSet();
+ Deque<Expression> expressions = new ArrayDeque<>();
+ expressions.add(expr);
+ while (!expressions.isEmpty()) {
+ Expression current = expressions.removeLast();
+ if (ExpressionUtils.isInferred(current)) {
+ continue;
+ }
+ if (current instanceof CompoundPredicate) {
+ expressions.addAll(current.children());
+ } else {
+ predicates.add(current);
+ }
+ }
+ return predicates;
+ }
+
private Expression replaceCmpMinMax(Expression expr, Set<Expression>
cmpMinMaxExprs) {
- // even if expr is nullable, replace it to true is ok because
expression will 'AND' it later
if (cmpMinMaxExprs.contains(expr)) {
return BooleanLiteral.TRUE;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
index fabadac3ffe..3ae49261636 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
@@ -302,6 +302,9 @@ public abstract class Expression extends
AbstractTreeNode<Expression> implements
throw new RuntimeException();
}
+ // `inferred` means this predicate was derived by optimizer rules and did
not exist in
+ // the original SQL. If an equivalent predicate already exists in the
original SQL,
+ // it is not inferred even when the optimizer can derive the same
predicate again.
public Expression withInferred(boolean inferred) {
throw new RuntimeException("current expression has not impl the
withInferred method");
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/mv/PredicatesTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/mv/PredicatesTest.java
index 473ebb4908f..a324e8e9961 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/mv/PredicatesTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/mv/PredicatesTest.java
@@ -22,18 +22,23 @@ import
org.apache.doris.nereids.rules.exploration.mv.ComparisonResult;
import org.apache.doris.nereids.rules.exploration.mv.HyperGraphComparator;
import org.apache.doris.nereids.rules.exploration.mv.Predicates;
import org.apache.doris.nereids.rules.exploration.mv.Predicates.ExpressionInfo;
+import
org.apache.doris.nereids.rules.exploration.mv.Predicates.PredicateCompensation;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo;
import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.sqltest.SqlTestBase;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import java.util.Arrays;
+import java.util.Locale;
import java.util.Map;
+import java.util.stream.Collectors;
/** Test the method in Predicates*/
public class PredicatesTest extends SqlTestBase {
@@ -45,6 +50,16 @@ public class PredicatesTest extends SqlTestBase {
createTables(
"CREATE TABLE IF NOT EXISTS T1 (\n"
+ + " id bigint,\n"
+ + " score bigint\n"
+ + ")\n"
+ + "DUPLICATE KEY(id)\n"
+ + "DISTRIBUTED BY HASH(id, score) BUCKETS 10\n"
+ + "PROPERTIES (\n"
+ + " \"replication_num\" = \"1\", \n"
+ + " \"colocate_with\" = \"T0\"\n"
+ + ")\n",
+ "CREATE TABLE IF NOT EXISTS T2 (\n"
+ " id bigint,\n"
+ " score bigint\n"
+ ")\n"
@@ -63,13 +78,12 @@ public class PredicatesTest extends SqlTestBase {
+ ",ELIMINATE_GROUP_BY_KEY_BY_UNIFORM"
+ ",ELIMINATE_CONST_JOIN_CONDITION"
+ ",CONSTANT_PROPAGATION"
- + ",INFER_PREDICATES"
);
}
@Test
public void testCompensateCouldNotPullUpPredicatesFail() {
- CascadesContext mvContext = createCascadesContext(
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
"select \n"
+ "id,\n"
+ "FIRST_VALUE(id) OVER (\n"
@@ -80,14 +94,6 @@ public class PredicatesTest extends SqlTestBase {
+ "from \n"
+ "T1\n"
+ "where score > 10 and id < 5;",
- connectContext
- );
- Plan mvPlan = PlanChecker.from(mvContext)
- .analyze()
- .rewrite()
- .getPlan().child(0);
-
- CascadesContext queryContext = createCascadesContext(
"select \n"
+ "id,\n"
+ "FIRST_VALUE(id) OVER (\n"
@@ -97,33 +103,19 @@ public class PredicatesTest extends SqlTestBase {
+ " ) AS first_value\n"
+ "from \n"
+ "T1\n"
- + "where score > 10 and id < 1;",
- connectContext
- );
- Plan queryPlan = PlanChecker.from(queryContext)
- .analyze()
- .rewrite()
- .getAllPlan().get(0).child(0);
-
- StructInfo mvStructInfo = StructInfo.of(mvPlan, mvPlan, mvContext);
- StructInfo queryStructInfo = StructInfo.of(queryPlan, queryPlan,
queryContext);
- RelationMapping relationMapping =
RelationMapping.generate(mvStructInfo.getRelations(),
- queryStructInfo.getRelations(), 16).get(0);
-
- SlotMapping mvToQuerySlotMapping =
SlotMapping.generate(relationMapping);
- ComparisonResult comparisonResult =
HyperGraphComparator.isLogicCompatible(
- queryStructInfo.getHyperGraph(),
- mvStructInfo.getHyperGraph(),
- constructContext(queryPlan, mvPlan, queryContext));
+ + "where score > 10 and id < 1;");
Map<Expression, ExpressionInfo> expressionExpressionInfoMap =
Predicates.compensateCouldNotPullUpPredicates(
- queryStructInfo, mvStructInfo, mvToQuerySlotMapping,
comparisonResult);
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult);
Assertions.assertNull(expressionExpressionInfoMap);
}
@Test
- public void testCompensateCouldNotPullUpPredicatesSuccess() {
- CascadesContext mvContext = createCascadesContext(
+ public void testFinalizeCompensationByResidualKeepsRangeCandidate() {
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
"select \n"
+ "id,\n"
+ "FIRST_VALUE(id) OVER (\n"
@@ -134,14 +126,6 @@ public class PredicatesTest extends SqlTestBase {
+ "from \n"
+ "T1\n"
+ "where score > 10 and id < 5;",
- connectContext
- );
- Plan mvPlan = PlanChecker.from(mvContext)
- .analyze()
- .rewrite()
- .getPlan().child(0);
-
- CascadesContext queryContext = createCascadesContext(
"select \n"
+ "id,\n"
+ "FIRST_VALUE(id) OVER (\n"
@@ -151,34 +135,325 @@ public class PredicatesTest extends SqlTestBase {
+ " ) AS first_value\n"
+ "from \n"
+ "T1\n"
- + "where score > 15 and id < 5;",
- connectContext
- );
+ + "where score > 15 and id < 5;");
+
+ Map<Expression, ExpressionInfo> compensateCouldNotPullUpPredicates =
Predicates.compensateCouldNotPullUpPredicates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult);
+ Assertions.assertNotNull(compensateCouldNotPullUpPredicates);
+ Assertions.assertTrue(compensateCouldNotPullUpPredicates.isEmpty());
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertTrue(compensationCandidates.getResiduals().isEmpty());
+ Assertions.assertEquals(1, compensationCandidates.getRanges().size());
+ assertPredicateSqlEquals(compensationCandidates.getRanges(),
+ "(score > 15)");
+ Assertions.assertThrows(UnsupportedOperationException.class,
+ compensationCandidates.getRanges()::clear);
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+ Assertions.assertEquals(1,
finalPredicateCompensation.getRanges().size());
+ assertPredicateSqlEquals(finalPredicateCompensation.getRanges(),
+ "(score > 15)");
+
Assertions.assertTrue(finalPredicateCompensation.getResiduals().isEmpty());
+ }
+
+ @Test
+ public void testResidualCompensateSupportsDnfBranchImplication() {
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
+ "select id, score from T1 where id = 5 or id > 10",
+ "select id, score from T1 where id > 10 or (score = 1 and id =
5)");
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertTrue(compensationCandidates.getEquals().isEmpty());
+ Assertions.assertTrue(compensationCandidates.getRanges().isEmpty());
+ Assertions.assertEquals(1,
compensationCandidates.getResiduals().size());
+ assertPredicateSqlEquals(compensationCandidates.getResiduals(),
+ "OR[(id > 10),AND[(score = 1),(id = 5)]]");
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+
Assertions.assertTrue(finalPredicateCompensation.getEquals().isEmpty());
+
Assertions.assertTrue(finalPredicateCompensation.getRanges().isEmpty());
+ Assertions.assertEquals(1,
finalPredicateCompensation.getResiduals().size());
+ assertPredicateSqlEquals(finalPredicateCompensation.getResiduals(),
+ "OR[(id > 10),AND[(score = 1),(id = 5)]]");
+ }
+
+ @Test
+ public void
testResidualCompensateSupportsStrongerRangeInDnfBranchImplication() {
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
+ "select id, score from T1 where id > 10 or (score = 1 and id =
5)",
+ "select id, score from T1 where id > 15 or (score = 1 and id =
5)");
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertTrue(compensationCandidates.getEquals().isEmpty());
+ Assertions.assertTrue(compensationCandidates.getRanges().isEmpty());
+ Assertions.assertEquals(1,
compensationCandidates.getResiduals().size());
+ assertPredicateSqlEquals(compensationCandidates.getResiduals(),
+ "OR[(id > 15),AND[(score = 1),(id = 5)]]");
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+
Assertions.assertTrue(finalPredicateCompensation.getEquals().isEmpty());
+
Assertions.assertTrue(finalPredicateCompensation.getRanges().isEmpty());
+ Assertions.assertEquals(1,
finalPredicateCompensation.getResiduals().size());
+ assertPredicateSqlEquals(finalPredicateCompensation.getResiduals(),
+ "OR[(id > 15),AND[(score = 1),(id = 5)]]");
+ }
+
+ @Test
+ public void testFinalizeCompensationByResidualConsumesCoveredResidual() {
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
+ "select id, score from T1 where id = 5 or id > 10",
+ "select id, score from T1 where id = 5 or id > 10");
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertEquals(1,
compensationCandidates.getResiduals().size());
+ assertPredicateSqlEquals(compensationCandidates.getResiduals(),
+ "OR[(id = 5),(id > 10)]");
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+
Assertions.assertTrue(finalPredicateCompensation.getEquals().isEmpty());
+
Assertions.assertTrue(finalPredicateCompensation.getRanges().isEmpty());
+
Assertions.assertTrue(finalPredicateCompensation.getResiduals().isEmpty());
+ }
+
+ @Test
+ public void testCoveredWindowResidualIsAllowedAfterFinalization() {
+ String sql = "select id, score from T1 qualify sum(score) over
(partition by id) > 10";
+ PredicateRewriteContext rewriteContext = buildRewriteContext(sql, sql);
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+
Assertions.assertTrue(compensationCandidates.getResiduals().keySet().stream()
+ .anyMatch(expression ->
expression.anyMatch(WindowExpression.class::isInstance)));
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+
Assertions.assertTrue(finalPredicateCompensation.getResiduals().isEmpty());
+ }
+
+ @Test
+ public void testQueryOnlyWindowResidualIsRejectedAsCompensation() {
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
+ "select id, score from T1",
+ "select id, score from T1 qualify sum(score) over (partition
by id) > 10");
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+
Assertions.assertTrue(compensationCandidates.getResiduals().keySet().stream()
+ .anyMatch(expression ->
expression.anyMatch(WindowExpression.class::isInstance)));
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNull(finalPredicateCompensation);
+ }
+
+ @Test
+ public void
testCompensateCandidatesByViewResidualKeepsExactResidualWhenDnfBranchesOverflow()
{
+ String overflowResidual = buildDnfOverflowResidual();
+ String sql = "select id, score from T1 where " + overflowResidual;
+ PredicateRewriteContext rewriteContext = buildRewriteContext(sql, sql);
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertEquals(11,
compensationCandidates.getResiduals().size());
+ assertPredicateSqlEquals(compensationCandidates.getResiduals(),
+ "OR[(id = 1),(score = 101)]",
+ "OR[(id = 2),(score = 102)]",
+ "OR[(id = 3),(score = 103)]",
+ "OR[(id = 4),(score = 104)]",
+ "OR[(id = 5),(score = 105)]",
+ "OR[(id = 6),(score = 106)]",
+ "OR[(id = 7),(score = 107)]",
+ "OR[(id = 8),(score = 108)]",
+ "OR[(id = 9),(score = 109)]",
+ "OR[(id = 10),(score = 110)]",
+ "OR[(id = 11),(score = 111)]");
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+
Assertions.assertTrue(finalPredicateCompensation.getEquals().isEmpty());
+
Assertions.assertTrue(finalPredicateCompensation.getRanges().isEmpty());
+
Assertions.assertTrue(finalPredicateCompensation.getResiduals().isEmpty());
+ }
+
+ @Test
+ public void
testCompensateCandidatesByViewResidualReturnsNullWhenDnfBranchesOverflow() {
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
+ "select id, score from T1 where id = 999 or score = 999",
+ "select id, score from T1 where " +
buildDnfOverflowResidual());
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertEquals(11,
compensationCandidates.getResiduals().size());
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ // Non-exact DNF branches exceed the guard threshold, so implication
falls back conservatively.
+ Assertions.assertNull(finalPredicateCompensation);
+ }
+
+ @Test
+ public void
testCompensateCandidatesByViewResidualSkipsDnfWhenViewResidualEmpty() {
+ String overflowResidual = buildDnfOverflowResidual();
+ PredicateRewriteContext rewriteContext = buildRewriteContext(
+ "select id, score from T1",
+ "select id, score from T1 where " + overflowResidual);
+
+ PredicateCompensation compensationCandidates =
Predicates.collectCompensationCandidates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ Assertions.assertNotNull(compensationCandidates);
+ Assertions.assertEquals(11,
compensationCandidates.getResiduals().size());
+
+ PredicateCompensation finalPredicateCompensation =
compensatePredicates(rewriteContext);
+ Assertions.assertNotNull(finalPredicateCompensation);
+ Assertions.assertEquals(compensationCandidates.getEquals(),
finalPredicateCompensation.getEquals());
+ Assertions.assertEquals(compensationCandidates.getRanges(),
finalPredicateCompensation.getRanges());
+ Assertions.assertEquals(compensationCandidates.getResiduals(),
finalPredicateCompensation.getResiduals());
+ }
+
+ private String buildDnfOverflowResidual() {
+ return "(id = 1 or score = 101)"
+ + " and (id = 2 or score = 102)"
+ + " and (id = 3 or score = 103)"
+ + " and (id = 4 or score = 104)"
+ + " and (id = 5 or score = 105)"
+ + " and (id = 6 or score = 106)"
+ + " and (id = 7 or score = 107)"
+ + " and (id = 8 or score = 108)"
+ + " and (id = 9 or score = 109)"
+ + " and (id = 10 or score = 110)"
+ + " and (id = 11 or score = 111)";
+ }
+
+ private PredicateRewriteContext buildRewriteContext(String viewSql, String
querySql) {
+ CascadesContext viewContext = createCascadesContext(viewSql,
connectContext);
+ Plan viewPlan = PlanChecker.from(viewContext)
+ .analyze()
+ .rewrite()
+ .getPlan().child(0);
+
+ CascadesContext queryContext = createCascadesContext(querySql,
connectContext);
Plan queryPlan = PlanChecker.from(queryContext)
.analyze()
.rewrite()
.getAllPlan().get(0).child(0);
- StructInfo mvStructInfo = StructInfo.of(mvPlan, mvPlan, mvContext);
+ StructInfo viewStructInfo = StructInfo.of(viewPlan, viewPlan,
viewContext);
StructInfo queryStructInfo = StructInfo.of(queryPlan, queryPlan,
queryContext);
- RelationMapping relationMapping =
RelationMapping.generate(mvStructInfo.getRelations(),
+ RelationMapping relationMapping =
RelationMapping.generate(viewStructInfo.getRelations(),
queryStructInfo.getRelations(), 16).get(0);
-
- SlotMapping mvToQuerySlotMapping =
SlotMapping.generate(relationMapping);
+ SlotMapping viewToQuerySlotMapping =
SlotMapping.generate(relationMapping);
ComparisonResult comparisonResult =
HyperGraphComparator.isLogicCompatible(
queryStructInfo.getHyperGraph(),
- mvStructInfo.getHyperGraph(),
- constructContext(queryPlan, mvPlan, queryContext));
+ viewStructInfo.getHyperGraph(),
+ constructContext(queryPlan, viewPlan, queryContext));
+ return new PredicateRewriteContext(
+ viewStructInfo,
+ queryStructInfo,
+ viewToQuerySlotMapping,
+ comparisonResult,
+ queryContext);
+ }
- Map<Expression, ExpressionInfo> compensateCouldNotPullUpPredicates =
Predicates.compensateCouldNotPullUpPredicates(
- queryStructInfo, mvStructInfo, mvToQuerySlotMapping,
comparisonResult);
- Assertions.assertNotNull(compensateCouldNotPullUpPredicates);
- Assertions.assertTrue(compensateCouldNotPullUpPredicates.isEmpty());
+ private PredicateCompensation compensatePredicates(PredicateRewriteContext
rewriteContext) {
+ return Predicates.compensatePredicates(
+ rewriteContext.queryStructInfo,
+ rewriteContext.viewStructInfo,
+ rewriteContext.viewToQuerySlotMapping,
+ rewriteContext.comparisonResult,
+ rewriteContext.queryContext);
+ }
- Map<Expression, ExpressionInfo> compensateRangePredicates =
Predicates.compensateRangePredicate(
- queryStructInfo, mvStructInfo, mvToQuerySlotMapping,
comparisonResult,
- queryContext);
- Assertions.assertNotNull(compensateRangePredicates);
- Assertions.assertEquals(1, compensateRangePredicates.size());
+ private static final class PredicateRewriteContext {
+ private final StructInfo viewStructInfo;
+ private final StructInfo queryStructInfo;
+ private final SlotMapping viewToQuerySlotMapping;
+ private final ComparisonResult comparisonResult;
+ private final CascadesContext queryContext;
+
+ private PredicateRewriteContext(StructInfo viewStructInfo, StructInfo
queryStructInfo,
+ SlotMapping viewToQuerySlotMapping, ComparisonResult
comparisonResult,
+ CascadesContext queryContext) {
+ this.viewStructInfo = viewStructInfo;
+ this.queryStructInfo = queryStructInfo;
+ this.viewToQuerySlotMapping = viewToQuerySlotMapping;
+ this.comparisonResult = comparisonResult;
+ this.queryContext = queryContext;
+ }
+ }
+
+ private static void assertPredicateSqlEquals(Map<Expression,
ExpressionInfo> predicates,
+ String... expectedPredicates) {
+ String actual = predicates.keySet().stream()
+ .map(Expression::toSql)
+ .map(PredicatesTest::normalizeSql)
+ .sorted()
+ .collect(Collectors.joining(";"));
+ String expected = Arrays.stream(expectedPredicates)
+ .map(PredicatesTest::normalizeSql)
+ .sorted()
+ .collect(Collectors.joining(";"));
+ Assertions.assertEquals(expected, actual, "predicate sql mismatch");
+ }
+
+ // Normalize SQL text to remove non-semantic noise before string
comparison.
+ private static String normalizeSql(String sql) {
+ return sql.replace("`", "")
+ .replaceAll("\\s+", " ")
+ .trim()
+ .toLowerCase(Locale.ROOT);
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index d3a59ea6d1d..05baed1e043 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -32,6 +32,7 @@ import
org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
@@ -48,12 +49,17 @@ import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
+import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.math.BigDecimal;
+import java.util.List;
+import java.util.Map;
/**
* all expr rewrite rule test case.
@@ -323,9 +329,9 @@ class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
@Test
void testAddMinMax() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
- bottomUp(
- AddMinMax.INSTANCE
- )
+ bottomUp(
+ AddMinMax.INSTANCE
+ )
));
assertRewriteAfterTypeCoercion("5 * 100 >= 10 and 5 * 100 <= 5", "5 *
100 >= 10 and 5 * 100 <= 5");
@@ -386,6 +392,55 @@ class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
+ @Test
+ void testAddMinMaxMarksOriginalAndInferredBoundaryPredicates() {
+ executor = new ExpressionRuleExecutor(ImmutableList.of(
+ bottomUp(
+ AddMinMax.INSTANCE
+ )
+ ));
+
+ Map<String, Slot> slots = Maps.newHashMap();
+ Expression expression =
typeCoercion(replaceUnboundSlot(PARSER.parseExpression(
+ "TA > 10 and TA < 20 or TA > 30 and TA < 40"), slots));
+ Expression lowerPredicate =
typeCoercion(replaceUnboundSlot(PARSER.parseExpression("TA > 10"), slots));
+ Expression upperPredicate =
typeCoercion(replaceUnboundSlot(PARSER.parseExpression("TA < 40"), slots));
+
+ Expression rewrittenExpression = executor.rewrite(expression, context);
+ List<Expression> rewrittenConjuncts =
ExpressionUtils.extractConjunction(rewrittenExpression);
+ Assertions.assertTrue(rewrittenConjuncts.stream()
+ .anyMatch(conjunct -> conjunct.equals(lowerPredicate) &&
!conjunct.isInferred()));
+ Assertions.assertTrue(rewrittenConjuncts.stream()
+ .anyMatch(conjunct -> conjunct.equals(upperPredicate) &&
!conjunct.isInferred()));
+
+ Expression inferredExpression = expression.withInferred(true);
+ rewrittenExpression = executor.rewrite(inferredExpression, context);
+ rewrittenConjuncts =
ExpressionUtils.extractConjunction(rewrittenExpression);
+ Assertions.assertTrue(rewrittenConjuncts.stream()
+ .anyMatch(conjunct -> conjunct.equals(lowerPredicate) &&
conjunct.isInferred()));
+ Assertions.assertTrue(rewrittenConjuncts.stream()
+ .anyMatch(conjunct -> conjunct.equals(upperPredicate) &&
conjunct.isInferred()));
+ }
+
+ @Test
+ void testAddMinMaxMarksGeneratedPredicateInferred() {
+ executor = new ExpressionRuleExecutor(ImmutableList.of(
+ bottomUp(
+ AddMinMax.INSTANCE
+ )
+ ));
+
+ Map<String, Slot> slots = Maps.newHashMap();
+ Expression expression =
typeCoercion(replaceUnboundSlot(PARSER.parseExpression(
+ "TA = 4 or (TA > 4 and TB is null)"), slots));
+ Expression generatedPredicate =
typeCoercion(replaceUnboundSlot(PARSER.parseExpression("TA >= 4"), slots));
+
+ Expression rewrittenExpression = executor.rewrite(expression, context);
+ List<Expression> rewrittenConjuncts =
ExpressionUtils.extractConjunction(rewrittenExpression);
+ Assertions.assertTrue(rewrittenConjuncts.stream()
+ .anyMatch(conjunct -> conjunct.equals(generatedPredicate) &&
conjunct.isInferred()));
+ }
+
@Test
void testSimplifyRangeAndAddMinMax() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
diff --git
a/regression-test/suites/nereids_rules_p0/mv/negative/negative_test.groovy
b/regression-test/suites/nereids_rules_p0/mv/negative/negative_test.groovy
index ea45b806529..69931c7a7b0 100644
--- a/regression-test/suites/nereids_rules_p0/mv/negative/negative_test.groovy
+++ b/regression-test/suites/nereids_rules_p0/mv/negative/negative_test.groovy
@@ -421,7 +421,10 @@ suite("negative_partition_mv_rewrite") {
"""
mv_rewrite_fail(query_sql, mv_name)
- // filter include and or
+ // filter include and or. Query predicate is stricter than MV predicate
and can be
+ // compensated by DNF implication:
+ // query: o_orderkey > 2 AND (o_orderdate >= ... OR l_partkey > 1)
+ // mv : (o_orderkey > 2 AND o_orderdate >= ...) OR l_partkey > 1
mtmv_sql = """
select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey
from lineitem_1
@@ -437,7 +440,7 @@ suite("negative_partition_mv_rewrite") {
on lineitem_1.l_orderkey = orders_1.o_orderkey
where orders_1.o_orderkey > 2 and (orders_1.o_orderdate >=
"2023-10-17" or l_partkey > 1)
"""
- mv_rewrite_fail(query_sql, mv_name)
+ mv_rewrite_success(query_sql, mv_name)
// group by under group by
mtmv_sql = """
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]