This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit bbb66b95fd420c3ff6e76580ad78c90a0e9ed144
Author: LiBinfeng <46676950+libinfeng...@users.noreply.github.com>
AuthorDate: Mon Sep 11 14:30:31 2023 +0800

    [Fix](Nereids) fix infer predicate lost cast of source expression (#23692)
    
    Problem:
    When inferring predicate,we lost cast of source expressions and some 
datatype derivation.
    
    Example:
    a = b and cast(a as targetType) = constant
    (cast(a as targetType) = constant ) this expression is define as source 
expression.
    we expect getting cast(b as targetType) = constant instead of b = constant
    
    Reason:
    When inferring predicate, we will compare original type of a and b. if they 
can be cast
    without precision lost, a new predicate would be created. But created 
predicate forgot
    to cast to target type
    
    Solved:
    Add cast to target type, and open make other datatype valid also.
---
 .../rules/rewrite/PredicatePropagation.java        | 39 +++++++++++++---------
 .../nereids/rules/rewrite/InferPredicatesTest.java | 30 +++++++++++++++++
 .../infer_predicate/infer_predicate.groovy         | 18 ++++++++++
 3 files changed, 71 insertions(+), 16 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index cc45952817..7181896669 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -59,12 +59,12 @@ public class PredicatePropagation {
     }
 
     /**
-     * Use the left or right child of `leftSlotEqualToRightSlot` to replace 
the left or right child of `expression`
+     * Use the left or right child of `equalExpr` to replace the left or right 
child of `expression`
      * Now only support infer `ComparisonPredicate`.
      * TODO: We should determine whether `expression` satisfies the condition 
for replacement
      *       eg: Satisfy `expression` is non-deterministic
      */
-    private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression 
expression) {
+    private Expression doInfer(Expression equalExpr, Expression expression) {
         return expression.accept(new DefaultExpressionRewriter<Void>() {
 
             @Override
@@ -76,36 +76,43 @@ public class PredicatePropagation {
             public Expression visitComparisonPredicate(ComparisonPredicate cp, 
Void context) {
                 // we need to get expression covered by cast, because we want 
to infer different datatype
                 if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) 
&& (cp.right().isConstant())) {
-                    return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
+                    return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.left()), equalExpr);
                 } else if 
(ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && 
cp.left().isConstant()) {
-                    return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
+                    return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.right()), equalExpr);
                 }
                 return super.visit(cp, context);
             }
 
             private boolean isDataTypeValid(DataType originDataType, 
Expression expr) {
-                if ((leftSlotEqualToRightSlot.child(0).getDataType() 
instanceof IntegralType)
-                        && (leftSlotEqualToRightSlot.child(1).getDataType() 
instanceof IntegralType)
+                if ((expr.child(0).getDataType() instanceof IntegralType)
+                        && (expr.child(1).getDataType() instanceof 
IntegralType)
                                 && (originDataType instanceof IntegralType)) {
                     // infer filter can not be lower than original datatype, 
or dataset would be wrong
                     if (!((IntegralType) originDataType).widerThan(
-                            (IntegralType) 
leftSlotEqualToRightSlot.child(0).getDataType())
+                            (IntegralType) expr.child(0).getDataType())
                                     && !((IntegralType) 
originDataType).widerThan(
-                                            (IntegralType) 
leftSlotEqualToRightSlot.child(1).getDataType())) {
+                                            (IntegralType) 
expr.child(1).getDataType())) {
                         return true;
                     }
+                } else if 
(expr.child(0).getDataType().equals(expr.child(1).getDataType())) {
+                    return true;
                 }
                 return false;
             }
 
