This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new a24c0356b0e [feature](Nereids): InferPredicates support In (#29458) 
(#30007)
a24c0356b0e is described below

commit a24c0356b0e8c4a17b10d82fb125c41778c262d8
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Tue Jan 16 16:21:03 2024 +0800

    [feature](Nereids): InferPredicates support In (#29458) (#30007)
    
    (cherry picked from commit 7a0734dbd60effa676d87bf5a5b7ca516e134d52)
---
 .../nereids/rules/rewrite/InferPredicates.java     |  11 +-
 .../rules/rewrite/PredicatePropagation.java        | 179 +++++++++++++--------
 .../nereids/rules/rewrite/PullUpPredicates.java    |   4 +-
 .../nereids/trees/expressions/InPredicate.java     |  34 ++--
 .../nereids/rules/rewrite/InferPredicatesTest.java |  60 ++++---
 .../rules/rewrite/PredicatePropagationTest.java    |  51 ++++++
 .../data/nereids_p0/hint/fix_leading.out           |   2 +-
 .../infer_predicate/infer_predicate.groovy         |   2 +-
 8 files changed, 220 insertions(+), 123 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
index 3c4593df54c..36236c3db8d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java
@@ -27,7 +27,6 @@ import 
org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.PlanUtils;
 
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 
@@ -37,6 +36,7 @@ import java.util.stream.Collectors;
 
 /**
  * infer additional predicates for `LogicalFilter` and `LogicalJoin`.
+ * <pre>
  * The logic is as follows:
  * 1. poll up bottom predicate then infer additional predicates
  *   for example:
@@ -49,9 +49,9 @@ import java.util.stream.Collectors;
  *      select * from (select * from t1 where t1.id = 1) t join t2 on t.id = 
t2.id and t2.id = 1
  * 2. put these predicates into `otherJoinConjuncts` , these predicates are 
processed in the next
  *   round of predicate push-down
+ * </pre>
  */
 public class InferPredicates extends DefaultPlanRewriter<JobContext> 
implements CustomRewriter {
-    private final PredicatePropagation propagation = new 
PredicatePropagation();
     private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
 
     @Override
@@ -62,6 +62,9 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
     @Override
     public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, JobContext context) {
         join = visitChildren(this, join, context);
+        if (join.isMarkJoin()) {
+            return join;
+        }
         Plan left = join.left();
         Plan right = join.right();
         Set<Expression> expressions = getAllExpressions(left, right, 
join.getOnClauseCondition());
@@ -86,7 +89,7 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
                 break;
         }
         if (left != join.left() || right != join.right()) {
-            return join.withChildren(ImmutableList.of(left, right));
+            return join.withChildren(left, right);
         } else {
             return join;
         }
@@ -109,7 +112,7 @@ public class InferPredicates extends 
DefaultPlanRewriter<JobContext> implements
         Set<Expression> baseExpressions = pullUpPredicates(left);
         baseExpressions.addAll(pullUpPredicates(right));
         condition.ifPresent(on -> 
baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
-        baseExpressions.addAll(propagation.infer(baseExpressions));
+        baseExpressions.addAll(PredicatePropagation.infer(baseExpressions));
         return baseExpressions;
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index 2317da427ea..aa520362a77 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.types.DataType;
@@ -55,8 +56,7 @@ public class PredicatePropagation {
         INTEGRAL(IntegralType.class),
         STRING(CharacterType.class),
         DATE(DateLikeType.class),
-        OTHER(DataType.class)
-        ;
+        OTHER(DataType.class);
 
         private final Class<? extends DataType> superClazz;
 
@@ -65,15 +65,15 @@ public class PredicatePropagation {
         }
     }
 
-    private class ComparisonInferInfo {
+    private static class EqualInferInfo {
 
         public final InferType inferType;
-        public final Optional<Expression> left;
-        public final Optional<Expression> right;
+        public final Expression left;
+        public final Expression right;
         public final ComparisonPredicate comparisonPredicate;
 
-        public ComparisonInferInfo(InferType inferType,
-                Optional<Expression> left, Optional<Expression> right,
+        public EqualInferInfo(InferType inferType,
+                Expression left, Expression right,
                 ComparisonPredicate comparisonPredicate) {
             this.inferType = inferType;
             this.left = left;
@@ -85,23 +85,24 @@ public class PredicatePropagation {
     /**
      * infer additional predicates.
      */
-    public Set<Expression> infer(Set<Expression> predicates) {
+    public static Set<Expression> infer(Set<Expression> predicates) {
         Set<Expression> inferred = Sets.newHashSet();
         for (Expression predicate : predicates) {
-            if (!(predicate instanceof ComparisonPredicate)) {
+            if (!(predicate instanceof ComparisonPredicate
+                    || (predicate instanceof InPredicate && ((InPredicate) 
predicate).isLiteralChildren()))) {
                 continue;
             }
-            ComparisonInferInfo equalInfo = 
getEquivalentInferInfo((ComparisonPredicate) predicate);
+            if (predicate instanceof InPredicate) {
+                continue;
+            }
+            EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) 
predicate);
             if (equalInfo.inferType == InferType.NONE) {
                 continue;
             }
             Set<Expression> newInferred = predicates.stream()
-                    .filter(ComparisonPredicate.class::isInstance)
                     .filter(p -> !p.equals(predicate))
-                    .map(ComparisonPredicate.class::cast)
-                    .map(this::inferInferInfo)
-                    .filter(predicateInfo -> predicateInfo.inferType != 
InferType.NONE)
-                    .map(predicateInfo -> doInfer(equalInfo, predicateInfo))
+                    .filter(p -> p instanceof ComparisonPredicate || p 
instanceof InPredicate)
+                    .map(predicateInfo -> doInferPredicate(equalInfo, 
predicateInfo))
                     .filter(Objects::nonNull)
                     .collect(Collectors.toSet());
             inferred.addAll(newInferred);
@@ -110,17 +111,66 @@ public class PredicatePropagation {
         return inferred;
     }
 
+    private static Expression doInferPredicate(EqualInferInfo equalInfo, 
Expression predicate) {
+        Expression equalLeft = equalInfo.left;
+        Expression equalRight = equalInfo.right;
+
+        DataType leftType = predicate.child(0).getDataType();
+        InferType inferType;
+        if (leftType instanceof CharacterType) {
+            inferType = InferType.STRING;
+        } else if (leftType instanceof IntegralType) {
+            inferType = InferType.INTEGRAL;
+        } else if (leftType instanceof DateLikeType) {
+            inferType = InferType.DATE;
+        } else {
+            inferType = InferType.OTHER;
+        }
+        if (predicate instanceof ComparisonPredicate) {
+            ComparisonPredicate comparisonPredicate = (ComparisonPredicate) 
predicate;
+            Optional<Expression> left = 
validForInfer(comparisonPredicate.left(), inferType);
+            Optional<Expression> right = 
validForInfer(comparisonPredicate.right(), inferType);
+            if (!left.isPresent() || !right.isPresent()) {
+                return null;
+            }
+        } else if (predicate instanceof InPredicate) {
+            InPredicate inPredicate = (InPredicate) predicate;
+            Optional<Expression> left = 
validForInfer(inPredicate.getCompareExpr(), inferType);
+            if (!left.isPresent()) {
+                return null;
+            }
+        }
+
+        Expression newPredicate = predicate.rewriteUp(e -> {
+            if (e.equals(equalLeft)) {
+                return equalRight;
+            } else if (e.equals(equalRight)) {
+                return equalLeft;
+            } else {
+                return e;
+            }
+        });
+        if (predicate instanceof ComparisonPredicate) {
+            return 
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate,
+                    newPredicate.child(0),
+                    newPredicate.child(1));
+        } else {
+            return TypeCoercionUtils.processInPredicate((InPredicate) 
newPredicate);
+        }
+    }
+
     /**
      * Use the left or right child of `leftSlotEqualToRightSlot` to replace 
the left or right child of `expression`
      * Now only support infer `ComparisonPredicate`.
      * TODO: We should determine whether `expression` satisfies the condition 
for replacement
      *       eg: Satisfy `expression` is non-deterministic
      */
-    private Expression doInfer(ComparisonInferInfo equalInfo, 
ComparisonInferInfo predicateInfo) {
-        Expression predicateLeft = predicateInfo.left.get();
-        Expression predicateRight = predicateInfo.right.get();
-        Expression equalLeft = equalInfo.left.get();
-        Expression equalRight = equalInfo.right.get();
+    private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo 
predicateInfo) {
+        Expression equalLeft = equalInfo.left;
+        Expression equalRight = equalInfo.right;
+
+        Expression predicateLeft = predicateInfo.left;
+        Expression predicateRight = predicateInfo.right;
         Expression newLeft = inferOneSide(predicateLeft, equalLeft, 
equalRight);
         Expression newRight = inferOneSide(predicateRight, equalLeft, 
equalRight);
         if (newLeft == null || newRight == null) {
@@ -133,7 +183,7 @@ public class PredicatePropagation {
         return DateFunctionRewrite.INSTANCE.rewrite(expr, null);
     }
 
-    private Expression inferOneSide(Expression predicateOneSide, Expression 
equalLeft, Expression equalRight) {
+    private static Expression inferOneSide(Expression predicateOneSide, 
Expression equalLeft, Expression equalRight) {
         if (predicateOneSide instanceof SlotReference) {
             if (predicateOneSide.equals(equalLeft)) {
                 return equalRight;
@@ -150,60 +200,55 @@ public class PredicatePropagation {
         return null;
     }
 
-    private Optional<Expression> validForInfer(Expression expression, 
InferType inferType) {
+    private static Optional<Expression> validForInfer(Expression expression, 
InferType inferType) {
         if 
(!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
             return Optional.empty();
         }
         if (expression instanceof SlotReference || expression.isConstant()) {
             return Optional.of(expression);
         }
+        if (!(expression instanceof Cast)) {
+            return Optional.empty();
+        }
+        Cast cast = (Cast) expression;
+        Expression child = cast.child();
+        DataType dataType = cast.getDataType();
+        DataType childType = child.getDataType();
         if (inferType == InferType.INTEGRAL) {
-            if (expression instanceof Cast) {
-                // avoid cast from wider type to narrower type, such as 
cast(int as smallint)
-                // IntegralType dataType = (IntegralType) 
expression.getDataType();
-                // DataType childType = ((Cast) 
expression).child().getDataType();
-                // if (childType instanceof IntegralType && 
dataType.widerThan((IntegralType) childType)) {
-                //     return validForInfer(((Cast) expression).child(), 
inferType);
-                // }
-                return validForInfer(((Cast) expression).child(), inferType);
-            }
+            // avoid cast from wider type to narrower type, such as cast(int 
as smallint)
+            // IntegralType dataType = (IntegralType) expression.getDataType();
+            // DataType childType = ((Cast) expression).child().getDataType();
+            // if (childType instanceof IntegralType && 
dataType.widerThan((IntegralType) childType)) {
+            //     return validForInfer(((Cast) expression).child(), 
inferType);
+            // }
+            return validForInfer(child, inferType);
         } else if (inferType == InferType.DATE) {
-            if (expression instanceof Cast) {
-                DataType dataType = expression.getDataType();
-                DataType childType = ((Cast) expression).child().getDataType();
-                // avoid lost precision
-                if (dataType instanceof DateType) {
-                    if (childType instanceof DateV2Type || childType 
instanceof DateType) {
-                        return validForInfer(((Cast) expression).child(), 
inferType);
-                    }
-                } else if (dataType instanceof DateV2Type) {
-                    if (childType instanceof DateType || childType instanceof 
DateV2Type) {
-                        return validForInfer(((Cast) expression).child(), 
inferType);
-                    }
-                } else if (dataType instanceof DateTimeType) {
-                    if (!(childType instanceof DateTimeV2Type)) {
-                        return validForInfer(((Cast) expression).child(), 
inferType);
-                    }
-                } else if (dataType instanceof DateTimeV2Type) {
-                    return validForInfer(((Cast) expression).child(), 
inferType);
+            // avoid lost precision
+            if (dataType instanceof DateType) {
+                if (childType instanceof DateV2Type || childType instanceof 
DateType) {
+                    return validForInfer(child, inferType);
+                }
+            } else if (dataType instanceof DateV2Type) {
+                if (childType instanceof DateType || childType instanceof 
DateV2Type) {
+                    return validForInfer(child, inferType);
                 }
+            } else if (dataType instanceof DateTimeType) {
+                if (!(childType instanceof DateTimeV2Type)) {
+                    return validForInfer(child, inferType);
+                }
+            } else if (dataType instanceof DateTimeV2Type) {
+                return validForInfer(child, inferType);
             }
         } else if (inferType == InferType.STRING) {
-            if (expression instanceof Cast) {
-                DataType dataType = expression.getDataType();
-                DataType childType = ((Cast) expression).child().getDataType();
-                // avoid substring cast such as cast(char(3) as char(2))
-                if (dataType.width() <= 0 || (dataType.width() >= 
childType.width() && childType.width() >= 0)) {
-                    return validForInfer(((Cast) expression).child(), 
inferType);
-                }
+            // avoid substring cast such as cast(char(3) as char(2))
+            if (dataType.width() <= 0 || (dataType.width() >= 
childType.width() && childType.width() >= 0)) {
+                return validForInfer(child, inferType);
             }
-        } else {
-            return Optional.empty();
         }
         return Optional.empty();
     }
 
-    private ComparisonInferInfo inferInferInfo(ComparisonPredicate 
comparisonPredicate) {
+    private static EqualInferInfo inferInferInfo(ComparisonPredicate 
comparisonPredicate) {
         DataType leftType = comparisonPredicate.left().getDataType();
         InferType inferType;
         if (leftType instanceof CharacterType) {
@@ -220,25 +265,27 @@ public class PredicatePropagation {
         if (!left.isPresent() || !right.isPresent()) {
             inferType = InferType.NONE;
         }
-        return new ComparisonInferInfo(inferType, left, right, 
comparisonPredicate);
+        return new EqualInferInfo(inferType, 
left.orElse(comparisonPredicate.left()),
+                right.orElse(comparisonPredicate.right()), 
comparisonPredicate);
     }
 
     /**
      * Currently only equivalence derivation is supported
      * and requires that the left and right sides of an expression must be slot
+     * <p>
+     * TODO: NullSafeEqual
      */
-    private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate 
predicate) {
+    private static EqualInferInfo getEqualInferInfo(ComparisonPredicate 
predicate) {
         if (!(predicate instanceof EqualTo)) {
-            return new ComparisonInferInfo(InferType.NONE,
-                    Optional.of(predicate.left()), 
Optional.of(predicate.right()), predicate);
+            return new EqualInferInfo(InferType.NONE, predicate.left(), 
predicate.right(), predicate);
         }
-        ComparisonInferInfo info = inferInferInfo(predicate);
+        EqualInferInfo info = inferInferInfo(predicate);
         if (info.inferType == InferType.NONE) {
             return info;
         }
-        if (info.left.get() instanceof SlotReference && info.right.get() 
instanceof SlotReference) {
+        if (info.left instanceof SlotReference && info.right instanceof 
SlotReference) {
             return info;
         }
-        return new ComparisonInferInfo(InferType.NONE, info.left, info.right, 
info.comparisonPredicate);
+        return new EqualInferInfo(InferType.NONE, info.left, info.right, 
info.comparisonPredicate);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 1a198c76ea5..26e1358c2e5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -47,7 +47,6 @@ import java.util.stream.Collectors;
  */
 public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, 
Void> {
 
-    PredicatePropagation propagation = new PredicatePropagation();
     Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
 
     @Override
@@ -99,6 +98,7 @@ public class PullUpPredicates extends 
PlanVisitor<ImmutableSet<Expression>, Void
     public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? 
extends Plan> aggregate, Void context) {
         return cacheOrElse(aggregate, () -> {
             ImmutableSet<Expression> childPredicates = 
aggregate.child().accept(this, context);
+            // TODO
             Map<Expression, Slot> expressionSlotMap = 
aggregate.getOutputExpressions()
                     .stream()
                     .filter(this::hasAgg)
@@ -130,7 +130,7 @@ public class PullUpPredicates extends 
PlanVisitor<ImmutableSet<Expression>, Void
 
     private ImmutableSet<Expression> 
getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
         Set<Expression> expressions = Sets.newHashSet(predicates);
-        expressions.addAll(propagation.infer(expressions));
+        expressions.addAll(PredicatePropagation.infer(expressions));
         return expressions.stream()
                 .filter(p -> 
plan.getOutputSet().containsAll(p.getInputSlots()))
                 .collect(ImmutableSet.toImmutableSet());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
index d08d8abff73..0bffb1c73ab 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
@@ -68,18 +68,30 @@ public class InPredicate extends Expression {
         return children().stream().anyMatch(Expression::nullable);
     }
 
+    @Override
+    public void checkLegalityBeforeTypeCoercion() {
+        children().forEach(c -> {
+            if (c.getDataType().isObjectType()) {
+                throw new AnalysisException("in predicate could not contains 
object type: " + this.toSql());
+            }
+            if (c.getDataType().isComplexType()) {
+                throw new AnalysisException("in predicate could not contains 
complex type: " + this.toSql());
+            }
+        });
+    }
+
     @Override
     public String toString() {
         return compareExpr + " IN " + options.stream()
-            .map(Expression::toString)
-            .collect(Collectors.joining(", ", "(", ")"));
+                .map(Expression::toString)
+                .collect(Collectors.joining(", ", "(", ")"));
     }
 
     @Override
     public String toSql() {
         return compareExpr.toSql() + " IN " + options.stream()
-            .map(Expression::toSql)
-            .collect(Collectors.joining(", ", "(", ")"));
+                .map(Expression::toSql)
+                .collect(Collectors.joining(", ", "(", ")"));
     }
 
     @Override
@@ -92,7 +104,7 @@ public class InPredicate extends Expression {
         }
         InPredicate that = (InPredicate) o;
         return Objects.equals(compareExpr, that.getCompareExpr())
-            && Objects.equals(options, that.getOptions());
+                && Objects.equals(options, that.getOptions());
     }
 
     @Override
@@ -119,16 +131,4 @@ public class InPredicate extends Expression {
         }
         return true;
     }
-
-    @Override
-    public void checkLegalityBeforeTypeCoercion() {
-        children().forEach(c -> {
-            if (c.getDataType().isObjectType()) {
-                throw new AnalysisException("in predicate could not contains 
object type: " + this.toSql());
-            }
-            if (c.getDataType().isComplexType()) {
-                throw new AnalysisException("in predicate could not contains 
complex type: " + this.toSql());
-            }
-        });
-    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index 5aff44f9411..243466b13c0 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -24,7 +24,7 @@ import org.apache.doris.utframe.TestWithFeService;
 
 import org.junit.jupiter.api.Test;
 
-public class InferPredicatesTest extends TestWithFeService implements 
MemoPatternMatchSupported {
+class InferPredicatesTest extends TestWithFeService implements 
MemoPatternMatchSupported {
 
     @Override
     protected void runBeforeAll() throws Exception {
@@ -76,7 +76,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest01() {
+    void inferPredicatesTest01() {
         String sql = "select * from student join score on student.id = 
score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -97,7 +97,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest02() {
+    void inferPredicatesTest02() {
         String sql = "select * from student join score on student.id = 
score.sid";
 
         PlanChecker.from(connectContext)
@@ -114,7 +114,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest03() {
+    void inferPredicatesTest03() {
         String sql = "select * from student join score on student.id = 
score.sid where student.id in (1,2,3)";
 
         PlanChecker.from(connectContext)
@@ -123,17 +123,15 @@ public class InferPredicatesTest extends 
TestWithFeService implements MemoPatter
                 .matches(
                     logicalProject(
                         logicalJoin(
-                            logicalFilter(
-                                    logicalOlapScan()
-                            ).when(filter -> 
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
-                            logicalOlapScan()
+                            logicalFilter(logicalOlapScan()).when(filter -> 
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
+                            logicalFilter(logicalOlapScan()).when(filter -> 
filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
                         )
                     )
                 );
     }
 
     @Test
-    public void inferPredicatesTest04() {
+    void inferPredicatesTest04() {
         String sql = "select * from student join score on student.id = 
score.sid and student.id in (1,2,3)";
 
         PlanChecker.from(connectContext)
@@ -142,17 +140,15 @@ public class InferPredicatesTest extends 
TestWithFeService implements MemoPatter
                 .matches(
                     logicalProject(
                         logicalJoin(
-                            logicalFilter(
-                                    logicalOlapScan()
-                            ).when(filter -> 
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
-                            logicalOlapScan()
+                            logicalFilter(logicalOlapScan()).when(filter -> 
filter.getPredicate().toSql().contains("id IN (1, 2, 3)")),
+                            logicalFilter(logicalOlapScan()).when(filter -> 
filter.getPredicate().toSql().contains("sid IN (1, 2, 3)"))
                         )
                     )
                 );
     }
 
     @Test
-    public void inferPredicatesTest05() {
+    void inferPredicatesTest05() {
         String sql = "select * from student join score on student.id = 
score.sid join course on score.sid = course.id where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -178,7 +174,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest06() {
+    void inferPredicatesTest06() {
         String sql = "select * from student join score on student.id = 
score.sid join course on score.sid = course.id and score.sid > 1";
 
         PlanChecker.from(connectContext)
@@ -204,7 +200,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest07() {
+    void inferPredicatesTest07() {
         String sql = "select * from student left join score on student.id = 
score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -225,7 +221,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest08() {
+    void inferPredicatesTest08() {
         String sql = "select * from student left join score on student.id = 
score.sid and student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -244,7 +240,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest09() {
+    void inferPredicatesTest09() {
         // convert left join to inner join
         String sql = "select * from student left join score on student.id = 
score.sid where score.sid > 1";
 
@@ -266,7 +262,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest10() {
+    void inferPredicatesTest10() {
         String sql = "select * from (select id as nid, name from student) t 
left join score on t.nid = score.sid where t.nid > 1";
 
         PlanChecker.from(connectContext)
@@ -289,7 +285,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest11() {
+    void inferPredicatesTest11() {
         String sql = "select * from (select id as nid, name from student) t 
left join score on t.nid = score.sid and t.nid > 1";
 
         PlanChecker.from(connectContext)
@@ -310,7 +306,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest12() {
+    void inferPredicatesTest12() {
         String sql = "select * from student left join (select sid as nid, 
sum(grade) from score group by sid) s on s.nid = student.id where student.id > 
1";
 
         PlanChecker.from(connectContext)
@@ -337,7 +333,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest13() {
+    void inferPredicatesTest13() {
         String sql = "select * from (select id, name from student where id = 
1) t left join score on t.id = score.sid";
 
         PlanChecker.from(connectContext)
@@ -360,7 +356,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest14() {
+    void inferPredicatesTest14() {
         String sql = "select * from student left semi join score on student.id 
= score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -383,7 +379,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest15() {
+    void inferPredicatesTest15() {
         String sql = "select * from student left semi join score on student.id 
= score.sid and student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -406,7 +402,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest16() {
+    void inferPredicatesTest16() {
         String sql = "select * from student left anti join score on student.id 
= score.sid and student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -427,7 +423,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest17() {
+    void inferPredicatesTest17() {
         String sql = "select * from student left anti join score on student.id 
= score.sid and score.sid > 1";
 
         PlanChecker.from(connectContext)
@@ -448,7 +444,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest18() {
+    void inferPredicatesTest18() {
         String sql = "select * from student left anti join score on student.id 
= score.sid where student.id > 1";
 
         PlanChecker.from(connectContext)
@@ -471,7 +467,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest19() {
+    void inferPredicatesTest19() {
         String sql = "select * from subquery1\n"
                 + "left semi join (\n"
                 + "  select t1.k3\n"
@@ -532,7 +528,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest20() {
+    void inferPredicatesTest20() {
         String sql = "select * from student left join score on student.id = 
score.sid and score.sid > 1 inner join course on course.id = score.sid";
         PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
         PlanChecker.from(connectContext)
@@ -558,7 +554,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
     }
 
     @Test
-    public void inferPredicatesTest21() {
+    void inferPredicatesTest21() {
         String sql = "select * from student,score,course where student.id = 
score.sid and score.sid = course.id and score.sid > 1";
         PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
         PlanChecker.from(connectContext)
@@ -587,7 +583,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
      * test for #15310
      */
     @Test
-    public void inferPredicatesTest22() {
+    void inferPredicatesTest22() {
         String sql = "select * from student join (select sid as id1, sid as 
id2, grade from score) s on student.id = s.id1 where s.id1 > 1";
         PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
         PlanChecker.from(connectContext)
@@ -613,7 +609,7 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
      * in this case, filter on relation s1 should not contain s1.id = 1.
      */
     @Test
-    public void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
+    void innerJoinShouldNotInferUnderLeftJoinOnClausePredicates() {
         String sql = "select * from student s1"
                 + " left join (select sid as id1, sid as id2, grade from 
score) s2 on s1.id = s2.id1 and s1.id = 1"
                 + " join (select sid as id1, sid as id2, grade from score) s3 
on s1.id = s3.id1 where s1.id = 2";
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
new file mode 100644
index 00000000000..b1aa25df1b1
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+class PredicatePropagationTest {
+    private final SlotReference a = new SlotReference("a", 
SmallIntType.INSTANCE);
+    private final SlotReference b = new SlotReference("b", 
BigIntType.INSTANCE);
+
+    @Test
+    void equal() {
+        Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new 
EqualTo(a, Literal.of(1)));
+        Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
+        System.out.println(inferExprs);
+    }
+
+    @Test
+    void in() {
+        Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new 
InPredicate(a, ImmutableList.of(Literal.of(1))));
+        Set<Expression> inferExprs = PredicatePropagation.infer(exprs);
+        System.out.println(inferExprs);
+    }
+}
diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out 
b/regression-test/data/nereids_p0/hint/fix_leading.out
index 54da890cced..a71ca311e75 100644
--- a/regression-test/data/nereids_p0/hint/fix_leading.out
+++ b/regression-test/data/nereids_p0/hint/fix_leading.out
@@ -9,7 +9,7 @@ PhysicalResultSink
 ----------PhysicalDistribute
 ------------PhysicalOlapScan[t2]
 --------PhysicalDistribute
-----------NestedLoopJoin[CROSS_JOIN]
+----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4)
 ------------PhysicalOlapScan[t3]
 ------------PhysicalDistribute
 --------------PhysicalOlapScan[t4]
diff --git 
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy 
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index c5942680ea7..55645ed8ea0 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
 
     explain {
         sql "select * from infer_tb1 inner join infer_tb2 where 
cast(infer_tb2.k4 as int) = infer_tb1.k2  and infer_tb2.k4 = 1;"
-        contains "PREDICATES: k2"
+        contains "PREDICATES: CAST(k2"
     }
 
     explain {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to