This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 22eb6803a8f [feat](unique function) add project for unique function
(#48449)
22eb6803a8f is described below
commit 22eb6803a8f2c8b9299e2bf2063eb7d62b72f3ec
Author: yujun <[email protected]>
AuthorDate: Sat Sep 6 00:34:33 2025 +0800
[feat](unique function) add project for unique function (#48449)
### What problem does this PR solve?
If an unique function exists multiple times, then be will calculate it
multiple times for each row, so it will be error.
for example: `filter(random() between 10 and 20)`, after rewrite the
`between`, it will get `filter(random() >= 10 and random() <= 20)`, this
will contains two random in one expression, the two RANDOM should be
one, so we add a project, then we can get `filter(k >= 10 and k <= 20)
-> project(random() as k)`
this PR also fix BETWEEN expression bug introduced by #55407
- between shouldn't be PropagateNullable because FoldConstantRuleOnFE
will rewrite a propagate nullable expression to null if any children is
NULL, but for sql `10 between null and 5` should be `FALSE`, not `NULL`;
- after analyzed between expression, it will get an AND expression, then
anlyzed join other conjunctions, need to extract conjuncts of each
analyzed other conjunction (that is flattern AND).
---
.../doris/nereids/jobs/executor/Rewriter.java | 7 +
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../nereids/rules/analysis/BindExpression.java | 9 +-
.../rules/rewrite/AddProjectForUniqueFunction.java | 299 +++++++++++++++++++++
.../rewrite/MergeOneRowRelationIntoUnion.java | 11 +-
.../doris/nereids/trees/expressions/Between.java | 17 +-
.../apache/doris/nereids/util/ExpressionUtils.java | 14 +
.../rules/expression/SimplifyRangeTest.java | 12 +-
.../rewrite/AddProjectForUniqueFunctionTest.java | 143 ++++++++++
.../add_project_for_unique_function.out | Bin 0 -> 13431 bytes
.../add_project_for_unique_function.groovy | 137 ++++++++++
.../nereids_rules_p0/unique_function/load.groovy | 13 +
12 files changed, 655 insertions(+), 8 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index f73b75d2dd6..f2204fc561d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -34,6 +34,7 @@ import
org.apache.doris.nereids.rules.expression.NullableDependentExpressionRewr
import org.apache.doris.nereids.rules.expression.QueryColumnCollector;
import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit;
import org.apache.doris.nereids.rules.rewrite.AddProjectForJoin;
+import org.apache.doris.nereids.rules.rewrite.AddProjectForUniqueFunction;
import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
import
org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
@@ -739,6 +740,12 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new SumLiteralRewrite(),
new MergePercentileToArray())
),
+ topic("add projection for unique function",
+ // separate AddProjectForUniqueFunction and
MergeProjectable
+ // to avoid dead loop if code has bug
+ topDown(new AddProjectForUniqueFunction()),
+ topDown(new MergeProjectable())
+ ),
topic("collect scan filter for hbo",
// this rule is to collect filter on basic table for hbo
usage
topDown(new CollectPredicateOnScan())
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index cd7bbd89712..478e6f35745 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -214,6 +214,7 @@ public enum RuleType {
PUSH_DOWN_DISTINCT_THROUGH_JOIN(RuleTypeClass.REWRITE),
ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE),
+ ADD_PROJECT_FOR_UNIQUE_FUNCTION(RuleTypeClass.REWRITE),
VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE),
CLEAR_CONTEXT_STATUS(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index 915501c8dbe..cc144702fbd 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -607,9 +607,12 @@ public class BindExpression implements AnalysisRuleFactory
{
Builder<Expression> otherJoinConjuncts =
ImmutableList.builderWithExpectedSize(
join.getOtherJoinConjuncts().size());
for (Expression otherJoinConjunct : join.getOtherJoinConjuncts()) {
- otherJoinConjunct = analyzer.analyze(otherJoinConjunct);
- otherJoinConjunct =
TypeCoercionUtils.castIfNotSameType(otherJoinConjunct, BooleanType.INSTANCE);
- otherJoinConjuncts.add(otherJoinConjunct);
+ // after analyzed, 'a between 1 and 10' will rewrite to 'a >= 1
and a <= 10'
+ Expression boundExpr = analyzer.analyze(otherJoinConjunct);
+ for (Expression conjunct :
ExpressionUtils.extractConjunction(boundExpr)) {
+ conjunct = TypeCoercionUtils.castIfNotSameType(conjunct,
BooleanType.INSTANCE);
+ otherJoinConjuncts.add(conjunct);
+ }
}
return new LogicalJoin<>(join.getJoinType(),
hashJoinConjuncts.build(), otherJoinConjuncts.build(),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java
new file mode 100644
index 00000000000..2c2537ac476
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunction.java
@@ -0,0 +1,299 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.JoinUtils;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Optional;
+
+/** extract unique function expression which exist multiple times, and add
them to a new project child.
+ * for example:
+ * before rewrite: filter(random() >= 5 and random() <= 10), suppose the two
random have the same unique expr id.
+ * after rewrite: filter(k >= 5 and k <= 10) -> project(random() as k)
+ */
+public class AddProjectForUniqueFunction implements RewriteRuleFactory {
+
+ @Override
+ public List<Rule> buildRules() {
+ return ImmutableList.of(
+ new GenerateRewrite().build(),
+ new OneRowRelationRewrite().build(),
+ new ProjectRewrite().build(),
+ new FilterRewrite().build(),
+ new HavingRewrite().build(),
+ new AggregateRewrite().build(),
+ new JoinRewrite().build()
+ );
+ }
+
+ private class GenerateRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalGenerate().thenApply(ctx -> {
+ LogicalGenerate<Plan> generate = ctx.root;
+ Optional<Pair<List<Function>, LogicalProject<Plan>>>
+ rewrittenOpt = rewriteExpressions(generate,
generate.getGenerators());
+ if (rewrittenOpt.isPresent()) {
+ return generate.withGenerators(rewrittenOpt.get().first)
+ .withChildren(rewrittenOpt.get().second);
+ } else {
+ return generate;
+ }
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ private class OneRowRelationRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalOneRowRelation().thenApply(ctx -> {
+ LogicalOneRowRelation oneRowRelation = ctx.root;
+ List<NamedExpression> uniqueFunctionAlias =
tryGenUniqueFunctionAlias(oneRowRelation.getProjects());
+ if (uniqueFunctionAlias.isEmpty()) {
+ return oneRowRelation;
+ }
+
+ Map<Expression, Slot> replaceMap = Maps.newHashMap();
+ for (NamedExpression alias : uniqueFunctionAlias) {
+ replaceMap.put(alias.child(0), alias.toSlot());
+ }
+ ImmutableList.Builder<NamedExpression> newProjectBuilder
+ =
ImmutableList.builderWithExpectedSize(oneRowRelation.getProjects().size());
+ for (NamedExpression expr : oneRowRelation.getProjects()) {
+ newProjectBuilder.add((NamedExpression)
ExpressionUtils.replace(expr, replaceMap));
+ }
+ return new LogicalProject<>(
+ newProjectBuilder.build(),
+ oneRowRelation.withProjects(uniqueFunctionAlias));
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ private class ProjectRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalProject().thenApply(ctx -> {
+ LogicalProject<Plan> project = ctx.root;
+ Optional<Pair<List<NamedExpression>, LogicalProject<Plan>>>
+ rewrittenOpt = rewriteExpressions(project,
project.getProjects());
+ if (rewrittenOpt.isPresent()) {
+ return
project.withProjectsAndChild(rewrittenOpt.get().first,
rewrittenOpt.get().second);
+ } else {
+ return project;
+ }
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ private class FilterRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalFilter().thenApply(ctx -> {
+ LogicalFilter<Plan> filter = ctx.root;
+ Optional<Pair<List<Expression>, LogicalProject<Plan>>>
+ rewrittenOpt = rewriteExpressions(filter,
filter.getConjuncts());
+ if (rewrittenOpt.isPresent()) {
+ return filter.withConjunctsAndChild(
+ ImmutableSet.copyOf(rewrittenOpt.get().first),
+ rewrittenOpt.get().second);
+ } else {
+ return filter;
+ }
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ private class HavingRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalHaving().thenApply(ctx -> {
+ LogicalHaving<Plan> having = ctx.root;
+ Optional<Pair<List<Expression>, LogicalProject<Plan>>>
+ rewrittenOpt = rewriteExpressions(having,
having.getConjuncts());
+ if (rewrittenOpt.isPresent()) {
+ return
having.withConjuncts(ImmutableSet.copyOf(rewrittenOpt.get().first))
+ .withChildren(rewrittenOpt.get().second);
+ } else {
+ return having;
+ }
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ private class AggregateRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalAggregate().thenApply(ctx -> {
+ LogicalAggregate<Plan> aggregate = ctx.root;
+ List<Expression> targets = Lists.newArrayList();
+ targets.addAll(aggregate.getGroupByExpressions());
+ targets.addAll(aggregate.getOutputExpressions());
+ Optional<Pair<List<Expression>, LogicalProject<Plan>>>
rewrittenOpt
+ = rewriteExpressions(aggregate, targets);
+ if (!rewrittenOpt.isPresent()) {
+ return aggregate;
+ }
+
+ LogicalProject<Plan> newChild = rewrittenOpt.get().second;
+ List<Expression> newTargets = rewrittenOpt.get().first;
+ int groupBySize = aggregate.getGroupByExpressions().size();
+ ImmutableList<Expression> newGroupBy = ImmutableList.copyOf(
+ newTargets.subList(0, groupBySize));
+ ImmutableList.Builder<NamedExpression> newOutputBuilder
+ =
ImmutableList.builderWithExpectedSize(aggregate.getOutputExpressions().size());
+ for (int i = groupBySize; i < newTargets.size(); i++) {
+ newOutputBuilder.add((NamedExpression) newTargets.get(i));
+ }
+ return aggregate.withChildGroupByAndOutput(newGroupBy,
newOutputBuilder.build(), newChild);
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ private class JoinRewrite extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalJoin().thenApply(ctx -> {
+ LogicalJoin<Plan, Plan> join = ctx.root;
+ int hashOtherConjunctsSize =
join.getHashJoinConjuncts().size() + join.getOtherJoinConjuncts().size();
+ int totalConjunctsSize = hashOtherConjunctsSize +
join.getMarkJoinConjuncts().size();
+ List<Expression> allConjuncts =
Lists.newArrayListWithExpectedSize(totalConjunctsSize);
+ allConjuncts.addAll(join.getHashJoinConjuncts());
+ allConjuncts.addAll(join.getOtherJoinConjuncts());
+ allConjuncts.addAll(join.getMarkJoinConjuncts());
+ Optional<Pair<List<Expression>, LogicalProject<Plan>>>
rewrittenOpt
+ = rewriteExpressions(join, allConjuncts);
+ if (!rewrittenOpt.isPresent()) {
+ return join;
+ }
+
+ LogicalProject<Plan> newLeftChild = rewrittenOpt.get().second;
+ List<Expression> newAllConjuncts = rewrittenOpt.get().first;
+ List<Expression> newHashOtherConjuncts =
newAllConjuncts.subList(0, hashOtherConjunctsSize);
+ List<Expression> newMarkJoinConjuncts = ImmutableList.copyOf(
+ newAllConjuncts.subList(hashOtherConjunctsSize,
totalConjunctsSize));
+ // TODO: code from FindHashConditionForJoin
+ Pair<List<Expression>, List<Expression>> pair =
JoinUtils.extractExpressionForHashTable(
+ newLeftChild.getOutput(), join.right().getOutput(),
newHashOtherConjuncts);
+ List<Expression> newHashJoinConjuncts = pair.first;
+ List<Expression> newOtherJoinConjuncts = pair.second;
+ JoinType joinType = join.getJoinType();
+ if (joinType == JoinType.CROSS_JOIN &&
!newHashJoinConjuncts.isEmpty()) {
+ joinType = JoinType.INNER_JOIN;
+ }
+ return new LogicalJoin<>(joinType,
+ newHashJoinConjuncts,
+ newOtherJoinConjuncts,
+ newMarkJoinConjuncts,
+ join.getDistributeHint(),
+ join.getMarkJoinSlotReference(),
+ ImmutableList.of(newLeftChild, join.right()),
+ join.getJoinReorderContext());
+ }).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
+ }
+ }
+
+ /**
+ * extract unique function which exist multiple times from targets,
+ * then alias the unique function and put them into a child project,
+ * then rewrite targets with the alias names.
+ */
+ @VisibleForTesting
+ public <T extends Expression> Optional<Pair<List<T>,
LogicalProject<Plan>>> rewriteExpressions(
+ LogicalPlan plan, Collection<T> targets) {
+ List<NamedExpression> uniqueFunctionAlias =
tryGenUniqueFunctionAlias(targets);
+ if (uniqueFunctionAlias.isEmpty()) {
+ return Optional.empty();
+ }
+
+ List<NamedExpression> projects =
ImmutableList.<NamedExpression>builder()
+ .addAll(plan.child(0).getOutputSet())
+ .addAll(uniqueFunctionAlias)
+ .build();
+
+ Map<Expression, Slot> replaceMap = Maps.newHashMap();
+ for (NamedExpression alias : uniqueFunctionAlias) {
+ replaceMap.put(alias.child(0), alias.toSlot());
+ }
+ ImmutableList.Builder<T> newTargetsBuilder =
ImmutableList.builderWithExpectedSize(targets.size());
+ for (T target : targets) {
+ newTargetsBuilder.add((T) ExpressionUtils.replace(target,
replaceMap));
+ }
+
+ return Optional.of(Pair.of(newTargetsBuilder.build(), new
LogicalProject<>(projects, plan.child(0))));
+ }
+
+ /**
+ * if a unique function exists multiple times in the targets, then add a
project to alias it.
+ */
+ @VisibleForTesting
+ public List<NamedExpression> tryGenUniqueFunctionAlias(Collection<?
extends Expression> targets) {
+ Map<UniqueFunction, Integer> unqiueFunctionCounter =
Maps.newLinkedHashMap();
+ for (Expression target : targets) {
+ target.foreach(e -> {
+ Expression expr = (Expression) e;
+ if (expr instanceof UniqueFunction) {
+ unqiueFunctionCounter.merge((UniqueFunction) expr, 1,
Integer::sum);
+ }
+ });
+ }
+
+ ImmutableList.Builder<NamedExpression> builder
+ =
ImmutableList.builderWithExpectedSize(unqiueFunctionCounter.size());
+ for (Entry<UniqueFunction, Integer> entry :
unqiueFunctionCounter.entrySet()) {
+ if (entry.getValue() > 1) {
+ ExprId exprId = StatementScopeIdGenerator.newExprId();
+ String name = "$_" + entry.getKey().getName() + "_" +
exprId.asInt() + "_$";
+ builder.add(new Alias(exprId, entry.getKey(), name));
+ }
+ }
+
+ return builder.build();
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java
index 06341a96038..abda37ff0b4 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeOneRowRelationIntoUnion.java
@@ -24,6 +24,7 @@ import
org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
@@ -45,7 +46,11 @@ public class MergeOneRowRelationIntoUnion extends
OneRewriteRuleFactory {
ImmutableList.Builder<List<SlotReference>>
newChildrenOutputs = ImmutableList.builder();
for (int i = 0; i < u.arity(); i++) {
Plan child = u.child(i);
- if (!(child instanceof LogicalOneRowRelation)) {
+ // if one row relation contains unique function which
exist multiple times,
+ // don't merge it, later AddProjectForUniqueFunction
will handle this one row relation.
+ if (!(child instanceof LogicalOneRowRelation)
+ ||
ExpressionUtils.containUniqueFunctionExistMultiple(
+ ((LogicalOneRowRelation)
child).getProjects())) {
newChildren.add(child);
newChildrenOutputs.add(u.getRegularChildOutput(i));
} else {
@@ -64,6 +69,10 @@ public class MergeOneRowRelationIntoUnion extends
OneRewriteRuleFactory {
constantExprsList.add(constantExprs.build());
}
}
+ // no change
+ if (newChildren.size() == u.arity()) {
+ return u;
+ }
return u.withChildrenAndConstExprsList(newChildren,
newChildrenOutputs.build(),
constantExprsList.build());
}).toRule(RuleType.MERGE_ONE_ROW_RELATION_INTO_UNION);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java
index 14cc02dadff..2c28c47a85a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Between.java
@@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
-import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
@@ -32,7 +31,7 @@ import java.util.List;
/**
* Between predicate expression.
*/
-public class Between extends Expression implements TernaryExpression,
PropagateNullable {
+public class Between extends Expression implements TernaryExpression {
private final Expression compareExpr;
private final Expression lowerBound;
@@ -77,6 +76,20 @@ public class Between extends Expression implements
TernaryExpression, PropagateN
return compareExpr + " BETWEEN " + lowerBound + " AND " + upperBound;
}
+ // nullable is true if any children is nullable,
+ // but between is not PropagateNullable,
+ // because FoldConstantRuleOnFE will fold a PropagateNullable expression
to NULL if any children is NULL.
+ // but `4 BETWEEN NULL AND 3` should fold to FALSE, not NULL.
+ @Override
+ public boolean nullable() {
+ for (Expression child : children()) {
+ if (child.nullable()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitBetween(this, context);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 4230240b82f..cd04627ee30 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -1254,4 +1254,18 @@ public class ExpressionUtils {
public static boolean hasNonWindowAggregateFunction(Expression expression)
{
return
expression.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null);
}
+
+ /**
+ * check if the expressions contain a unique function which exists
multiple times
+ */
+ public static boolean containUniqueFunctionExistMultiple(Collection<?
extends Expression> expressions) {
+ Set<UniqueFunction> counterSet = Sets.newHashSet();
+ for (Expression expression : expressions) {
+ if (expression.anyMatch(
+ expr -> expr instanceof UniqueFunction &&
!counterSet.add((UniqueFunction) expr))) {
+ return true;
+ }
+ }
+ return false;
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
index 36c96179f0f..e6e856852d0 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
@@ -240,8 +240,8 @@ public class SimplifyRangeTest extends ExpressionRewrite {
assertRewrite("SA > '20250101' and SA > '20260110'", "SA >
'20260110'");
// random is non-foldable, so the two random(1, 10) are distinct,
cann't merge range for them.
- Expression expr = rewrite("TA + random(1, 10) > 10 AND TA + random(1,
10) < 1", Maps.newHashMap());
- Assertions.assertEquals("AND[((cast(TA as BIGINT) + random(1, 10)) >
10),((cast(TA as BIGINT) + random(1, 10)) < 1)]", expr.toSql());
+ Expression expr = rewriteExpression("X + random(1, 10) > 10 AND X +
random(1, 10) < 1", true);
+ Assertions.assertEquals("AND[((X + random(1, 10)) > 10),((X +
random(1, 10)) < 1)]", expr.toSql());
expr = rewrite("TA + random(1, 10) between 10 and 20",
Maps.newHashMap());
Assertions.assertEquals("AND[((cast(TA as BIGINT) + random(1, 10)) >=
10),((cast(TA as BIGINT) + random(1, 10)) <= 20)]", expr.toSql());
@@ -446,6 +446,14 @@ public class SimplifyRangeTest extends ExpressionRewrite {
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}
+ private Expression rewriteExpression(String expression, boolean nullable) {
+ Map<String, Slot> mem = Maps.newHashMap();
+ Expression needRewriteExpression = PARSER.parseExpression(expression);
+ needRewriteExpression = nullable ?
replaceUnboundSlot(needRewriteExpression, mem) :
replaceNotNullUnboundSlot(needRewriteExpression, mem);
+ needRewriteExpression = typeCoercion(needRewriteExpression);
+ return executor.rewrite(needRewriteExpression, context);
+ }
+
private Expression replaceUnboundSlot(Expression expression, Map<String,
Slot> mem) {
List<Expression> children = Lists.newArrayList();
boolean hasNewChildren = false;
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java
new file mode 100644
index 00000000000..715c4c3f1c5
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AddProjectForUniqueFunctionTest.java
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.plans.DistributeType;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.Optional;
+
+public class AddProjectForUniqueFunctionTest implements
MemoPatternMatchSupported {
+ private final LogicalOlapScan studentOlapScan
+ = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
PlanConstructor.student);
+
+ @Test
+ void testGenUniqueFunctionAlias() {
+ Random random1 = new Random();
+ Random random2 = new Random();
+ Random random3 = new Random();
+ List<Expression> expressions = ImmutableList.of(
+ new Add(random1, new Add(random1, new DoubleLiteral(1.0))),
+ new Add(random2, random3),
+ random3);
+ List<NamedExpression> namedExpressions = new
AddProjectForUniqueFunction().tryGenUniqueFunctionAlias(expressions);
+ Assertions.assertEquals(2, namedExpressions.size());
+ Assertions.assertInstanceOf(Alias.class, namedExpressions.get(0));
+ Assertions.assertEquals(((Alias) namedExpressions.get(0)).child(),
random1);
+ Assertions.assertInstanceOf(Alias.class, namedExpressions.get(1));
+ Assertions.assertEquals(((Alias) namedExpressions.get(1)).child(),
random3);
+ }
+
+ @Test
+ void testRewriteExpressionNoChange() {
+ Random random1 = new Random();
+ Random random2 = new Random();
+ Random random3 = new Random();
+ List<NamedExpression> projections = ImmutableList.of(
+ new Alias(new Add(random1, new Add(new DoubleLiteral(1.0), new
DoubleLiteral(1.0)))),
+ new Alias(new Add(random2, new DoubleLiteral(1.0))),
+ new Alias(random3));
+ LogicalProject<?> project = new LogicalProject<Plan>(projections,
studentOlapScan);
+ Optional<Pair<List<NamedExpression>, LogicalProject<Plan>>> result =
new AddProjectForUniqueFunction()
+ .rewriteExpressions(project, project.getProjects());
+ Assertions.assertEquals(Optional.empty(), result);
+ }
+
+ @Test
+ void testRewriteExpressionProjectSucc() {
+ Random random1 = new Random();
+ Random random2 = new Random();
+ List<NamedExpression> projections = ImmutableList.of(
+ new Alias(new Add(random1, new Add(new DoubleLiteral(1.0), new
DoubleLiteral(1.0)))),
+ new Alias(new Add(random2, new DoubleLiteral(1.0))),
+ new Alias(random2));
+ LogicalProject<?> project = new LogicalProject<Plan>(projections,
studentOlapScan);
+ Optional<Pair<List<NamedExpression>, LogicalProject<Plan>>> result =
new AddProjectForUniqueFunction()
+ .rewriteExpressions(project, project.getProjects());
+ Assertions.assertTrue(result.isPresent());
+ Assertions.assertInstanceOf(LogicalProject.class, result.get().second);
+ LogicalProject<?> bottomProject = (LogicalProject<?>)
result.get().second;
+ List<NamedExpression> bottomProjections = bottomProject.getProjects();
+ Assertions.assertEquals(studentOlapScan.getOutput().size() + 1,
bottomProjections.size());
+ Assertions.assertEquals(studentOlapScan.getOutput(),
bottomProjections.subList(0, studentOlapScan.getOutput().size()));
+ Alias alis = (Alias) bottomProjections.get(bottomProjections.size() -
1);
+ Assertions.assertEquals(alis.child(), random2);
+ List<NamedExpression> expectedTopProjections = ImmutableList.of(
+ projections.get(0),
+ new Alias(projections.get(1).getExprId(), new
Add(alis.toSlot(), new DoubleLiteral(1.0))),
+ new Alias(projections.get(2).getExprId(), alis.toSlot())
+ );
+ Assertions.assertEquals(expectedTopProjections, result.get().first);
+ }
+
+ @Test
+ void testRewriteJoin() {
+ LogicalOlapScan scoreOlapScan
+ = new
LogicalOlapScan(StatementScopeIdGenerator.newRelationId(),
PlanConstructor.score);
+ SlotReference sid = (SlotReference) scoreOlapScan.getOutput().get(0);
+ Random random = new Random();
+ LogicalJoin<?, ?> join = new LogicalJoin<Plan,
Plan>(JoinType.CROSS_JOIN,
+ ImmutableList.of(),
+ ImmutableList.of(new EqualTo(random, sid)),
+ ImmutableList.of(new EqualTo(random, new DoubleLiteral(1.0))),
+ new DistributeHint(DistributeType.NONE),
+ Optional.empty(),
+ ImmutableList.of(studentOlapScan, scoreOlapScan),
+ null);
+
+ Plan root = PlanChecker.from(MemoTestUtils.createConnectContext(),
join)
+ .applyTopDown(new AddProjectForUniqueFunction())
+ .getPlan();
+ Assertions.assertInstanceOf(LogicalJoin.class, root);
+ LogicalJoin<?, ?> newJoin = (LogicalJoin<?, ?>) root;
+ Assertions.assertInstanceOf(LogicalProject.class, newJoin.left());
+ LogicalProject<?> leftProject = (LogicalProject<?>) newJoin.left();
+ Assertions.assertEquals(studentOlapScan, leftProject.child());
+ Assertions.assertEquals(scoreOlapScan, newJoin.right());
+ Alias alias = (Alias)
leftProject.getProjects().get(leftProject.getProjects().size() - 1);
+ Assertions.assertEquals(alias.child(), random);
+ Assertions.assertEquals(ImmutableList.of(new EqualTo(alias.toSlot(),
sid)), newJoin.getHashJoinConjuncts());
+ Assertions.assertEquals(ImmutableList.of(),
newJoin.getOtherJoinConjuncts());
+ Assertions.assertEquals(ImmutableList.of(new EqualTo(alias.toSlot(),
new DoubleLiteral(1.0))), newJoin.getMarkJoinConjuncts());
+ Assertions.assertEquals(JoinType.INNER_JOIN, newJoin.getJoinType());
+ }
+}
diff --git
a/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out
b/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out
new file mode 100644
index 00000000000..456ee6b09bf
Binary files /dev/null and
b/regression-test/data/nereids_rules_p0/unique_function/add_project_for_unique_function.out
differ
diff --git
a/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy
b/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy
new file mode 100644
index 00000000000..9a400900ca3
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/unique_function/add_project_for_unique_function.groovy
@@ -0,0 +1,137 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite('add_project_for_unique_function') {
+ sql 'SET enable_nereids_planner=true'
+ sql 'SET runtime_filter_mode=OFF'
+ sql 'SET enable_fallback_to_original_planner=false'
+ sql "SET ignore_shape_nodes='PhysicalDistribute'"
+ sql "SET
detail_shape_nodes='PhysicalProject,PhysicalOneRowRelation,PhysicalUnion,PhysicalQuickSort,PhysicalHashAggregate'"
+ sql 'SET disable_nereids_rules=PRUNE_EMPTY_PARTITION'
+
+ // no project
+ qt_one_row_relation_1 '''
+ explain shape plan select random(1, 100), uuid_to_int(uuid())
+ '''
+
+ qt_one_row_relation_2 '''
+ explain shape plan select random(1, 100) between 10 and 20,
uuid_to_int(uuid())
+ '''
+
+ qt_one_row_relation_3 '''
+ explain shape plan select random(1, 100) between 10 and 20,
uuid_to_int(uuid()) between 111 and 222
+ '''
+
+ qt_project_1 '''
+ explain shape plan select id + random(1, 100) > 20, id * 200 from t1
+ '''
+
+ qt_project_2 '''
+ explain shape plan select id + random(1, 100) between 10 and 20, id *
200 from t1
+ '''
+
+ qt_filter_1 '''
+ explain shape plan select id from t1 where id + random(1, 100) >= 10
+ '''
+
+ qt_filter_2 '''
+ explain shape plan select id from t1 where id + random(1, 100) between
10 and 20
+ '''
+
+ qt_union_1 '''
+ explain shape plan select (random() between 0.1 and 0.5) as k union
select true
+ '''
+
+ qt_union_2 '''
+ explain shape plan select (id + random() between 0.1 and 0.5) as k
from t1 union select true
+ '''
+
+ qt_union_all_1 '''
+ explain shape plan select (random() between 0.1 and 0.5) as k union
all select true
+ '''
+
+ qt_union_all_2 '''
+ explain shape plan select (id + random() between 0.1 and 0.5) as k
from t1 union all select true
+ '''
+
+ qt_intersect_1 '''
+ explain shape plan select (random() between 0.1 and 0.5) as k
intersect select true
+ '''
+
+ qt_intersect_2 '''
+ explain shape plan select (id + random() between 0.1 and 0.5) as k
from t1 intersect select true
+ '''
+
+ qt_except_1 '''
+ explain shape plan select (random() between 0.1 and 0.5) as k except
select true
+ '''
+
+ qt_except_2 '''
+ explain shape plan select (id + random() between 0.1 and 0.5) as k
from t1 except select true
+ '''
+
+ qt_sort_1 '''
+ explain shape plan select * from (select (random() between 0.1 and
0.5) as k) t
+ order by k + random(100) between 0.6 and 0.7
+ '''
+
+ qt_sort_2 '''
+ explain shape plan select * from (select (id + random() between 0.1
and 0.5) as k from t1) t
+ order by k + random(100) between 0.6 and 0.7
+ '''
+
+ qt_agg_1 '''
+ explain shape plan select sum(random(100) between 0.6 and 0.7)
+ '''
+
+ qt_agg_2 '''
+ explain shape plan select sum(id), sum(random(100) between 0.6 and
0.7) from t1
+ '''
+
+ qt_agg_3 '''
+ explain shape plan select sum(id), sum(random(100) between 0.6 and
0.7) from t1
+ group by random() between 0.1 and 0.5
+ '''
+
+ qt_agg_4 '''
+ explain shape plan select sum(id), sum(random(100) between 0.6 and
0.7) from t1
+ group by id + random() between 0.1 and 0.5
+ '''
+
+ qt_window_1 '''
+ explain shape plan select sum(random(1) between 0.1 and 0.11)
+ over(partition by random(2) between 0.2 and 0.22)
+ '''
+
+ qt_window_2 '''
+ explain shape plan select sum(random(1) between 0.1 and 0.11)
+ over(partition by random(2) between 0.2 and 0.22 order by
random(3) between 0.3 and 0.33)
+ '''
+
+ qt_window_3 '''
+ explain shape plan select sum(id + random(1) between 0.1 and 0.11)
+ over(partition by id + random(2) between 0.2 and 0.22 order by id
+ random(3) between 0.3 and 0.33)
+ from t1
+ '''
+
+ qt_join_1 '''
+ explain shape plan select * from t1 join t2 on
+ t1.id + t2.id + random(1, 100) between 10 and 20
+ and t2.id * random(1, 100) between 100 and 200
+ and random(1, 100) between 1 and 10
+ '''
+}
diff --git
a/regression-test/suites/nereids_rules_p0/unique_function/load.groovy
b/regression-test/suites/nereids_rules_p0/unique_function/load.groovy
index cf4746b4976..65bb565ff6b 100644
--- a/regression-test/suites/nereids_rules_p0/unique_function/load.groovy
+++ b/regression-test/suites/nereids_rules_p0/unique_function/load.groovy
@@ -33,6 +33,19 @@ suite("load") {
"replication_allocation" = "tag.location.default: 1"
);
"""
+ sql """
+ DROP TABLE IF EXISTS t2
+ """
+ sql """
+ CREATE TABLE IF NOT EXISTS t2(
+ `id` int(11) NULL,
+ `msg` text NULL
+ ) ENGINE = OLAP
+ DISTRIBUTED BY HASH(id) BUCKETS 4
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+ """
def tbl = "tbl_unique_function_with_one_row"
sql "drop table if exists ${tbl} force"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]