-            private Expression replaceSlot(Expression expr, DataType 
originDataType) {
-                return expr.rewriteUp(e -> {
-                    if (isDataTypeValid(originDataType, 
leftSlotEqualToRightSlot)) {
-                        if (ExpressionUtils.isTwoExpressionEqualWithCast(e, 
leftSlotEqualToRightSlot.child(0))) {
-                            return leftSlotEqualToRightSlot.child(1);
-                        } else if 
(ExpressionUtils.isTwoExpressionEqualWithCast(e, 
leftSlotEqualToRightSlot.child(1))) {
-                            return leftSlotEqualToRightSlot.child(0);
-                        }
+            private Expression replaceSlot(Expression sourcePredicate, 
DataType originDataType, Expression equal) {
+                if (!isDataTypeValid(originDataType, equal)) {
+                    return sourcePredicate;
+                }
+                return sourcePredicate.rewriteUp(e -> {
+                    // we can not replace Cast expression to slot because when 
rewrite up, we have replace child of cast
+                    if (e instanceof Cast) {
+                        return e;
+                    }
+                    if (ExpressionUtils.isTwoExpressionEqualWithCast(e, 
equal.child(0))) {
+                        return equal.child(1);
+                    } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, 
equal.child(1))) {
+                        return equal.child(0);
                     }
                     return e;
                 });
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
index adc67ca835..b7b235d2b4 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicatesTest.java
@@ -17,15 +17,33 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
 import org.apache.doris.utframe.TestWithFeService;
 
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.util.Optional;
+import java.util.Set;
+
 public class InferPredicatesTest extends TestWithFeService implements 
MemoPatternMatchSupported {
 
+    private final LogicalOlapScan scan1 = 
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+
+    private final LogicalOlapScan scan2 = 
PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+
+    private final PredicatePropagation propagation = new 
PredicatePropagation();
+
     @Override
     protected void runBeforeAll() throws Exception {
         createDatabase("test");
@@ -628,4 +646,16 @@ public class InferPredicatesTest extends TestWithFeService 
implements MemoPatter
                         ).when(join -> join.getJoinType() == 
JoinType.LEFT_OUTER_JOIN)
                 );
     }
+
+    @Test
+    void testInfer() {
+        EqualTo equalTo = new EqualTo(new Cast(scan1.getOutput().get(0), 
BigIntType.INSTANCE), Literal.of(1));
+        EqualTo equalTo2 = new EqualTo(scan2.getOutput().get(0), 
scan1.getOutput().get(0));
+        Set<Expression> predicates = Sets.newHashSet();
+        predicates.add(equalTo2);
+        predicates.add(equalTo);
+        Set<Expression> newPredicates = propagation.infer(predicates);
+        Optional<Expression> newPredicate = newPredicates.stream().findFirst();
+        Assertions.assertTrue(newPredicate.get().equals(new EqualTo(new 
Cast(scan2.getOutput().get(0), BigIntType.INSTANCE), Literal.of(1))));
+    }
 }
diff --git 
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy 
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index a1621f1c23..120c9a8f67 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -22,6 +22,8 @@ suite("test_infer_predicate") {
     sql 'drop table if exists infer_tb1;'
     sql 'drop table if exists infer_tb2;'
     sql 'drop table if exists infer_tb3;'
+    sql 'drop table if exists infer_tb4;'
+    sql 'drop table if exists infer_tb5;'
 
     sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1) 
buckets 3 properties('replication_num' = '1');'''
 
@@ -29,6 +31,10 @@ suite("test_infer_predicate") {
 
     sql '''create table infer_tb3 (k1 varchar(100), k2 int) distributed by 
hash(k1) buckets 3 properties('replication_num' = '1');'''
 
+    sql '''create table infer_tb4 (k1 varchar(100), k2 date) distributed by 
hash(k1) buckets 3 properties('replication_num' = '1');'''
+
+    sql '''create table infer_tb5 (k1 varchar(100), k3 date) distributed by 
hash(k1) buckets 3 properties('replication_num' = '1');'''
+
     explain {
         sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 = 
infer_tb1.k2  and infer_tb2.k1 = 1;"
         contains "PREDICATES: k2"
@@ -55,4 +61,16 @@ suite("test_infer_predicate") {
         contains "PREDICATES: k3"
         contains "PREDICATES: k2"
     }
+
+    explain {
+        sql "select * from infer_tb4 left join infer_tb5 on infer_tb4.k2 = 
infer_tb5.k3 where infer_tb4.k2 = '20230901';"
+        contains "PREDICATES: k3"
+        contains "PREDICATES: k2"
+    }
+
+    sql 'drop table if exists infer_tb1;'
+    sql 'drop table if exists infer_tb2;'
+    sql 'drop table if exists infer_tb3;'
+    sql 'drop table if exists infer_tb4;'
+    sql 'drop table if exists infer_tb5;'
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to