This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new fa36652c8f7 [opt](nereids) opt for adjustting slot nullable and add
exception for changing slot nullable (#52748)
fa36652c8f7 is described below
commit fa36652c8f7723a4f602f94537e8671baca0c6c9
Author: yujun <[email protected]>
AuthorDate: Fri Jul 18 14:25:02 2025 +0800
[opt](nereids) opt for adjustting slot nullable and add exception for
changing slot nullable (#52748)
### What problem does this PR solve?
this PR optimize adjusting slot nullables, includes:
1) Apply AdjustNullable rules both at the end of analyzed phase and
rewrite phase. And at rewrite phase, if met a slot adjust from
not-nullable to nullable, it will throw an exception.
2) For LogicalApply, fix its computting output slots, and fix it in
AdjustNullable rule;
3) Opt NullableAggregateFunction's nullable, include:
- for agg, its constructor will adjust its NullableAggregateFunction
nullable;
- for project and one row relation, will change
NullableAggregateFunction to nullable;
- for having and sort, use rule AdjustAggregateNullableForEmptySet to
update their agg function nullable
---
.../doris/nereids/jobs/executor/Analyzer.java | 23 +-
.../doris/nereids/jobs/executor/Rewriter.java | 9 +-
.../org/apache/doris/nereids/rules/RuleType.java | 4 +-
.../AdjustAggregateNullableForEmptySet.java | 109 ++++--
.../nereids/rules/analysis/BindExpression.java | 40 ++-
.../nereids/rules/analysis/ExpressionAnalyzer.java | 12 +
.../nereids/rules/analysis/FillUpMissingSlots.java | 8 +-
.../nereids/rules/rewrite/AdjustNullable.java | 160 +++++++--
.../nereids/rules/rewrite/EliminateGroupBy.java | 15 +-
.../nereids/rules/rewrite/EliminateJoinByFK.java | 67 +++-
.../doris/nereids/rules/rewrite/SaltJoin.java | 6 +-
.../trees/expressions/WindowExpression.java | 6 +
.../trees/plans/logical/LogicalAggregate.java | 7 +-
.../nereids/trees/plans/logical/LogicalApply.java | 39 +-
.../rules/analysis/AnalyzeSubQueryTest.java | 32 +-
.../rules/analysis/AnalyzeWhereSubqueryTest.java | 4 +-
.../rules/analysis/NormalizeAggregateTest.java | 392 +++++++++++++++++++++
.../AggScalarSubQueryToWindowFunctionTest.java | 1 +
.../rules/rewrite/EliminateJoinByFkTest.java | 96 ++++-
.../rules/rewrite/PullUpProjectUnderApplyTest.java | 2 +
.../trees/plans/logical/LogicalAggregateTest.java | 101 ++++++
.../apache/doris/utframe/TestWithFeService.java | 1 +
.../adjust_nullable/test_agg_nullable.out | Bin 0 -> 297 bytes
.../data/nereids_rules_p0/salt_join/salt_join.out | Bin 19356 -> 19356 bytes
.../adjust_nullable/test_agg_nullable.groovy | 30 ++
25 files changed, 1010 insertions(+), 154 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index e2c95aecca2..4dd7fabf861 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.executor;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
+import org.apache.doris.nereids.rules.RuleType;
import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AnalyzeCTE;
import org.apache.doris.nereids.rules.analysis.BindExpression;
@@ -49,6 +50,7 @@ import
org.apache.doris.nereids.rules.analysis.QualifyToFilter;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
import org.apache.doris.nereids.rules.analysis.VariableToLiteral;
+import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
import org.apache.doris.nereids.rules.rewrite.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute;
import org.apache.doris.nereids.rules.rewrite.SimplifyAggGroupBy;
@@ -120,7 +122,19 @@ public class Analyzer extends AbstractBatchJobExecutor {
new EliminateDistinctConstant(),
new ProjectWithDistinctToAggregate(),
new ReplaceExpressionByChildOutput(),
- new OneRowRelationExtractAggregate()
+ new OneRowRelationExtractAggregate(),
+
+ // ProjectToGlobalAggregate may generate an aggregate with
empty group by expressions.
+ // for sort / having, need to adjust their agg functions'
nullable.
+ // for example: select sum(a) from t having sum(b) > 10
order by sum(c),
+ // then will have:
+ // sort(sum(c)) sort(sum(c))
+ // | |
+ // having(sum(b) > 10) ==> having(sum(b) > 10)
+ // | |
+ // project(sum(a)) agg(sum(a))
+ // then need to adjust SORT and HAVING's sum to nullable.
+ new AdjustAggregateNullableForEmptySet()
),
topDown(
new FillUpMissingSlots(),
@@ -129,7 +143,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
// LogicalProject for normalize. This rule depends on
FillUpMissingSlots to fill up slots.
new NormalizeRepeat()
),
- bottomUp(new AdjustAggregateNullableForEmptySet()),
// consider sql with user defined var @t_zone
// set @t_zone='GMT';
// SELECT
@@ -169,7 +182,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
topDown(new LeadingJoin()),
topDown(new BindSkewExpr()),
bottomUp(new NormalizeGenerate()),
- bottomUp(new SubqueryToApply()),
/*
* Notice, MergeProjects rule should NOT be placed after
SubqueryToApply in analyze phase.
* because in SubqueryToApply, we may add assert_true function
with subquery output slot in projects list.
@@ -182,7 +194,10 @@ public class Analyzer extends AbstractBatchJobExecutor {
topDown(
// merge normal filter and hidden column filter
new MergeFilters()
- )
+ ),
+ // for cte: analyze producer -> analyze consumer -> rewrite
consumer -> rewrite producer,
+ // in order to ensure cte consumer had right nullable attribute,
need adjust nullable at analyze phase.
+ custom(RuleType.ADJUST_NULLABLE, () -> new AdjustNullable(true))
);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index dd592d107fc..7e0a7a21f79 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -23,7 +23,6 @@ import
org.apache.doris.nereids.jobs.rewrite.CostBasedRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
-import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import
org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
@@ -271,8 +270,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
new SimplifyEncodeDecode()
)
),
- // please note: this rule must run before NormalizeAggregate
- topDown(new AdjustAggregateNullableForEmptySet()),
// The rule modification needs to be done after the subquery
is unnested,
// because for scalarSubQuery, the connection condition is
stored in apply in the analyzer phase,
// but when normalizeAggregate/normalizeSort is performed, the
members in apply cannot be obtained,
@@ -405,9 +402,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
cascadesContext ->
cascadesContext.rewritePlanContainsTypes(LogicalAggregate.class),
topDown(
new EliminateGroupBy(),
- new MergeAggregate(),
- // need to adjust min/max/sum nullable
attribute after merge aggregate
- new AdjustAggregateNullableForEmptySet()
+ new MergeAggregate()
)
),
topic("Eager aggregation",
@@ -659,7 +654,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
)
),
topic("whole plan check",
- custom(RuleType.ADJUST_NULLABLE,
AdjustNullable::new)
+ custom(RuleType.ADJUST_NULLABLE, () -> new
AdjustNullable(false))
),
// NullableDependentExpressionRewrite need to be done
after nullable fixed
topic("condition function", bottomUp(ImmutableList.of(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 75cc22c5036..83afc17ff7a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -100,10 +100,8 @@ public enum RuleType {
ANALYZE_CTE(RuleTypeClass.REWRITE),
RELATION_AUTHENTICATION(RuleTypeClass.VALIDATION),
- ADJUST_NULLABLE_FOR_PROJECT_SLOT(RuleTypeClass.REWRITE),
- ADJUST_NULLABLE_FOR_AGGREGATE_SLOT(RuleTypeClass.REWRITE),
+ ADJUST_NULLABLE_FOR_SORT_SLOT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_FOR_HAVING_SLOT(RuleTypeClass.REWRITE),
- ADJUST_NULLABLE_FOR_REPEAT_SLOT(RuleTypeClass.REWRITE),
ADD_DEFAULT_LIMIT(RuleTypeClass.REWRITE),
CHECK_ROW_POLICY(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
index 5543341ae27..947b42bb80b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
@@ -17,23 +17,24 @@
package org.apache.doris.nereids.rules.analysis;
+import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Set;
-import java.util.stream.Collectors;
/**
* adjust aggregate nullable when: group expr list is empty and function is
NullableAggregateFunction,
@@ -43,57 +44,87 @@ public class AdjustAggregateNullableForEmptySet implements
RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
- RuleType.ADJUST_NULLABLE_FOR_AGGREGATE_SLOT.build(
- logicalAggregate()
- .then(agg -> {
- List<NamedExpression> outputExprs =
agg.getOutputExpressions();
- boolean noGroupBy =
agg.getGroupByExpressions().isEmpty();
- ImmutableList.Builder<NamedExpression>
newOutput
- =
ImmutableList.builderWithExpectedSize(outputExprs.size());
- for (NamedExpression ne : outputExprs) {
- NamedExpression newExpr =
- ((NamedExpression)
FunctionReplacer.INSTANCE.replace(ne, noGroupBy));
- newOutput.add(newExpr);
- }
- return
agg.withAggOutput(newOutput.build());
- })
- ),
RuleType.ADJUST_NULLABLE_FOR_HAVING_SLOT.build(
logicalHaving(logicalAggregate())
- .then(having -> {
- Set<Expression> conjuncts =
having.getConjuncts();
- boolean noGroupBy =
having.child().getGroupByExpressions().isEmpty();
- ImmutableSet.Builder<Expression>
newConjuncts
- =
ImmutableSet.builderWithExpectedSize(conjuncts.size());
- for (Expression expr : conjuncts) {
- Expression newExpr =
FunctionReplacer.INSTANCE.replace(expr, noGroupBy);
- newConjuncts.add(newExpr);
- }
- return new
LogicalHaving<>(newConjuncts.build(), having.child());
- })
+ .then(having -> replaceHaving(having,
having.child().getGroupByExpressions().isEmpty()))
+ ),
+ RuleType.ADJUST_NULLABLE_FOR_SORT_SLOT.build(
+ logicalSort(logicalAggregate())
+ .then(sort -> replaceSort(sort,
sort.child().getGroupByExpressions().isEmpty()))
+ ),
+ RuleType.ADJUST_NULLABLE_FOR_SORT_SLOT.build(
+ logicalSort(logicalHaving(logicalAggregate()))
+ .then(sort -> replaceSort(sort,
sort.child().child().getGroupByExpressions().isEmpty()))
)
);
}
+ public static Expression replaceExpression(Expression expression, boolean
alwaysNullable) {
+ return FunctionReplacer.INSTANCE.replace(expression, alwaysNullable);
+ }
+
+ private LogicalPlan replaceSort(LogicalSort<?> sort, boolean
alwaysNullable) {
+ ImmutableList.Builder<OrderKey> newOrderKeysBuilder
+ =
ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size());
+ sort.getOrderKeys().forEach(
+ key ->
newOrderKeysBuilder.add(key.withExpression(replaceExpression(key.getExpr(),
alwaysNullable))));
+ List<OrderKey> newOrderKeys = newOrderKeysBuilder.build();
+ if (newOrderKeys.equals(sort.getOrderKeys())) {
+ return null;
+ }
+ return sort.withOrderKeys(newOrderKeys);
+ }
+
+ private LogicalPlan replaceHaving(LogicalHaving<?> having, boolean
alwaysNullable) {
+ Set<Expression> conjuncts = having.getConjuncts();
+ ImmutableSet.Builder<Expression> newConjunctsBuilder
+ = ImmutableSet.builderWithExpectedSize(conjuncts.size());
+ for (Expression expr : conjuncts) {
+ Expression newExpr = replaceExpression(expr, alwaysNullable);
+ newConjunctsBuilder.add(newExpr);
+ }
+ ImmutableSet<Expression> newConjuncts = newConjunctsBuilder.build();
+ if (newConjuncts.equals(having.getConjuncts())) {
+ return null;
+ }
+ return having.withConjuncts(newConjuncts);
+ }
+
+ /**
+ * replace NullableAggregateFunction nullable
+ */
private static class FunctionReplacer extends
DefaultExpressionRewriter<Boolean> {
public static final FunctionReplacer INSTANCE = new FunctionReplacer();
public Expression replace(Expression expression, boolean
alwaysNullable) {
- return expression.accept(INSTANCE, alwaysNullable);
+ return expression.accept(this, alwaysNullable);
}
@Override
public Expression visitWindow(WindowExpression windowExpression,
Boolean alwaysNullable) {
- return windowExpression.withPartitionKeysOrderKeys(
- windowExpression.getPartitionKeys().stream()
- .map(k -> k.accept(INSTANCE, alwaysNullable))
- .collect(Collectors.toList()),
- windowExpression.getOrderKeys().stream()
- .map(k -> (OrderExpression)
k.withChildren(k.children().stream()
- .map(c -> c.accept(INSTANCE,
alwaysNullable))
- .collect(Collectors.toList())))
- .collect(Collectors.toList())
- );
+ ImmutableList.Builder<Expression> newFunctionChildrenBuilder
+ =
ImmutableList.builderWithExpectedSize(windowExpression.getFunction().children().size());
+ for (Expression child : windowExpression.getFunction().children())
{
+ newFunctionChildrenBuilder.add(child.accept(this,
alwaysNullable));
+ }
+ Expression newFunction =
windowExpression.getFunction().withChildren(newFunctionChildrenBuilder.build());
+ ImmutableList.Builder<Expression> newPartitionKeysBuilder
+ =
ImmutableList.builderWithExpectedSize(windowExpression.getPartitionKeys().size());
+ for (Expression partitionKey :
windowExpression.getPartitionKeys()) {
+ newPartitionKeysBuilder.add(partitionKey.accept(this,
alwaysNullable));
+ }
+ ImmutableList.Builder<OrderExpression> newOrderKeysBuilder
+ =
ImmutableList.builderWithExpectedSize(windowExpression.getOrderKeys().size());
+ for (OrderExpression orderKey : windowExpression.getOrderKeys()) {
+ ImmutableList.Builder<Expression> newChildrenBuilder
+ =
ImmutableList.builderWithExpectedSize(orderKey.children().size());
+ for (Expression child : orderKey.children()) {
+ newChildrenBuilder.add(child.accept(this, alwaysNullable));
+ }
+ newOrderKeysBuilder.add((OrderExpression)
orderKey.withChildren(newChildrenBuilder.build()));
+ }
+ return windowExpression.withFunctionPartitionKeysOrderKeys(
+ newFunction, newPartitionKeysBuilder.build(),
newOrderKeysBuilder.build());
}
@Override
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index 5c0a3a94b9f..90614b29135 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -55,6 +55,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import
org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
@@ -363,6 +364,7 @@ public class BindExpression implements AnalysisRuleFactory {
CascadesContext cascadesContext = ctx.cascadesContext;
SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(oneRowRelation,
cascadesContext, ImmutableList.of());
List<NamedExpression> projects =
analyzer.analyzeToList(oneRowRelation.getProjects());
+ projects = adjustProjectionAggNullable(projects);
return new LogicalOneRowRelation(oneRowRelation.getRelationId(),
projects);
}
@@ -732,24 +734,32 @@ public class BindExpression implements
AnalysisRuleFactory {
});
}
}
- List<NamedExpression> boundProjections =
boundProjectionsBuilder.build();
- if (!SqlModeHelper.hasOnlyFullGroupBy()) {
- boolean hasAggregation = boundProjections
- .stream()
- .anyMatch(e ->
e.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
- if (hasAggregation) {
- boundProjectionsBuilder
- =
ImmutableList.builderWithExpectedSize(project.getProjects().size());
- for (NamedExpression expr : boundProjections) {
- if (expr instanceof SlotReference) {
- expr = new Alias(expr, expr.getName());
- }
- boundProjectionsBuilder.add(expr);
+ List<NamedExpression> projects =
adjustProjectionAggNullable(boundProjectionsBuilder.build());
+ return project.withProjects(projects);
+ }
+
+ private List<NamedExpression>
adjustProjectionAggNullable(List<NamedExpression> expressions) {
+ boolean hasAggregation = expressions.stream()
+ .anyMatch(expr ->
expr.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
+ if (!hasAggregation) {
+ return expressions;
+ }
+ boolean hasOnlyFullGroupBy = SqlModeHelper.hasOnlyFullGroupBy();
+ Builder<NamedExpression> newExpressionsBuilder =
ImmutableList.builderWithExpectedSize(expressions.size());
+ for (NamedExpression expr : expressions) {
+ expr = (NamedExpression) expr.rewriteDownShortCircuit(e -> {
+ // for `select sum(a) from t`, sum(a) is nullable
+ if (e instanceof NullableAggregateFunction) {
+ return ((NullableAggregateFunction)
e).withAlwaysNullable(true);
}
- boundProjections = boundProjectionsBuilder.build();
+ return e;
+ });
+ if (!hasOnlyFullGroupBy && expr instanceof SlotReference) {
+ expr = new Alias(expr, expr.getName());
}
+ newExpressionsBuilder.add(expr);
}
- return project.withProjects(boundProjections);
+ return newExpressionsBuilder.build();
}
private Plan bindLoadProject(MatchingContext<LogicalLoadProject<Plan>>
ctx) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
index ac687a6d181..63c88aa0e8a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
@@ -65,9 +65,11 @@ import
org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.Variable;
import org.apache.doris.nereids.trees.expressions.WhenClause;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import
org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import
org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
@@ -523,6 +525,16 @@ public class ExpressionAnalyzer extends
SubExprAnalyzer<ExpressionRewriteContext
return TypeCoercionUtils.processBoundFunction(boundFunction);
}
+ @Override
+ public Expression visitWindow(WindowExpression windowExpression,
ExpressionRewriteContext context) {
+ windowExpression = (WindowExpression)
super.visitWindow(windowExpression, context);
+ Expression function = windowExpression.getFunction();
+ if (function instanceof NullableAggregateFunction) {
+ return windowExpression.withFunction(((NullableAggregateFunction)
function).withAlwaysNullable(true));
+ }
+ return windowExpression;
+ }
+
/**
* gets the method for calculating the time.
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
index cc213070e9f..3f92f612152 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
@@ -222,10 +222,14 @@ public class FillUpMissingSlots implements
AnalysisRuleFactory {
// avoid throw exception even if having have slot
from its child.
// because we will add a project between having
and project.
Resolver resolver = new Resolver(agg, false,
outerScope);
- having.getConjuncts().forEach(resolver::resolve);
+ Set<Expression> adjustAggNullableConjuncts =
having.getConjuncts().stream()
+ .map(conjunct ->
AdjustAggregateNullableForEmptySet.replaceExpression(
+ conjunct, true))
+ .collect(Collectors.toSet());
+
adjustAggNullableConjuncts.forEach(resolver::resolve);
agg =
agg.withAggOutput(resolver.getNewOutputSlots());
Set<Expression> newConjuncts =
ExpressionUtils.replace(
- having.getConjuncts(),
resolver.getSubstitution());
+ adjustAggNullableConjuncts,
resolver.getSubstitution());
ImmutableList.Builder<NamedExpression> projects =
ImmutableList.builder();
projects.addAll(project.getOutputs()).addAll(agg.getOutput());
return new LogicalHaving<>(newConjuncts, new
LogicalProject<>(projects.build(), agg));
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
index 08f8cfb4f80..547fa753cb2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
@@ -17,11 +17,14 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.util.DebugUtil;
+import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@@ -29,6 +32,7 @@ import
org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
@@ -46,14 +50,16 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.qe.ConnectContext;
-import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
import java.util.LinkedHashMap;
import java.util.List;
@@ -68,6 +74,14 @@ import java.util.concurrent.atomic.AtomicBoolean;
*/
public class AdjustNullable extends DefaultPlanRewriter<Map<ExprId, Slot>>
implements CustomRewriter {
+ private static final Logger LOG =
LogManager.getLogger(AdjustNullable.class);
+
+ private final boolean isAnalyzedPhase;
+
+ public AdjustNullable(boolean isAnalyzedPhase) {
+ this.isAnalyzedPhase = isAnalyzedPhase;
+ }
+
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, Maps.newHashMap());
@@ -86,7 +100,8 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink,
Map<ExprId, Slot> replaceMap) {
logicalSink = (LogicalSink<? extends Plan>) super.visit(logicalSink,
replaceMap);
- Optional<List<NamedExpression>> newOutputExprs =
updateExpressions(logicalSink.getOutputExprs(), replaceMap);
+ Optional<List<NamedExpression>> newOutputExprs
+ = updateExpressions(logicalSink.getOutputExprs(), replaceMap,
true);
if (!newOutputExprs.isPresent()) {
return logicalSink;
} else {
@@ -98,8 +113,8 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan>
aggregate, Map<ExprId, Slot> replaceMap) {
aggregate = (LogicalAggregate<? extends Plan>) super.visit(aggregate,
replaceMap);
Optional<List<NamedExpression>> newOutputs
- = updateExpressions(aggregate.getOutputExpressions(),
replaceMap);
- Optional<List<Expression>> newGroupBy =
updateExpressions(aggregate.getGroupByExpressions(), replaceMap);
+ = updateExpressions(aggregate.getOutputExpressions(),
replaceMap, true);
+ Optional<List<Expression>> newGroupBy =
updateExpressions(aggregate.getGroupByExpressions(), replaceMap, true);
for (NamedExpression newOutput :
newOutputs.orElse(aggregate.getOutputExpressions())) {
replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
}
@@ -115,7 +130,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter,
Map<ExprId, Slot> replaceMap) {
filter = (LogicalFilter<? extends Plan>) super.visit(filter,
replaceMap);
- Optional<Set<Expression>> conjuncts =
updateExpressions(filter.getConjuncts(), replaceMap);
+ Optional<Set<Expression>> conjuncts =
updateExpressions(filter.getConjuncts(), replaceMap, true);
if (!conjuncts.isPresent()) {
return filter;
}
@@ -125,7 +140,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate,
Map<ExprId, Slot> replaceMap) {
generate = (LogicalGenerate<? extends Plan>) super.visit(generate,
replaceMap);
- Optional<List<Function>> newGenerators =
updateExpressions(generate.getGenerators(), replaceMap);
+ Optional<List<Function>> newGenerators =
updateExpressions(generate.getGenerators(), replaceMap, true);
Plan newGenerate = generate;
if (newGenerators.isPresent()) {
newGenerate =
generate.withGenerators(newGenerators.get()).recomputeLogicalProperties();
@@ -139,28 +154,37 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
join, Map<ExprId, Slot> replaceMap) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join,
replaceMap);
- Optional<List<Expression>> hashConjuncts =
updateExpressions(join.getHashJoinConjuncts(), replaceMap);
+ Optional<List<Expression>> hashConjuncts =
updateExpressions(join.getHashJoinConjuncts(), replaceMap, true);
Optional<List<Expression>> markConjuncts = Optional.empty();
- boolean needCheckHashConjuncts = false;
- if (!hashConjuncts.isPresent() || hashConjuncts.get().isEmpty()) {
- // if hashConjuncts is empty, mark join conjuncts may used to
build hash table
+ boolean hadUpdatedMarkConjuncts = false;
+ if (isAnalyzedPhase || join.getHashJoinConjuncts().isEmpty()) {
+ // if hashConjuncts is empty, mark join conjuncts may use to build
hash table
// so need call updateExpressions for mark join conjuncts before
adjust nullable by output slot
- markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap);
- } else {
- needCheckHashConjuncts = true;
+ markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap, true);
+ hadUpdatedMarkConjuncts = true;
+ }
+ // in fact, otherConjuncts shouldn't use join output nullable
attribute,
+ // it should use left and right tables' origin nullable attribute.
+ // but for history reason, BE use join output nullable attribute for
evaluating the other conditions.
+ // so here, we make a difference:
+ // 1) when at analyzed phase, still update other conjuncts without
using join output nullables.
+ // then later at rewrite phase, the join conditions may push down,
and the push down condition with proper
+ // nullable attribute.
+ // 2) when at the end of rewrite phase, update other conjuncts with
join output nullables.
+ // Just change it to be consistent with BE.
+ Optional<List<Expression>> otherConjuncts = Optional.empty();
+ if (isAnalyzedPhase) {
+ otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(),
replaceMap, true);
}
for (Slot slot : join.getOutput()) {
replaceMap.put(slot.getExprId(), slot);
}
- if (needCheckHashConjuncts) {
- // hashConjuncts is not empty, mark join conjuncts are processed
like other join conjuncts
- Preconditions.checkState(
-
!hashConjuncts.orElse(join.getHashJoinConjuncts()).isEmpty(),
- "hash conjuncts should not be empty"
- );
- markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap);
+ if (!hadUpdatedMarkConjuncts) {
+ markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap, false);
+ }
+ if (!isAnalyzedPhase) {
+ otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(),
replaceMap, false);
}
- Optional<List<Expression>> otherConjuncts =
updateExpressions(join.getOtherJoinConjuncts(), replaceMap);
if (!hashConjuncts.isPresent() && !markConjuncts.isPresent() &&
!otherConjuncts.isPresent()) {
return join;
}
@@ -175,7 +199,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project,
Map<ExprId, Slot> replaceMap) {
project = (LogicalProject<? extends Plan>) super.visit(project,
replaceMap);
- Optional<List<NamedExpression>> newProjects =
updateExpressions(project.getProjects(), replaceMap);
+ Optional<List<NamedExpression>> newProjects =
updateExpressions(project.getProjects(), replaceMap, true);
for (NamedExpression newProject :
newProjects.orElse(project.getProjects())) {
replaceMap.put(newProject.getExprId(), newProject.toSlot());
}
@@ -185,6 +209,38 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
return project.withProjects(newProjects.get());
}
+ @Override
+ public Plan visitLogicalApply(LogicalApply<? extends Plan, ? extends Plan>
apply, Map<ExprId, Slot> replaceMap) {
+ apply = (LogicalApply<? extends Plan, ? extends Plan>)
super.visit(apply, replaceMap);
+ Optional<Expression> newCompareExpr =
updateExpression(apply.getCompareExpr(), replaceMap, true);
+ Optional<Expression> newTypeCoercionExpr =
updateExpression(apply.getTypeCoercionExpr(), replaceMap, true);
+ Optional<List<Slot>> newCorrelationSlot =
updateExpressions(apply.getCorrelationSlot(), replaceMap, true);
+ Optional<Expression> newCorrelationFilter =
updateExpression(apply.getCorrelationFilter(), replaceMap, true);
+ Optional<MarkJoinSlotReference> newMarkJoinSlotReference =
+ updateExpression(apply.getMarkJoinSlotReference(), replaceMap,
true);
+
+ for (Slot slot : apply.getOutput()) {
+ replaceMap.put(slot.getExprId(), slot);
+ }
+ if (!newCompareExpr.isPresent() && !newTypeCoercionExpr.isPresent() &&
!newCorrelationSlot.isPresent()
+ && !newCorrelationFilter.isPresent() &&
!newMarkJoinSlotReference.isPresent()) {
+ return apply;
+ }
+
+ return new LogicalApply<>(
+ newCorrelationSlot.orElse(apply.getCorrelationSlot()),
+ apply.getSubqueryType(),
+ apply.isNot(),
+ newCompareExpr.isPresent() ? newCompareExpr :
apply.getCompareExpr(),
+ newTypeCoercionExpr.isPresent() ? newTypeCoercionExpr :
apply.getTypeCoercionExpr(),
+ newCorrelationFilter.isPresent() ? newCorrelationFilter :
apply.getCorrelationFilter(),
+ newMarkJoinSlotReference.isPresent() ?
newMarkJoinSlotReference : apply.getMarkJoinSlotReference(),
+ apply.isNeedAddSubOutputToProjects(),
+ apply.isMarkJoinSlotNotNull(),
+ apply.left(),
+ apply.right());
+ }
+
@Override
public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat,
Map<ExprId, Slot> replaceMap) {
repeat = (LogicalRepeat<? extends Plan>) super.visit(repeat,
replaceMap);
@@ -196,7 +252,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
if (flattenGroupingSetExpr.contains(output)) {
newOutput = output;
} else {
- newOutput = updateExpression(output,
replaceMap).orElse(output);
+ newOutput = updateExpression(output, replaceMap,
true).orElse(output);
}
newOutputs.add(newOutput);
replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
@@ -274,7 +330,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
boolean changed = false;
ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builder();
for (OrderKey orderKey : sort.getOrderKeys()) {
- Optional<Expression> newOrderKey =
updateExpression(orderKey.getExpr(), replaceMap);
+ Optional<Expression> newOrderKey =
updateExpression(orderKey.getExpr(), replaceMap, true);
if (!newOrderKey.isPresent()) {
newOrderKeys.add(orderKey);
} else {
@@ -295,7 +351,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
boolean changed = false;
ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builder();
for (OrderKey orderKey : topN.getOrderKeys()) {
- Optional<Expression> newOrderKey =
updateExpression(orderKey.getExpr(), replaceMap);
+ Optional<Expression> newOrderKey =
updateExpression(orderKey.getExpr(), replaceMap, true);
if (!newOrderKey.isPresent()) {
newOrderKeys.add(orderKey);
} else {
@@ -313,7 +369,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window,
Map<ExprId, Slot> replaceMap) {
window = (LogicalWindow<? extends Plan>) super.visit(window,
replaceMap);
Optional<List<NamedExpression>> windowExpressions =
- updateExpressions(window.getWindowExpressions(), replaceMap);
+ updateExpressions(window.getWindowExpressions(), replaceMap,
true);
for (NamedExpression w :
windowExpressions.orElse(window.getWindowExpressions())) {
replaceMap.put(w.getExprId(), w.toSlot());
}
@@ -327,8 +383,9 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan>
partitionTopN,
Map<ExprId, Slot> replaceMap) {
partitionTopN = (LogicalPartitionTopN<? extends Plan>)
super.visit(partitionTopN, replaceMap);
- Optional<List<Expression>> partitionKeys =
updateExpressions(partitionTopN.getPartitionKeys(), replaceMap);
- Optional<List<OrderExpression>> orderKeys =
updateExpressions(partitionTopN.getOrderKeys(), replaceMap);
+ Optional<List<Expression>> partitionKeys
+ = updateExpressions(partitionTopN.getPartitionKeys(),
replaceMap, true);
+ Optional<List<OrderExpression>> orderKeys =
updateExpressions(partitionTopN.getOrderKeys(), replaceMap, true);
if (!partitionKeys.isPresent() && !orderKeys.isPresent()) {
return partitionTopN;
}
@@ -342,7 +399,7 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
Map<Slot, Slot> consumerToProducerOutputMap = new LinkedHashMap<>();
Multimap<Slot, Slot> producerToConsumerOutputMap =
LinkedHashMultimap.create();
for (Slot producerOutputSlot :
cteConsumer.getConsumerToProducerOutputMap().values()) {
- Optional<Slot> newProducerOutputSlot =
updateExpression(producerOutputSlot, replaceMap);
+ Optional<Slot> newProducerOutputSlot =
updateExpression(producerOutputSlot, replaceMap, true);
for (Slot consumerOutputSlot :
cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot)) {
Slot slot = newProducerOutputSlot.orElse(producerOutputSlot);
Slot newConsumerOutputSlot =
consumerOutputSlot.withNullable(slot.nullable());
@@ -354,11 +411,18 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
return cteConsumer.withTwoMaps(consumerToProducerOutputMap,
producerToConsumerOutputMap);
}
- private <T extends Expression> Optional<T> updateExpression(T input,
Map<ExprId, Slot> replaceMap) {
+ private <T extends Expression> Optional<T> updateExpression(Optional<T>
input,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
+ return input.isPresent() ? updateExpression(input.get(), replaceMap,
debugCheck) : Optional.empty();
+ }
+
+ private <T extends Expression> Optional<T> updateExpression(T input,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
AtomicBoolean changed = new AtomicBoolean(false);
Expression replaced = input.rewriteDownShortCircuit(e -> {
if (e instanceof SlotReference) {
SlotReference slotReference = (SlotReference) e;
+ Slot newSlotReference = slotReference;
Slot replacedSlot = replaceMap.get(slotReference.getExprId());
if (replacedSlot != null) {
if (replacedSlot.getDataType().isAggStateType()) {
@@ -369,16 +433,32 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
// TODO: remove if statement after we ensure be
constant folding do not change
// expr type at all.
changed.set(true);
- return slotReference.withNullableAndDataType(
- replacedSlot.nullable(),
replacedSlot.getDataType()
- );
+ newSlotReference =
slotReference.withNullableAndDataType(
+ replacedSlot.nullable(),
replacedSlot.getDataType());
}
} else if (slotReference.nullable() !=
replacedSlot.nullable()) {
changed.set(true);
- return
slotReference.withNullable(replacedSlot.nullable());
+ newSlotReference =
slotReference.withNullable(replacedSlot.nullable());
+ }
+ }
+ // for join other conditions, debugCheck = false, for other
case, debugCheck is always true.
+ // Because join other condition use join output's nullable
attribute, outer join may check fail.
+ // At analyzed phase, the slot reference nullable may change,
for example, NormalRepeat may adjust some
+ // slot reference to nullable, after this rule, node above
repeat need adjust.
+ // so analyzed phase don't assert not-nullable -> nullable,
otherwise adjust plan above
+ // repeat may check fail.
+ if (!slotReference.nullable() && newSlotReference.nullable()
+ && !isAnalyzedPhase && debugCheck &&
ConnectContext.get() != null) {
+ if (ConnectContext.get().getSessionVariable().feDebug) {
+ throw new AnalysisException("AdjustNullable convert
slot " + slotReference
+ + " from not-nullable to nullable. You can
disable check by set fe_debug = false.");
+ } else {
+ LOG.warn("adjust nullable convert slot '" +
slotReference
+ + "' from not-nullable to nullable for query "
+ +
DebugUtil.printId(ConnectContext.get().queryId()));
}
}
- return slotReference;
+ return newSlotReference;
} else {
return e;
}
@@ -386,22 +466,24 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
return changed.get() ? Optional.of((T) replaced) : Optional.empty();
}
- private <T extends Expression> Optional<List<T>> updateExpressions(List<T>
inputs, Map<ExprId, Slot> replaceMap) {
+ private <T extends Expression> Optional<List<T>> updateExpressions(List<T>
inputs,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
ImmutableList.Builder<T> result =
ImmutableList.builderWithExpectedSize(inputs.size());
boolean changed = false;
for (T input : inputs) {
- Optional<T> newInput = updateExpression(input, replaceMap);
+ Optional<T> newInput = updateExpression(input, replaceMap,
debugCheck);
changed |= newInput.isPresent();
result.add(newInput.orElse(input));
}
return changed ? Optional.of(result.build()) : Optional.empty();
}
- private <T extends Expression> Optional<Set<T>> updateExpressions(Set<T>
inputs, Map<ExprId, Slot> replaceMap) {
+ private <T extends Expression> Optional<Set<T>> updateExpressions(Set<T>
inputs,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
boolean changed = false;
ImmutableSet.Builder<T> result =
ImmutableSet.builderWithExpectedSize(inputs.size());
for (T input : inputs) {
- Optional<T> newInput = updateExpression(input, replaceMap);
+ Optional<T> newInput = updateExpression(input, replaceMap,
debugCheck);
changed |= newInput.isPresent();
result.add(newInput.orElse(input));
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
index 65345fb37e6..628bd3cd823 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
@@ -118,8 +118,7 @@ public class EliminateGroupBy extends OneRewriteRuleFactory
{
.castIfNotSameType(new BigIntLiteral(1),
f.getDataType())));
} else {
newOutput.add((NamedExpression) ne.withChildren(
- new If(new IsNull(f.child(0)), new
BigIntLiteral(0),
- new BigIntLiteral(1))));
+ ifNullElse(f.child(0), new BigIntLiteral(0),
new BigIntLiteral(1))));
}
} else if (f instanceof Sum0) {
Coalesce coalesce = new Coalesce(f.child(0),
@@ -127,14 +126,14 @@ public class EliminateGroupBy extends
OneRewriteRuleFactory {
newOutput.add((NamedExpression) ne.withChildren(
TypeCoercionUtils.castIfNotSameType(coalesce,
f.getDataType())));
} else if (supportedTwoArgsFunctions.contains(f.getClass())) {
- If ifFunc = new If(new IsNull(f.child(1)), new
NullLiteral(f.child(0).getDataType()),
+ Expression expr = ifNullElse(f.child(1), new
NullLiteral(f.child(0).getDataType()),
f.child(0));
newOutput.add((NamedExpression) ne.withChildren(
- TypeCoercionUtils.castIfNotSameType(ifFunc,
f.getDataType())));
+ TypeCoercionUtils.castIfNotSameType(expr,
f.getDataType())));
} else if (supportedDevLikeFunctions.contains(f.getClass())) {
- If ifFunc = new If(new IsNull(f.child(0)), new
NullLiteral(DoubleType.INSTANCE),
+ Expression expr = ifNullElse(f.child(0), new
NullLiteral(DoubleType.INSTANCE),
new DoubleLiteral(0));
- newOutput.add((NamedExpression) ne.withChildren(ifFunc));
+ newOutput.add((NamedExpression) ne.withChildren(expr));
} else {
return null;
}
@@ -154,4 +153,8 @@ public class EliminateGroupBy extends OneRewriteRuleFactory
{
}
return false;
}
+
+ private Expression ifNullElse(Expression conditionExpr, Expression ifExpr,
Expression elseExpr) {
+ return conditionExpr.nullable() ? new If(new IsNull(conditionExpr),
ifExpr, elseExpr) : elseExpr;
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
index 44c24a3a004..b8e2ab6a003 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
@@ -25,7 +26,7 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@@ -33,9 +34,11 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.JoinUtils;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.List;
@@ -79,24 +82,46 @@ public class EliminateJoinByFK extends
OneRewriteRuleFactory {
}
Set<Slot> output = project.getInputSlots();
Set<Slot> foreignKeys = Sets.intersection(foreign.getOutputSet(),
equalSet.getAllItemSet());
- Map<Expression, Expression> outputToForeign =
- tryMapOutputToForeignPlan(foreign, output, equalSet);
+ Map<Slot, Slot> outputToForeign = tryMapOutputToForeignPlan(foreign,
output, equalSet);
if (outputToForeign != null) {
+ Pair<Plan, Set<Slot>> newChildPair =
applyNullCompensationFilter(foreign, foreignKeys);
+ Map<Slot, Expression> replacedSlots =
getReplaceSlotMap(outputToForeign, newChildPair.second);
List<NamedExpression> newProjects = project.getProjects().stream()
- .map(e -> outputToForeign.containsKey(e)
- ? new Alias(e.getExprId(), outputToForeign.get(e),
e.toSql())
- : (NamedExpression) e.rewriteUp(s ->
outputToForeign.getOrDefault(s, s)))
+ .map(e -> replacedSlots.containsKey(e)
+ ? new Alias(e.getExprId(), replacedSlots.get(e),
e.toSql())
+ : (NamedExpression) e.rewriteUp(s ->
replacedSlots.getOrDefault(s, s)))
.collect(ImmutableList.toImmutableList());
- return project.withProjects(newProjects)
- .withChildren(applyNullCompensationFilter(foreign,
foreignKeys));
+ return
project.withProjects(newProjects).withChildren(newChildPair.first);
}
return project;
}
- private @Nullable Map<Expression, Expression>
tryMapOutputToForeignPlan(Plan foreignPlan,
+ /**
+ * get replace slots, include replace the primary slots and replace the
nullable foreign slots.
+ * @param outputToForeign primary slot to foreign slot map
+ * @param compensationForeignSlots foreign slots which are nullable but
add a filter 'slot is not null'
+ * @return the replaced map, include primary slot to foreign slot, and
foreign nullable slot to non-nullable(slot)
+ */
+ @VisibleForTesting
+ public Map<Slot, Expression> getReplaceSlotMap(Map<Slot, Slot>
outputToForeign,
+ Set<Slot> compensationForeignSlots) {
+ Map<Slot, Expression> replacedSlots = Maps.newHashMap();
+ for (Map.Entry<Slot, Slot> entry : outputToForeign.entrySet()) {
+ Slot forgeinSlot = entry.getValue();
+ Expression replacedExpr =
compensationForeignSlots.contains(forgeinSlot)
+ ? new NonNullable(forgeinSlot) : forgeinSlot;
+ replacedSlots.put(entry.getKey(), replacedExpr);
+ }
+ for (Slot forgeinSlot : compensationForeignSlots) {
+ replacedSlots.put(forgeinSlot, new NonNullable(forgeinSlot));
+ }
+ return replacedSlots;
+ }
+
+ private @Nullable Map<Slot, Slot> tryMapOutputToForeignPlan(Plan
foreignPlan,
Set<Slot> output, ImmutableEqualSet<Slot> equalSet) {
Set<Slot> residualPrimary = Sets.difference(output,
foreignPlan.getOutputSet());
- ImmutableMap.Builder<Expression, Expression> builder = new
ImmutableMap.Builder<>();
+ ImmutableMap.Builder<Slot, Slot> builder = new
ImmutableMap.Builder<>();
for (Slot primarySlot : residualPrimary) {
Optional<Slot> replacedForeign =
equalSet.calEqualSet(primarySlot).stream()
.filter(foreignPlan.getOutputSet()::contains)
@@ -109,14 +134,20 @@ public class EliminateJoinByFK extends
OneRewriteRuleFactory {
return builder.build();
}
- private Plan applyNullCompensationFilter(Plan child, Set<Slot> childSlots)
{
- Set<Expression> predicates = childSlots.stream()
- .filter(ExpressionTrait::nullable)
- .map(s -> new Not(new IsNull(s)))
- .collect(ImmutableSet.toImmutableSet());
- if (predicates.isEmpty()) {
- return child;
+ /**
+ * add a filter for foreign slots which is nullable, the filter is 'slot
is not null'
+ */
+ @VisibleForTesting
+ public Pair<Plan, Set<Slot>> applyNullCompensationFilter(Plan child,
Set<Slot> childSlots) {
+ ImmutableSet.Builder<Expression> predicatesBuilder =
ImmutableSet.builder();
+ Set<Slot> filterNotNullSlots = Sets.newHashSet();
+ for (Slot slot : childSlots) {
+ if (slot.nullable()) {
+ filterNotNullSlots.add(slot);
+ predicatesBuilder.add(new Not(new IsNull(slot)));
+ }
}
- return new LogicalFilter<>(predicates, child);
+ Plan newChild = filterNotNullSlots.isEmpty() ? child : new
LogicalFilter<>(predicatesBuilder.build(), child);
+ return Pair.of(newChild, filterNotNullSlots);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java
index b62fb64737a..00421ac0a6c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java
@@ -277,8 +277,6 @@ public class SaltJoin extends OneRewriteRuleFactory {
// lateral view explode_numbers(1000) tmp1 as explodeColumn
ImmutableList.Builder<List<NamedExpression>> constantExprsList =
ImmutableList.builderWithExpectedSize(
saltedSkewValues.size());
- List<NamedExpression> outputs = ImmutableList.of(new SlotReference(
- SKEW_VALUE_COLUMN_NAME + ctx.generateColumnName(),
skewExpr.getDataType(), false));
boolean saltedSkewValuesHasNull = false;
for (Expression skewValue : saltedSkewValues) {
constantExprsList.add(ImmutableList.of(new Alias(skewValue,
SKEW_VALUE_COLUMN_NAME
@@ -287,11 +285,13 @@ public class SaltJoin extends OneRewriteRuleFactory {
saltedSkewValuesHasNull = true;
}
}
+ List<NamedExpression> outputs = ImmutableList.of(new SlotReference(
+ SKEW_VALUE_COLUMN_NAME + ctx.generateColumnName(),
skewExpr.getDataType(), saltedSkewValuesHasNull));
LogicalUnion union = new LogicalUnion(Qualifier.ALL, outputs,
ImmutableList.of(), constantExprsList.build(),
false, ImmutableList.of());
List<Function> generators = ImmutableList.of(new ExplodeNumbers(new
IntegerLiteral(factor)));
SlotReference generateSlot = new
SlotReference(EXPLODE_NUMBER_COLUMN_NAME + ctx.generateColumnName(),
- IntegerType.INSTANCE, false);
+ IntegerType.INSTANCE, true);
LogicalGenerate<Plan> generate = new LogicalGenerate<>(generators,
ImmutableList.of(generateSlot), union);
ImmutableList.Builder<NamedExpression> projectsBuilder =
ImmutableList.builderWithExpectedSize(
union.getOutput().size() + 1);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
index 0e9756e6dda..659e4066113 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
@@ -133,6 +133,12 @@ public class WindowExpression extends Expression {
.orElseGet(() -> new WindowExpression(function, partitionKeys,
orderKeys, isSkew));
}
+ public WindowExpression withFunctionPartitionKeysOrderKeys(Expression
function,
+ List<Expression> partitionKeys, List<OrderExpression> orderKeys) {
+ return windowFrame.map(frame -> new WindowExpression(function,
partitionKeys, orderKeys, frame, isSkew))
+ .orElseGet(() -> new WindowExpression(function, partitionKeys,
orderKeys, isSkew));
+ }
+
@Override
public boolean nullable() {
return function.nullable();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index b24c9d59431..3ca6d433228 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.LogicalProperties;
+import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
@@ -146,7 +147,11 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
CHILD_TYPE child) {
super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties,
child);
this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
- this.outputExpressions = ImmutableList.copyOf(outputExpressions);
+ boolean noGroupby = groupByExpressions.isEmpty();
+ ImmutableList.Builder<NamedExpression> builder =
ImmutableList.builder();
+ outputExpressions.forEach(output -> builder.add(
+ (NamedExpression)
AdjustAggregateNullableForEmptySet.replaceExpression(output, noGroupby)));
+ this.outputExpressions = builder.build();
this.normalized = normalized;
this.ordinalIsResolved = ordinalIsResolved;
this.generated = generated;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
index f6d13f4acbe..6c73f5ac977 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
@@ -184,11 +184,44 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
if (markJoinSlotReference.isPresent()) {
builder.add(markJoinSlotReference.get());
}
+ // only scalar apply can be needAddSubOutputToProjects = true
if (needAddSubOutputToProjects) {
- if (isScalar()) {
-
builder.add(ScalarSubquery.getScalarQueryOutputAdjustNullable(right(),
correlationSlot));
+ // correlated apply right child may contain multiple output slots
+ // in rule ScalarApplyToJoin, only '(isCorrelated() &&
correlationFilter.isPresent())'
+ // but at analyzed phase, the correlationFilter is empty, only
after rule UnCorrelatedApplyAggregateFilter
+ // correlationFilter will be set, so we skip check
correlationFilter here.
+ // correlated apply will change to a left outer join, then all the
right child output will be nullable.
+ if (isCorrelated()) {
+ // for sql:
+ // `select t1.a,
+ // (select if(sum(t2.a) > 10, count(t2.b), max(t2.c))
as k from t2 where t1.a = t2.a)
+ // from t1`,
+ // its plan is:
+ // LogicalProject(t1.a, if(sum(t2.a) > 10, count(t2.b),
max(t2.c)) as k)
+ // |-- LogicalProject(..., if(sum(t2.a > 10),
ifnull(count(t2.b), 0), max(t2.c)) as k)
+ // |-- LogicalApply(correlationSlot = [t1.a])
+ // |-- LogicalOlapScan(t1)
+ // |-- LogicalAggregate(output = [sum(t2.a),
count(t2.b), max(t2.c)])
+ for (Slot slot : right().getOutput()) {
+ // in fact some slots may non-nullable, like count.
+ // but after convert correlated apply to left outer join,
all the join right child's slots
+ // will become nullable, so we let all slots be nullable,
then they wouldn't change nullable
+ // even after convert to join.
+ builder.add(slot.toSlot().withNullable(true));
+ }
} else {
- builder.add(right().getOutput().get(0));
+ // uncorrelated apply right child always contains one output
slot.
+ // for sql:
+ // `select t1.a,
+ // (select if(sum(t2.a) > 10, count(t2.b), max(t2.c))
as k from t2)
+ // from t1`,
+ // its plan is:
+ // LogicalProject(t1.a, k)
+ // |--LogicalApply(correlationSlot = [])
+ // |- LogicalOlapScan(t1)
+ // |- LogicalProject(if(sum(t2.a) > 10, count(t2.b),
max(t2.c)) as k)
+ // |-- LogicalAggregate(output = [sum(t2.a),
count(t2.b), max(t2.c)])
+
builder.add(ScalarSubquery.getScalarQueryOutputAdjustNullable(right(),
correlationSlot));
}
}
return builder.build();
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
index d0ce4bedc91..f73ee3425af 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
@@ -26,10 +26,12 @@ import
org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.types.BigIntType;
@@ -247,15 +249,15 @@ public class AnalyzeSubQueryTest extends
TestWithFeService implements MemoPatter
private void checkScalarSubquerySlotNullable(String sql, boolean
outputNullable) {
Plan root = PlanChecker.from(connectContext)
.analyze(sql)
- .applyTopDown(new LogicalSubQueryAliasToLogicalProject())
.getPlan();
List<LogicalProject<?>> projectList = Lists.newArrayList();
+ List<LogicalPlan> plansAboveApply = Lists.newArrayList();
root.foreach(plan -> {
if (plan instanceof LogicalProject && plan.child(0) instanceof
LogicalApply) {
projectList.add((LogicalProject<?>) plan);
- return true;
- } else {
- return false;
+ }
+ if (!(plan instanceof LogicalApply) && plan.anyMatch(p -> p
instanceof LogicalApply)) {
+ plansAboveApply.add((LogicalPlan) plan);
}
});
@@ -272,10 +274,26 @@ public class AnalyzeSubQueryTest extends
TestWithFeService implements MemoPatter
.findFirst().orElse(null);
Assertions.assertNotNull(output);
Assertions.assertEquals(outputNullable, output.nullable());
- output = apply.getOutput().stream()
+
+ Slot applySubqueySlot = apply.getOutput().stream()
.filter(e -> slotKName.contains(e.getName()))
.findFirst().orElse(null);
- Assertions.assertNotNull(output);
- Assertions.assertEquals(outputNullable, output.nullable());
+ Assertions.assertNotNull(applySubqueySlot);
+ if (apply.isCorrelated()) {
+ // apply will change to outer join
+ Assertions.assertTrue(applySubqueySlot.nullable());
+ } else {
+ Assertions.assertEquals(outputNullable,
applySubqueySlot.nullable());
+ }
+
+ for (LogicalPlan plan : plansAboveApply) {
+ Assertions.assertTrue(plan.getInputSlots().stream()
+ .filter(slot ->
slot.getExprId().equals(applySubqueySlot.getExprId()))
+ .allMatch(slot -> slot.nullable() ==
applySubqueySlot.nullable()));
+
+ Assertions.assertTrue(plan.getOutput().stream()
+ .filter(slot ->
slot.getExprId().equals(applySubqueySlot.getExprId()))
+ .allMatch(slot -> slot.nullable() ==
applySubqueySlot.nullable()));
+ }
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
index edcd9279343..a1935ad2984 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
@@ -194,7 +194,7 @@ public class AnalyzeWhereSubqueryTest extends
TestWithFeService implements MemoP
logicalAggregate().when(FieldChecker.check("outputExpressions",
ImmutableList.of(
new Alias(new ExprId(7), (new
Sum(
new SlotReference(new
ExprId(4), "k3", BigIntType.INSTANCE, true,
-
ImmutableList.of("test", "t7")))).withAlwaysNullable(true),
+
ImmutableList.of("test", "t7")))),
"sum(t7.k3)"),
new SlotReference(new
ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("test", "t7"))
@@ -473,7 +473,7 @@ public class AnalyzeWhereSubqueryTest extends
TestWithFeService implements MemoP
logicalProject()
).when(FieldChecker.check("outputExpressions",
ImmutableList.of(
new Alias(new ExprId(8), (new Max(new
SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, true,
-
ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)"),
+ ImmutableList.of("t2")))),
"max(aa)"),
new SlotReference(new ExprId(6), "v2",
BigIntType.INSTANCE, true,
ImmutableList.of("test",
"t7")))))
.when(FieldChecker.check("groupByExpressions", ImmutableList.of(
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
index 2fa945b0011..6a66c09b445 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
@@ -19,9 +19,11 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@@ -29,8 +31,10 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.FieldChecker;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@@ -41,10 +45,14 @@ import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
+import java.util.stream.Collectors;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class NormalizeAggregateTest extends TestWithFeService implements
MemoPatternMatchSupported {
@@ -59,6 +67,15 @@ public class NormalizeAggregateTest extends
TestWithFeService implements MemoPat
createTables(
"CREATE TABLE IF NOT EXISTS t1 (\n"
+ " id int not null,\n"
+ + " no int not null,\n"
+ + " name char\n"
+ + ")\n"
+ + "DUPLICATE KEY(id)\n"
+ + "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ + "PROPERTIES (\"replication_num\" = \"1\")\n",
+ "CREATE TABLE IF NOT EXISTS t2 (\n"
+ + " id int not null,\n"
+ + " no int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
@@ -301,4 +318,379 @@ public class NormalizeAggregateTest extends
TestWithFeService implements MemoPat
agg.getGroupByExpressions().size() == 1
&&
agg.getOutputExpressions().stream().anyMatch(e ->
e.toString().contains("COUNT"))));
}
+
+ @Test
+ void testAggFunctionNullabe() {
+ List<String> aggNullableSqls = ImmutableList.of(
+ // one row relation
+ "select sum(1) as k",
+
+ "select sum(id) as k from t1",
+ "select sum(id) as k from t1 where id > 10",
+
+ // sub query alias
+ "select * from (select sum(id) as k from t1) t",
+ "select * from (select sum(id) as k from t1 where id > 10) t",
+
+ // project sub query
+ "select id, (select sum(t2.id) as k from t2) from t1",
+ "select id, (select sum(t2.id) as k from t2 where t2.id > 10)
from t1",
+ "select id, (select sum(t2.id) as k from t2 where t1.id =
t2.id) from t1",
+
+ // filter sub query
+ "select * from t1 where t1.id > (select sum(t2.id) as k from
t2)",
+ "select * from t1 where t1.id > (select sum(t2.id) as k from
t2 where t2.id > 10)",
+ "select * from t1 where t1.id > (select sum(t2.id) as k from
t2 where t1.name = t2.name)"
+ );
+ for (String sql : aggNullableSqls) {
+ checkAggFunctionNullable(sql, true);
+ }
+
+ List<String> aggNotNullableSqls = ImmutableList.of(
+ "select sum(id) as k from t1 group by name",
+ "select sum(id) as k from t1 group by 'abcde' ",
+ "select sum(id) as k from t1 where id > 10 group by name",
+ "select sum(id) as k from t1 where id > 10 group by 'abcde' ",
+
+ // sub query alias
+ "select * from (select sum(id) as k from t1 group by name) t",
+ "select * from (select sum(id) as k from t1 group by 'abcde')
t",
+ "select * from (select sum(id) as k from t1 where id > 10
group by name) t",
+ "select * from (select sum(id) as k from t1 where id > 10
group by 'abcde') t"
+ );
+ for (String sql : aggNotNullableSqls) {
+ checkAggFunctionNullable(sql, false);
+ }
+ }
+
+ private void checkAggFunctionNullable(String sql, boolean nullable) {
+ List<LogicalAggregate<?>> aggList = Lists.newArrayList();
+ List<LogicalProject<?>> projectList = Lists.newArrayList();
+ List<LogicalApply<?, ?>> applyList = Lists.newArrayList();
+ List<LogicalPlan> planAboveApply = Lists.newArrayList();
+ List<LogicalPlan> planAboveAgg = Lists.newArrayList();
+ Plan root = PlanChecker.from(connectContext)
+ .analyze(sql).getPlan();
+ root.foreach(plan -> {
+ if (plan instanceof LogicalAggregate) {
+ aggList.add((LogicalAggregate<?>) plan);
+ } else if (plan instanceof LogicalProject) {
+ projectList.add((LogicalProject<?>) plan);
+ } else if (plan instanceof LogicalApply) {
+ applyList.add((LogicalApply<?, ?>) plan);
+ }
+
+ if (!(plan instanceof LogicalApply) && plan.anyMatch(p -> p
instanceof LogicalApply)) {
+ planAboveApply.add((LogicalPlan) plan);
+ }
+ if (!(plan instanceof LogicalAggregate)
+ && plan.anyMatch(p -> p instanceof LogicalAggregate)
+ && !(plan.anyMatch(p -> p instanceof LogicalApply))) {
+ planAboveAgg.add((LogicalPlan) plan);
+ }
+ });
+ List<String> slotKName = ImmutableList.of("k");
+
+ Assertions.assertEquals(1, aggList.size());
+ LogicalAggregate<?> agg = aggList.get(0);
+ NamedExpression slotK = agg.getOutputExpressions().stream()
+ .filter(output -> slotKName.contains(output.getName()))
+ .findFirst().orElse(null);
+ Assertions.assertNotNull(slotK);
+ Assertions.assertEquals(nullable, slotK.nullable());
+
+ Assertions.assertTrue(applyList.size() <= 1);
+ Slot applySlot = null;
+ if (applyList.size() == 1) {
+ LogicalApply<?, ?> apply = applyList.get(0);
+ applySlot = apply.getOutput().stream()
+ .filter(output ->
output.getExprId().equals(slotK.getExprId()))
+ .findFirst().orElse(null);
+ Assertions.assertNotNull(applySlot);
+ Assertions.assertTrue(applySlot.nullable());
+ }
+ for (LogicalProject<?> project : projectList) {
+ if (!project.anyMatch(plan -> plan instanceof LogicalAggregate)) {
+ continue;
+ }
+
+ NamedExpression expr = project.getProjects().stream()
+ .filter(output ->
output.getExprId().equals(slotK.getExprId()))
+ .findFirst().orElse(null);
+ if (expr == null) {
+ expr = project.getProjects().stream()
+ .map(output -> output instanceof Alias &&
output.child(0) instanceof SlotReference
+ ? (SlotReference) output.child(0) : output)
+ .filter(output ->
output.getExprId().equals(slotK.getExprId()))
+ .findFirst().orElse(null);
+ }
+ if (expr == null) {
+ continue;
+ }
+
+ boolean aboveApply = project.anyMatch(plan -> plan instanceof
LogicalApply);
+ if (aboveApply) {
+ Assertions.assertTrue(expr.nullable());
+ } else {
+ Assertions.assertEquals(nullable, expr.nullable());
+ }
+ }
+
+ if (applySlot != null) {
+ ExprId applySlotExprId = applySlot.getExprId();
+ boolean applySlotNullable = applySlot.nullable();
+ for (LogicalPlan plan : planAboveApply) {
+ Assertions.assertTrue(plan.getInputSlots().stream()
+ .filter(slot ->
slot.getExprId().equals(applySlotExprId))
+ .allMatch(slot -> slot.nullable() ==
applySlotNullable));
+ Assertions.assertTrue(plan.getOutput().stream()
+ .filter(slot ->
slot.getExprId().equals(applySlotExprId))
+ .allMatch(slot -> slot.nullable() ==
applySlotNullable));
+ }
+ }
+ for (LogicalPlan plan : planAboveAgg) {
+ ExprId kSlotExprId = slotK.getExprId();
+ boolean kSlotNullable = slotK.nullable();
+ Assertions.assertTrue(plan.getInputSlots().stream()
+ .filter(slot -> slot.getExprId().equals(kSlotExprId))
+ .allMatch(slot -> slot.nullable() == kSlotNullable));
+ Assertions.assertTrue(plan.getOutput().stream()
+ .filter(slot -> slot.getExprId().equals(kSlotExprId))
+ .allMatch(slot -> slot.nullable() == kSlotNullable));
+ }
+ }
+
+ @Test
+ void testAggFunctionNullabe2() {
+ PlanChecker.from(connectContext)
+ .analyze("select sum(id) from t1")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalAggregate().when(agg -> {
+ List<Slot> output =
agg.getOutput();
+ checkExprsToSql(output, "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select 1 from t1 having sum(id) > 10")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalFilter(
+ logicalProject(
+ logicalProject(
+
logicalAggregate().when(agg -> {
+ List<Slot> output
= agg.getOutput();
+
checkExprsToSql(output, "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression>
projects = project.getProjects();
+ checkExprsToSql(projects,
"sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "1 AS
`1`", "sum(id)");
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(filter -> {
+ List<Expression> conjuncts =
filter.getExpressions();
+ checkExprsToSql(conjuncts, "(sum(id) >
10)");
+
Assertions.assertTrue(conjuncts.get(0).child(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(id), sum(no) from t1 having sum(id) > 10")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalProject(
+ logicalFilter(
+
logicalAggregate().when(agg -> {
+ List<Slot> output
= agg.getOutput();
+
checkExprsToSql(output, "sum(id)", "sum(no)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+ return true;
+ })
+ ).when(filter -> {
+ List<Expression> conjuncts
= filter.getExpressions();
+ checkExprsToSql(conjuncts,
"(sum(id) > 10)");
+
Assertions.assertTrue(conjuncts.get(0).child(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects,
"sum(id)", "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(id)",
"sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(id), sum(no) from t1 order by sum(id)")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalSort(
+ logicalProject(
+
logicalAggregate().when(agg -> {
+ List<Slot> output =
agg.getOutput();
+
checkExprsToSql(output, "sum(id)", "sum(no)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects
= project.getProjects();
+ checkExprsToSql(projects,
"sum(id)", "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sort -> {
+ List<? extends Expression> keys =
sort.getExpressions();
+ checkExprsToSql(keys, "sum(id)");
+
Assertions.assertTrue(keys.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(no) from t1 order by sum(id)")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalSort(
+ logicalProject(
+
logicalAggregate().when(agg -> {
+ List<Slot> output =
agg.getOutput();
+
checkExprsToSql(output, "sum(no)", "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects
= project.getProjects();
+ checkExprsToSql(projects,
"sum(no)", "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sort -> {
+ List<? extends Expression> keys =
sort.getExpressions();
+ checkExprsToSql(keys, "sum(id)");
+
Assertions.assertTrue(keys.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(no) from t1 having sum(no) > 10 order by
sum(id)")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalSort(
+ logicalProject(
+ logicalProject(
+ logicalFilter(
+
logicalAggregate().when(agg -> {
+
List<Slot> output = agg.getOutput();
+
checkExprsToSql(output, "sum(no)", "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+
return true;
+ })
+ ).when(filter
-> {
+
List<Expression> conjuncts = filter.getExpressions();
+
checkExprsToSql(conjuncts, "(sum(no) > 10)");
+
Assertions.assertTrue(conjuncts.get(0).child(0).nullable());
+ return
true;
+ })
+ ).when(project -> {
+
List<NamedExpression> projects = project.getProjects();
+
checkExprsToSql(projects, "sum(no)", "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression>
projects = project.getProjects();
+ checkExprsToSql(projects,
"sum(no)", "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sort -> {
+ List<? extends Expression> keys =
sort.getExpressions();
+ checkExprsToSql(keys, "sum(id)");
+
Assertions.assertTrue(keys.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ // a window function, not agg
+ PlanChecker.from(connectContext)
+ .analyze("select sum(1) over()")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalOneRowRelation()
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(1) OVER()
AS `sum(1) over()`");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ ).when(sink -> {
+
Assertions.assertTrue(sink.getOutput().get(0).nullable());
+ return true;
+ })
+ );
+ }
+
+ private void checkExprsToSql(Collection<? extends Expression> expressions,
String... exprsToSql) {
+ Assertions.assertEquals(Arrays.asList(exprsToSql),
+
expressions.stream().map(Expression::toSql).collect(Collectors.toList()));
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
index d7dad886ac3..443dbaebd8f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
@@ -74,6 +74,7 @@ public class AggScalarSubQueryToWindowFunctionTest extends
TPCHTestBase implemen
@Test
public void testRuleOnTPCHTest() {
+ connectContext.getSessionVariable().feDebug = false;
check(TPCHUtils.Q2);
check(TPCHUtils.Q17);
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
index ab1665c1c30..7622ae601ca 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
@@ -17,13 +17,34 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
class EliminateJoinByFkTest extends TestWithFeService implements
MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@@ -113,15 +134,37 @@ class EliminateJoinByFkTest extends TestWithFeService
implements MemoPatternMatc
@Test
void testNull() throws Exception {
- String sql = "select pri.id1 from pri inner join foreign_null on
pri.id1 = foreign_null.id3";
+ String sql = "select pri.id1, 1 + foreign_null.id3 as k from pri inner
join foreign_null on pri.id1 = foreign_null.id3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin())
- .matches(logicalFilter().when(f -> {
- Assertions.assertTrue(f.getPredicate().toSql().contains("(
not id3 IS NULL)"));
- return true;
- }))
+ .matches(
+ logicalResultSink(
+ logicalProject(
+ logicalFilter().when(f -> {
+
Assertions.assertTrue(f.getPredicate().toSql().contains("( not id3 IS NULL)"));
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ Assertions.assertEquals(2, projects.size());
+ Assertions.assertEquals("non_nullable(id3) AS
`id1`", projects.get(0).toSql());
+ Assertions.assertEquals("(cast(non_nullable(id3)
as BIGINT) + 1) AS `k`", projects.get(1).toSql());
+ Assertions.assertFalse(projects.get(0).nullable());
+ Assertions.assertFalse(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sink -> {
+ List<NamedExpression> projects = sink.getOutputExprs();
+ Assertions.assertEquals(2, projects.size());
+ Assertions.assertEquals("id1",
projects.get(0).toSql());
+ Assertions.assertFalse(projects.get(0).nullable());
+ Assertions.assertEquals("k", projects.get(1).toSql());
+ Assertions.assertFalse(projects.get(1).nullable());
+ return true;
+ })
+ )
.printlnTree();
sql = "select foreign_null.id3 from pri inner join foreign_null on
pri.id1 = foreign_null.id3";
PlanChecker.from(connectContext)
@@ -204,4 +247,47 @@ class EliminateJoinByFkTest extends TestWithFeService
implements MemoPatternMatc
.rewrite()
.matches(logicalOlapScan().when(scan ->
scan.getTable().getName().equals("pri")));
}
+
+ @Test
+ void testReplaceMap() {
+ Slot a = new SlotReference("a", IntegerType.INSTANCE);
+ Slot b = new SlotReference("b", IntegerType.INSTANCE);
+ Slot x = new SlotReference("x", IntegerType.INSTANCE);
+ Slot y = new SlotReference("y", IntegerType.INSTANCE);
+ Slot z = new SlotReference("z", IntegerType.INSTANCE);
+ Map<Slot, Slot> outputToForeign = Maps.newHashMap();
+ outputToForeign.put(a, x);
+ outputToForeign.put(b, y);
+
+ Set<Slot> compensationForeignSlots = Sets.newHashSet();
+ compensationForeignSlots.add(x);
+ compensationForeignSlots.add(z);
+
+ Map<Slot, Expression> replacedSlots = new
EliminateJoinByFK().getReplaceSlotMap(outputToForeign,
compensationForeignSlots);
+ Map<Slot, Expression> expectedReplacedSlots = Maps.newHashMap();
+ expectedReplacedSlots.put(a, new NonNullable(x));
+ expectedReplacedSlots.put(b, y);
+ expectedReplacedSlots.put(x, new NonNullable(x));
+ expectedReplacedSlots.put(z, new NonNullable(z));
+ Assertions.assertEquals(expectedReplacedSlots, replacedSlots);
+ }
+
+ @Test
+ void testyNullCompensationFilter() {
+ EliminateJoinByFK instance = new EliminateJoinByFK();
+ SlotReference notNull1 = new SlotReference("notNull1",
IntegerType.INSTANCE, false);
+ SlotReference notNull2 = new SlotReference("notNull2",
IntegerType.INSTANCE, false);
+ SlotReference null1 = new SlotReference("null1", IntegerType.INSTANCE,
true);
+ SlotReference null2 = new SlotReference("null2", IntegerType.INSTANCE,
true);
+ LogicalOneRowRelation oneRowRelation = new LogicalOneRowRelation(new
RelationId(100), ImmutableList.of());
+ Pair<Plan, Set<Slot>> result1 =
instance.applyNullCompensationFilter(oneRowRelation, ImmutableSet.of(notNull1,
notNull2));
+ Assertions.assertEquals(ImmutableSet.of(), result1.second);
+ Assertions.assertEquals(oneRowRelation, result1.first);
+ Pair<Plan, Set<Slot>> result2 =
instance.applyNullCompensationFilter(oneRowRelation, ImmutableSet.of(notNull1,
notNull2, null1, null2));
+ Assertions.assertEquals(ImmutableSet.of(null1, null2), result2.second);
+ LogicalFilter<?> expectFilter = new LogicalFilter<>(
+ ImmutableSet.of(new Not(new IsNull(null1)), new Not(new
IsNull(null2))),
+ oneRowRelation);
+ Assertions.assertEquals(expectFilter, result2.first);
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
index 3e1c4d58c40..6cb8f9afb8a 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
@@ -56,6 +56,8 @@ class PullUpProjectUnderApplyTest extends TestWithFeService
implements MemoPatte
@Test
void testPullUpProjectUnderApply() {
+ connectContext.getSessionVariable().feDebug = false;
+
List<String> testSql = ImmutableList.of(
"select * from T as T1 where id = (select max(id) from T as T2
where T1.score = T2.score)",
"select * from T as T1 where id = (select max(id) + 1 from T
as T2 where T1.score = T2.score)"
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregateTest.java
new file mode 100644
index 00000000000..eb6e3ea5fc8
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregateTest.java
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans.logical;
+
+import org.apache.doris.nereids.properties.OrderKey;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.OrderExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
+import org.apache.doris.nereids.types.IntegerType;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class LogicalAggregateTest {
+
+ @Test
+ void testAdjustAggNullableWithEmptyGroupBy() {
+ SlotReference a = new SlotReference("a", IntegerType.INSTANCE, false);
+ SlotReference b = new SlotReference("b", IntegerType.INSTANCE, false);
+
+ LogicalOneRowRelation oneRowRelation = new
LogicalOneRowRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(a, b));
+
+ // agg with empty group by
+ NamedExpression originOutput1 = new Alias(new Add(new Sum(a), new
IntegerLiteral(1)));
+ NamedExpression originOutput2 = new Alias(new WindowExpression(
+ new Sum(false, true, new Add(new Sum(b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true)))));
+ Assertions.assertFalse(originOutput1.nullable());
+ LogicalAggregate<LogicalOneRowRelation> agg = new LogicalAggregate<>(
+ ImmutableList.of(), ImmutableList.of(originOutput1,
originOutput2), oneRowRelation);
+ NamedExpression output1 = agg.getOutputs().get(0);
+ NamedExpression output2 = agg.getOutputs().get(1);
+ Assertions.assertNotEquals(originOutput1, output1);
+ Assertions.assertNotEquals(originOutput2, output2);
+ Assertions.assertTrue(output1.nullable());
+ Expression expectOutput1Child = new Add(new Sum(false, true, a), new
IntegerLiteral(1));
+ Expression expectOutput2Child = new WindowExpression(
+ new Sum(false, true, new Add(new Sum(false, true, b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true))));
+ Assertions.assertEquals(expectOutput1Child, output1.child(0));
+ Assertions.assertEquals(expectOutput2Child, output2.child(0));
+ }
+
+ @Test
+ void testAdjustAggNullableWithNotEmptyGroupBy() {
+ SlotReference a = new SlotReference("a", IntegerType.INSTANCE, false);
+ SlotReference b = new SlotReference("b", IntegerType.INSTANCE, false);
+
+ LogicalOneRowRelation oneRowRelation = new
LogicalOneRowRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(a, b));
+
+ // agg with not empty group by
+ NamedExpression originOutput1 = new Alias(new Add(new Sum(false, true,
a), new IntegerLiteral(1)));
+ NamedExpression originOutput2 = new Alias(new WindowExpression(
+ new Sum(false, true, new Add(new Sum(false, true, b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true)))));
+ Assertions.assertTrue(originOutput1.nullable());
+ LogicalAggregate<LogicalOneRowRelation> agg = new LogicalAggregate<>(
+ ImmutableList.of(new TinyIntLiteral((byte) 1)),
ImmutableList.of(originOutput1, originOutput2), oneRowRelation);
+ NamedExpression output1 = agg.getOutputs().get(0);
+ NamedExpression output2 = agg.getOutputs().get(1);
+ Assertions.assertNotEquals(originOutput1, output1);
+ Assertions.assertNotEquals(originOutput2, output2);
+ Assertions.assertFalse(output1.nullable());
+ Expression expectOutput1Child = new Add(new Sum(false, false, a), new
IntegerLiteral(1));
+ Expression expectOutput2Child = new WindowExpression(
+ new Sum(false, true, new Add(new Sum(false, false, b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true))));
+ Assertions.assertEquals(expectOutput1Child, output1.child(0));
+ Assertions.assertEquals(expectOutput2Child, output2.child(0));
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
index 430d4cdc5a7..074849218d7 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
@@ -160,6 +160,7 @@ public abstract class TestWithFeService {
FeConstants.disableWGCheckerForUT = true;
beforeCreatingConnectContext();
connectContext = createDefaultCtx();
+ connectContext.getSessionVariable().feDebug = true;
beforeCluster();
createDorisCluster();
Env.getCurrentEnv().getWorkloadGroupMgr().createNormalWorkloadGroupForUT();
diff --git
a/regression-test/data/nereids_rules_p0/adjust_nullable/test_agg_nullable.out
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_agg_nullable.out
new file mode 100644
index 00000000000..7ec09edaf5d
Binary files /dev/null and
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_agg_nullable.out
differ
diff --git a/regression-test/data/nereids_rules_p0/salt_join/salt_join.out
b/regression-test/data/nereids_rules_p0/salt_join/salt_join.out
index fba356d3583..7a8aa656ca6 100644
Binary files a/regression-test/data/nereids_rules_p0/salt_join/salt_join.out
and b/regression-test/data/nereids_rules_p0/salt_join/salt_join.out differ
diff --git
a/regression-test/suites/nereids_rules_p0/adjust_nullable/test_agg_nullable.groovy
b/regression-test/suites/nereids_rules_p0/adjust_nullable/test_agg_nullable.groovy
new file mode 100644
index 00000000000..60e3342ea0f
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/adjust_nullable/test_agg_nullable.groovy
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite('test_agg_nullable') {
+ sql 'DROP TABLE IF EXISTS test_agg_nullable_t1 FORCE'
+ sql "CREATE TABLE test_agg_nullable_t1(a int not null, b int not null, c
int not null) distributed by hash(a) properties('replication_num' = '1')"
+ sql "SET detail_shape_nodes='PhysicalProject'"
+ order_qt_agg_nullable '''
+ select k > 10 and k < 5 from (select sum(a) as k from
test_agg_nullable_t1) s
+ '''
+ qt_agg_nullable_shape '''explain shape plan
+ select k > 10 and k < 5 from (select sum(a) as k from
test_agg_nullable_t1) s
+ '''
+ sql 'DROP TABLE IF EXISTS test_agg_nullable_t1 FORCE'
+}
+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]