This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 60a852ee574 [fix](Nereids) simplify range result wrong when reference 
is nullable (#41356) (#42990)
60a852ee574 is described below

commit 60a852ee574403f869a99abfc83b49fb352c7e63
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Fri Nov 1 14:55:04 2024 +0800

    [fix](Nereids) simplify range result wrong when reference is nullable 
(#41356) (#42990)
    
    pick from master #41356
    
    if reference is nullable and simplify result is boolean literal. the
    real result should be:
    
    IF(${reference} IS NULL, NULL, ${not_null_result})
---
 .../expression/rules/FoldConstantRuleOnFE.java     |   8 +-
 .../rules/expression/rules/SimplifyRange.java      | 244 ++++++++++++---------
 .../apache/doris/nereids/util/ExpressionUtils.java |  13 ++
 .../rules/expression/SimplifyRangeTest.java        |  68 ++++--
 4 files changed, 206 insertions(+), 127 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
index 71377c021b2..b84012ffd64 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
@@ -93,7 +93,13 @@ public class FoldConstantRuleOnFE extends 
AbstractExpressionRewriteRule {
         } else if (expr instanceof AggregateExpression && 
((AggregateExpression) expr).getFunction().isDistinct()) {
             return expr;
         }
-        return expr.accept(this, ctx);
+        // ATTN: we must return original expr, because OrToIn is implemented 
with MutableState,
+        //   newExpr will lose these states leading to dead loop by OrToIn -> 
SimplifyRange -> FoldConstantByFE
+        Expression newExpr = expr.accept(this, ctx);
+        if (newExpr.equals(expr)) {
+            return expr;
+        }
+        return newExpr;
     }
 
     /**
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
index 2c673aabd23..111b164a459 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
@@ -35,7 +35,9 @@ import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.collect.BoundType;
@@ -82,85 +84,85 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
     public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
         if (expr instanceof CompoundPredicate) {
             ValueDesc valueDesc = expr.accept(new RangeInference(), null);
-            Expression simplifiedExpr = valueDesc.toExpression();
-            return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr;
+            Expression exprForNonNull = valueDesc.toExpressionForNonNull();
+            if (exprForNonNull == null) {
+                // this mean cannot simplify
+                return valueDesc.exprForNonNull;
+            }
+            return exprForNonNull;
         }
         return expr;
     }
 
-    private static class RangeInference extends ExpressionVisitor<ValueDesc, 
Void> {
+    private static class RangeInference extends ExpressionVisitor<ValueDesc, 
ExpressionRewriteContext> {
 
         @Override
-        public ValueDesc visit(Expression expr, Void context) {
-            return new UnknownValue(expr);
+        public ValueDesc visit(Expression expr, ExpressionRewriteContext 
context) {
+            return new UnknownValue(context, expr);
         }
 
-        private ValueDesc buildRange(ComparisonPredicate predicate) {
+        private ValueDesc buildRange(ExpressionRewriteContext context, 
ComparisonPredicate predicate) {
             Expression rewrite = ExpressionRuleExecutor.normalize(predicate);
             Expression right = rewrite.child(1);
             if (right.isNullLiteral()) {
-                // it's safe to return empty value if >, >=, <, <= and = with 
null
-                if ((predicate instanceof GreaterThan || predicate instanceof 
GreaterThanEqual
-                        || predicate instanceof LessThan || predicate 
instanceof LessThanEqual
-                        || predicate instanceof EqualTo)) {
-                    return new EmptyValue(rewrite.child(0), rewrite);
-                } else {
-                    return new UnknownValue(predicate);
-                }
+                return new UnknownValue(context, predicate);
             }
             // only handle `NumericType`
             if (right.isLiteral() && right.getDataType().isNumericType()) {
-                return ValueDesc.range((ComparisonPredicate) rewrite);
+                return ValueDesc.range(context, (ComparisonPredicate) rewrite);
             }
-            return new UnknownValue(predicate);
+            return new UnknownValue(context, predicate);
         }
 
         @Override
-        public ValueDesc visitGreaterThan(GreaterThan greaterThan, Void 
context) {
-            return buildRange(greaterThan);
+        public ValueDesc visitGreaterThan(GreaterThan greaterThan, 
ExpressionRewriteContext context) {
+            return buildRange(context, greaterThan);
         }
 
         @Override
-        public ValueDesc visitGreaterThanEqual(GreaterThanEqual 
greaterThanEqual, Void context) {
-            return buildRange(greaterThanEqual);
+        public ValueDesc visitGreaterThanEqual(GreaterThanEqual 
greaterThanEqual, ExpressionRewriteContext context) {
+            return buildRange(context, greaterThanEqual);
         }
 
         @Override
-        public ValueDesc visitLessThan(LessThan lessThan, Void context) {
-            return buildRange(lessThan);
+        public ValueDesc visitLessThan(LessThan lessThan, 
ExpressionRewriteContext context) {
+            return buildRange(context, lessThan);
         }
 
         @Override
-        public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, Void 
context) {
-            return buildRange(lessThanEqual);
+        public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, 
ExpressionRewriteContext context) {
+            return buildRange(context, lessThanEqual);
         }
 
         @Override
-        public ValueDesc visitEqualTo(EqualTo equalTo, Void context) {
-            return buildRange(equalTo);
+        public ValueDesc visitEqualTo(EqualTo equalTo, 
ExpressionRewriteContext context) {
+            return buildRange(context, equalTo);
         }
 
         @Override
-        public ValueDesc visitInPredicate(InPredicate inPredicate, Void 
context) {
+        public ValueDesc visitInPredicate(InPredicate inPredicate, 
ExpressionRewriteContext context) {
             // only handle `NumericType`
-            if (ExpressionUtils.isAllLiteral(inPredicate.getOptions())
+            if (ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
                     && 
ExpressionUtils.matchNumericType(inPredicate.getOptions())) {
-                return ValueDesc.discrete(inPredicate);
+                return ValueDesc.discrete(context, inPredicate);
             }
-            return new UnknownValue(inPredicate);
+            return new UnknownValue(context, inPredicate);
         }
 
         @Override
-        public ValueDesc visitAnd(And and, Void context) {
-            return simplify(and, ExpressionUtils.extractConjunction(and), 
ValueDesc::intersect, ExpressionUtils::and);
+        public ValueDesc visitAnd(And and, ExpressionRewriteContext context) {
+            return simplify(context, and, 
ExpressionUtils.extractConjunction(and),
+                    ValueDesc::intersect, ExpressionUtils::and);
         }
 
         @Override
-        public ValueDesc visitOr(Or or, Void context) {
-            return simplify(or, ExpressionUtils.extractDisjunction(or), 
ValueDesc::union, ExpressionUtils::or);
+        public ValueDesc visitOr(Or or, ExpressionRewriteContext context) {
+            return simplify(context, or, 
ExpressionUtils.extractDisjunction(or),
+                    ValueDesc::union, ExpressionUtils::or);
         }
 
-        private ValueDesc simplify(Expression originExpr, List<Expression> 
predicates,
+        private ValueDesc simplify(ExpressionRewriteContext context,
+                Expression originExpr, List<Expression> predicates,
                 BinaryOperator<ValueDesc> op, BinaryOperator<Expression> 
exprOp) {
 
             Map<Expression, List<ValueDesc>> groupByReference = 
predicates.stream()
@@ -184,52 +186,58 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             }
 
             // use UnknownValue to wrap different references
-            return new UnknownValue(valuePerRefs, originExpr, exprOp);
+            return new UnknownValue(context, valuePerRefs, originExpr, exprOp);
         }
     }
 
     private abstract static class ValueDesc {
-        Expression expr;
+        ExpressionRewriteContext context;
+        Expression exprForNonNull;
         Expression reference;
 
-        public ValueDesc(Expression reference, Expression expr) {
-            this.expr = expr;
+        public ValueDesc(ExpressionRewriteContext context, Expression 
reference, Expression exprForNonNull) {
+            this.context = context;
+            this.exprForNonNull = exprForNonNull;
             this.reference = reference;
         }
 
         public abstract ValueDesc union(ValueDesc other);
 
-        public static ValueDesc union(RangeValue range, DiscreteValue 
discrete, boolean reverseOrder) {
+        public static ValueDesc union(ExpressionRewriteContext context,
+                RangeValue range, DiscreteValue discrete, boolean 
reverseOrder) {
             long count = discrete.values.stream().filter(x -> 
range.range.test(x)).count();
             if (count == discrete.values.size()) {
                 return range;
             }
-            Expression originExpr = ExpressionUtils.or(range.expr, 
discrete.expr);
+            Expression exprForNonNull = FoldConstantRuleOnFE.INSTANCE.rewrite(
+                    ExpressionUtils.or(range.exprForNonNull, 
discrete.exprForNonNull), context);
             List<ValueDesc> sourceValues = reverseOrder
                     ? ImmutableList.of(discrete, range)
                     : ImmutableList.of(range, discrete);
-            return new UnknownValue(sourceValues, originExpr, 
ExpressionUtils::or);
+            return new UnknownValue(context, sourceValues, exprForNonNull, 
ExpressionUtils::or);
         }
 
         public abstract ValueDesc intersect(ValueDesc other);
 
-        public static ValueDesc intersect(RangeValue range, DiscreteValue 
discrete) {
-            DiscreteValue result = new DiscreteValue(discrete.reference, 
discrete.expr);
+        public static ValueDesc intersect(ExpressionRewriteContext context, 
RangeValue range, DiscreteValue discrete) {
+            DiscreteValue result = new DiscreteValue(context, 
discrete.reference, discrete.exprForNonNull);
             discrete.values.stream().filter(x -> 
range.range.contains(x)).forEach(result.values::add);
             if (!result.values.isEmpty()) {
                 return result;
             }
-            return new EmptyValue(range.reference, 
ExpressionUtils.and(range.expr, discrete.expr));
+            Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                    ExpressionUtils.and(range.exprForNonNull, 
discrete.exprForNonNull), context);
+            return new EmptyValue(context, range.reference, 
originExprForNonNull);
         }
 
-        public abstract Expression toExpression();
+        public abstract Expression toExpressionForNonNull();
 
-        public static ValueDesc range(ComparisonPredicate predicate) {
+        public static ValueDesc range(ExpressionRewriteContext context, 
ComparisonPredicate predicate) {
             Literal value = (Literal) predicate.right();
             if (predicate instanceof EqualTo) {
-                return new DiscreteValue(predicate.left(), predicate, value);
+                return new DiscreteValue(context, predicate.left(), predicate, 
value);
             }
-            RangeValue rangeValue = new RangeValue(predicate.left(), 
predicate);
+            RangeValue rangeValue = new RangeValue(context, predicate.left(), 
predicate);
             if (predicate instanceof GreaterThanEqual) {
                 rangeValue.range = Range.atLeast(value);
             } else if (predicate instanceof GreaterThan) {
@@ -243,16 +251,16 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             return rangeValue;
         }
 
-        public static ValueDesc discrete(InPredicate in) {
+        public static ValueDesc discrete(ExpressionRewriteContext context, 
InPredicate in) {
             Set<Literal> literals = 
in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet());
-            return new DiscreteValue(in.getCompareExpr(), in, literals);
+            return new DiscreteValue(context, in.getCompareExpr(), in, 
literals);
         }
     }
 
     private static class EmptyValue extends ValueDesc {
 
-        public EmptyValue(Expression reference, Expression expr) {
-            super(reference, expr);
+        public EmptyValue(ExpressionRewriteContext context, Expression 
reference, Expression exprForNonNull) {
+            super(context, reference, exprForNonNull);
         }
 
         @Override
@@ -266,8 +274,12 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
         }
 
         @Override
-        public Expression toExpression() {
-            return BooleanLiteral.FALSE;
+        public Expression toExpressionForNonNull() {
+            if (reference.nullable()) {
+                return new And(new IsNull(reference), new 
NullLiteral(BooleanType.INSTANCE));
+            } else {
+                return BooleanLiteral.FALSE;
+            }
         }
     }
 
@@ -279,8 +291,8 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
     private static class RangeValue extends ValueDesc {
         Range<Literal> range;
 
-        public RangeValue(Expression reference, Expression expr) {
-            super(reference, expr);
+        public RangeValue(ExpressionRewriteContext context, Expression 
reference, Expression exprForNonNull) {
+            super(context, reference, exprForNonNull);
         }
 
         @Override
@@ -290,19 +302,23 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             }
             try {
                 if (other instanceof RangeValue) {
+                    Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                            ExpressionUtils.or(exprForNonNull, 
other.exprForNonNull), context);
                     RangeValue o = (RangeValue) other;
                     if (range.isConnected(o.range)) {
-                        RangeValue rangeValue = new RangeValue(reference, 
ExpressionUtils.or(expr, other.expr));
+                        RangeValue rangeValue = new RangeValue(context, 
reference, originExprForNonNull);
                         rangeValue.range = range.span(o.range);
                         return rangeValue;
                     }
-                    Expression originExpr = ExpressionUtils.or(expr, 
other.expr);
-                    return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::or);
+                    return new UnknownValue(context, ImmutableList.of(this, 
other),
+                            originExprForNonNull, ExpressionUtils::or);
                 }
-                return union(this, (DiscreteValue) other, false);
+                return union(context, this, (DiscreteValue) other, false);
             } catch (Exception e) {
-                Expression originExpr = ExpressionUtils.or(expr, other.expr);
-                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::or);
+                Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                        ExpressionUtils.or(exprForNonNull, 
other.exprForNonNull), context);
+                return new UnknownValue(context, ImmutableList.of(this, other),
+                        originExprForNonNull, ExpressionUtils::or);
             }
         }
 
@@ -313,23 +329,27 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             }
             try {
                 if (other instanceof RangeValue) {
+                    Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                            ExpressionUtils.and(exprForNonNull, 
other.exprForNonNull), context);
                     RangeValue o = (RangeValue) other;
                     if (range.isConnected(o.range)) {
-                        RangeValue rangeValue = new RangeValue(reference, 
ExpressionUtils.and(expr, other.expr));
+                        RangeValue rangeValue = new RangeValue(context, 
reference, originExprForNonNull);
                         rangeValue.range = range.intersection(o.range);
                         return rangeValue;
                     }
-                    return new EmptyValue(reference, ExpressionUtils.and(expr, 
other.expr));
+                    return new EmptyValue(context, reference, 
originExprForNonNull);
                 }
-                return intersect(this, (DiscreteValue) other);
+                return intersect(context, this, (DiscreteValue) other);
             } catch (Exception e) {
-                Expression originExpr = ExpressionUtils.and(expr, other.expr);
-                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::and);
+                Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                        ExpressionUtils.and(exprForNonNull, 
other.exprForNonNull), context);
+                return new UnknownValue(context, ImmutableList.of(this, other),
+                        originExprForNonNull, ExpressionUtils::and);
             }
         }
 
         @Override
-        public Expression toExpression() {
+        public Expression toExpressionForNonNull() {
             List<Expression> result = Lists.newArrayList();
             if (range.hasLowerBound()) {
                 if (range.lowerBoundType() == BoundType.CLOSED) {
@@ -347,11 +367,12 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             }
             if (!result.isEmpty()) {
                 return ExpressionUtils.and(result);
-            } else if (reference.nullable()) {
-                // when reference is nullable, we should filter null slot.
-                return new Not(new IsNull(reference));
             } else {
-                return BooleanLiteral.TRUE;
+                if (reference.nullable()) {
+                    return new Or(new Not(new IsNull(reference)), new 
NullLiteral(BooleanType.INSTANCE));
+                } else {
+                    return BooleanLiteral.TRUE;
+                }
             }
         }
 
@@ -369,12 +390,14 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
     private static class DiscreteValue extends ValueDesc {
         Set<Literal> values;
 
-        public DiscreteValue(Expression reference, Expression expr, Literal... 
values) {
-            this(reference, expr, Arrays.asList(values));
+        public DiscreteValue(ExpressionRewriteContext context,
+                Expression reference, Expression exprForNonNull, Literal... 
values) {
+            this(context, reference, exprForNonNull, Arrays.asList(values));
         }
 
-        public DiscreteValue(Expression reference, Expression expr, 
Collection<Literal> values) {
-            super(reference, expr);
+        public DiscreteValue(ExpressionRewriteContext context,
+                Expression reference, Expression exprForNonNull, 
Collection<Literal> values) {
+            super(context, reference, exprForNonNull);
             this.values = Sets.newTreeSet(values);
         }
 
@@ -385,15 +408,19 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             }
             try {
                 if (other instanceof DiscreteValue) {
-                    DiscreteValue discreteValue = new DiscreteValue(reference, 
ExpressionUtils.or(expr, other.expr));
+                    Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                            ExpressionUtils.or(exprForNonNull, 
other.exprForNonNull), context);
+                    DiscreteValue discreteValue = new DiscreteValue(context, 
reference, originExprForNonNull);
                     discreteValue.values.addAll(((DiscreteValue) 
other).values);
                     discreteValue.values.addAll(this.values);
                     return discreteValue;
                 }
-                return union((RangeValue) other, this, true);
+                return union(context, (RangeValue) other, this, true);
             } catch (Exception e) {
-                Expression originExpr = ExpressionUtils.or(expr, other.expr);
-                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::or);
+                Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                        ExpressionUtils.or(exprForNonNull, 
other.exprForNonNull), context);
+                return new UnknownValue(context, ImmutableList.of(this, other),
+                        originExprForNonNull, ExpressionUtils::or);
             }
         }
 
@@ -404,24 +431,28 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
             }
             try {
                 if (other instanceof DiscreteValue) {
-                    DiscreteValue discreteValue = new DiscreteValue(reference, 
ExpressionUtils.and(expr, other.expr));
+                    Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                            ExpressionUtils.and(exprForNonNull, 
other.exprForNonNull), context);
+                    DiscreteValue discreteValue = new DiscreteValue(context, 
reference, originExprForNonNull);
                     discreteValue.values.addAll(((DiscreteValue) 
other).values);
                     discreteValue.values.retainAll(this.values);
                     if (discreteValue.values.isEmpty()) {
-                        return new EmptyValue(reference, 
ExpressionUtils.and(expr, other.expr));
+                        return new EmptyValue(context, reference, 
originExprForNonNull);
                     } else {
                         return discreteValue;
                     }
                 }
-                return intersect((RangeValue) other, this);
+                return intersect(context, (RangeValue) other, this);
             } catch (Exception e) {
-                Expression originExpr = ExpressionUtils.and(expr, other.expr);
-                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::and);
+                Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                        ExpressionUtils.and(exprForNonNull, 
other.exprForNonNull), context);
+                return new UnknownValue(context, ImmutableList.of(this, other),
+                        originExprForNonNull, ExpressionUtils::and);
             }
         }
 
         @Override
-        public Expression toExpression() {
+        public Expression toExpressionForNonNull() {
             // NOTICE: it's related with `InPredicateToEqualToRule`
             // They are same processes, so must change synchronously.
             if (values.size() == 1) {
@@ -447,40 +478,49 @@ public class SimplifyRange extends 
AbstractExpressionRewriteRule {
         private final List<ValueDesc> sourceValues;
         private final BinaryOperator<Expression> mergeExprOp;
 
-        private UnknownValue(Expression expr) {
-            super(expr, expr);
+        private UnknownValue(ExpressionRewriteContext context, Expression 
expr) {
+            super(context, expr, expr);
             sourceValues = ImmutableList.of();
             mergeExprOp = null;
         }
 
-        public UnknownValue(List<ValueDesc> sourceValues, Expression 
originExpr,
-                BinaryOperator<Expression> mergeExprOp) {
-            super(sourceValues.get(0).reference, originExpr);
+        public UnknownValue(ExpressionRewriteContext context,
+                List<ValueDesc> sourceValues, Expression exprForNonNull, 
BinaryOperator<Expression> mergeExprOp) {
+            super(context, sourceValues.get(0).reference, exprForNonNull);
             this.sourceValues = ImmutableList.copyOf(sourceValues);
             this.mergeExprOp = mergeExprOp;
         }
 
         @Override
         public ValueDesc union(ValueDesc other) {
-            Expression originExpr = ExpressionUtils.or(expr, other.expr);
-            return new UnknownValue(ImmutableList.of(this, other), originExpr, 
ExpressionUtils::or);
+            Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                    ExpressionUtils.or(exprForNonNull, other.exprForNonNull), 
context);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
originExprForNonNull, ExpressionUtils::or);
         }
 
         @Override
         public ValueDesc intersect(ValueDesc other) {
-            Expression originExpr = ExpressionUtils.and(expr, other.expr);
-            return new UnknownValue(ImmutableList.of(this, other), originExpr, 
ExpressionUtils::and);
+            Expression originExprForNonNull = 
FoldConstantRuleOnFE.INSTANCE.rewrite(
+                    ExpressionUtils.and(exprForNonNull, other.exprForNonNull), 
context);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
originExprForNonNull, ExpressionUtils::and);
         }
 
         @Override
-        public Expression toExpression() {
+        public Expression toExpressionForNonNull() {
             if (sourceValues.isEmpty()) {
-                return expr;
+                return exprForNonNull;
+            }
+            Expression result = sourceValues.get(0).toExpressionForNonNull();
+            for (int i = 1; i < sourceValues.size(); i++) {
+                result = mergeExprOp.apply(result, 
sourceValues.get(i).toExpressionForNonNull());
+            }
+            result = FoldConstantRuleOnFE.INSTANCE.rewrite(result, context);
+            // ATTN: we must return original expr, because OrToIn is 
implemented with MutableState,
+            //   newExpr will lose these states leading to dead loop by OrToIn 
-> SimplifyRange -> FoldConstantByFE
+            if (result.equals(exprForNonNull)) {
+                return exprForNonNull;
             }
-            return sourceValues.stream()
-                    .map(ValueDesc::toExpression)
-                    .reduce(mergeExprOp)
-                    .get();
+            return result;
         }
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 8bec208896a..739a0e9a3cc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -373,6 +373,19 @@ public class ExpressionUtils {
         return children.stream().allMatch(c -> c instanceof Literal);
     }
 
+    /**
+     * return true if all children are literal but not null literal.
+     */
+    public static boolean isAllNonNullLiteral(List<Expression> children) {
+        for (Expression child : children) {
+            if ((!(child instanceof Literal)) || (child instanceof 
NullLiteral)) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /** matchNumericType */
     public static boolean matchNumericType(List<Expression> children) {
         return children.stream().allMatch(c -> 
c.getDataType().isNumericType());
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
index 74843a26b21..a668cc79925 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
@@ -30,7 +30,9 @@ import org.apache.doris.nereids.trees.plans.RelationId;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DataType;
-import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DateV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.StringType;
 import org.apache.doris.nereids.types.TinyIntType;
@@ -61,34 +63,39 @@ public class SimplifyRangeTest {
     public void testSimplify() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
         assertRewrite("TA", "TA");
-        assertRewrite("TA > 3 or TA > null", "TA > 3");
-        assertRewrite("TA > 3 or TA < null", "TA > 3");
-        assertRewrite("TA > 3 or TA = null", "TA > 3");
-        assertRewrite("TA > 3 or TA <> null", "TA > 3 or TA <> null");
+        assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL");
+        assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL");
+        assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL");
+        assertRewrite("TA > 3 or TA <> null", "TA > 3 or null");
         assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null");
-        assertRewrite("TA > 3 and TA > null", "false");
-        assertRewrite("TA > 3 and TA < null", "false");
-        assertRewrite("TA > 3 and TA = null", "false");
-        assertRewrite("TA > 3 and TA <> null", "TA > 3 and TA <> null");
+        assertRewriteNotNull("TA > 3 and TA > null", "TA > 3 and NULL");
+        assertRewriteNotNull("TA > 3 and TA < null", "TA > 3 and NULL");
+        assertRewriteNotNull("TA > 3 and TA = null", "TA > 3 and NULL");
+        assertRewrite("TA > 3 and TA > null", "TA > 3 and null");
+        assertRewrite("TA > 3 and TA < null", "TA > 3 and null");
+        assertRewrite("TA > 3 and TA = null", "TA > 3 and null");
+        assertRewrite("TA > 3 and TA <> null", "TA > 3 and null");
         assertRewrite("TA > 3 and TA <=> null", "TA > 3 and TA <=> null");
         assertRewrite("(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)", "(TA >= 
1 and TA <=3 ) or (TA > 5 and TA < 7)");
-        assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "FALSE");
-        assertRewrite("TA > 3 and TA < 1", "FALSE");
+        assertRewriteNotNull("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", 
"FALSE");
+        assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "TA is 
null and null");
+        assertRewriteNotNull("TA > 3 and TA < 1", "FALSE");
+        assertRewrite("TA > 3 and TA < 1", "TA is null and null");
         assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3");
-        assertRewrite("TA = 1 and TA > 10", "FALSE");
+        assertRewriteNotNull("TA = 1 and TA > 10", "FALSE");
+        assertRewrite("TA = 1 and TA > 10", "TA is null and null");
         assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1");
         assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1");
-        assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA IS NOT NULL");
+        assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null");
         assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE");
         assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10");
         assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10");
         assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1");
-        assertRewrite("TA > 1 or TA < 10", "TA IS NOT NULL");
+        assertRewrite("TA > 1 or TA < 10", "TA is not null or null");
         assertRewriteNotNull("TA > 1 or TA < 10", "TRUE");
         assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10");
         assertRewrite("TA > 5 and TA > 10", "TA > 10");
-        assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
-        assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
+        assertRewrite("TA > 5 + 1 and TA > 10", "cast(TA as smallint) > 6 and 
TA > 10");
         assertRewrite("(TA > 1 and TA > 10) or TA > 20", "TA > 10");
         assertRewrite("(TA > 1 or TA > 10) and TA > 20", "TA > 20");
         assertRewrite("(TA + TB > 1 or TA + TB > 10) and TA + TB > 20", "TA + 
TB > 20");
@@ -98,25 +105,32 @@ public class SimplifyRangeTest {
         assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB 
> 20)", "TB > 30 and TA > 40");
         assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA > 
10 and TB > 10 or TB > 20");
         assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
-        assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
+        assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE");
+        assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null");
         assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
         assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))");
         assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
         assertRewrite("TA in (1)", "TA in (1)");
         assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
-        assertRewrite("TA in (1,2,3) and TA < 1", "FALSE");
+        assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE");
+        assertRewrite("TA in (1,2,3) and TA < 1", "TA is null and null");
         assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1");
         assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)");
         assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)");
-        assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "FALSE");
+        assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "TA is null and 
null");
+        assertRewriteNotNull("TA in (1,2,3) and TA in (4,5,6)", "FALSE");
         assertRewrite("TA in (1,2,3) and TA in (3,4,5)", "TA = 3");
         assertRewrite("TA + TB in (1,2,3) and TA + TB in (3,4,5)", "TA + TB = 
3");
         assertRewrite("TA in (1,2,3) and DA > 1.5", "TA in (1,2,3) and DA > 
1.5");
-        assertRewrite("TA = 1 and TA = 3", "FALSE");
-        assertRewrite("TA in (1) and TA in (3)", "FALSE");
+        assertRewriteNotNull("TA = 1 and TA = 3", "FALSE");
+        assertRewrite("TA = 1 and TA = 3", "TA is null and null");
+        assertRewriteNotNull("TA in (1) and TA in (3)", "FALSE");
+        assertRewrite("TA in (1) and TA in (3)", "TA is null and null");
         assertRewrite("TA in (1) and TA in (1)", "TA = 1");
-        assertRewrite("(TA > 3 and TA < 1) and TB < 5", "FALSE");
-        assertRewrite("(TA > 3 and TA < 1) or TB < 5", "TB < 5");
+        assertRewriteNotNull("(TA > 3 and TA < 1) and TB < 5", "FALSE");
+        assertRewrite("(TA > 3 and TA < 1) and TB < 5", "TA is null and null 
and TB < 5");
+        assertRewrite("TA > 3 and TB < 5 and TA < 1", "TA is null and null and 
TB < 5");
+        assertRewrite("(TA > 3 and TA < 1) or TB < 5", "(TA is null and null) 
or TB < 5");
         assertRewrite("((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1", "((IA 
= 1 AND SC ='1') OR SC = '1212') AND IA =1");
     }
 
@@ -133,7 +147,9 @@ public class SimplifyRangeTest {
     private void assertRewriteNotNull(String expression, String expected) {
         Map<String, Slot> mem = Maps.newHashMap();
         Expression needRewriteExpression = 
replaceNotNullUnboundSlot(PARSER.parseExpression(expression), mem);
+        needRewriteExpression = typeCoercion(needRewriteExpression);
         Expression expectedExpression = 
replaceNotNullUnboundSlot(PARSER.parseExpression(expected), mem);
+        expectedExpression = typeCoercion(expectedExpression);
         Expression rewrittenExpression = 
executor.rewrite(needRewriteExpression, context);
         Assertions.assertEquals(expectedExpression, rewrittenExpression);
     }
@@ -185,11 +201,15 @@ public class SimplifyRangeTest {
             case 'I':
                 return IntegerType.INSTANCE;
             case 'D':
-                return DoubleType.INSTANCE;
+                return DecimalV3Type.createDecimalV3Type(2, 1);
             case 'S':
                 return StringType.INSTANCE;
             case 'B':
                 return BooleanType.INSTANCE;
+            case 'C':
+                return DateTimeV2Type.SYSTEM_DEFAULT;
+            case 'A':
+                return DateV2Type.INSTANCE;
             default:
                 return BigIntType.INSTANCE;
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to